Compare commits

..

32 Commits

Author SHA1 Message Date
bcmmbaga
8c44187900 Merge branch 'refs/heads/test/add-stopwatch' into deploy/peer-performance 2024-05-06 14:29:22 +03:00
Pascal Fischer
0dc050b354 update sqlite operations 2024-05-06 13:25:46 +02:00
Pascal Fischer
69eb6f17c6 remove stacktrace 2024-05-03 17:30:26 +02:00
bcmmbaga
ccbe2e0100 Merge branch 'refs/heads/migrate-with-json-serializer' into deploy/peer-performance 2024-05-03 14:11:23 +03:00
bcmmbaga
5b042eee74 migrate null blob values 2024-05-03 14:07:54 +03:00
Pascal Fischer
125c418d85 add debug to GetAccount sqlite store 2024-04-30 13:52:19 +02:00
Pascal Fischer
9f607029c0 add debug to GetAccount sqlite store 2024-04-30 13:35:04 +02:00
bcmmbaga
4df1d556c8 fix tests 2024-04-30 13:35:49 +03:00
bcmmbaga
acaf315152 Merge branch 'refs/heads/migrate-with-json-serializer' into deploy/peer-performance 2024-04-30 13:32:48 +03:00
bcmmbaga
81b6f7dc56 Merge remote-tracking branch 'refs/remotes/origin/test/add-stopwatch' into deploy/peer-performance 2024-04-30 13:32:37 +03:00
bcmmbaga
cb9a8e9fa4 Add tests 2024-04-30 13:27:14 +03:00
Pascal Fischer
f68d9ce042 fix linter 2024-04-30 12:03:59 +02:00
Pascal Fischer
1f182411c6 update store 2024-04-30 11:43:48 +02:00
Pascal Fischer
68e971cb82 Merge remote-tracking branch 'origin/test/add-stopwatch' into test/add-stopwatch 2024-04-30 11:29:08 +02:00
Pascal Fischer
f4a464fb69 update account manager mock 2024-04-30 11:28:57 +02:00
bcmmbaga
7aff0e828b Refactor 2024-04-30 10:07:40 +03:00
bcmmbaga
90ebf0017b remove duplicate index 2024-04-29 23:21:00 +03:00
bcmmbaga
bf8c6b95de run net ip migration 2024-04-29 23:20:43 +03:00
bcmmbaga
8415064758 migrate net ip field from blob to json 2024-04-29 23:20:22 +03:00
Pascal Fischer
809f6cdbc7 add stacktrace for get network map 2024-04-29 18:56:58 +02:00
Pascal Fischer
7474876d6a add stacktrace for get network map 2024-04-29 18:12:22 +02:00
Pascal Fischer
2095e35217 remove stopwatches + logs 2024-04-29 16:26:44 +02:00
Pascal Fischer
02537e5536 remove stopwatches 2024-04-29 14:57:57 +02:00
bcmmbaga
f99477d3db serialize net.IP as json 2024-04-29 15:45:46 +03:00
Pascal Fischer
3085383729 introduce RWLocks 2024-04-29 14:03:33 +02:00
Pascal Fischer
07bdf977d0 update db access 2024-04-26 18:05:41 +02:00
Pascal Fischer
4a147006d2 add table 2024-04-26 17:45:10 +02:00
Pascal Fischer
480192028e try improving Sync method 2024-04-26 17:18:24 +02:00
Pascal Fischer
027c0e5c63 update 2024-04-25 15:47:32 +02:00
Pascal Fischer
fbb1181b64 update 2024-04-25 15:44:06 +02:00
Pascal Fischer
e50ed290d9 update 2024-04-25 15:17:57 +02:00
Pascal Fischer
8f8cfcbc20 add method timing logs 2024-04-25 14:15:36 +02:00
514 changed files with 11976 additions and 35287 deletions

View File

@@ -1,8 +0,0 @@
root = true
[*]
end_of_line = lf
insert_final_newline = true
[*.go]
indent_style = tab

View File

@@ -31,14 +31,9 @@ Please specify whether you use NetBird Cloud or self-host NetBird's control plan
`netbird version` `netbird version`
**NetBird status -dA output:** **NetBird status -d output:**
If applicable, add the `netbird status -dA' command output. If applicable, add the `netbird status -d' command output.
**Do you face any (non-mobile) client issues?**
Please provide the file created by `netbird debug for 1m -AS`.
We advise reviewing the anonymized files for any remaining PII.
**Screenshots** **Screenshots**

View File

@@ -14,18 +14,18 @@ jobs:
test: test:
strategy: strategy:
matrix: matrix:
store: ['sqlite'] store: ['jsonfile', 'sqlite']
runs-on: macos-latest runs-on: macos-latest
steps: steps:
- name: Install Go - name: Install Go
uses: actions/setup-go@v5 uses: actions/setup-go@v4
with: with:
go-version: "1.23.x" go-version: "1.21.x"
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v3
- name: Cache Go modules - name: Cache Go modules
uses: actions/cache@v4 uses: actions/cache@v3
with: with:
path: ~/go/pkg/mod path: ~/go/pkg/mod
key: macos-go-${{ hashFiles('**/go.sum') }} key: macos-go-${{ hashFiles('**/go.sum') }}

View File

@@ -1,46 +0,0 @@
name: Test Code FreeBSD
on:
push:
branches:
- main
pull_request:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
test:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v4
- name: Test in FreeBSD
id: test
uses: vmactions/freebsd-vm@v1
with:
usesh: true
copyback: false
release: "14.1"
prepare: |
pkg install -y go
# -x - to print all executed commands
# -e - to faile on first error
run: |
set -e -x
time go build -o netbird client/main.go
# check all component except management, since we do not support management server on freebsd
time go test -timeout 1m -failfast ./base62/...
# NOTE: without -p1 `client/internal/dns` will fail becasue of `listen udp4 :33100: bind: address already in use`
time go test -timeout 8m -failfast -p 1 ./client/...
time go test -timeout 1m -failfast ./dns/...
time go test -timeout 1m -failfast ./encryption/...
time go test -timeout 1m -failfast ./formatter/...
time go test -timeout 1m -failfast ./iface/...
time go test -timeout 1m -failfast ./route/...
time go test -timeout 1m -failfast ./sharedsock/...
time go test -timeout 1m -failfast ./signal/...
time go test -timeout 1m -failfast ./util/...
time go test -timeout 1m -failfast ./version/...

View File

@@ -15,17 +15,17 @@ jobs:
strategy: strategy:
matrix: matrix:
arch: [ '386','amd64' ] arch: [ '386','amd64' ]
store: [ 'sqlite', 'postgres'] store: [ 'jsonfile', 'sqlite' ]
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Install Go - name: Install Go
uses: actions/setup-go@v5 uses: actions/setup-go@v4
with: with:
go-version: "1.23.x" go-version: "1.21.x"
- name: Cache Go modules - name: Cache Go modules
uses: actions/cache@v4 uses: actions/cache@v3
with: with:
path: ~/go/pkg/mod path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
@@ -33,7 +33,7 @@ jobs:
${{ runner.os }}-go- ${{ runner.os }}-go-
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v3
- name: Install dependencies - name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
@@ -49,18 +49,18 @@ jobs:
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
- name: Test - name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 6m -p 1 ./... run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 ./...
test_client_on_docker: test_client_on_docker:
runs-on: ubuntu-20.04 runs-on: ubuntu-20.04
steps: steps:
- name: Install Go - name: Install Go
uses: actions/setup-go@v5 uses: actions/setup-go@v4
with: with:
go-version: "1.23.x" go-version: "1.21.x"
- name: Cache Go modules - name: Cache Go modules
uses: actions/cache@v4 uses: actions/cache@v3
with: with:
path: ~/go/pkg/mod path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
@@ -68,7 +68,7 @@ jobs:
${{ runner.os }}-go- ${{ runner.os }}-go-
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v3
- name: Install dependencies - name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
@@ -86,10 +86,7 @@ jobs:
run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock
- name: Generate RouteManager Test bin - name: Generate RouteManager Test bin
run: CGO_ENABLED=0 go test -c -o routemanager-testing.bin ./client/internal/routemanager run: CGO_ENABLED=1 go test -c -o routemanager-testing.bin -tags netgo -ldflags '-w -extldflags "-static -ldbus-1 -lpcap"' ./client/internal/routemanager/...
- name: Generate SystemOps Test bin
run: CGO_ENABLED=1 go test -c -o systemops-testing.bin -tags netgo -ldflags '-w -extldflags "-static -ldbus-1 -lpcap"' ./client/internal/routemanager/systemops
- name: Generate nftables Manager Test bin - name: Generate nftables Manager Test bin
run: CGO_ENABLED=0 go test -c -o nftablesmanager-testing.bin ./client/firewall/nftables/... run: CGO_ENABLED=0 go test -c -o nftablesmanager-testing.bin ./client/firewall/nftables/...
@@ -111,9 +108,6 @@ jobs:
- name: Run RouteManager tests in docker - name: Run RouteManager tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin -test.timeout 5m -test.parallel 1 run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin -test.timeout 5m -test.parallel 1
- name: Run SystemOps tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager/systemops --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/systemops-testing.bin -test.timeout 5m -test.parallel 1
- name: Run nftables Manager tests in docker - name: Run nftables Manager tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/firewall --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/nftablesmanager-testing.bin -test.timeout 5m -test.parallel 1 run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/firewall --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/nftablesmanager-testing.bin -test.timeout 5m -test.parallel 1

View File

@@ -17,13 +17,13 @@ jobs:
runs-on: windows-latest runs-on: windows-latest
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v3
- name: Install Go - name: Install Go
uses: actions/setup-go@v5 uses: actions/setup-go@v4
id: go id: go
with: with:
go-version: "1.23.x" go-version: "1.21.x"
- name: Download wintun - name: Download wintun
uses: carlosperate/download-file-action@v2 uses: carlosperate/download-file-action@v2

View File

@@ -15,7 +15,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v3
- name: codespell - name: codespell
uses: codespell-project/actions-codespell@v2 uses: codespell-project/actions-codespell@v2
with: with:
@@ -32,15 +32,15 @@ jobs:
timeout-minutes: 15 timeout-minutes: 15
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v3
- name: Check for duplicate constants - name: Check for duplicate constants
if: matrix.os == 'ubuntu-latest' if: matrix.os == 'ubuntu-latest'
run: | run: |
! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep . ! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep .
- name: Install Go - name: Install Go
uses: actions/setup-go@v5 uses: actions/setup-go@v4
with: with:
go-version: "1.23.x" go-version: "1.21.x"
cache: false cache: false
- name: Install dependencies - name: Install dependencies
if: matrix.os == 'ubuntu-latest' if: matrix.os == 'ubuntu-latest'

View File

@@ -21,7 +21,7 @@ jobs:
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v3
- name: run install script - name: run install script
env: env:

View File

@@ -15,23 +15,23 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v4 uses: actions/checkout@v3
- name: Install Go - name: Install Go
uses: actions/setup-go@v5 uses: actions/setup-go@v4
with: with:
go-version: "1.23.x" go-version: "1.21.x"
- name: Setup Android SDK - name: Setup Android SDK
uses: android-actions/setup-android@v3 uses: android-actions/setup-android@v3
with: with:
cmdline-tools-version: 8512546 cmdline-tools-version: 8512546
- name: Setup Java - name: Setup Java
uses: actions/setup-java@v4 uses: actions/setup-java@v3
with: with:
java-version: "11" java-version: "11"
distribution: "adopt" distribution: "adopt"
- name: NDK Cache - name: NDK Cache
id: ndk-cache id: ndk-cache
uses: actions/cache@v4 uses: actions/cache@v3
with: with:
path: /usr/local/lib/android/sdk/ndk path: /usr/local/lib/android/sdk/ndk
key: ndk-cache-23.1.7779620 key: ndk-cache-23.1.7779620
@@ -50,11 +50,11 @@ jobs:
runs-on: macos-latest runs-on: macos-latest
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v4 uses: actions/checkout@v3
- name: Install Go - name: Install Go
uses: actions/setup-go@v5 uses: actions/setup-go@v4
with: with:
go-version: "1.23.x" go-version: "1.21.x"
- name: install gomobile - name: install gomobile
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed
- name: gomobile init - name: gomobile init

View File

@@ -7,13 +7,21 @@ on:
branches: branches:
- main - main
pull_request: pull_request:
paths:
- 'go.mod'
- 'go.sum'
- '.goreleaser.yml'
- '.goreleaser_ui.yaml'
- '.goreleaser_ui_darwin.yaml'
- '.github/workflows/release.yml'
- 'release_files/**'
- '**/Dockerfile'
- '**/Dockerfile.*'
- 'client/ui/**'
env: env:
SIGN_PIPE_VER: "v0.0.14" SIGN_PIPE_VER: "v0.0.11"
GORELEASER_VER: "v1.14.1" GORELEASER_VER: "v1.14.1"
PRODUCT_NAME: "NetBird"
COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)"
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }} group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
@@ -25,29 +33,22 @@ jobs:
env: env:
flags: "" flags: ""
steps: steps:
- name: Parse semver string
id: semver_parser
uses: booxmedialtd/ws-action-parse-semver@v1
with:
input_string: ${{ (startsWith(github.ref, 'refs/tags/v') && github.ref) || 'refs/tags/v0.0.0' }}
version_extractor_regex: '\/v(.*)$'
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }} - if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV run: echo "flags=--snapshot" >> $GITHUB_ENV
- -
name: Checkout name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v3
with: with:
fetch-depth: 0 # It is required for GoReleaser to work properly fetch-depth: 0 # It is required for GoReleaser to work properly
- -
name: Set up Go name: Set up Go
uses: actions/setup-go@v5 uses: actions/setup-go@v4
with: with:
go-version: "1.23" go-version: "1.21"
cache: false cache: false
- -
name: Cache Go modules name: Cache Go modules
uses: actions/cache@v4 uses: actions/cache@v3
with: with:
path: | path: |
~/go/pkg/mod ~/go/pkg/mod
@@ -77,11 +78,18 @@ jobs:
- name: Install OS build dependencies - name: Install OS build dependencies
run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu
- name: Install goversioninfo - name: Install rsrc
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e run: go install github.com/akavel/rsrc@v0.10.2
- name: Generate windows syso amd64 - name: Generate windows rsrc amd64
run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_amd64.syso run: rsrc -arch amd64 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_amd64.syso
- name: Run GoReleaser - name: Generate windows rsrc arm64
run: rsrc -arch arm64 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_arm64.syso
- name: Generate windows rsrc arm
run: rsrc -arch arm -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_arm.syso
- name: Generate windows rsrc 386
run: rsrc -arch 386 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_386.syso
-
name: Run GoReleaser
uses: goreleaser/goreleaser-action@v4 uses: goreleaser/goreleaser-action@v4
with: with:
version: ${{ env.GORELEASER_VER }} version: ${{ env.GORELEASER_VER }}
@@ -93,57 +101,29 @@ jobs:
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }} UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
- -
name: upload non tags for debug purposes name: upload non tags for debug purposes
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v3
with: with:
name: release name: release
path: dist/ path: dist/
retention-days: 3 retention-days: 3
-
name: upload linux packages
uses: actions/upload-artifact@v4
with:
name: linux-packages
path: dist/netbird_linux**
retention-days: 3
-
name: upload windows packages
uses: actions/upload-artifact@v4
with:
name: windows-packages
path: dist/netbird_windows**
retention-days: 3
-
name: upload macos packages
uses: actions/upload-artifact@v4
with:
name: macos-packages
path: dist/netbird_darwin**
retention-days: 3
release_ui: release_ui:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Parse semver string
id: semver_parser
uses: booxmedialtd/ws-action-parse-semver@v1
with:
input_string: ${{ (startsWith(github.ref, 'refs/tags/v') && github.ref) || 'refs/tags/v0.0.0' }}
version_extractor_regex: '\/v(.*)$'
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }} - if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV run: echo "flags=--snapshot" >> $GITHUB_ENV
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v3
with: with:
fetch-depth: 0 # It is required for GoReleaser to work properly fetch-depth: 0 # It is required for GoReleaser to work properly
- name: Set up Go - name: Set up Go
uses: actions/setup-go@v5 uses: actions/setup-go@v4
with: with:
go-version: "1.23" go-version: "1.21"
cache: false cache: false
- name: Cache Go modules - name: Cache Go modules
uses: actions/cache@v4 uses: actions/cache@v3
with: with:
path: | path: |
~/go/pkg/mod ~/go/pkg/mod
@@ -160,11 +140,10 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: sudo apt update && sudo apt install -y -q libappindicator3-dev gir1.2-appindicator3-0.1 libxxf86vm-dev gcc-mingw-w64-x86-64 run: sudo apt update && sudo apt install -y -q libappindicator3-dev gir1.2-appindicator3-0.1 libxxf86vm-dev gcc-mingw-w64-x86-64
- name: Install goversioninfo - name: Install rsrc
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e run: go install github.com/akavel/rsrc@v0.10.2
- name: Generate windows syso amd64 - name: Generate windows rsrc
run: goversioninfo -64 -icon client/ui/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_amd64.syso run: rsrc -arch amd64 -ico client/ui/netbird.ico -manifest client/ui/manifest.xml -o client/ui/resources_windows_amd64.syso
- name: Run GoReleaser - name: Run GoReleaser
uses: goreleaser/goreleaser-action@v4 uses: goreleaser/goreleaser-action@v4
with: with:
@@ -176,31 +155,31 @@ jobs:
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }} UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }} UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
- name: upload non tags for debug purposes - name: upload non tags for debug purposes
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v3
with: with:
name: release-ui name: release-ui
path: dist/ path: dist/
retention-days: 3 retention-days: 3
release_ui_darwin: release_ui_darwin:
runs-on: macos-latest runs-on: macos-11
steps: steps:
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }} - if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV run: echo "flags=--snapshot" >> $GITHUB_ENV
- -
name: Checkout name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v3
with: with:
fetch-depth: 0 # It is required for GoReleaser to work properly fetch-depth: 0 # It is required for GoReleaser to work properly
- -
name: Set up Go name: Set up Go
uses: actions/setup-go@v5 uses: actions/setup-go@v4
with: with:
go-version: "1.23" go-version: "1.21"
cache: false cache: false
- -
name: Cache Go modules name: Cache Go modules
uses: actions/cache@v4 uses: actions/cache@v3
with: with:
path: | path: |
~/go/pkg/mod ~/go/pkg/mod
@@ -225,21 +204,35 @@ jobs:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- -
name: upload non tags for debug purposes name: upload non tags for debug purposes
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v3
with: with:
name: release-ui-darwin name: release-ui-darwin
path: dist/ path: dist/
retention-days: 3 retention-days: 3
trigger_signer: trigger_windows_signer:
runs-on: ubuntu-latest runs-on: ubuntu-latest
needs: [release,release_ui,release_ui_darwin] needs: [release,release_ui]
if: startsWith(github.ref, 'refs/tags/') if: startsWith(github.ref, 'refs/tags/')
steps: steps:
- name: Trigger binaries sign pipelines - name: Trigger Windows binaries sign pipeline
uses: benc-uk/workflow-dispatch@v1 uses: benc-uk/workflow-dispatch@v1
with: with:
workflow: Sign bin and installer workflow: Sign windows bin and installer
repo: netbirdio/sign-pipelines
ref: ${{ env.SIGN_PIPE_VER }}
token: ${{ secrets.SIGN_GITHUB_TOKEN }}
inputs: '{ "tag": "${{ github.ref }}" }'
trigger_darwin_signer:
runs-on: ubuntu-latest
needs: [release,release_ui_darwin]
if: startsWith(github.ref, 'refs/tags/')
steps:
- name: Trigger Darwin App binaries sign pipeline
uses: benc-uk/workflow-dispatch@v1
with:
workflow: Sign darwin ui app with dispatch
repo: netbirdio/sign-pipelines repo: netbirdio/sign-pipelines
ref: ${{ env.SIGN_PIPE_VER }} ref: ${{ env.SIGN_PIPE_VER }}
token: ${{ secrets.SIGN_GITHUB_TOKEN }} token: ${{ secrets.SIGN_GITHUB_TOKEN }}

View File

@@ -18,31 +18,7 @@ concurrency:
jobs: jobs:
test-docker-compose: test-docker-compose:
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy:
matrix:
store: [ 'sqlite', 'postgres' ]
services:
postgres:
image: ${{ (matrix.store == 'postgres') && 'postgres' || '' }}
env:
POSTGRES_USER: netbird
POSTGRES_PASSWORD: postgres
POSTGRES_DB: netbird
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
ports:
- 5432:5432
steps: steps:
- name: Set Database Connection String
run: |
if [ "${{ matrix.store }}" == "postgres" ]; then
echo "NETBIRD_STORE_ENGINE_POSTGRES_DSN=host=$(hostname -I | awk '{print $1}') user=netbird password=postgres dbname=netbird port=5432" >> $GITHUB_ENV
else
echo "NETBIRD_STORE_ENGINE_POSTGRES_DSN==" >> $GITHUB_ENV
fi
- name: Install jq - name: Install jq
run: sudo apt-get install -y jq run: sudo apt-get install -y jq
@@ -50,12 +26,12 @@ jobs:
run: sudo apt-get install -y curl run: sudo apt-get install -y curl
- name: Install Go - name: Install Go
uses: actions/setup-go@v5 uses: actions/setup-go@v4
with: with:
go-version: "1.23.x" go-version: "1.21.x"
- name: Cache Go modules - name: Cache Go modules
uses: actions/cache@v4 uses: actions/cache@v3
with: with:
path: ~/go/pkg/mod path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
@@ -63,7 +39,7 @@ jobs:
${{ runner.os }}-go- ${{ runner.os }}-go-
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v3
- name: cp setup.env - name: cp setup.env
run: cp infrastructure_files/tests/setup.env infrastructure_files/ run: cp infrastructure_files/tests/setup.env infrastructure_files/
@@ -82,8 +58,7 @@ jobs:
CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id
CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret
CI_NETBIRD_AUTH_SUPPORTED_SCOPES: "openid profile email offline_access api email_verified" CI_NETBIRD_AUTH_SUPPORTED_SCOPES: "openid profile email offline_access api email_verified"
CI_NETBIRD_STORE_CONFIG_ENGINE: ${{ matrix.store }} CI_NETBIRD_STORE_CONFIG_ENGINE: "sqlite"
NETBIRD_STORE_ENGINE_POSTGRES_DSN: ${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
- name: check values - name: check values
@@ -110,8 +85,7 @@ jobs:
CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id
CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret
CI_NETBIRD_SIGNAL_PORT: 12345 CI_NETBIRD_SIGNAL_PORT: 12345
CI_NETBIRD_STORE_CONFIG_ENGINE: ${{ matrix.store }} CI_NETBIRD_STORE_CONFIG_ENGINE: "sqlite"
NETBIRD_STORE_ENGINE_POSTGRES_DSN: '${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}$'
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
CI_NETBIRD_TURN_EXTERNAL_IP: "1.2.3.4" CI_NETBIRD_TURN_EXTERNAL_IP: "1.2.3.4"
@@ -149,14 +123,6 @@ jobs:
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep Scope | grep "$CI_NETBIRD_AUTH_SUPPORTED_SCOPES" grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep Scope | grep "$CI_NETBIRD_AUTH_SUPPORTED_SCOPES"
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep -A 3 RedirectURLs | grep "http://localhost:53000" grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep -A 3 RedirectURLs | grep "http://localhost:53000"
grep "external-ip" turnserver.conf | grep $CI_NETBIRD_TURN_EXTERNAL_IP grep "external-ip" turnserver.conf | grep $CI_NETBIRD_TURN_EXTERNAL_IP
grep NETBIRD_STORE_ENGINE_POSTGRES_DSN docker-compose.yml | egrep "$NETBIRD_STORE_ENGINE_POSTGRES_DSN"
# check relay values
grep "NB_EXPOSED_ADDRESS=$CI_NETBIRD_DOMAIN:33445" docker-compose.yml
grep "NB_LISTEN_ADDRESS=:33445" docker-compose.yml
grep '33445:33445' docker-compose.yml
grep -A 10 'relay:' docker-compose.yml | egrep 'NB_AUTH_SECRET=.+$'
grep -A 7 Relay management.json | grep "rel://$CI_NETBIRD_DOMAIN:33445"
grep -A 7 Relay management.json | egrep '"Secret": ".+"'
- name: Install modules - name: Install modules
run: go mod tidy run: go mod tidy
@@ -182,35 +148,26 @@ jobs:
run: | run: |
docker build -t netbirdio/signal:latest . docker build -t netbirdio/signal:latest .
- name: Build relay binary
working-directory: relay
run: CGO_ENABLED=0 go build -o netbird-relay main.go
- name: Build relay docker image
working-directory: relay
run: |
docker build -t netbirdio/relay:latest .
- name: run docker compose up - name: run docker compose up
working-directory: infrastructure_files/artifacts working-directory: infrastructure_files/artifacts
run: | run: |
docker compose up -d docker-compose up -d
sleep 5 sleep 5
docker compose ps docker-compose ps
docker compose logs --tail=20 docker-compose logs --tail=20
- name: test running containers - name: test running containers
run: | run: |
count=$(docker compose ps --format json | jq '. | select(.Name | contains("artifacts")) | .State' | grep -c running) count=$(docker compose ps --format json | jq '. | select(.Name | contains("artifacts")) | .State' | grep -c running)
test $count -eq 5 || docker compose logs test $count -eq 4
working-directory: infrastructure_files/artifacts working-directory: infrastructure_files/artifacts
- name: test geolocation databases - name: test geolocation databases
working-directory: infrastructure_files/artifacts working-directory: infrastructure_files/artifacts
run: | run: |
sleep 30 sleep 30
docker compose exec management ls -l /var/lib/netbird/ | grep -i GeoLite2-City_[0-9]*.mmdb docker compose exec management ls -l /var/lib/netbird/ | grep -i GeoLite2-City.mmdb
docker compose exec management ls -l /var/lib/netbird/ | grep -i geonames_[0-9]*.db docker compose exec management ls -l /var/lib/netbird/ | grep -i geonames.db
test-getting-started-script: test-getting-started-script:
runs-on: ubuntu-latest runs-on: ubuntu-latest
@@ -219,69 +176,36 @@ jobs:
run: sudo apt-get install -y jq run: sudo apt-get install -y jq
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v3
- name: run script with Zitadel PostgreSQL - name: run script
run: NETBIRD_DOMAIN=use-ip bash -x infrastructure_files/getting-started-with-zitadel.sh run: NETBIRD_DOMAIN=use-ip bash -x infrastructure_files/getting-started-with-zitadel.sh
- name: test Caddy file gen postgres - name: test Caddy file gen
run: test -f Caddyfile run: test -f Caddyfile
- name: test docker-compose file gen
- name: test docker-compose file gen postgres
run: test -f docker-compose.yml run: test -f docker-compose.yml
- name: test management.json file gen
- name: test management.json file gen postgres
run: test -f management.json run: test -f management.json
- name: test turnserver.conf file gen
- name: test turnserver.conf file gen postgres
run: | run: |
set -x set -x
test -f turnserver.conf test -f turnserver.conf
grep external-ip turnserver.conf grep external-ip turnserver.conf
- name: test zitadel.env file gen
- name: test zitadel.env file gen postgres
run: test -f zitadel.env run: test -f zitadel.env
- name: test dashboard.env file gen
- name: test dashboard.env file gen postgres
run: test -f dashboard.env run: test -f dashboard.env
test-download-geolite2-script:
- name: test relay.env file gen postgres runs-on: ubuntu-latest
run: test -f relay.env steps:
- name: Install jq
- name: test zdb.env file gen postgres run: sudo apt-get update && sudo apt-get install -y unzip sqlite3
run: test -f zdb.env - name: Checkout code
uses: actions/checkout@v3
- name: Postgres run cleanup - name: test script
run: | run: bash -x infrastructure_files/download-geolite2.sh
docker compose down --volumes --rmi all - name: test mmdb file exists
rm -rf docker-compose.yml Caddyfile zitadel.env dashboard.env machinekey/zitadel-admin-sa.token turnserver.conf management.json zdb.env run: test -f GeoLite2-City.mmdb
- name: test geonames file exists
- name: run script with Zitadel CockroachDB run: test -f geonames.db
run: bash -x infrastructure_files/getting-started-with-zitadel.sh
env:
NETBIRD_DOMAIN: use-ip
ZITADEL_DATABASE: cockroach
- name: test Caddy file gen CockroachDB
run: test -f Caddyfile
- name: test docker-compose file gen CockroachDB
run: test -f docker-compose.yml
- name: test management.json file gen CockroachDB
run: test -f management.json
- name: test turnserver.conf file gen CockroachDB
run: |
set -x
test -f turnserver.conf
grep external-ip turnserver.conf
- name: test zitadel.env file gen CockroachDB
run: test -f zitadel.env
- name: test dashboard.env file gen CockroachDB
run: test -f dashboard.env
- name: test relay.env file gen CockroachDB
run: test -f relay.env

1
.gitignore vendored
View File

@@ -29,3 +29,4 @@ infrastructure_files/setup.env
infrastructure_files/setup-*.env infrastructure_files/setup-*.env
.vscode .vscode
.DS_Store .DS_Store
GeoLite2-City*

View File

@@ -130,10 +130,3 @@ issues:
- path: mock\.go - path: mock\.go
linters: linters:
- nilnil - nilnil
# Exclude specific deprecation warnings for grpc methods
- linters:
- staticcheck
text: "grpc.DialContext is deprecated"
- linters:
- staticcheck
text: "grpc.WithBlock is deprecated"

View File

@@ -80,20 +80,6 @@ builds:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: '{{ .CommitTimestamp }}' mod_timestamp: '{{ .CommitTimestamp }}'
- id: netbird-relay
dir: relay
env: [CGO_ENABLED=0]
binary: netbird-relay
goos:
- linux
goarch:
- amd64
- arm64
- arm
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: '{{ .CommitTimestamp }}'
archives: archives:
- builds: - builds:
- netbird - netbird
@@ -175,52 +161,6 @@ dockers:
- "--label=org.opencontainers.image.revision={{.FullCommit}}" - "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.version={{.Version}}" - "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io" - "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/relay:{{ .Version }}-amd64
ids:
- netbird-relay
goarch: amd64
use: buildx
dockerfile: relay/Dockerfile
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/relay:{{ .Version }}-arm64v8
ids:
- netbird-relay
goarch: arm64
use: buildx
dockerfile: relay/Dockerfile
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/relay:{{ .Version }}-arm
ids:
- netbird-relay
goarch: arm
goarm: 6
use: buildx
dockerfile: relay/Dockerfile
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
- image_templates: - image_templates:
- netbirdio/signal:{{ .Version }}-amd64 - netbirdio/signal:{{ .Version }}-amd64
ids: ids:
@@ -373,18 +313,6 @@ docker_manifests:
- netbirdio/netbird:{{ .Version }}-arm - netbirdio/netbird:{{ .Version }}-arm
- netbirdio/netbird:{{ .Version }}-amd64 - netbirdio/netbird:{{ .Version }}-amd64
- name_template: netbirdio/relay:{{ .Version }}
image_templates:
- netbirdio/relay:{{ .Version }}-arm64v8
- netbirdio/relay:{{ .Version }}-arm
- netbirdio/relay:{{ .Version }}-amd64
- name_template: netbirdio/relay:latest
image_templates:
- netbirdio/relay:{{ .Version }}-arm64v8
- netbirdio/relay:{{ .Version }}-arm
- netbirdio/relay:{{ .Version }}-amd64
- name_template: netbirdio/signal:{{ .Version }} - name_template: netbirdio/signal:{{ .Version }}
image_templates: image_templates:
- netbirdio/signal:{{ .Version }}-arm64v8 - netbirdio/signal:{{ .Version }}-arm64v8

View File

@@ -11,6 +11,8 @@ builds:
- amd64 - amd64
ldflags: ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
tags:
- legacy_appindicator
mod_timestamp: '{{ .CommitTimestamp }}' mod_timestamp: '{{ .CommitTimestamp }}'
- id: netbird-ui-windows - id: netbird-ui-windows

View File

@@ -3,10 +3,8 @@ builds:
- id: netbird-ui-darwin - id: netbird-ui-darwin
dir: client/ui dir: client/ui
binary: netbird-ui binary: netbird-ui
env: env: [CGO_ENABLED=1]
- CGO_ENABLED=1
- MACOSX_DEPLOYMENT_TARGET=11.0
- MACOS_DEPLOYMENT_TARGET=11.0
goos: goos:
- darwin - darwin
goarch: goarch:

View File

@@ -5,7 +5,7 @@
We as members, contributors, and leaders pledge to make participation in our We as members, contributors, and leaders pledge to make participation in our
community a harassment-free experience for everyone, regardless of age, body community a harassment-free experience for everyone, regardless of age, body
size, visible or invisible disability, ethnicity, sex characteristics, gender size, visible or invisible disability, ethnicity, sex characteristics, gender
identity and expression, level of experience, education, socioeconomic status, identity and expression, level of experience, education, socio-economic status,
nationality, personal appearance, race, caste, color, religion, or sexual nationality, personal appearance, race, caste, color, religion, or sexual
identity and orientation. identity and orientation.

View File

@@ -10,14 +10,12 @@
<img width="234" src="docs/media/logo-full.png"/> <img width="234" src="docs/media/logo-full.png"/>
</p> </p>
<p> <p>
<a href="https://img.shields.io/badge/license-BSD--3-blue)">
<img src="https://sonarcloud.io/api/project_badges/measure?project=netbirdio_netbird&metric=alert_status" />
</a>
<a href="https://github.com/netbirdio/netbird/blob/main/LICENSE"> <a href="https://github.com/netbirdio/netbird/blob/main/LICENSE">
<img src="https://img.shields.io/badge/license-BSD--3-blue" /> <img src="https://img.shields.io/badge/license-BSD--3-blue" />
</a> </a>
<a href="https://www.codacy.com/gh/netbirdio/netbird/dashboard?utm_source=github.com&amp;utm_medium=referral&amp;utm_content=netbirdio/netbird&amp;utm_campaign=Badge_Grade"><img src="https://app.codacy.com/project/badge/Grade/e3013d046aec44cdb7462c8673b00976"/></a>
<br> <br>
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2p5zwhm4g-8fHollzrQa5y4PZF5AEpvQ"> <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-vrahf41g-ik1v7fV8du6t0RwxSrJ96A">
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/> <img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
</a> </a>
</p> </p>
@@ -30,7 +28,7 @@
<br/> <br/>
See <a href="https://netbird.io/docs/">Documentation</a> See <a href="https://netbird.io/docs/">Documentation</a>
<br/> <br/>
Join our <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2p5zwhm4g-8fHollzrQa5y4PZF5AEpvQ">Slack channel</a> Join our <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-vrahf41g-ik1v7fV8du6t0RwxSrJ96A">Slack channel</a>
<br/> <br/>
</strong> </strong>

View File

@@ -1,4 +1,4 @@
FROM alpine:3.20 FROM alpine:3.18.5
RUN apk add --no-cache ca-certificates iptables ip6tables RUN apk add --no-cache ca-certificates iptables ip6tables
ENV NB_FOREGROUND_MODE=true ENV NB_FOREGROUND_MODE=true
ENTRYPOINT [ "/usr/local/bin/netbird","up"] ENTRYPOINT [ "/usr/local/bin/netbird","up"]

View File

@@ -1,5 +1,3 @@
//go:build android
package android package android
import ( import (
@@ -16,7 +14,6 @@ import (
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/formatter" "github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/util/net"
) )
// ConnectionListener export internal Listener for mobile // ConnectionListener export internal Listener for mobile
@@ -57,17 +54,14 @@ type Client struct {
ctxCancel context.CancelFunc ctxCancel context.CancelFunc
ctxCancelLock *sync.Mutex ctxCancelLock *sync.Mutex
deviceName string deviceName string
uiVersion string
networkChangeListener listener.NetworkChangeListener networkChangeListener listener.NetworkChangeListener
} }
// NewClient instantiate a new Client // NewClient instantiate a new Client
func NewClient(cfgFile, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client { func NewClient(cfgFile, deviceName string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client {
net.SetAndroidProtectSocketFn(tunAdapter.ProtectSocket)
return &Client{ return &Client{
cfgFile: cfgFile, cfgFile: cfgFile,
deviceName: deviceName, deviceName: deviceName,
uiVersion: uiVersion,
tunAdapter: tunAdapter, tunAdapter: tunAdapter,
iFaceDiscover: iFaceDiscover, iFaceDiscover: iFaceDiscover,
recorder: peer.NewRecorder(""), recorder: peer.NewRecorder(""),
@@ -90,9 +84,6 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
var ctx context.Context var ctx context.Context
//nolint //nolint
ctxWithValues := context.WithValue(context.Background(), system.DeviceNameCtxKey, c.deviceName) ctxWithValues := context.WithValue(context.Background(), system.DeviceNameCtxKey, c.deviceName)
//nolint
ctxWithValues = context.WithValue(ctxWithValues, system.UiVersionCtxKey, c.uiVersion)
c.ctxCancelLock.Lock() c.ctxCancelLock.Lock()
ctx, c.ctxCancel = context.WithCancel(ctxWithValues) ctx, c.ctxCancel = context.WithCancel(ctxWithValues)
defer c.ctxCancel() defer c.ctxCancel()
@@ -106,8 +97,7 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
// todo do not throw error in case of cancelled context // todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx) ctx = internal.CtxInitState(ctx)
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder) return internal.RunClientMobile(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
} }
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot). // RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
@@ -132,8 +122,7 @@ func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener
// todo do not throw error in case of cancelled context // todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx) ctx = internal.CtxInitState(ctx)
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder) return internal.RunClientMobile(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
} }
// Stop the internal client and free the resources // Stop the internal client and free the resources

View File

@@ -84,7 +84,7 @@ func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) {
func (a *Auth) saveConfigIfSSOSupported() (bool, error) { func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
supportsSSO := true supportsSSO := true
err := a.withBackOff(a.ctx, func() (err error) { err := a.withBackOff(a.ctx, func() (err error) {
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil) _, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) { if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL) _, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
s, ok := gstatus.FromError(err) s, ok := gstatus.FromError(err)

View File

@@ -178,21 +178,6 @@ func (a *Anonymizer) AnonymizeDNSLogLine(logEntry string) string {
}) })
} }
// AnonymizeRoute anonymizes a route string by replacing IP addresses with anonymized versions and
// domain names with random strings.
func (a *Anonymizer) AnonymizeRoute(route string) string {
prefix, err := netip.ParsePrefix(route)
if err == nil {
ip := a.AnonymizeIPString(prefix.Addr().String())
return fmt.Sprintf("%s/%d", ip, prefix.Bits())
}
domains := strings.Split(route, ", ")
for i, domain := range domains {
domains[i] = a.AnonymizeDomain(domain)
}
return strings.Join(domains, ", ")
}
func isWellKnown(addr netip.Addr) bool { func isWellKnown(addr netip.Addr) bool {
wellKnown := []string{ wellKnown := []string{
"8.8.8.8", "8.8.4.4", // Google DNS IPv4 "8.8.8.8", "8.8.4.4", // Google DNS IPv4

View File

@@ -3,19 +3,15 @@ package cmd
import ( import (
"context" "context"
"fmt" "fmt"
"strings"
"time" "time"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/server"
) )
const errCloseConnection = "Failed to close connection: %v"
var debugCmd = &cobra.Command{ var debugCmd = &cobra.Command{
Use: "debug", Use: "debug",
Short: "Debugging commands", Short: "Debugging commands",
@@ -62,21 +58,16 @@ var forCmd = &cobra.Command{
} }
func debugBundle(cmd *cobra.Command, _ []string) error { func debugBundle(cmd *cobra.Command, _ []string) error {
conn, err := getClient(cmd) conn, err := getClient(cmd.Context())
if err != nil { if err != nil {
return err return err
} }
defer func() { defer conn.Close()
if err := conn.Close(); err != nil {
log.Errorf(errCloseConnection, err)
}
}()
client := proto.NewDaemonServiceClient(conn) client := proto.NewDaemonServiceClient(conn)
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{ resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
Anonymize: anonymizeFlag, Anonymize: anonymizeFlag,
Status: getStatusOutput(cmd), Status: getStatusOutput(cmd),
SystemInfo: debugSystemInfoFlag,
}) })
if err != nil { if err != nil {
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message()) return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
@@ -88,18 +79,14 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
} }
func setLogLevel(cmd *cobra.Command, args []string) error { func setLogLevel(cmd *cobra.Command, args []string) error {
conn, err := getClient(cmd) conn, err := getClient(cmd.Context())
if err != nil { if err != nil {
return err return err
} }
defer func() { defer conn.Close()
if err := conn.Close(); err != nil {
log.Errorf(errCloseConnection, err)
}
}()
client := proto.NewDaemonServiceClient(conn) client := proto.NewDaemonServiceClient(conn)
level := server.ParseLogLevel(args[0]) level := parseLogLevel(args[0])
if level == proto.LogLevel_UNKNOWN { if level == proto.LogLevel_UNKNOWN {
return fmt.Errorf("unknown log level: %s. Available levels are: panic, fatal, error, warn, info, debug, trace\n", args[0]) return fmt.Errorf("unknown log level: %s. Available levels are: panic, fatal, error, warn, info, debug, trace\n", args[0])
} }
@@ -115,60 +102,54 @@ func setLogLevel(cmd *cobra.Command, args []string) error {
return nil return nil
} }
func parseLogLevel(level string) proto.LogLevel {
switch strings.ToLower(level) {
case "panic":
return proto.LogLevel_PANIC
case "fatal":
return proto.LogLevel_FATAL
case "error":
return proto.LogLevel_ERROR
case "warn":
return proto.LogLevel_WARN
case "info":
return proto.LogLevel_INFO
case "debug":
return proto.LogLevel_DEBUG
case "trace":
return proto.LogLevel_TRACE
default:
return proto.LogLevel_UNKNOWN
}
}
func runForDuration(cmd *cobra.Command, args []string) error { func runForDuration(cmd *cobra.Command, args []string) error {
duration, err := time.ParseDuration(args[0]) duration, err := time.ParseDuration(args[0])
if err != nil { if err != nil {
return fmt.Errorf("invalid duration format: %v", err) return fmt.Errorf("invalid duration format: %v", err)
} }
conn, err := getClient(cmd) conn, err := getClient(cmd.Context())
if err != nil { if err != nil {
return err return err
} }
defer func() { defer conn.Close()
if err := conn.Close(); err != nil {
log.Errorf(errCloseConnection, err)
}
}()
client := proto.NewDaemonServiceClient(conn) client := proto.NewDaemonServiceClient(conn)
stat, err := client.Status(cmd.Context(), &proto.StatusRequest{})
if err != nil {
return fmt.Errorf("failed to get status: %v", status.Convert(err).Message())
}
stateWasDown := stat.Status != string(internal.StatusConnected) && stat.Status != string(internal.StatusConnecting)
initialLogLevel, err := client.GetLogLevel(cmd.Context(), &proto.GetLogLevelRequest{})
if err != nil {
return fmt.Errorf("failed to get log level: %v", status.Convert(err).Message())
}
if stateWasDown {
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
}
cmd.Println("Netbird up")
time.Sleep(time.Second * 10)
}
initialLevelTrace := initialLogLevel.GetLevel() >= proto.LogLevel_TRACE
if !initialLevelTrace {
_, err = client.SetLogLevel(cmd.Context(), &proto.SetLogLevelRequest{
Level: proto.LogLevel_TRACE,
})
if err != nil {
return fmt.Errorf("failed to set log level to TRACE: %v", status.Convert(err).Message())
}
cmd.Println("Log level set to trace.")
}
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil { if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
return fmt.Errorf("failed to down: %v", status.Convert(err).Message()) return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
} }
cmd.Println("Netbird down") cmd.Println("Netbird down")
_, err = client.SetLogLevel(cmd.Context(), &proto.SetLogLevelRequest{
Level: proto.LogLevel_TRACE,
})
if err != nil {
return fmt.Errorf("failed to set log level to trace: %v", status.Convert(err).Message())
}
cmd.Println("Log level set to trace.")
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil { if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
@@ -186,34 +167,28 @@ func runForDuration(cmd *cobra.Command, args []string) error {
} }
cmd.Println("\nDuration completed") cmd.Println("\nDuration completed")
cmd.Println("Creating debug bundle...")
headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration) headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd)) statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd))
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
}
cmd.Println("Netbird down")
// TODO reset log level
time.Sleep(1 * time.Second)
cmd.Println("Creating debug bundle...")
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{ resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
Anonymize: anonymizeFlag, Anonymize: anonymizeFlag,
Status: statusOutput, Status: statusOutput,
SystemInfo: debugSystemInfoFlag,
}) })
if err != nil { if err != nil {
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message()) return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
} }
if stateWasDown {
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
}
cmd.Println("Netbird down")
}
if !initialLevelTrace {
if _, err := client.SetLogLevel(cmd.Context(), &proto.SetLogLevelRequest{Level: initialLogLevel.GetLevel()}); err != nil {
return fmt.Errorf("failed to restore log level: %v", status.Convert(err).Message())
}
cmd.Println("Log level restored to", initialLogLevel.GetLevel())
}
cmd.Println(resp.GetPath()) cmd.Println(resp.GetPath())
return nil return nil

View File

@@ -2,9 +2,8 @@ package cmd
import ( import (
"context" "context"
"time"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
"time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/spf13/cobra" "github.com/spf13/cobra"
@@ -26,7 +25,7 @@ var downCmd = &cobra.Command{
return err return err
} }
ctx, cancel := context.WithTimeout(context.Background(), time.Second*7) ctx, cancel := context.WithTimeout(context.Background(), time.Second*3)
defer cancel() defer cancel()
conn, err := DialClientGRPCServer(ctx, daemonAddr) conn, err := DialClientGRPCServer(ctx, daemonAddr)
@@ -42,8 +41,6 @@ var downCmd = &cobra.Command{
log.Errorf("call service down method: %v", err) log.Errorf("call service down method: %v", err)
return err return err
} }
cmd.Println("Disconnected")
return nil return nil
}, },
} }

View File

@@ -39,11 +39,6 @@ var loginCmd = &cobra.Command{
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, hostName) ctx = context.WithValue(ctx, system.DeviceNameCtxKey, hostName)
} }
providedSetupKey, err := getSetupKey()
if err != nil {
return err
}
// workaround to run without service // workaround to run without service
if logFile == "console" { if logFile == "console" {
err = handleRebrand(cmd) err = handleRebrand(cmd)
@@ -67,7 +62,7 @@ var loginCmd = &cobra.Command{
config, _ = internal.UpdateOldManagementURL(ctx, config, configPath) config, _ = internal.UpdateOldManagementURL(ctx, config, configPath)
err = foregroundLogin(ctx, cmd, config, providedSetupKey) err = foregroundLogin(ctx, cmd, config, setupKey)
if err != nil { if err != nil {
return fmt.Errorf("foreground login failed: %v", err) return fmt.Errorf("foreground login failed: %v", err)
} }
@@ -86,7 +81,7 @@ var loginCmd = &cobra.Command{
client := proto.NewDaemonServiceClient(conn) client := proto.NewDaemonServiceClient(conn)
loginRequest := proto.LoginRequest{ loginRequest := proto.LoginRequest{
SetupKey: providedSetupKey, SetupKey: setupKey,
ManagementUrl: managementURL, ManagementUrl: managementURL,
IsLinuxDesktopClient: isLinuxRunningDesktop(), IsLinuxDesktopClient: isLinuxRunningDesktop(),
Hostname: hostName, Hostname: hostName,

View File

@@ -32,12 +32,9 @@ const (
preSharedKeyFlag = "preshared-key" preSharedKeyFlag = "preshared-key"
interfaceNameFlag = "interface-name" interfaceNameFlag = "interface-name"
wireguardPortFlag = "wireguard-port" wireguardPortFlag = "wireguard-port"
networkMonitorFlag = "network-monitor"
disableAutoConnectFlag = "disable-auto-connect" disableAutoConnectFlag = "disable-auto-connect"
serverSSHAllowedFlag = "allow-server-ssh" serverSSHAllowedFlag = "allow-server-ssh"
extraIFaceBlackListFlag = "extra-iface-blacklist" extraIFaceBlackListFlag = "extra-iface-blacklist"
dnsRouteIntervalFlag = "dns-router-interval"
systemInfoFlag = "system-info"
) )
var ( var (
@@ -56,7 +53,6 @@ var (
managementURL string managementURL string
adminURL string adminURL string
setupKey string setupKey string
setupKeyPath string
hostName string hostName string
preSharedKey string preSharedKey string
natExternalIPs []string natExternalIPs []string
@@ -66,15 +62,11 @@ var (
serverSSHAllowed bool serverSSHAllowed bool
interfaceName string interfaceName string
wireguardPort uint16 wireguardPort uint16
networkMonitor bool
serviceName string serviceName string
autoConnectDisabled bool autoConnectDisabled bool
extraIFaceBlackList []string extraIFaceBlackList []string
anonymizeFlag bool anonymizeFlag bool
debugSystemInfoFlag bool rootCmd = &cobra.Command{
dnsRouteInterval time.Duration
rootCmd = &cobra.Command{
Use: "netbird", Use: "netbird",
Short: "", Short: "",
Long: "", Long: "",
@@ -94,15 +86,12 @@ func init() {
oldDefaultConfigPathDir = "/etc/wiretrustee/" oldDefaultConfigPathDir = "/etc/wiretrustee/"
oldDefaultLogFileDir = "/var/log/wiretrustee/" oldDefaultLogFileDir = "/var/log/wiretrustee/"
switch runtime.GOOS { if runtime.GOOS == "windows" {
case "windows":
defaultConfigPathDir = os.Getenv("PROGRAMDATA") + "\\Netbird\\" defaultConfigPathDir = os.Getenv("PROGRAMDATA") + "\\Netbird\\"
defaultLogFileDir = os.Getenv("PROGRAMDATA") + "\\Netbird\\" defaultLogFileDir = os.Getenv("PROGRAMDATA") + "\\Netbird\\"
oldDefaultConfigPathDir = os.Getenv("PROGRAMDATA") + "\\Wiretrustee\\" oldDefaultConfigPathDir = os.Getenv("PROGRAMDATA") + "\\Wiretrustee\\"
oldDefaultLogFileDir = os.Getenv("PROGRAMDATA") + "\\Wiretrustee\\" oldDefaultLogFileDir = os.Getenv("PROGRAMDATA") + "\\Wiretrustee\\"
case "freebsd":
defaultConfigPathDir = "/var/db/netbird/"
} }
defaultConfigPath = defaultConfigPathDir + "config.json" defaultConfigPath = defaultConfigPathDir + "config.json"
@@ -127,10 +116,8 @@ func init() {
rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name") rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name")
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", defaultConfigPath, "Netbird config file location") rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", defaultConfigPath, "Netbird config file location")
rootCmd.PersistentFlags().StringVarP(&logLevel, "log-level", "l", "info", "sets Netbird log level") rootCmd.PersistentFlags().StringVarP(&logLevel, "log-level", "l", "info", "sets Netbird log level")
rootCmd.PersistentFlags().StringVar(&logFile, "log-file", defaultLogFile, "sets Netbird log path. If console is specified the log will be output to stdout. If syslog is specified the log will be sent to syslog daemon.") rootCmd.PersistentFlags().StringVar(&logFile, "log-file", defaultLogFile, "sets Netbird log path. If console is specified the log will be output to stdout")
rootCmd.PersistentFlags().StringVarP(&setupKey, "setup-key", "k", "", "Setup key obtained from the Management Service Dashboard (used to register peer)") rootCmd.PersistentFlags().StringVarP(&setupKey, "setup-key", "k", "", "Setup key obtained from the Management Service Dashboard (used to register peer)")
rootCmd.PersistentFlags().StringVar(&setupKeyPath, "setup-key-file", "", "The path to a setup key obtained from the Management Service Dashboard (used to register peer) This is ignored if the setup-key flag is provided.")
rootCmd.MarkFlagsMutuallyExclusive("setup-key", "setup-key-file")
rootCmd.PersistentFlags().StringVar(&preSharedKey, preSharedKeyFlag, "", "Sets Wireguard PreSharedKey property. If set, then only peers that have the same key can communicate.") rootCmd.PersistentFlags().StringVar(&preSharedKey, preSharedKeyFlag, "", "Sets Wireguard PreSharedKey property. If set, then only peers that have the same key can communicate.")
rootCmd.PersistentFlags().StringVarP(&hostName, "hostname", "n", "", "Sets a custom hostname for the device") rootCmd.PersistentFlags().StringVarP(&hostName, "hostname", "n", "", "Sets a custom hostname for the device")
rootCmd.PersistentFlags().BoolVarP(&anonymizeFlag, "anonymize", "A", false, "anonymize IP addresses and non-netbird.io domains in logs and status output") rootCmd.PersistentFlags().BoolVarP(&anonymizeFlag, "anonymize", "A", false, "anonymize IP addresses and non-netbird.io domains in logs and status output")
@@ -173,8 +160,6 @@ func init() {
upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.") upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.")
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted") upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted")
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.") upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", false, "Adds system information to the debug bundle")
} }
// SetupCloseHandler handles SIGTERM signal and exits with success // SetupCloseHandler handles SIGTERM signal and exits with success
@@ -256,21 +241,6 @@ var CLIBackOffSettings = &backoff.ExponentialBackOff{
Clock: backoff.SystemClock, Clock: backoff.SystemClock,
} }
func getSetupKey() (string, error) {
if setupKeyPath != "" && setupKey == "" {
return getSetupKeyFromFile(setupKeyPath)
}
return setupKey, nil
}
func getSetupKeyFromFile(setupKeyPath string) (string, error) {
data, err := os.ReadFile(setupKeyPath)
if err != nil {
return "", fmt.Errorf("failed to read setup key file: %v", err)
}
return strings.TrimSpace(string(data)), nil
}
func handleRebrand(cmd *cobra.Command) error { func handleRebrand(cmd *cobra.Command) error {
var err error var err error
if logFile == defaultLogFile { if logFile == defaultLogFile {
@@ -381,11 +351,8 @@ func migrateToNetbird(oldPath, newPath string) bool {
return true return true
} }
func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) { func getClient(ctx context.Context) (*grpc.ClientConn, error) {
SetFlagsFromEnvVars(rootCmd) conn, err := DialClientGRPCServer(ctx, daemonAddr)
cmd.SetOut(cmd.OutOrStdout())
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+ return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+ "If the daemon is not running please run: "+

View File

@@ -4,10 +4,6 @@ import (
"fmt" "fmt"
"io" "io"
"testing" "testing"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/iface"
) )
func TestInitCommands(t *testing.T) { func TestInitCommands(t *testing.T) {
@@ -38,44 +34,3 @@ func TestInitCommands(t *testing.T) {
}) })
} }
} }
func TestSetFlagsFromEnvVars(t *testing.T) {
var cmd = &cobra.Command{
Use: "netbird",
Long: "test",
SilenceUsage: true,
Run: func(cmd *cobra.Command, args []string) {
SetFlagsFromEnvVars(cmd)
},
}
cmd.PersistentFlags().StringSliceVar(&natExternalIPs, externalIPMapFlag, nil,
`comma separated list of external IPs to map to the Wireguard interface`)
cmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name")
cmd.PersistentFlags().BoolVar(&rosenpassEnabled, enableRosenpassFlag, false, "Enable Rosenpass feature Rosenpass.")
cmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port")
t.Setenv("NB_EXTERNAL_IP_MAP", "abc,dec")
t.Setenv("NB_INTERFACE_NAME", "test-name")
t.Setenv("NB_ENABLE_ROSENPASS", "true")
t.Setenv("NB_WIREGUARD_PORT", "10000")
err := cmd.Execute()
if err != nil {
t.Fatalf("expected no error while running netbird command, got %v", err)
}
if len(natExternalIPs) != 2 {
t.Errorf("expected 2 external ips, got %d", len(natExternalIPs))
}
if natExternalIPs[0] != "abc" || natExternalIPs[1] != "dec" {
t.Errorf("expected abc,dec, got %s,%s", natExternalIPs[0], natExternalIPs[1])
}
if interfaceName != "test-name" {
t.Errorf("expected test-name, got %s", interfaceName)
}
if !rosenpassEnabled {
t.Errorf("expected rosenpassEnabled to be true, got false")
}
if wireguardPort != 10000 {
t.Errorf("expected wireguardPort to be 10000, got %d", wireguardPort)
}
}

View File

@@ -2,7 +2,6 @@ package cmd
import ( import (
"fmt" "fmt"
"strings"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
@@ -50,7 +49,7 @@ func init() {
} }
func routesList(cmd *cobra.Command, _ []string) error { func routesList(cmd *cobra.Command, _ []string) error {
conn, err := getClient(cmd) conn, err := getClient(cmd.Context())
if err != nil { if err != nil {
return err return err
} }
@@ -67,62 +66,20 @@ func routesList(cmd *cobra.Command, _ []string) error {
return nil return nil
} }
printRoutes(cmd, resp) cmd.Println("Available Routes:")
for _, route := range resp.Routes {
selectedStatus := "Not Selected"
if route.GetSelected() {
selectedStatus = "Selected"
}
cmd.Printf("\n - ID: %s\n Network: %s\n Status: %s\n", route.GetID(), route.GetNetwork(), selectedStatus)
}
return nil return nil
} }
func printRoutes(cmd *cobra.Command, resp *proto.ListRoutesResponse) {
cmd.Println("Available Routes:")
for _, route := range resp.Routes {
printRoute(cmd, route)
}
}
func printRoute(cmd *cobra.Command, route *proto.Route) {
selectedStatus := getSelectedStatus(route)
domains := route.GetDomains()
if len(domains) > 0 {
printDomainRoute(cmd, route, domains, selectedStatus)
} else {
printNetworkRoute(cmd, route, selectedStatus)
}
}
func getSelectedStatus(route *proto.Route) string {
if route.GetSelected() {
return "Selected"
}
return "Not Selected"
}
func printDomainRoute(cmd *cobra.Command, route *proto.Route, domains []string, selectedStatus string) {
cmd.Printf("\n - ID: %s\n Domains: %s\n Status: %s\n", route.GetID(), strings.Join(domains, ", "), selectedStatus)
resolvedIPs := route.GetResolvedIPs()
if len(resolvedIPs) > 0 {
printResolvedIPs(cmd, domains, resolvedIPs)
} else {
cmd.Printf(" Resolved IPs: -\n")
}
}
func printNetworkRoute(cmd *cobra.Command, route *proto.Route, selectedStatus string) {
cmd.Printf("\n - ID: %s\n Network: %s\n Status: %s\n", route.GetID(), route.GetNetwork(), selectedStatus)
}
func printResolvedIPs(cmd *cobra.Command, domains []string, resolvedIPs map[string]*proto.IPList) {
cmd.Printf(" Resolved IPs:\n")
for _, domain := range domains {
if ipList, exists := resolvedIPs[domain]; exists {
cmd.Printf(" [%s]: %s\n", domain, strings.Join(ipList.GetIps(), ", "))
}
}
}
func routesSelect(cmd *cobra.Command, args []string) error { func routesSelect(cmd *cobra.Command, args []string) error {
conn, err := getClient(cmd) conn, err := getClient(cmd.Context())
if err != nil { if err != nil {
return err return err
} }
@@ -149,7 +106,7 @@ func routesSelect(cmd *cobra.Command, args []string) error {
} }
func routesDeselect(cmd *cobra.Command, args []string) error { func routesDeselect(cmd *cobra.Command, args []string) error {
conn, err := getClient(cmd) conn, err := getClient(cmd.Context())
if err != nil { if err != nil {
return err return err
} }

View File

@@ -2,21 +2,18 @@ package cmd
import ( import (
"context" "context"
"github.com/kardianos/service" "github.com/kardianos/service"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"google.golang.org/grpc" "google.golang.org/grpc"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/server"
) )
type program struct { type program struct {
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
serv *grpc.Server serv *grpc.Server
serverInstance *server.Server
} }
func newProgram(ctx context.Context, cancel context.CancelFunc) *program { func newProgram(ctx context.Context, cancel context.CancelFunc) *program {

View File

@@ -61,8 +61,6 @@ func (p *program) Start(svc service.Service) error {
} }
proto.RegisterDaemonServiceServer(p.serv, serverInstance) proto.RegisterDaemonServiceServer(p.serv, serverInstance)
p.serverInstance = serverInstance
log.Printf("started daemon server: %v", split[1]) log.Printf("started daemon server: %v", split[1])
if err := p.serv.Serve(listen); err != nil { if err := p.serv.Serve(listen); err != nil {
log.Errorf("failed to serve daemon requests: %v", err) log.Errorf("failed to serve daemon requests: %v", err)
@@ -72,14 +70,6 @@ func (p *program) Start(svc service.Service) error {
} }
func (p *program) Stop(srv service.Service) error { func (p *program) Stop(srv service.Service) error {
if p.serverInstance != nil {
in := new(proto.DownRequest)
_, err := p.serverInstance.Down(p.ctx, in)
if err != nil {
log.Errorf("failed to stop daemon: %v", err)
}
}
p.cancel() p.cancel()
if p.serv != nil { if p.serv != nil {

View File

@@ -31,8 +31,6 @@ var installCmd = &cobra.Command{
configPath, configPath,
"--log-level", "--log-level",
logLevel, logLevel,
"--daemon-addr",
daemonAddr,
} }
if managementURL != "" { if managementURL != "" {

View File

@@ -31,9 +31,9 @@ type peerStateDetailOutput struct {
Status string `json:"status" yaml:"status"` Status string `json:"status" yaml:"status"`
LastStatusUpdate time.Time `json:"lastStatusUpdate" yaml:"lastStatusUpdate"` LastStatusUpdate time.Time `json:"lastStatusUpdate" yaml:"lastStatusUpdate"`
ConnType string `json:"connectionType" yaml:"connectionType"` ConnType string `json:"connectionType" yaml:"connectionType"`
Direct bool `json:"direct" yaml:"direct"`
IceCandidateType iceCandidateType `json:"iceCandidateType" yaml:"iceCandidateType"` IceCandidateType iceCandidateType `json:"iceCandidateType" yaml:"iceCandidateType"`
IceCandidateEndpoint iceCandidateType `json:"iceCandidateEndpoint" yaml:"iceCandidateEndpoint"` IceCandidateEndpoint iceCandidateType `json:"iceCandidateEndpoint" yaml:"iceCandidateEndpoint"`
RelayAddress string `json:"relayAddress" yaml:"relayAddress"`
LastWireguardHandshake time.Time `json:"lastWireguardHandshake" yaml:"lastWireguardHandshake"` LastWireguardHandshake time.Time `json:"lastWireguardHandshake" yaml:"lastWireguardHandshake"`
TransferReceived int64 `json:"transferReceived" yaml:"transferReceived"` TransferReceived int64 `json:"transferReceived" yaml:"transferReceived"`
TransferSent int64 `json:"transferSent" yaml:"transferSent"` TransferSent int64 `json:"transferSent" yaml:"transferSent"`
@@ -335,18 +335,16 @@ func mapNSGroups(servers []*proto.NSGroupState) []nsServerGroupStateOutput {
func mapPeers(peers []*proto.PeerState) peersStateOutput { func mapPeers(peers []*proto.PeerState) peersStateOutput {
var peersStateDetail []peerStateDetailOutput var peersStateDetail []peerStateDetailOutput
localICE := ""
remoteICE := ""
localICEEndpoint := ""
remoteICEEndpoint := ""
connType := ""
peersConnected := 0 peersConnected := 0
lastHandshake := time.Time{}
transferReceived := int64(0)
transferSent := int64(0)
for _, pbPeerState := range peers { for _, pbPeerState := range peers {
localICE := ""
remoteICE := ""
localICEEndpoint := ""
remoteICEEndpoint := ""
relayServerAddress := ""
connType := ""
lastHandshake := time.Time{}
transferReceived := int64(0)
transferSent := int64(0)
isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String() isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String()
if skipDetailByFilters(pbPeerState, isPeerConnected) { if skipDetailByFilters(pbPeerState, isPeerConnected) {
continue continue
@@ -362,7 +360,6 @@ func mapPeers(peers []*proto.PeerState) peersStateOutput {
if pbPeerState.Relayed { if pbPeerState.Relayed {
connType = "Relayed" connType = "Relayed"
} }
relayServerAddress = pbPeerState.GetRelayAddress()
lastHandshake = pbPeerState.GetLastWireguardHandshake().AsTime().Local() lastHandshake = pbPeerState.GetLastWireguardHandshake().AsTime().Local()
transferReceived = pbPeerState.GetBytesRx() transferReceived = pbPeerState.GetBytesRx()
transferSent = pbPeerState.GetBytesTx() transferSent = pbPeerState.GetBytesTx()
@@ -375,6 +372,7 @@ func mapPeers(peers []*proto.PeerState) peersStateOutput {
Status: pbPeerState.GetConnStatus(), Status: pbPeerState.GetConnStatus(),
LastStatusUpdate: timeLocal, LastStatusUpdate: timeLocal,
ConnType: connType, ConnType: connType,
Direct: pbPeerState.GetDirect(),
IceCandidateType: iceCandidateType{ IceCandidateType: iceCandidateType{
Local: localICE, Local: localICE,
Remote: remoteICE, Remote: remoteICE,
@@ -383,7 +381,6 @@ func mapPeers(peers []*proto.PeerState) peersStateOutput {
Local: localICEEndpoint, Local: localICEEndpoint,
Remote: remoteICEEndpoint, Remote: remoteICEEndpoint,
}, },
RelayAddress: relayServerAddress,
FQDN: pbPeerState.GetFqdn(), FQDN: pbPeerState.GetFqdn(),
LastWireguardHandshake: lastHandshake, LastWireguardHandshake: lastHandshake,
TransferReceived: transferReceived, TransferReceived: transferReceived,
@@ -644,9 +641,9 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo
" Status: %s\n"+ " Status: %s\n"+
" -- detail --\n"+ " -- detail --\n"+
" Connection type: %s\n"+ " Connection type: %s\n"+
" Direct: %t\n"+
" ICE candidate (Local/Remote): %s/%s\n"+ " ICE candidate (Local/Remote): %s/%s\n"+
" ICE candidate endpoints (Local/Remote): %s/%s\n"+ " ICE candidate endpoints (Local/Remote): %s/%s\n"+
" Relay server address: %s\n"+
" Last connection update: %s\n"+ " Last connection update: %s\n"+
" Last WireGuard handshake: %s\n"+ " Last WireGuard handshake: %s\n"+
" Transfer status (received/sent) %s/%s\n"+ " Transfer status (received/sent) %s/%s\n"+
@@ -658,11 +655,11 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo
peerState.PubKey, peerState.PubKey,
peerState.Status, peerState.Status,
peerState.ConnType, peerState.ConnType,
peerState.Direct,
localICE, localICE,
remoteICE, remoteICE,
localICEEndpoint, localICEEndpoint,
remoteICEEndpoint, remoteICEEndpoint,
peerState.RelayAddress,
timeAgo(peerState.LastStatusUpdate), timeAgo(peerState.LastStatusUpdate),
timeAgo(peerState.LastWireguardHandshake), timeAgo(peerState.LastWireguardHandshake),
toIEC(peerState.TransferReceived), toIEC(peerState.TransferReceived),
@@ -810,7 +807,11 @@ func anonymizePeerDetail(a *anonymize.Anonymizer, peer *peerStateDetailOutput) {
} }
for i, route := range peer.Routes { for i, route := range peer.Routes {
peer.Routes[i] = a.AnonymizeRoute(route) prefix, err := netip.ParsePrefix(route)
if err == nil {
ip := a.AnonymizeIPString(prefix.Addr().String())
peer.Routes[i] = fmt.Sprintf("%s/%d", ip, prefix.Bits())
}
} }
} }
@@ -846,7 +847,11 @@ func anonymizeOverview(a *anonymize.Anonymizer, overview *statusOutputOverview)
} }
for i, route := range overview.Routes { for i, route := range overview.Routes {
overview.Routes[i] = a.AnonymizeRoute(route) prefix, err := netip.ParsePrefix(route)
if err == nil {
ip := a.AnonymizeIPString(prefix.Addr().String())
overview.Routes[i] = fmt.Sprintf("%s/%d", ip, prefix.Bits())
}
} }
overview.FQDN = a.AnonymizeDomain(overview.FQDN) overview.FQDN = a.AnonymizeDomain(overview.FQDN)

View File

@@ -37,6 +37,7 @@ var resp = &proto.StatusResponse{
ConnStatus: "Connected", ConnStatus: "Connected",
ConnStatusUpdate: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 1, 0, time.UTC)), ConnStatusUpdate: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 1, 0, time.UTC)),
Relayed: false, Relayed: false,
Direct: true,
LocalIceCandidateType: "", LocalIceCandidateType: "",
RemoteIceCandidateType: "", RemoteIceCandidateType: "",
LocalIceCandidateEndpoint: "", LocalIceCandidateEndpoint: "",
@@ -56,6 +57,7 @@ var resp = &proto.StatusResponse{
ConnStatus: "Connected", ConnStatus: "Connected",
ConnStatusUpdate: timestamppb.New(time.Date(2002, time.Month(2), 2, 2, 2, 2, 0, time.UTC)), ConnStatusUpdate: timestamppb.New(time.Date(2002, time.Month(2), 2, 2, 2, 2, 0, time.UTC)),
Relayed: true, Relayed: true,
Direct: false,
LocalIceCandidateType: "relay", LocalIceCandidateType: "relay",
RemoteIceCandidateType: "prflx", RemoteIceCandidateType: "prflx",
LocalIceCandidateEndpoint: "10.0.0.1:10001", LocalIceCandidateEndpoint: "10.0.0.1:10001",
@@ -135,6 +137,7 @@ var overview = statusOutputOverview{
Status: "Connected", Status: "Connected",
LastStatusUpdate: time.Date(2001, 1, 1, 1, 1, 1, 0, time.UTC), LastStatusUpdate: time.Date(2001, 1, 1, 1, 1, 1, 0, time.UTC),
ConnType: "P2P", ConnType: "P2P",
Direct: true,
IceCandidateType: iceCandidateType{ IceCandidateType: iceCandidateType{
Local: "", Local: "",
Remote: "", Remote: "",
@@ -158,6 +161,7 @@ var overview = statusOutputOverview{
Status: "Connected", Status: "Connected",
LastStatusUpdate: time.Date(2002, 2, 2, 2, 2, 2, 0, time.UTC), LastStatusUpdate: time.Date(2002, 2, 2, 2, 2, 2, 0, time.UTC),
ConnType: "Relayed", ConnType: "Relayed",
Direct: false,
IceCandidateType: iceCandidateType{ IceCandidateType: iceCandidateType{
Local: "relay", Local: "relay",
Remote: "prflx", Remote: "prflx",
@@ -279,6 +283,7 @@ func TestParsingToJSON(t *testing.T) {
"status": "Connected", "status": "Connected",
"lastStatusUpdate": "2001-01-01T01:01:01Z", "lastStatusUpdate": "2001-01-01T01:01:01Z",
"connectionType": "P2P", "connectionType": "P2P",
"direct": true,
"iceCandidateType": { "iceCandidateType": {
"local": "", "local": "",
"remote": "" "remote": ""
@@ -287,7 +292,6 @@ func TestParsingToJSON(t *testing.T) {
"local": "", "local": "",
"remote": "" "remote": ""
}, },
"relayAddress": "",
"lastWireguardHandshake": "2001-01-01T01:01:02Z", "lastWireguardHandshake": "2001-01-01T01:01:02Z",
"transferReceived": 200, "transferReceived": 200,
"transferSent": 100, "transferSent": 100,
@@ -304,6 +308,7 @@ func TestParsingToJSON(t *testing.T) {
"status": "Connected", "status": "Connected",
"lastStatusUpdate": "2002-02-02T02:02:02Z", "lastStatusUpdate": "2002-02-02T02:02:02Z",
"connectionType": "Relayed", "connectionType": "Relayed",
"direct": false,
"iceCandidateType": { "iceCandidateType": {
"local": "relay", "local": "relay",
"remote": "prflx" "remote": "prflx"
@@ -312,7 +317,6 @@ func TestParsingToJSON(t *testing.T) {
"local": "10.0.0.1:10001", "local": "10.0.0.1:10001",
"remote": "10.0.10.1:10002" "remote": "10.0.10.1:10002"
}, },
"relayAddress": "",
"lastWireguardHandshake": "2002-02-02T02:02:03Z", "lastWireguardHandshake": "2002-02-02T02:02:03Z",
"transferReceived": 2000, "transferReceived": 2000,
"transferSent": 1000, "transferSent": 1000,
@@ -404,13 +408,13 @@ func TestParsingToYAML(t *testing.T) {
status: Connected status: Connected
lastStatusUpdate: 2001-01-01T01:01:01Z lastStatusUpdate: 2001-01-01T01:01:01Z
connectionType: P2P connectionType: P2P
direct: true
iceCandidateType: iceCandidateType:
local: "" local: ""
remote: "" remote: ""
iceCandidateEndpoint: iceCandidateEndpoint:
local: "" local: ""
remote: "" remote: ""
relayAddress: ""
lastWireguardHandshake: 2001-01-01T01:01:02Z lastWireguardHandshake: 2001-01-01T01:01:02Z
transferReceived: 200 transferReceived: 200
transferSent: 100 transferSent: 100
@@ -424,13 +428,13 @@ func TestParsingToYAML(t *testing.T) {
status: Connected status: Connected
lastStatusUpdate: 2002-02-02T02:02:02Z lastStatusUpdate: 2002-02-02T02:02:02Z
connectionType: Relayed connectionType: Relayed
direct: false
iceCandidateType: iceCandidateType:
local: relay local: relay
remote: prflx remote: prflx
iceCandidateEndpoint: iceCandidateEndpoint:
local: 10.0.0.1:10001 local: 10.0.0.1:10001
remote: 10.0.10.1:10002 remote: 10.0.10.1:10002
relayAddress: ""
lastWireguardHandshake: 2002-02-02T02:02:03Z lastWireguardHandshake: 2002-02-02T02:02:03Z
transferReceived: 2000 transferReceived: 2000
transferSent: 1000 transferSent: 1000
@@ -501,9 +505,9 @@ func TestParsingToDetail(t *testing.T) {
Status: Connected Status: Connected
-- detail -- -- detail --
Connection type: P2P Connection type: P2P
Direct: true
ICE candidate (Local/Remote): -/- ICE candidate (Local/Remote): -/-
ICE candidate endpoints (Local/Remote): -/- ICE candidate endpoints (Local/Remote): -/-
Relay server address:
Last connection update: %s Last connection update: %s
Last WireGuard handshake: %s Last WireGuard handshake: %s
Transfer status (received/sent) 200 B/100 B Transfer status (received/sent) 200 B/100 B
@@ -517,9 +521,9 @@ func TestParsingToDetail(t *testing.T) {
Status: Connected Status: Connected
-- detail -- -- detail --
Connection type: Relayed Connection type: Relayed
Direct: false
ICE candidate (Local/Remote): relay/prflx ICE candidate (Local/Remote): relay/prflx
ICE candidate endpoints (Local/Remote): 10.0.0.1:10001/10.0.10.1:10002 ICE candidate endpoints (Local/Remote): 10.0.0.1:10001/10.0.10.1:10002
Relay server address:
Last connection update: %s Last connection update: %s
Last WireGuard handshake: %s Last WireGuard handshake: %s
Transfer status (received/sent) 2.0 KiB/1000 B Transfer status (received/sent) 2.0 KiB/1000 B

View File

@@ -7,18 +7,13 @@ import (
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
"google.golang.org/grpc" "google.golang.org/grpc"
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/management-integrations/integrations"
clientProto "github.com/netbirdio/netbird/client/proto" clientProto "github.com/netbirdio/netbird/client/proto"
client "github.com/netbirdio/netbird/client/server" client "github.com/netbirdio/netbird/client/server"
mgmtProto "github.com/netbirdio/netbird/management/proto" mgmtProto "github.com/netbirdio/netbird/management/proto"
@@ -57,10 +52,7 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
t.Fatal(err) t.Fatal(err)
} }
s := grpc.NewServer() s := grpc.NewServer()
srv, err := sig.NewServer(otel.Meter("")) sigProto.RegisterSignalExchangeServer(s, sig.NewServer())
require.NoError(t, err)
sigProto.RegisterSignalExchangeServer(s, srv)
go func() { go func() {
if err := s.Serve(lis); err != nil { if err := s.Serve(lis); err != nil {
panic(err) panic(err)
@@ -72,35 +64,28 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Listener) { func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Listener) {
t.Helper() t.Helper()
lis, err := net.Listen("tcp", ":0") lis, err := net.Listen("tcp", ":0")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
s := grpc.NewServer() s := grpc.NewServer()
store, cleanUp, err := mgmt.NewTestStoreFromJson(context.Background(), config.Datadir) store, err := mgmt.NewStoreFromJson(config.Datadir, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
t.Cleanup(cleanUp)
peersUpdateManager := mgmt.NewPeersUpdateManager(nil) peersUpdateManager := mgmt.NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{} eventStore := &activity.InMemoryEventStore{}
if err != nil { if err != nil {
return nil, nil return nil, nil
} }
iv, _ := integrations.NewIntegratedValidator(context.Background(), eventStore) iv, _ := integrations.NewIntegratedValidator(eventStore)
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay) mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, peersUpdateManager, secretsManager, nil, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -115,7 +100,7 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
} }
func startClientDaemon( func startClientDaemon(
t *testing.T, ctx context.Context, _, configPath string, t *testing.T, ctx context.Context, managementURL, configPath string,
) (*grpc.Server, net.Listener) { ) (*grpc.Server, net.Listener) {
t.Helper() t.Helper()
lis, err := net.Listen("tcp", "127.0.0.1:0") lis, err := net.Listen("tcp", "127.0.0.1:0")

View File

@@ -7,13 +7,11 @@ import (
"net/netip" "net/netip"
"runtime" "runtime"
"strings" "strings"
"time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status" gstatus "google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/durationpb"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
@@ -42,12 +40,7 @@ func init() {
upCmd.PersistentFlags().BoolVarP(&foregroundMode, "foreground-mode", "F", false, "start service in foreground") upCmd.PersistentFlags().BoolVarP(&foregroundMode, "foreground-mode", "F", false, "start service in foreground")
upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name") upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name")
upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port") upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port")
upCmd.PersistentFlags().BoolVarP(&networkMonitor, networkMonitorFlag, "N", networkMonitor,
`Manage network monitoring. Defaults to true on Windows and macOS, false on Linux. `+
`E.g. --network-monitor=false to disable or --network-monitor=true to enable.`,
)
upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening") upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening")
upCmd.PersistentFlags().DurationVar(&dnsRouteInterval, dnsRouteIntervalFlag, time.Minute, "DNS route update interval")
} }
func upFunc(cmd *cobra.Command, args []string) error { func upFunc(cmd *cobra.Command, args []string) error {
@@ -123,10 +116,6 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
ic.WireguardPort = &p ic.WireguardPort = &p
} }
if cmd.Flag(networkMonitorFlag).Changed {
ic.NetworkMonitor = &networkMonitor
}
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) { if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
ic.PreSharedKey = &preSharedKey ic.PreSharedKey = &preSharedKey
} }
@@ -143,15 +132,6 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
} }
} }
if cmd.Flag(dnsRouteIntervalFlag).Changed {
ic.DNSRouteInterval = &dnsRouteInterval
}
providedSetupKey, err := getSetupKey()
if err != nil {
return err
}
config, err := internal.UpdateOrCreateConfig(ic) config, err := internal.UpdateOrCreateConfig(ic)
if err != nil { if err != nil {
return fmt.Errorf("get config file: %v", err) return fmt.Errorf("get config file: %v", err)
@@ -159,7 +139,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
config, _ = internal.UpdateOldManagementURL(ctx, config, configPath) config, _ = internal.UpdateOldManagementURL(ctx, config, configPath)
err = foregroundLogin(ctx, cmd, config, providedSetupKey) err = foregroundLogin(ctx, cmd, config, setupKey)
if err != nil { if err != nil {
return fmt.Errorf("foreground login failed: %v", err) return fmt.Errorf("foreground login failed: %v", err)
} }
@@ -167,12 +147,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
var cancel context.CancelFunc var cancel context.CancelFunc
ctx, cancel = context.WithCancel(ctx) ctx, cancel = context.WithCancel(ctx)
SetupCloseHandler(ctx, cancel) SetupCloseHandler(ctx, cancel)
return internal.RunClient(ctx, config, peer.NewRecorder(config.ManagementURL.String()))
r := peer.NewRecorder(config.ManagementURL.String())
r.GetFullStatus()
connectClient := internal.NewConnectClient(ctx, config, r)
return connectClient.Run()
} }
func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error { func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
@@ -207,13 +182,8 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
return nil return nil
} }
providedSetupKey, err := getSetupKey()
if err != nil {
return err
}
loginRequest := proto.LoginRequest{ loginRequest := proto.LoginRequest{
SetupKey: providedSetupKey, SetupKey: setupKey,
ManagementUrl: managementURL, ManagementUrl: managementURL,
AdminURL: adminURL, AdminURL: adminURL,
NatExternalIPs: natExternalIPs, NatExternalIPs: natExternalIPs,
@@ -256,14 +226,6 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
loginRequest.WireguardPort = &wp loginRequest.WireguardPort = &wp
} }
if cmd.Flag(networkMonitorFlag).Changed {
loginRequest.NetworkMonitor = &networkMonitor
}
if cmd.Flag(dnsRouteIntervalFlag).Changed {
loginRequest.DnsRouteInterval = durationpb.New(dnsRouteInterval)
}
var loginErr error var loginErr error
var loginResp *proto.LoginResponse var loginResp *proto.LoginResponse

View File

@@ -2,7 +2,6 @@ package cmd
import ( import (
"context" "context"
"os"
"testing" "testing"
"time" "time"
@@ -41,36 +40,6 @@ func TestUpDaemon(t *testing.T) {
return return
} }
// Test the setup-key-file flag.
tempFile, err := os.CreateTemp("", "setup-key")
if err != nil {
t.Errorf("could not create temp file, got error %v", err)
return
}
defer os.Remove(tempFile.Name())
if _, err := tempFile.Write([]byte("A2C8E62B-38F5-4553-B31E-DD66C696CEBB")); err != nil {
t.Errorf("could not write to temp file, got error %v", err)
return
}
if err := tempFile.Close(); err != nil {
t.Errorf("unable to close file, got error %v", err)
}
rootCmd.SetArgs([]string{
"login",
"--daemon-addr", "tcp://" + cliAddr,
"--setup-key-file", tempFile.Name(),
"--log-file", "",
})
if err := rootCmd.Execute(); err != nil {
t.Errorf("expected no error while running up command, got %v", err)
return
}
time.Sleep(time.Second * 3)
if status, err := state.Status(); err != nil && status != internal.StatusIdle {
t.Errorf("wrong status after login: %s, %v", internal.StatusIdle, err)
return
}
rootCmd.SetArgs([]string{ rootCmd.SetArgs([]string{
"up", "up",
"--daemon-addr", "tcp://" + cliAddr, "--daemon-addr", "tcp://" + cliAddr,

View File

@@ -1,30 +0,0 @@
package errors
import (
"fmt"
"strings"
"github.com/hashicorp/go-multierror"
)
func formatError(es []error) string {
if len(es) == 0 {
return fmt.Sprintf("0 error occurred:\n\t* %s", es[0])
}
points := make([]string, len(es))
for i, err := range es {
points[i] = fmt.Sprintf("* %s", err)
}
return fmt.Sprintf(
"%d errors occurred:\n\t%s",
len(es), strings.Join(points, "\n\t"))
}
func FormatErrorOrNil(err *multierror.Error) error {
if err != nil {
err.ErrorFormat = formatError
}
return err.ErrorOrNil()
}

View File

@@ -42,20 +42,20 @@ func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager,
switch check() { switch check() {
case IPTABLES: case IPTABLES:
log.Info("creating an iptables firewall manager") log.Debug("creating an iptables firewall manager")
fm, errFw = nbiptables.Create(context, iface) fm, errFw = nbiptables.Create(context, iface)
if errFw != nil { if errFw != nil {
log.Errorf("failed to create iptables manager: %s", errFw) log.Errorf("failed to create iptables manager: %s", errFw)
} }
case NFTABLES: case NFTABLES:
log.Info("creating an nftables firewall manager") log.Debug("creating an nftables firewall manager")
fm, errFw = nbnftables.Create(context, iface) fm, errFw = nbnftables.Create(context, iface)
if errFw != nil { if errFw != nil {
log.Errorf("failed to create nftables manager: %s", errFw) log.Errorf("failed to create nftables manager: %s", errFw)
} }
default: default:
errFw = fmt.Errorf("no firewall manager found") errFw = fmt.Errorf("no firewall manager found")
log.Info("no firewall manager found, trying to use userspace packet filtering firewall") log.Debug("no firewall manager found, try to use userspace packet filtering firewall")
} }
if iface.IsUserspaceBind() { if iface.IsUserspaceBind() {
@@ -85,58 +85,16 @@ func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager,
// check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found. // check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found.
func check() FWType { func check() FWType {
useIPTABLES := false
var iptablesChains []string
ip, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
if err == nil && isIptablesClientAvailable(ip) {
major, minor, _ := ip.GetIptablesVersion()
// use iptables when its version is lower than 1.8.0 which doesn't work well with our nftables manager
if major < 1 || (major == 1 && minor < 8) {
return IPTABLES
}
useIPTABLES = true
iptablesChains, err = ip.ListChains("filter")
if err != nil {
log.Errorf("failed to list iptables chains: %s", err)
useIPTABLES = false
}
}
nf := nftables.Conn{} nf := nftables.Conn{}
if chains, err := nf.ListChains(); err == nil && os.Getenv(SKIP_NFTABLES_ENV) != "true" { if _, err := nf.ListChains(); err == nil && os.Getenv(SKIP_NFTABLES_ENV) != "true" {
if !useIPTABLES { return NFTABLES
return NFTABLES
}
// search for chains where table is filter
// if we find one, we assume that nftables manager can be used with iptables
for _, chain := range chains {
if chain.Table.Name == "filter" {
return NFTABLES
}
}
// check tables for the following constraints:
// 1. there is no chain in nftables for the filter table and there is at least one chain in iptables, we assume that nftables manager can not be used
// 2. there is no tables or more than one table, we assume that nftables manager can be used
// 3. there is only one table and its name is filter, we assume that nftables manager can not be used, since there was no chain in it
// 4. if we find an error we log and continue with iptables check
nbTablesList, err := nf.ListTables()
switch {
case err == nil && len(iptablesChains) > 0:
return IPTABLES
case err == nil && len(nbTablesList) != 1:
return NFTABLES
case err == nil && len(nbTablesList) == 1 && nbTablesList[0].Name == "filter":
return IPTABLES
case err != nil:
log.Errorf("failed to list nftables tables on fw manager discovery: %s", err)
}
} }
if useIPTABLES { ip, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
if err != nil {
return UNKNOWN
}
if isIptablesClientAvailable(ip) {
return IPTABLES return IPTABLES
} }

View File

@@ -74,12 +74,12 @@ func (i *routerManager) InsertRoutingRules(pair firewall.RouterPair) error {
return nil return nil
} }
err = i.addNATRule(firewall.NatFormat, tableNat, chainRTNAT, routingFinalNatJump, pair) err = i.insertRoutingRule(firewall.NatFormat, tableNat, chainRTNAT, routingFinalNatJump, pair)
if err != nil { if err != nil {
return err return err
} }
err = i.addNATRule(firewall.InNatFormat, tableNat, chainRTNAT, routingFinalNatJump, firewall.GetInPair(pair)) err = i.insertRoutingRule(firewall.InNatFormat, tableNat, chainRTNAT, routingFinalNatJump, firewall.GetInPair(pair))
if err != nil { if err != nil {
return err return err
} }
@@ -87,12 +87,12 @@ func (i *routerManager) InsertRoutingRules(pair firewall.RouterPair) error {
return nil return nil
} }
// insertRoutingRule inserts an iptables rule // insertRoutingRule inserts an iptable rule
func (i *routerManager) insertRoutingRule(keyFormat, table, chain, jump string, pair firewall.RouterPair) error { func (i *routerManager) insertRoutingRule(keyFormat, table, chain, jump string, pair firewall.RouterPair) error {
var err error var err error
ruleKey := firewall.GenKey(keyFormat, pair.ID) ruleKey := firewall.GenKey(keyFormat, pair.ID)
rule := genRuleSpec(jump, pair.Source, pair.Destination) rule := genRuleSpec(jump, ruleKey, pair.Source, pair.Destination)
existingRule, found := i.rules[ruleKey] existingRule, found := i.rules[ruleKey]
if found { if found {
err = i.iptablesClient.DeleteIfExists(table, chain, existingRule...) err = i.iptablesClient.DeleteIfExists(table, chain, existingRule...)
@@ -101,7 +101,6 @@ func (i *routerManager) insertRoutingRule(keyFormat, table, chain, jump string,
} }
delete(i.rules, ruleKey) delete(i.rules, ruleKey)
} }
err = i.iptablesClient.Insert(table, chain, 1, rule...) err = i.iptablesClient.Insert(table, chain, 1, rule...)
if err != nil { if err != nil {
return fmt.Errorf("error while adding new %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err) return fmt.Errorf("error while adding new %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err)
@@ -318,13 +317,6 @@ func (i *routerManager) createChain(table, newChain string) error {
return fmt.Errorf("couldn't create chain %s in %s table, error: %v", newChain, table, err) return fmt.Errorf("couldn't create chain %s in %s table, error: %v", newChain, table, err)
} }
// Add the loopback return rule to the NAT chain
loopbackRule := []string{"-o", "lo", "-j", "RETURN"}
err = i.iptablesClient.Insert(table, newChain, 1, loopbackRule...)
if err != nil {
return fmt.Errorf("failed to add loopback return rule to %s: %v", chainRTNAT, err)
}
err = i.iptablesClient.Append(table, newChain, "-j", "RETURN") err = i.iptablesClient.Append(table, newChain, "-j", "RETURN")
if err != nil { if err != nil {
return fmt.Errorf("couldn't create chain %s default rule, error: %v", newChain, err) return fmt.Errorf("couldn't create chain %s default rule, error: %v", newChain, err)
@@ -334,33 +326,9 @@ func (i *routerManager) createChain(table, newChain string) error {
return nil return nil
} }
// addNATRule appends an iptables rule pair to the nat chain // genRuleSpec generates rule specification with comment identifier
func (i *routerManager) addNATRule(keyFormat, table, chain, jump string, pair firewall.RouterPair) error { func genRuleSpec(jump, id, source, destination string) []string {
ruleKey := firewall.GenKey(keyFormat, pair.ID) return []string{"-s", source, "-d", destination, "-j", jump, "-m", "comment", "--comment", id}
rule := genRuleSpec(jump, pair.Source, pair.Destination)
existingRule, found := i.rules[ruleKey]
if found {
err := i.iptablesClient.DeleteIfExists(table, chain, existingRule...)
if err != nil {
return fmt.Errorf("error while removing existing NAT rule for %s: %v", pair.Destination, err)
}
delete(i.rules, ruleKey)
}
// inserting after loopback ignore rule
err := i.iptablesClient.Insert(table, chain, 2, rule...)
if err != nil {
return fmt.Errorf("error while appending new NAT rule for %s: %v", pair.Destination, err)
}
i.rules[ruleKey] = rule
return nil
}
// genRuleSpec generates rule specification
func genRuleSpec(jump, source, destination string) []string {
return []string{"-s", source, "-d", destination, "-j", jump}
} }
func getIptablesRuleType(table string) string { func getIptablesRuleType(table string) string {

View File

@@ -51,12 +51,14 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
Destination: "100.100.100.0/24", Destination: "100.100.100.0/24",
Masquerade: true, Masquerade: true,
} }
forward4Rule := genRuleSpec(routingFinalForwardJump, pair.Source, pair.Destination) forward4RuleKey := firewall.GenKey(firewall.ForwardingFormat, pair.ID)
forward4Rule := genRuleSpec(routingFinalForwardJump, forward4RuleKey, pair.Source, pair.Destination)
err = manager.iptablesClient.Insert(tableFilter, chainRTFWD, 1, forward4Rule...) err = manager.iptablesClient.Insert(tableFilter, chainRTFWD, 1, forward4Rule...)
require.NoError(t, err, "inserting rule should not return error") require.NoError(t, err, "inserting rule should not return error")
nat4Rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination) nat4RuleKey := firewall.GenKey(firewall.NatFormat, pair.ID)
nat4Rule := genRuleSpec(routingFinalNatJump, nat4RuleKey, pair.Source, pair.Destination)
err = manager.iptablesClient.Insert(tableNat, chainRTNAT, 1, nat4Rule...) err = manager.iptablesClient.Insert(tableNat, chainRTNAT, 1, nat4Rule...)
require.NoError(t, err, "inserting rule should not return error") require.NoError(t, err, "inserting rule should not return error")
@@ -90,7 +92,7 @@ func TestIptablesManager_InsertRoutingRules(t *testing.T) {
require.NoError(t, err, "forwarding pair should be inserted") require.NoError(t, err, "forwarding pair should be inserted")
forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID) forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
forwardRule := genRuleSpec(routingFinalForwardJump, testCase.InputPair.Source, testCase.InputPair.Destination) forwardRule := genRuleSpec(routingFinalForwardJump, forwardRuleKey, testCase.InputPair.Source, testCase.InputPair.Destination)
exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, forwardRule...) exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, forwardRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD) require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
@@ -101,7 +103,7 @@ func TestIptablesManager_InsertRoutingRules(t *testing.T) {
require.Equal(t, forwardRule[:4], foundRule[:4], "stored forwarding rule should match") require.Equal(t, forwardRule[:4], foundRule[:4], "stored forwarding rule should match")
inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID) inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
inForwardRule := genRuleSpec(routingFinalForwardJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination) inForwardRule := genRuleSpec(routingFinalForwardJump, inForwardRuleKey, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
exists, err = iptablesClient.Exists(tableFilter, chainRTFWD, inForwardRule...) exists, err = iptablesClient.Exists(tableFilter, chainRTFWD, inForwardRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD) require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
@@ -112,7 +114,7 @@ func TestIptablesManager_InsertRoutingRules(t *testing.T) {
require.Equal(t, inForwardRule[:4], foundRule[:4], "stored income forwarding rule should match") require.Equal(t, inForwardRule[:4], foundRule[:4], "stored income forwarding rule should match")
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID) natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination) natRule := genRuleSpec(routingFinalNatJump, natRuleKey, testCase.InputPair.Source, testCase.InputPair.Destination)
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, natRule...) exists, err = iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT) require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
@@ -128,7 +130,7 @@ func TestIptablesManager_InsertRoutingRules(t *testing.T) {
} }
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID) inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination) inNatRule := genRuleSpec(routingFinalNatJump, inNatRuleKey, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...) exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT) require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
@@ -165,25 +167,25 @@ func TestIptablesManager_RemoveRoutingRules(t *testing.T) {
require.NoError(t, err, "shouldn't return error") require.NoError(t, err, "shouldn't return error")
forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID) forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
forwardRule := genRuleSpec(routingFinalForwardJump, testCase.InputPair.Source, testCase.InputPair.Destination) forwardRule := genRuleSpec(routingFinalForwardJump, forwardRuleKey, testCase.InputPair.Source, testCase.InputPair.Destination)
err = iptablesClient.Insert(tableFilter, chainRTFWD, 1, forwardRule...) err = iptablesClient.Insert(tableFilter, chainRTFWD, 1, forwardRule...)
require.NoError(t, err, "inserting rule should not return error") require.NoError(t, err, "inserting rule should not return error")
inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID) inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
inForwardRule := genRuleSpec(routingFinalForwardJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination) inForwardRule := genRuleSpec(routingFinalForwardJump, inForwardRuleKey, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
err = iptablesClient.Insert(tableFilter, chainRTFWD, 1, inForwardRule...) err = iptablesClient.Insert(tableFilter, chainRTFWD, 1, inForwardRule...)
require.NoError(t, err, "inserting rule should not return error") require.NoError(t, err, "inserting rule should not return error")
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID) natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination) natRule := genRuleSpec(routingFinalNatJump, natRuleKey, testCase.InputPair.Source, testCase.InputPair.Destination)
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, natRule...) err = iptablesClient.Insert(tableNat, chainRTNAT, 1, natRule...)
require.NoError(t, err, "inserting rule should not return error") require.NoError(t, err, "inserting rule should not return error")
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID) inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination) inNatRule := genRuleSpec(routingFinalNatJump, inNatRuleKey, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, inNatRule...) err = iptablesClient.Insert(tableNat, chainRTNAT, 1, inNatRule...)
require.NoError(t, err, "inserting rule should not return error") require.NoError(t, err, "inserting rule should not return error")

View File

@@ -95,7 +95,7 @@ func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
return m.router.AddRoutingRules(pair) return m.router.InsertRoutingRules(pair)
} }
func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error { func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error {

View File

@@ -22,8 +22,6 @@ const (
userDataAcceptForwardRuleSrc = "frwacceptsrc" userDataAcceptForwardRuleSrc = "frwacceptsrc"
userDataAcceptForwardRuleDst = "frwacceptdst" userDataAcceptForwardRuleDst = "frwacceptdst"
loopbackInterface = "lo\x00"
) )
// some presets for building nftable rules // some presets for building nftable rules
@@ -128,22 +126,6 @@ func (r *router) createContainers() error {
Type: nftables.ChainTypeNAT, Type: nftables.ChainTypeNAT,
}) })
// Add RETURN rule for loopback interface
loRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingNat],
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte(loopbackInterface),
},
&expr.Verdict{Kind: expr.VerdictReturn},
},
}
r.conn.InsertRule(loRule)
err := r.refreshRulesMap() err := r.refreshRulesMap()
if err != nil { if err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err) log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
@@ -156,28 +138,28 @@ func (r *router) createContainers() error {
return nil return nil
} }
// AddRoutingRules appends a nftable rule pair to the forwarding chain and if enabled, to the nat chain // InsertRoutingRules inserts a nftable rule pair to the forwarding chain and if enabled, to the nat chain
func (r *router) AddRoutingRules(pair manager.RouterPair) error { func (r *router) InsertRoutingRules(pair manager.RouterPair) error {
err := r.refreshRulesMap() err := r.refreshRulesMap()
if err != nil { if err != nil {
return err return err
} }
err = r.addRoutingRule(manager.ForwardingFormat, chainNameRouteingFw, pair, false) err = r.insertRoutingRule(manager.ForwardingFormat, chainNameRouteingFw, pair, false)
if err != nil { if err != nil {
return err return err
} }
err = r.addRoutingRule(manager.InForwardingFormat, chainNameRouteingFw, manager.GetInPair(pair), false) err = r.insertRoutingRule(manager.InForwardingFormat, chainNameRouteingFw, manager.GetInPair(pair), false)
if err != nil { if err != nil {
return err return err
} }
if pair.Masquerade { if pair.Masquerade {
err = r.addRoutingRule(manager.NatFormat, chainNameRoutingNat, pair, true) err = r.insertRoutingRule(manager.NatFormat, chainNameRoutingNat, pair, true)
if err != nil { if err != nil {
return err return err
} }
err = r.addRoutingRule(manager.InNatFormat, chainNameRoutingNat, manager.GetInPair(pair), true) err = r.insertRoutingRule(manager.InNatFormat, chainNameRoutingNat, manager.GetInPair(pair), true)
if err != nil { if err != nil {
return err return err
} }
@@ -195,8 +177,8 @@ func (r *router) AddRoutingRules(pair manager.RouterPair) error {
return nil return nil
} }
// addRoutingRule inserts a nftable rule to the conn client flush queue // insertRoutingRule inserts a nftable rule to the conn client flush queue
func (r *router) addRoutingRule(format, chainName string, pair manager.RouterPair, isNat bool) error { func (r *router) insertRoutingRule(format, chainName string, pair manager.RouterPair, isNat bool) error {
sourceExp := generateCIDRMatcherExpressions(true, pair.Source) sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
destExp := generateCIDRMatcherExpressions(false, pair.Destination) destExp := generateCIDRMatcherExpressions(false, pair.Destination)
@@ -217,7 +199,7 @@ func (r *router) addRoutingRule(format, chainName string, pair manager.RouterPai
} }
} }
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{ r.rules[ruleKey] = r.conn.InsertRule(&nftables.Rule{
Table: r.workTable, Table: r.workTable,
Chain: r.chains[chainName], Chain: r.chains[chainName],
Exprs: expression, Exprs: expression,

View File

@@ -47,7 +47,7 @@ func TestNftablesManager_InsertRoutingRules(t *testing.T) {
require.NoError(t, err, "shouldn't return error") require.NoError(t, err, "shouldn't return error")
err = manager.AddRoutingRules(testCase.InputPair) err = manager.InsertRoutingRules(testCase.InputPair)
defer func() { defer func() {
_ = manager.RemoveRoutingRules(testCase.InputPair) _ = manager.RemoveRoutingRules(testCase.InputPair)
}() }()

View File

@@ -64,18 +64,15 @@ func manageFirewallRule(ruleName string, action action, extraArgs ...string) err
if action == addRule { if action == addRule {
args = append(args, extraArgs...) args = append(args, extraArgs...)
} }
netshCmd := GetSystem32Command("netsh")
cmd := exec.Command(netshCmd, args...) cmd := exec.Command("netsh", args...)
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true} cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
return cmd.Run() return cmd.Run()
} }
func isWindowsFirewallReachable() bool { func isWindowsFirewallReachable() bool {
args := []string{"advfirewall", "show", "allprofiles", "state"} args := []string{"advfirewall", "show", "allprofiles", "state"}
cmd := exec.Command("netsh", args...)
netshCmd := GetSystem32Command("netsh")
cmd := exec.Command(netshCmd, args...)
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true} cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
_, err := cmd.Output() _, err := cmd.Output()
@@ -90,23 +87,8 @@ func isWindowsFirewallReachable() bool {
func isFirewallRuleActive(ruleName string) bool { func isFirewallRuleActive(ruleName string) bool {
args := []string{"advfirewall", "firewall", "show", "rule", "name=" + ruleName} args := []string{"advfirewall", "firewall", "show", "rule", "name=" + ruleName}
netshCmd := GetSystem32Command("netsh") cmd := exec.Command("netsh", args...)
cmd := exec.Command(netshCmd, args...)
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true} cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
_, err := cmd.Output() _, err := cmd.Output()
return err == nil return err == nil
} }
// GetSystem32Command checks if a command can be found in the system path and returns it. In case it can't find it
// in the path it will return the full path of a command assuming C:\windows\system32 as the base path.
func GetSystem32Command(command string) string {
_, err := exec.LookPath(command)
if err == nil {
return command
}
log.Tracef("Command %s not found in PATH, using C:\\windows\\system32\\%s.exe path", command, command)
return "C:\\windows\\system32\\" + command + ".exe"
}

View File

@@ -337,6 +337,7 @@ func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decode
if rule.dPort != 0 && rule.dPort == uint16(d.udp.DstPort) { if rule.dPort != 0 && rule.dPort == uint16(d.udp.DstPort) {
return rule.drop, true return rule.drop, true
} }
return rule.drop, true
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6: case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
return rule.drop, true return rule.drop, true
} }

View File

@@ -3,7 +3,6 @@ package auth
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
@@ -181,7 +180,7 @@ func (d *DeviceAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowIn
continue continue
} }
return TokenInfo{}, errors.New(tokenResponse.ErrorDescription) return TokenInfo{}, fmt.Errorf(tokenResponse.ErrorDescription)
} }
tokenInfo := TokenInfo{ tokenInfo := TokenInfo{

View File

@@ -69,11 +69,6 @@ func NewOAuthFlow(ctx context.Context, config *internal.Config, isLinuxDesktopCl
return authenticateWithDeviceCodeFlow(ctx, config) return authenticateWithDeviceCodeFlow(ctx, config)
} }
// On FreeBSD we currently do not support desktop environments and offer only Device Code Flow (#2384)
if runtime.GOOS == "freebsd" {
return authenticateWithDeviceCodeFlow(ctx, config)
}
pkceFlow, err := authenticateWithPKCEFlow(ctx, config) pkceFlow, err := authenticateWithPKCEFlow(ctx, config)
if err != nil { if err != nil {
// fallback to device code flow // fallback to device code flow
@@ -86,7 +81,7 @@ func NewOAuthFlow(ctx context.Context, config *internal.Config, isLinuxDesktopCl
// authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow // authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow
func authenticateWithPKCEFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) { func authenticateWithPKCEFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) {
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL, config.ClientCertKeyPair) pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
if err != nil { if err != nil {
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err) return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err)
} }

View File

@@ -4,7 +4,6 @@ import (
"context" "context"
"crypto/sha256" "crypto/sha256"
"crypto/subtle" "crypto/subtle"
"crypto/tls"
"encoding/base64" "encoding/base64"
"errors" "errors"
"fmt" "fmt"
@@ -144,18 +143,6 @@ func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (
func (p *PKCEAuthorizationFlow) startServer(server *http.Server, tokenChan chan<- *oauth2.Token, errChan chan<- error) { func (p *PKCEAuthorizationFlow) startServer(server *http.Server, tokenChan chan<- *oauth2.Token, errChan chan<- error) {
mux := http.NewServeMux() mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) { mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
cert := p.providerConfig.ClientCertPair
if cert != nil {
tr := &http.Transport{
TLSClientConfig: &tls.Config{
Certificates: []tls.Certificate{*cert},
},
}
sslClient := &http.Client{Transport: tr}
ctx := context.WithValue(req.Context(), oauth2.HTTPClient, sslClient)
req = req.WithContext(ctx)
}
token, err := p.handleRequest(req) token, err := p.handleRequest(req)
if err != nil { if err != nil {
renderPKCEFlowTmpl(w, err) renderPKCEFlowTmpl(w, err)

View File

@@ -2,21 +2,15 @@ package internal
import ( import (
"context" "context"
"crypto/tls"
"fmt" "fmt"
"net/url" "net/url"
"os" "os"
"reflect"
"runtime"
"strings"
"time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
"github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
mgm "github.com/netbirdio/netbird/management/client" mgm "github.com/netbirdio/netbird/management/client"
@@ -54,12 +48,8 @@ type ConfigInput struct {
RosenpassPermissive *bool RosenpassPermissive *bool
InterfaceName *string InterfaceName *string
WireguardPort *int WireguardPort *int
NetworkMonitor *bool
DisableAutoConnect *bool DisableAutoConnect *bool
ExtraIFaceBlackList []string ExtraIFaceBlackList []string
DNSRouteInterval *time.Duration
ClientCertPath string
ClientCertKeyPath string
} }
// Config Configuration type // Config Configuration type
@@ -71,7 +61,6 @@ type Config struct {
AdminURL *url.URL AdminURL *url.URL
WgIface string WgIface string
WgPort int WgPort int
NetworkMonitor *bool
IFaceBlackList []string IFaceBlackList []string
DisableIPv6Discovery bool DisableIPv6Discovery bool
RosenpassEnabled bool RosenpassEnabled bool
@@ -102,38 +91,15 @@ type Config struct {
// DisableAutoConnect determines whether the client should not start with the service // DisableAutoConnect determines whether the client should not start with the service
// it's set to false by default due to backwards compatibility // it's set to false by default due to backwards compatibility
DisableAutoConnect bool DisableAutoConnect bool
// DNSRouteInterval is the interval in which the DNS routes are updated
DNSRouteInterval time.Duration
//Path to a certificate used for mTLS authentication
ClientCertPath string
//Path to corresponding private key of ClientCertPath
ClientCertKeyPath string
ClientCertKeyPair *tls.Certificate `json:"-"`
} }
// ReadConfig read config file and return with Config. If it is not exists create a new with default values // ReadConfig read config file and return with Config. If it is not exists create a new with default values
func ReadConfig(configPath string) (*Config, error) { func ReadConfig(configPath string) (*Config, error) {
if configFileIsExists(configPath) { if configFileIsExists(configPath) {
err := util.EnforcePermission(configPath)
if err != nil {
log.Errorf("failed to enforce permission on config dir: %v", err)
}
config := &Config{} config := &Config{}
if _, err := util.ReadJson(configPath, config); err != nil { if _, err := util.ReadJson(configPath, config); err != nil {
return nil, err return nil, err
} }
// initialize through apply() without changes
if changed, err := config.apply(ConfigInput{}); err != nil {
return nil, err
} else if changed {
if err = WriteOutConfig(configPath, config); err != nil {
return nil, err
}
}
return config, nil return config, nil
} }
@@ -164,17 +130,13 @@ func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = util.WriteJsonWithRestrictedPermission(input.ConfigPath, cfg) err = WriteOutConfig(input.ConfigPath, cfg)
return cfg, err return cfg, err
} }
if isPreSharedKeyHidden(input.PreSharedKey) { if isPreSharedKeyHidden(input.PreSharedKey) {
input.PreSharedKey = nil input.PreSharedKey = nil
} }
err := util.EnforcePermission(input.ConfigPath)
if err != nil {
log.Errorf("failed to enforce permission on config dir: %v", err)
}
return update(input) return update(input)
} }
@@ -190,15 +152,79 @@ func WriteOutConfig(path string, config *Config) error {
// createNewConfig creates a new config generating a new Wireguard key and saving to file // createNewConfig creates a new config generating a new Wireguard key and saving to file
func createNewConfig(input ConfigInput) (*Config, error) { func createNewConfig(input ConfigInput) (*Config, error) {
config := &Config{ wgKey := generateKey()
// defaults to false only for new (post 0.26) configurations pem, err := ssh.GeneratePrivateKey(ssh.ED25519)
ServerSSHAllowed: util.False(), if err != nil {
}
if _, err := config.apply(input); err != nil {
return nil, err return nil, err
} }
config := &Config{
SSHKey: string(pem),
PrivateKey: wgKey,
IFaceBlackList: []string{},
DisableIPv6Discovery: false,
NATExternalIPs: input.NATExternalIPs,
CustomDNSAddress: string(input.CustomDNSAddress),
ServerSSHAllowed: util.False(),
DisableAutoConnect: false,
}
defaultManagementURL, err := parseURL("Management URL", DefaultManagementURL)
if err != nil {
return nil, err
}
config.ManagementURL = defaultManagementURL
if input.ManagementURL != "" {
URL, err := parseURL("Management URL", input.ManagementURL)
if err != nil {
return nil, err
}
config.ManagementURL = URL
}
config.WgPort = iface.DefaultWgPort
if input.WireguardPort != nil {
config.WgPort = *input.WireguardPort
}
config.WgIface = iface.WgInterfaceDefault
if input.InterfaceName != nil {
config.WgIface = *input.InterfaceName
}
if input.PreSharedKey != nil {
config.PreSharedKey = *input.PreSharedKey
}
if input.RosenpassEnabled != nil {
config.RosenpassEnabled = *input.RosenpassEnabled
}
if input.RosenpassPermissive != nil {
config.RosenpassPermissive = *input.RosenpassPermissive
}
if input.ServerSSHAllowed != nil {
config.ServerSSHAllowed = input.ServerSSHAllowed
}
defaultAdminURL, err := parseURL("Admin URL", DefaultAdminURL)
if err != nil {
return nil, err
}
config.AdminURL = defaultAdminURL
if input.AdminURL != "" {
newURL, err := parseURL("Admin Panel URL", input.AdminURL)
if err != nil {
return nil, err
}
config.AdminURL = newURL
}
// nolint:gocritic
config.IFaceBlackList = append(defaultInterfaceBlacklist, input.ExtraIFaceBlackList...)
return config, nil return config, nil
} }
@@ -209,12 +235,104 @@ func update(input ConfigInput) (*Config, error) {
return nil, err return nil, err
} }
updated, err := config.apply(input) refresh := false
if err != nil {
return nil, err if input.ManagementURL != "" && config.ManagementURL.String() != input.ManagementURL {
log.Infof("new Management URL provided, updated to %s (old value %s)",
input.ManagementURL, config.ManagementURL)
newURL, err := parseURL("Management URL", input.ManagementURL)
if err != nil {
return nil, err
}
config.ManagementURL = newURL
refresh = true
} }
if updated { if input.AdminURL != "" && (config.AdminURL == nil || config.AdminURL.String() != input.AdminURL) {
log.Infof("new Admin Panel URL provided, updated to %s (old value %s)",
input.AdminURL, config.AdminURL)
newURL, err := parseURL("Admin Panel URL", input.AdminURL)
if err != nil {
return nil, err
}
config.AdminURL = newURL
refresh = true
}
if input.PreSharedKey != nil && config.PreSharedKey != *input.PreSharedKey {
log.Infof("new pre-shared key provided, replacing old key")
config.PreSharedKey = *input.PreSharedKey
refresh = true
}
if config.SSHKey == "" {
pem, err := ssh.GeneratePrivateKey(ssh.ED25519)
if err != nil {
return nil, err
}
config.SSHKey = string(pem)
refresh = true
}
if config.WgPort == 0 {
config.WgPort = iface.DefaultWgPort
refresh = true
}
if input.WireguardPort != nil {
config.WgPort = *input.WireguardPort
refresh = true
}
if input.InterfaceName != nil {
config.WgIface = *input.InterfaceName
refresh = true
}
if input.NATExternalIPs != nil && len(config.NATExternalIPs) != len(input.NATExternalIPs) {
config.NATExternalIPs = input.NATExternalIPs
refresh = true
}
if input.CustomDNSAddress != nil {
config.CustomDNSAddress = string(input.CustomDNSAddress)
refresh = true
}
if input.RosenpassEnabled != nil {
config.RosenpassEnabled = *input.RosenpassEnabled
refresh = true
}
if input.RosenpassPermissive != nil {
config.RosenpassPermissive = *input.RosenpassPermissive
refresh = true
}
if input.DisableAutoConnect != nil {
config.DisableAutoConnect = *input.DisableAutoConnect
refresh = true
}
if input.ServerSSHAllowed != nil {
config.ServerSSHAllowed = input.ServerSSHAllowed
refresh = true
}
if config.ServerSSHAllowed == nil {
config.ServerSSHAllowed = util.True()
refresh = true
}
if len(input.ExtraIFaceBlackList) > 0 {
for _, iFace := range util.SliceDiff(input.ExtraIFaceBlackList, config.IFaceBlackList) {
config.IFaceBlackList = append(config.IFaceBlackList, iFace)
refresh = true
}
}
if refresh {
// since we have new management URL, we need to update config file
if err := util.WriteJson(input.ConfigPath, config); err != nil { if err := util.WriteJson(input.ConfigPath, config); err != nil {
return nil, err return nil, err
} }
@@ -223,210 +341,6 @@ func update(input ConfigInput) (*Config, error) {
return config, nil return config, nil
} }
func (config *Config) apply(input ConfigInput) (updated bool, err error) {
if config.ManagementURL == nil {
log.Infof("using default Management URL %s", DefaultManagementURL)
config.ManagementURL, err = parseURL("Management URL", DefaultManagementURL)
if err != nil {
return false, err
}
}
if input.ManagementURL != "" && input.ManagementURL != config.ManagementURL.String() {
log.Infof("new Management URL provided, updated to %#v (old value %#v)",
input.ManagementURL, config.ManagementURL.String())
URL, err := parseURL("Management URL", input.ManagementURL)
if err != nil {
return false, err
}
config.ManagementURL = URL
updated = true
} else if config.ManagementURL == nil {
log.Infof("using default Management URL %s", DefaultManagementURL)
config.ManagementURL, err = parseURL("Management URL", DefaultManagementURL)
if err != nil {
return false, err
}
}
if config.AdminURL == nil {
log.Infof("using default Admin URL %s", DefaultManagementURL)
config.AdminURL, err = parseURL("Admin URL", DefaultAdminURL)
if err != nil {
return false, err
}
}
if input.AdminURL != "" && input.AdminURL != config.AdminURL.String() {
log.Infof("new Admin Panel URL provided, updated to %#v (old value %#v)",
input.AdminURL, config.AdminURL.String())
newURL, err := parseURL("Admin Panel URL", input.AdminURL)
if err != nil {
return updated, err
}
config.AdminURL = newURL
updated = true
}
if config.PrivateKey == "" {
log.Infof("generated new Wireguard key")
config.PrivateKey = generateKey()
updated = true
}
if config.SSHKey == "" {
log.Infof("generated new SSH key")
pem, err := ssh.GeneratePrivateKey(ssh.ED25519)
if err != nil {
return false, err
}
config.SSHKey = string(pem)
updated = true
}
if input.WireguardPort != nil && *input.WireguardPort != config.WgPort {
log.Infof("updating Wireguard port %d (old value %d)",
*input.WireguardPort, config.WgPort)
config.WgPort = *input.WireguardPort
updated = true
} else if config.WgPort == 0 {
config.WgPort = iface.DefaultWgPort
log.Infof("using default Wireguard port %d", config.WgPort)
updated = true
}
if input.InterfaceName != nil && *input.InterfaceName != config.WgIface {
log.Infof("updating Wireguard interface %#v (old value %#v)",
*input.InterfaceName, config.WgIface)
config.WgIface = *input.InterfaceName
updated = true
} else if config.WgIface == "" {
config.WgIface = iface.WgInterfaceDefault
log.Infof("using default Wireguard interface %s", config.WgIface)
updated = true
}
if input.NATExternalIPs != nil && !reflect.DeepEqual(config.NATExternalIPs, input.NATExternalIPs) {
log.Infof("updating NAT External IP [ %s ] (old value: [ %s ])",
strings.Join(input.NATExternalIPs, " "),
strings.Join(config.NATExternalIPs, " "))
config.NATExternalIPs = input.NATExternalIPs
updated = true
}
if input.PreSharedKey != nil && *input.PreSharedKey != config.PreSharedKey {
log.Infof("new pre-shared key provided, replacing old key")
config.PreSharedKey = *input.PreSharedKey
updated = true
}
if input.RosenpassEnabled != nil && *input.RosenpassEnabled != config.RosenpassEnabled {
log.Infof("switching Rosenpass to %t", *input.RosenpassEnabled)
config.RosenpassEnabled = *input.RosenpassEnabled
updated = true
}
if input.RosenpassPermissive != nil && *input.RosenpassPermissive != config.RosenpassPermissive {
log.Infof("switching Rosenpass permissive to %t", *input.RosenpassPermissive)
config.RosenpassPermissive = *input.RosenpassPermissive
updated = true
}
if input.NetworkMonitor != nil && input.NetworkMonitor != config.NetworkMonitor {
log.Infof("switching Network Monitor to %t", *input.NetworkMonitor)
config.NetworkMonitor = input.NetworkMonitor
updated = true
}
if config.NetworkMonitor == nil {
// enable network monitoring by default on windows and darwin clients
if runtime.GOOS == "windows" || runtime.GOOS == "darwin" {
enabled := true
config.NetworkMonitor = &enabled
updated = true
}
}
if input.CustomDNSAddress != nil && string(input.CustomDNSAddress) != config.CustomDNSAddress {
log.Infof("updating custom DNS address %#v (old value %#v)",
string(input.CustomDNSAddress), config.CustomDNSAddress)
config.CustomDNSAddress = string(input.CustomDNSAddress)
updated = true
}
if len(config.IFaceBlackList) == 0 {
log.Infof("filling in interface blacklist with defaults: [ %s ]",
strings.Join(defaultInterfaceBlacklist, " "))
config.IFaceBlackList = append(config.IFaceBlackList, defaultInterfaceBlacklist...)
updated = true
}
if len(input.ExtraIFaceBlackList) > 0 {
for _, iFace := range util.SliceDiff(input.ExtraIFaceBlackList, config.IFaceBlackList) {
log.Infof("adding new entry to interface blacklist: %s", iFace)
config.IFaceBlackList = append(config.IFaceBlackList, iFace)
updated = true
}
}
if input.DisableAutoConnect != nil && *input.DisableAutoConnect != config.DisableAutoConnect {
if *input.DisableAutoConnect {
log.Infof("turning off automatic connection on startup")
} else {
log.Infof("enabling automatic connection on startup")
}
config.DisableAutoConnect = *input.DisableAutoConnect
updated = true
}
if input.ServerSSHAllowed != nil && *input.ServerSSHAllowed != *config.ServerSSHAllowed {
if *input.ServerSSHAllowed {
log.Infof("enabling SSH server")
} else {
log.Infof("disabling SSH server")
}
config.ServerSSHAllowed = input.ServerSSHAllowed
updated = true
} else if config.ServerSSHAllowed == nil {
// enables SSH for configs from old versions to preserve backwards compatibility
log.Infof("falling back to enabled SSH server for pre-existing configuration")
config.ServerSSHAllowed = util.True()
updated = true
}
if input.DNSRouteInterval != nil && *input.DNSRouteInterval != config.DNSRouteInterval {
log.Infof("updating DNS route interval to %s (old value %s)",
input.DNSRouteInterval.String(), config.DNSRouteInterval.String())
config.DNSRouteInterval = *input.DNSRouteInterval
updated = true
} else if config.DNSRouteInterval == 0 {
config.DNSRouteInterval = dynamic.DefaultInterval
log.Infof("using default DNS route interval %s", config.DNSRouteInterval)
updated = true
}
if input.ClientCertKeyPath != "" {
config.ClientCertKeyPath = input.ClientCertKeyPath
updated = true
}
if input.ClientCertPath != "" {
config.ClientCertPath = input.ClientCertPath
updated = true
}
if config.ClientCertPath != "" && config.ClientCertKeyPath != "" {
cert, err := tls.LoadX509KeyPair(config.ClientCertPath, config.ClientCertKeyPath)
if err != nil {
log.Error("Failed to load mTLS cert/key pair: ", err)
} else {
config.ClientCertKeyPair = &cert
log.Info("Loaded client mTLS cert/key pair")
}
}
return updated, nil
}
// parseURL parses and validates a service URL // parseURL parses and validates a service URL
func parseURL(serviceName, serviceURL string) (*url.URL, error) { func parseURL(serviceName, serviceURL string) (*url.URL, error) {
parsedMgmtURL, err := url.ParseRequestURI(serviceURL) parsedMgmtURL, err := url.ParseRequestURI(serviceURL)

View File

@@ -4,11 +4,9 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"net"
"runtime" "runtime"
"runtime/debug" "runtime/debug"
"strings" "strings"
"sync"
"time" "time"
"github.com/cenkalti/backoff/v4" "github.com/cenkalti/backoff/v4"
@@ -26,50 +24,35 @@ import (
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
mgm "github.com/netbirdio/netbird/management/client" mgm "github.com/netbirdio/netbird/management/client"
mgmProto "github.com/netbirdio/netbird/management/proto" mgmProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/relay/auth/hmac"
relayClient "github.com/netbirdio/netbird/relay/client"
signal "github.com/netbirdio/netbird/signal/client" signal "github.com/netbirdio/netbird/signal/client"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/version" "github.com/netbirdio/netbird/version"
) )
type ConnectClient struct { // RunClient with main logic.
ctx context.Context func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status) error {
config *Config return runClient(ctx, config, statusRecorder, MobileDependency{}, nil, nil, nil, nil, nil)
statusRecorder *peer.Status
engine *Engine
engineMutex sync.Mutex
} }
func NewConnectClient( // RunClientWithProbes runs the client's main logic with probes attached
func RunClientWithProbes(
ctx context.Context, ctx context.Context,
config *Config, config *Config,
statusRecorder *peer.Status, statusRecorder *peer.Status,
mgmProbe *Probe,
) *ConnectClient { signalProbe *Probe,
return &ConnectClient{ relayProbe *Probe,
ctx: ctx, wgProbe *Probe,
config: config, engineChan chan<- *Engine,
statusRecorder: statusRecorder,
engineMutex: sync.Mutex{},
}
}
// Run with main logic.
func (c *ConnectClient) Run() error {
return c.run(MobileDependency{}, nil, nil)
}
// RunWithProbes runs the client's main logic with probes attached
func (c *ConnectClient) RunWithProbes(
probes *ProbeHolder,
runningChan chan error,
) error { ) error {
return c.run(MobileDependency{}, probes, runningChan) return runClient(ctx, config, statusRecorder, MobileDependency{}, mgmProbe, signalProbe, relayProbe, wgProbe, engineChan)
} }
// RunOnAndroid with main logic on mobile system // RunClientMobile with main logic on mobile system
func (c *ConnectClient) RunOnAndroid( func RunClientMobile(
ctx context.Context,
config *Config,
statusRecorder *peer.Status,
tunAdapter iface.TunAdapter, tunAdapter iface.TunAdapter,
iFaceDiscover stdnet.ExternalIFaceDiscover, iFaceDiscover stdnet.ExternalIFaceDiscover,
networkChangeListener listener.NetworkChangeListener, networkChangeListener listener.NetworkChangeListener,
@@ -84,29 +67,35 @@ func (c *ConnectClient) RunOnAndroid(
HostDNSAddresses: dnsAddresses, HostDNSAddresses: dnsAddresses,
DnsReadyListener: dnsReadyListener, DnsReadyListener: dnsReadyListener,
} }
return c.run(mobileDependency, nil, nil) return runClient(ctx, config, statusRecorder, mobileDependency, nil, nil, nil, nil, nil)
} }
func (c *ConnectClient) RunOniOS( func RunClientiOS(
ctx context.Context,
config *Config,
statusRecorder *peer.Status,
fileDescriptor int32, fileDescriptor int32,
networkChangeListener listener.NetworkChangeListener, networkChangeListener listener.NetworkChangeListener,
dnsManager dns.IosDnsManager, dnsManager dns.IosDnsManager,
) error { ) error {
// Set GC percent to 5% to reduce memory usage as iOS only allows 50MB of memory for the extension.
debug.SetGCPercent(5)
mobileDependency := MobileDependency{ mobileDependency := MobileDependency{
FileDescriptor: fileDescriptor, FileDescriptor: fileDescriptor,
NetworkChangeListener: networkChangeListener, NetworkChangeListener: networkChangeListener,
DnsManager: dnsManager, DnsManager: dnsManager,
} }
return c.run(mobileDependency, nil, nil) return runClient(ctx, config, statusRecorder, mobileDependency, nil, nil, nil, nil, nil)
} }
func (c *ConnectClient) run( func runClient(
ctx context.Context,
config *Config,
statusRecorder *peer.Status,
mobileDependency MobileDependency, mobileDependency MobileDependency,
probes *ProbeHolder, mgmProbe *Probe,
runningChan chan error, signalProbe *Probe,
relayProbe *Probe,
wgProbe *Probe,
engineChan chan<- *Engine,
) error { ) error {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
@@ -118,7 +107,7 @@ func (c *ConnectClient) run(
// Check if client was not shut down in a clean way and restore DNS config if required. // Check if client was not shut down in a clean way and restore DNS config if required.
// Otherwise, we might not be able to connect to the management server to retrieve new config. // Otherwise, we might not be able to connect to the management server to retrieve new config.
if err := dns.CheckUncleanShutdown(c.config.WgIface); err != nil { if err := dns.CheckUncleanShutdown(config.WgIface); err != nil {
log.Errorf("checking unclean shutdown error: %s", err) log.Errorf("checking unclean shutdown error: %s", err)
} }
@@ -132,7 +121,7 @@ func (c *ConnectClient) run(
Clock: backoff.SystemClock, Clock: backoff.SystemClock,
} }
state := CtxGetState(c.ctx) state := CtxGetState(ctx)
defer func() { defer func() {
s, err := state.Status() s, err := state.Status()
if err != nil || s != StatusNeedsLogin { if err != nil || s != StatusNeedsLogin {
@@ -141,50 +130,52 @@ func (c *ConnectClient) run(
}() }()
wrapErr := state.Wrap wrapErr := state.Wrap
myPrivateKey, err := wgtypes.ParseKey(c.config.PrivateKey) myPrivateKey, err := wgtypes.ParseKey(config.PrivateKey)
if err != nil { if err != nil {
log.Errorf("failed parsing Wireguard key %s: [%s]", c.config.PrivateKey, err.Error()) log.Errorf("failed parsing Wireguard key %s: [%s]", config.PrivateKey, err.Error())
return wrapErr(err) return wrapErr(err)
} }
var mgmTlsEnabled bool var mgmTlsEnabled bool
if c.config.ManagementURL.Scheme == "https" { if config.ManagementURL.Scheme == "https" {
mgmTlsEnabled = true mgmTlsEnabled = true
} }
publicSSHKey, err := ssh.GeneratePublicKey([]byte(c.config.SSHKey)) publicSSHKey, err := ssh.GeneratePublicKey([]byte(config.SSHKey))
if err != nil { if err != nil {
return err return err
} }
defer c.statusRecorder.ClientStop() defer statusRecorder.ClientStop()
runningChanOpen := true
operation := func() error { operation := func() error {
// if context cancelled we not start new backoff cycle // if context cancelled we not start new backoff cycle
if c.isContextCancelled() { select {
case <-ctx.Done():
return nil return nil
default:
} }
state.Set(StatusConnecting) state.Set(StatusConnecting)
engineCtx, cancel := context.WithCancel(c.ctx) engineCtx, cancel := context.WithCancel(ctx)
defer func() { defer func() {
c.statusRecorder.MarkManagementDisconnected(state.err) statusRecorder.MarkManagementDisconnected(state.err)
c.statusRecorder.CleanLocalPeerState() statusRecorder.CleanLocalPeerState()
cancel() cancel()
}() }()
log.Debugf("connecting to the Management service %s", c.config.ManagementURL.Host) log.Debugf("connecting to the Management service %s", config.ManagementURL.Host)
mgmClient, err := mgm.NewClient(engineCtx, c.config.ManagementURL.Host, myPrivateKey, mgmTlsEnabled) mgmClient, err := mgm.NewClient(engineCtx, config.ManagementURL.Host, myPrivateKey, mgmTlsEnabled)
if err != nil { if err != nil {
return wrapErr(gstatus.Errorf(codes.FailedPrecondition, "failed connecting to Management Service : %s", err)) return wrapErr(gstatus.Errorf(codes.FailedPrecondition, "failed connecting to Management Service : %s", err))
} }
mgmNotifier := statusRecorderToMgmConnStateNotifier(c.statusRecorder) mgmNotifier := statusRecorderToMgmConnStateNotifier(statusRecorder)
mgmClient.SetConnStateListener(mgmNotifier) mgmClient.SetConnStateListener(mgmNotifier)
log.Debugf("connected to the Management service %s", c.config.ManagementURL.Host) log.Debugf("connected to the Management service %s", config.ManagementURL.Host)
defer func() { defer func() {
if err = mgmClient.Close(); err != nil { err = mgmClient.Close()
if err != nil {
log.Warnf("failed to close the Management service client %v", err) log.Warnf("failed to close the Management service client %v", err)
} }
}() }()
@@ -195,12 +186,11 @@ func (c *ConnectClient) run(
log.Debug(err) log.Debug(err)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) { if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
state.Set(StatusNeedsLogin) state.Set(StatusNeedsLogin)
_ = c.Stop()
return backoff.Permanent(wrapErr(err)) // unrecoverable error return backoff.Permanent(wrapErr(err)) // unrecoverable error
} }
return wrapErr(err) return wrapErr(err)
} }
c.statusRecorder.MarkManagementConnected() statusRecorder.MarkManagementConnected()
localPeerState := peer.LocalPeerState{ localPeerState := peer.LocalPeerState{
IP: loginResp.GetPeerConfig().GetAddress(), IP: loginResp.GetPeerConfig().GetAddress(),
@@ -208,18 +198,19 @@ func (c *ConnectClient) run(
KernelInterface: iface.WireGuardModuleIsLoaded(), KernelInterface: iface.WireGuardModuleIsLoaded(),
FQDN: loginResp.GetPeerConfig().GetFqdn(), FQDN: loginResp.GetPeerConfig().GetFqdn(),
} }
c.statusRecorder.UpdateLocalPeerState(localPeerState)
statusRecorder.UpdateLocalPeerState(localPeerState)
signalURL := fmt.Sprintf("%s://%s", signalURL := fmt.Sprintf("%s://%s",
strings.ToLower(loginResp.GetWiretrusteeConfig().GetSignal().GetProtocol().String()), strings.ToLower(loginResp.GetWiretrusteeConfig().GetSignal().GetProtocol().String()),
loginResp.GetWiretrusteeConfig().GetSignal().GetUri(), loginResp.GetWiretrusteeConfig().GetSignal().GetUri(),
) )
c.statusRecorder.UpdateSignalAddress(signalURL) statusRecorder.UpdateSignalAddress(signalURL)
c.statusRecorder.MarkSignalDisconnected(nil) statusRecorder.MarkSignalDisconnected(nil)
defer func() { defer func() {
c.statusRecorder.MarkSignalDisconnected(state.err) statusRecorder.MarkSignalDisconnected(state.err)
}() }()
// with the global Wiretrustee config in hand connect (just a connection, no stream yet) Signal // with the global Wiretrustee config in hand connect (just a connection, no stream yet) Signal
@@ -235,68 +226,47 @@ func (c *ConnectClient) run(
} }
}() }()
signalNotifier := statusRecorderToSignalConnStateNotifier(c.statusRecorder) signalNotifier := statusRecorderToSignalConnStateNotifier(statusRecorder)
signalClient.SetConnStateListener(signalNotifier) signalClient.SetConnStateListener(signalNotifier)
c.statusRecorder.MarkSignalConnected() statusRecorder.MarkSignalConnected()
relayURLs, token := parseRelayInfo(loginResp)
relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String())
if len(relayURLs) > 0 {
if token != nil {
if err := relayManager.UpdateToken(token); err != nil {
log.Errorf("failed to update token: %s", err)
return wrapErr(err)
}
}
log.Infof("connecting to the Relay service(s): %s", strings.Join(relayURLs, ", "))
if err = relayManager.Serve(); err != nil {
log.Error(err)
return wrapErr(err)
}
c.statusRecorder.SetRelayMgr(relayManager)
}
peerConfig := loginResp.GetPeerConfig() peerConfig := loginResp.GetPeerConfig()
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig) engineConfig, err := createEngineConfig(myPrivateKey, config, peerConfig)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
return wrapErr(err) return wrapErr(err)
} }
checks := loginResp.GetChecks() engine := NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, engineConfig, mobileDependency, statusRecorder, mgmProbe, signalProbe, relayProbe, wgProbe)
err = engine.Start()
c.engineMutex.Lock() if err != nil {
if c.engine != nil && c.engine.ctx.Err() != nil {
log.Info("Stopping Netbird Engine")
if err := c.engine.Stop(); err != nil {
log.Errorf("Failed to stop engine: %v", err)
}
}
c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, probes, checks)
c.engineMutex.Unlock()
if err := c.engine.Start(); err != nil {
log.Errorf("error while starting Netbird Connection Engine: %s", err) log.Errorf("error while starting Netbird Connection Engine: %s", err)
return wrapErr(err) return wrapErr(err)
} }
if engineChan != nil {
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress()) engineChan <- engine
state.Set(StatusConnected)
if runningChan != nil && runningChanOpen {
runningChan <- nil
close(runningChan)
runningChanOpen = false
} }
log.Print("Netbird engine started, my IP is: ", peerConfig.Address)
state.Set(StatusConnected)
<-engineCtx.Done() <-engineCtx.Done()
c.statusRecorder.ClientTeardown() statusRecorder.ClientTeardown()
backOff.Reset() backOff.Reset()
if engineChan != nil {
engineChan <- nil
}
err = engine.Stop()
if err != nil {
log.Errorf("failed stopping engine %v", err)
return wrapErr(err)
}
log.Info("stopped NetBird client") log.Info("stopped NetBird client")
if _, err := state.Status(); errors.Is(err, ErrResetConnection) { if _, err := state.Status(); errors.Is(err, ErrResetConnection) {
@@ -306,72 +276,20 @@ func (c *ConnectClient) run(
return nil return nil
} }
c.statusRecorder.ClientStart() statusRecorder.ClientStart()
err = backoff.Retry(operation, backOff) err = backoff.Retry(operation, backOff)
if err != nil { if err != nil {
log.Debugf("exiting client retry loop due to unrecoverable error: %s", err) log.Debugf("exiting client retry loop due to unrecoverable error: %s", err)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) { if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
state.Set(StatusNeedsLogin) state.Set(StatusNeedsLogin)
_ = c.Stop()
} }
return err return err
} }
return nil return nil
} }
func parseRelayInfo(loginResp *mgmProto.LoginResponse) ([]string, *hmac.Token) {
relayCfg := loginResp.GetWiretrusteeConfig().GetRelay()
if relayCfg == nil {
return nil, nil
}
token := &hmac.Token{
Payload: relayCfg.GetTokenPayload(),
Signature: relayCfg.GetTokenSignature(),
}
return relayCfg.GetUrls(), token
}
func (c *ConnectClient) Engine() *Engine {
if c == nil {
return nil
}
var e *Engine
c.engineMutex.Lock()
e = c.engine
c.engineMutex.Unlock()
return e
}
func (c *ConnectClient) Stop() error {
if c == nil {
return nil
}
c.engineMutex.Lock()
defer c.engineMutex.Unlock()
if c.engine == nil {
return nil
}
return c.engine.Stop()
}
func (c *ConnectClient) isContextCancelled() bool {
select {
case <-c.ctx.Done():
return true
default:
return false
}
}
// createEngineConfig converts configuration received from Management Service to EngineConfig // createEngineConfig converts configuration received from Management Service to EngineConfig
func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) { func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) {
nm := false
if config.NetworkMonitor != nil {
nm = *config.NetworkMonitor
}
engineConf := &EngineConfig{ engineConf := &EngineConfig{
WgIfaceName: config.WgIface, WgIfaceName: config.WgIface,
WgAddr: peerConfig.Address, WgAddr: peerConfig.Address,
@@ -379,14 +297,12 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe
DisableIPv6Discovery: config.DisableIPv6Discovery, DisableIPv6Discovery: config.DisableIPv6Discovery,
WgPrivateKey: key, WgPrivateKey: key,
WgPort: config.WgPort, WgPort: config.WgPort,
NetworkMonitor: nm,
SSHKey: []byte(config.SSHKey), SSHKey: []byte(config.SSHKey),
NATExternalIPs: config.NATExternalIPs, NATExternalIPs: config.NATExternalIPs,
CustomDNSAddress: config.CustomDNSAddress, CustomDNSAddress: config.CustomDNSAddress,
RosenpassEnabled: config.RosenpassEnabled, RosenpassEnabled: config.RosenpassEnabled,
RosenpassPermissive: config.RosenpassPermissive, RosenpassPermissive: config.RosenpassPermissive,
ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed), ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed),
DNSRouteInterval: config.DNSRouteInterval,
} }
if config.PreSharedKey != "" { if config.PreSharedKey != "" {
@@ -397,15 +313,6 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe
engineConf.PreSharedKey = &preSharedKey engineConf.PreSharedKey = &preSharedKey
} }
port, err := freePort(config.WgPort)
if err != nil {
return nil, err
}
if port != config.WgPort {
log.Infof("using %d as wireguard port: %d is in use", port, config.WgPort)
}
engineConf.WgPort = port
return engineConf, nil return engineConf, nil
} }
@@ -455,44 +362,3 @@ func statusRecorderToSignalConnStateNotifier(statusRecorder *peer.Status) signal
notifier, _ := sri.(signal.ConnStateNotifier) notifier, _ := sri.(signal.ConnStateNotifier)
return notifier return notifier
} }
// freePort attempts to determine if the provided port is available, if not it will ask the system for a free port.
func freePort(initPort int) (int, error) {
addr := net.UDPAddr{}
if initPort == 0 {
initPort = iface.DefaultWgPort
}
addr.Port = initPort
conn, err := net.ListenUDP("udp", &addr)
if err == nil {
closeConnWithLog(conn)
return initPort, nil
}
// if the port is already in use, ask the system for a free port
addr.Port = 0
conn, err = net.ListenUDP("udp", &addr)
if err != nil {
return 0, fmt.Errorf("unable to get a free port: %v", err)
}
udpAddr, ok := conn.LocalAddr().(*net.UDPAddr)
if !ok {
return 0, errors.New("wrong address type when getting a free port")
}
closeConnWithLog(conn)
return udpAddr.Port, nil
}
func closeConnWithLog(conn *net.UDPConn) {
startClosing := time.Now()
err := conn.Close()
if err != nil {
log.Warnf("closing probe port %d failed: %v. NetBird will still attempt to use this port for connection.", conn.LocalAddr().(*net.UDPAddr).Port, err)
}
if time.Since(startClosing) > time.Second {
log.Warnf("closing the testing port %d took %s. Usually it is safe to ignore, but continuous warnings may indicate a problem.", conn.LocalAddr().(*net.UDPAddr).Port, time.Since(startClosing))
}
}

View File

@@ -1,61 +0,0 @@
package internal
import (
"net"
"testing"
)
func Test_freePort(t *testing.T) {
tests := []struct {
name string
port int
want int
shouldMatch bool
}{
{
name: "not provided, fallback to default",
port: 0,
want: 51820,
shouldMatch: true,
},
{
name: "provided and available",
port: 51821,
want: 51821,
shouldMatch: true,
},
{
name: "provided and not available",
port: 51830,
want: 51830,
shouldMatch: false,
},
}
c1, err := net.ListenUDP("udp", &net.UDPAddr{Port: 51830})
if err != nil {
t.Errorf("freePort error = %v", err)
}
defer func(c1 *net.UDPConn) {
_ = c1.Close()
}(c1)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := freePort(tt.port)
if err != nil {
t.Errorf("got an error while getting free port: %v", err)
}
if tt.shouldMatch && got != tt.want {
t.Errorf("got a different port %v, want %v", got, tt.want)
}
if !tt.shouldMatch && got == tt.want {
t.Errorf("got the same port %v, want a different port", tt.want)
}
})
}
}

View File

@@ -1,6 +0,0 @@
package dns
const (
fileUncleanShutdownResolvConfLocation = "/var/db/netbird/resolv.conf"
fileUncleanShutdownManagerTypeLocation = "/var/db/netbird/manager"
)

View File

@@ -1,8 +0,0 @@
//go:build !android
package dns
const (
fileUncleanShutdownResolvConfLocation = "/var/lib/netbird/resolv.conf"
fileUncleanShutdownManagerTypeLocation = "/var/lib/netbird/manager"
)

View File

@@ -1,4 +1,4 @@
//go:build (linux && !android) || freebsd //go:build !android
package dns package dns

View File

@@ -1,4 +1,4 @@
//go:build (linux && !android) || freebsd //go:build !android
package dns package dns
@@ -47,20 +47,24 @@ func (f *fileConfigurator) supportCustomPort() bool {
} }
func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig) error { func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig) error {
backupFileExist := f.isBackupFileExist() backupFileExist := false
_, err := os.Stat(fileDefaultResolvConfBackupLocation)
if err == nil {
backupFileExist = true
}
if !config.RouteAll { if !config.RouteAll {
if backupFileExist { if backupFileExist {
f.repair.stopWatchFileChanges() err = f.restore()
err := f.restore()
if err != nil { if err != nil {
return fmt.Errorf("restoring the original resolv.conf file return err: %w", err) return fmt.Errorf("unable to configure DNS for this peer using file manager without a Primary nameserver group. Restoring the original file return err: %w", err)
} }
} }
return fmt.Errorf("unable to configure DNS for this peer using file manager without a nameserver group with all domains configured") return fmt.Errorf("unable to configure DNS for this peer using file manager without a nameserver group with all domains configured")
} }
if !backupFileExist { if !backupFileExist {
err := f.backup() err = f.backup()
if err != nil { if err != nil {
return fmt.Errorf("unable to backup the resolv.conf file: %w", err) return fmt.Errorf("unable to backup the resolv.conf file: %w", err)
} }
@@ -180,11 +184,6 @@ func (f *fileConfigurator) restoreUncleanShutdownDNS(storedDNSAddress *netip.Add
return nil return nil
} }
func (f *fileConfigurator) isBackupFileExist() bool {
_, err := os.Stat(fileDefaultResolvConfBackupLocation)
return err == nil
}
func restoreResolvConfFile() error { func restoreResolvConfFile() error {
log.Debugf("restoring unclean shutdown: restoring %s from %s", defaultResolvConfPath, fileUncleanShutdownResolvConfLocation) log.Debugf("restoring unclean shutdown: restoring %s from %s", defaultResolvConfPath, fileUncleanShutdownResolvConfLocation)

View File

@@ -1,4 +1,4 @@
//go:build (linux && !android) || freebsd //go:build !android
package dns package dns

View File

@@ -1,4 +1,4 @@
//go:build (linux && !android) || freebsd //go:build !android
package dns package dns

View File

@@ -1,4 +1,4 @@
//go:build (linux && !android) || freebsd //go:build !android
package dns package dns

View File

@@ -1,4 +1,4 @@
//go:build (linux && !android) || freebsd //go:build !android
package dns package dns

View File

@@ -1,4 +1,4 @@
//go:build (linux && !android) || freebsd //go:build !android
package dns package dns

View File

@@ -15,12 +15,6 @@ type hostManager interface {
restoreUncleanShutdownDNS(storedDNSAddress *netip.Addr) error restoreUncleanShutdownDNS(storedDNSAddress *netip.Addr) error
} }
type SystemDNSSettings struct {
Domains []string
ServerIP string
ServerPort int
}
type HostDNSConfig struct { type HostDNSConfig struct {
Domains []DomainConfig `json:"domains"` Domains []DomainConfig `json:"domains"`
RouteAll bool `json:"routeAll"` RouteAll bool `json:"routeAll"`

View File

@@ -7,7 +7,6 @@ import (
"bytes" "bytes"
"fmt" "fmt"
"io" "io"
"net"
"net/netip" "net/netip"
"os/exec" "os/exec"
"strconv" "strconv"
@@ -19,7 +18,7 @@ import (
const ( const (
netbirdDNSStateKeyFormat = "State:/Network/Service/NetBird-%s/DNS" netbirdDNSStateKeyFormat = "State:/Network/Service/NetBird-%s/DNS"
globalIPv4State = "State:/Network/Global/IPv4" globalIPv4State = "State:/Network/Global/IPv4"
primaryServiceStateKeyFormat = "State:/Network/Service/%s/DNS" primaryServiceSetupKeyFormat = "Setup:/Network/Service/%s/DNS"
keySupplementalMatchDomains = "SupplementalMatchDomains" keySupplementalMatchDomains = "SupplementalMatchDomains"
keySupplementalMatchDomainsNoSearch = "SupplementalMatchDomainsNoSearch" keySupplementalMatchDomainsNoSearch = "SupplementalMatchDomainsNoSearch"
keyServerAddresses = "ServerAddresses" keyServerAddresses = "ServerAddresses"
@@ -29,12 +28,12 @@ const (
scutilPath = "/usr/sbin/scutil" scutilPath = "/usr/sbin/scutil"
searchSuffix = "Search" searchSuffix = "Search"
matchSuffix = "Match" matchSuffix = "Match"
localSuffix = "Local"
) )
type systemConfigurator struct { type systemConfigurator struct {
createdKeys map[string]struct{} // primaryServiceID primary interface in the system. AKA the interface with the default route
systemDNSSettings SystemDNSSettings primaryServiceID string
createdKeys map[string]struct{}
} }
func newHostManager() (hostManager, error) { func newHostManager() (hostManager, error) {
@@ -50,6 +49,20 @@ func (s *systemConfigurator) supportCustomPort() bool {
func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig) error { func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig) error {
var err error var err error
if config.RouteAll {
err = s.addDNSSetupForAll(config.ServerIP, config.ServerPort)
if err != nil {
return fmt.Errorf("add dns setup for all: %w", err)
}
} else if s.primaryServiceID != "" {
err = s.removeKeyFromSystemConfig(getKeyWithInput(primaryServiceSetupKeyFormat, s.primaryServiceID))
if err != nil {
return fmt.Errorf("remote key from system config: %w", err)
}
s.primaryServiceID = ""
log.Infof("removed %s:%d as main DNS resolver for this peer", config.ServerIP, config.ServerPort)
}
// create a file for unclean shutdown detection // create a file for unclean shutdown detection
if err := createUncleanShutdownIndicator(); err != nil { if err := createUncleanShutdownIndicator(); err != nil {
log.Errorf("failed to create unclean shutdown file: %s", err) log.Errorf("failed to create unclean shutdown file: %s", err)
@@ -60,19 +73,6 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig) error {
matchDomains []string matchDomains []string
) )
err = s.recordSystemDNSSettings(true)
if err != nil {
log.Errorf("unable to update record of System's DNS config: %s", err.Error())
}
if config.RouteAll {
searchDomains = append(searchDomains, "\"\"")
err = s.addLocalDNS()
if err != nil {
log.Infof("failed to enable split DNS")
}
}
for _, dConf := range config.Domains { for _, dConf := range config.Domains {
if dConf.Disabled { if dConf.Disabled {
continue continue
@@ -110,17 +110,23 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig) error {
} }
func (s *systemConfigurator) restoreHostDNS() error { func (s *systemConfigurator) restoreHostDNS() error {
keys := s.getRemovableKeysWithDefaults() lines := ""
for _, key := range keys { for key := range s.createdKeys {
lines += buildRemoveKeyOperation(key)
keyType := "search" keyType := "search"
if strings.Contains(key, matchSuffix) { if strings.Contains(key, matchSuffix) {
keyType = "match" keyType = "match"
} }
log.Infof("removing %s domains from system", keyType) log.Infof("removing %s domains from system", keyType)
err := s.removeKeyFromSystemConfig(key) }
if err != nil { if s.primaryServiceID != "" {
log.Errorf("failed to remove %s domains from system: %s", keyType, err) lines += buildRemoveKeyOperation(getKeyWithInput(primaryServiceSetupKeyFormat, s.primaryServiceID))
} log.Infof("restoring DNS resolver configuration for system")
}
_, err := runSystemConfigCommand(wrapCommand(lines))
if err != nil {
log.Errorf("got an error while cleaning the system configuration: %s", err)
return fmt.Errorf("clean system: %w", err)
} }
if err := removeUncleanShutdownIndicator(); err != nil { if err := removeUncleanShutdownIndicator(); err != nil {
@@ -130,19 +136,6 @@ func (s *systemConfigurator) restoreHostDNS() error {
return nil return nil
} }
func (s *systemConfigurator) getRemovableKeysWithDefaults() []string {
if len(s.createdKeys) == 0 {
// return defaults for startup calls
return []string{getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix), getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)}
}
keys := make([]string, 0, len(s.createdKeys))
for key := range s.createdKeys {
keys = append(keys, key)
}
return keys
}
func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error { func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error {
line := buildRemoveKeyOperation(key) line := buildRemoveKeyOperation(key)
_, err := runSystemConfigCommand(wrapCommand(line)) _, err := runSystemConfigCommand(wrapCommand(line))
@@ -155,97 +148,6 @@ func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error {
return nil return nil
} }
func (s *systemConfigurator) addLocalDNS() error {
if s.systemDNSSettings.ServerIP == "" || len(s.systemDNSSettings.Domains) == 0 {
err := s.recordSystemDNSSettings(true)
log.Errorf("Unable to get system DNS configuration")
return err
}
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
if s.systemDNSSettings.ServerIP != "" && len(s.systemDNSSettings.Domains) != 0 {
err := s.addSearchDomains(localKey, strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort)
if err != nil {
return fmt.Errorf("couldn't add local network DNS conf: %w", err)
}
} else {
log.Info("Not enabling local DNS server")
}
return nil
}
func (s *systemConfigurator) recordSystemDNSSettings(force bool) error {
if s.systemDNSSettings.ServerIP != "" && len(s.systemDNSSettings.Domains) != 0 && !force {
return nil
}
systemDNSSettings, err := s.getSystemDNSSettings()
if err != nil {
return fmt.Errorf("couldn't get current DNS config: %w", err)
}
s.systemDNSSettings = systemDNSSettings
return nil
}
func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
primaryServiceKey, _, err := s.getPrimaryService()
if err != nil || primaryServiceKey == "" {
return SystemDNSSettings{}, fmt.Errorf("couldn't find the primary service key: %w", err)
}
dnsServiceKey := getKeyWithInput(primaryServiceStateKeyFormat, primaryServiceKey)
line := buildCommandLine("show", dnsServiceKey, "")
stdinCommands := wrapCommand(line)
b, err := runSystemConfigCommand(stdinCommands)
if err != nil {
return SystemDNSSettings{}, fmt.Errorf("sending the command: %w", err)
}
var dnsSettings SystemDNSSettings
inSearchDomainsArray := false
inServerAddressesArray := false
scanner := bufio.NewScanner(bytes.NewReader(b))
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
switch {
case strings.HasPrefix(line, "DomainName :"):
domainName := strings.TrimSpace(strings.Split(line, ":")[1])
dnsSettings.Domains = append(dnsSettings.Domains, domainName)
case line == "SearchDomains : <array> {":
inSearchDomainsArray = true
continue
case line == "ServerAddresses : <array> {":
inServerAddressesArray = true
continue
case line == "}":
inSearchDomainsArray = false
inServerAddressesArray = false
}
if inSearchDomainsArray {
searchDomain := strings.Split(line, " : ")[1]
dnsSettings.Domains = append(dnsSettings.Domains, searchDomain)
} else if inServerAddressesArray {
address := strings.Split(line, " : ")[1]
if ip := net.ParseIP(address); ip != nil && ip.To4() != nil {
dnsSettings.ServerIP = address
inServerAddressesArray = false // Stop reading after finding the first IPv4 address
}
}
}
if err := scanner.Err(); err != nil {
return dnsSettings, err
}
// default to 53 port
dnsSettings.ServerPort = 53
return dnsSettings, nil
}
func (s *systemConfigurator) addSearchDomains(key, domains string, ip string, port int) error { func (s *systemConfigurator) addSearchDomains(key, domains string, ip string, port int) error {
err := s.addDNSState(key, domains, ip, port, true) err := s.addDNSState(key, domains, ip, port, true)
if err != nil { if err != nil {
@@ -292,6 +194,23 @@ func (s *systemConfigurator) addDNSState(state, domains, dnsServer string, port
return nil return nil
} }
func (s *systemConfigurator) addDNSSetupForAll(dnsServer string, port int) error {
primaryServiceKey, existingNameserver, err := s.getPrimaryService()
if err != nil || primaryServiceKey == "" {
return fmt.Errorf("couldn't find the primary service key: %w", err)
}
err = s.addDNSSetup(getKeyWithInput(primaryServiceSetupKeyFormat, primaryServiceKey), dnsServer, port, existingNameserver)
if err != nil {
return fmt.Errorf("add dns setup: %w", err)
}
log.Infof("configured %s:%d as main DNS resolver for this peer", dnsServer, port)
s.primaryServiceID = primaryServiceKey
return nil
}
func (s *systemConfigurator) getPrimaryService() (string, string, error) { func (s *systemConfigurator) getPrimaryService() (string, string, error) {
line := buildCommandLine("show", globalIPv4State, "") line := buildCommandLine("show", globalIPv4State, "")
stdinCommands := wrapCommand(line) stdinCommands := wrapCommand(line)
@@ -320,6 +239,19 @@ func (s *systemConfigurator) getPrimaryService() (string, string, error) {
return primaryService, router, nil return primaryService, router, nil
} }
func (s *systemConfigurator) addDNSSetup(setupKey, dnsServer string, port int, existingDNSServer string) error {
lines := buildAddCommandLine(keySupplementalMatchDomainsNoSearch, digitSymbol+strconv.Itoa(0))
lines += buildAddCommandLine(keyServerAddresses, arraySymbol+dnsServer+" "+existingDNSServer)
lines += buildAddCommandLine(keyServerPort, digitSymbol+strconv.Itoa(port))
addDomainCommand := buildCreateStateWithOperation(setupKey, lines)
stdinCommands := wrapCommand(addDomainCommand)
_, err := runSystemConfigCommand(stdinCommands)
if err != nil {
return fmt.Errorf("applying dns setup, error: %w", err)
}
return nil
}
func (s *systemConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error { func (s *systemConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error {
if err := s.restoreHostDNS(); err != nil { if err := s.restoreHostDNS(); err != nil {
return fmt.Errorf("restoring dns via scutil: %w", err) return fmt.Errorf("restoring dns via scutil: %w", err)

View File

@@ -1,4 +1,4 @@
//go:build (linux && !android) || freebsd //go:build !android
package dns package dns
@@ -108,7 +108,7 @@ func getOSDNSManagerType() (osManagerType, error) {
if strings.Contains(text, "NetworkManager") && isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() { if strings.Contains(text, "NetworkManager") && isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() {
return networkManager, nil return networkManager, nil
} }
if strings.Contains(text, "systemd-resolved") && isSystemdResolvedRunning() { if strings.Contains(text, "systemd-resolved") && isDbusListenerRunning(systemdResolvedDest, systemdDbusObjectNode) {
if checkStub() { if checkStub() {
return systemdManager, nil return systemdManager, nil
} else { } else {
@@ -116,10 +116,16 @@ func getOSDNSManagerType() (osManagerType, error) {
} }
} }
if strings.Contains(text, "resolvconf") { if strings.Contains(text, "resolvconf") {
if isSystemdResolveConfMode() { if isDbusListenerRunning(systemdResolvedDest, systemdDbusObjectNode) {
return systemdManager, nil var value string
err = getSystemdDbusProperty(systemdDbusResolvConfModeProperty, &value)
if err == nil {
if value == systemdDbusResolvConfModeForeign {
return systemdManager, nil
}
}
log.Errorf("got an error while checking systemd resolv conf mode, error: %s", err)
} }
return resolvConfManager, nil return resolvConfManager, nil
} }
} }

View File

@@ -1,63 +0,0 @@
package dns
import (
"fmt"
"net/netip"
"sync"
log "github.com/sirupsen/logrus"
)
type hostsDNSHolder struct {
unprotectedDNSList map[string]struct{}
mutex sync.RWMutex
}
func newHostsDNSHolder() *hostsDNSHolder {
return &hostsDNSHolder{
unprotectedDNSList: make(map[string]struct{}),
}
}
func (h *hostsDNSHolder) set(list []string) {
h.mutex.Lock()
h.unprotectedDNSList = make(map[string]struct{})
for _, dns := range list {
dnsAddr, err := h.normalizeAddress(dns)
if err != nil {
continue
}
h.unprotectedDNSList[dnsAddr] = struct{}{}
}
h.mutex.Unlock()
}
func (h *hostsDNSHolder) get() map[string]struct{} {
h.mutex.RLock()
l := h.unprotectedDNSList
h.mutex.RUnlock()
return l
}
//nolint:unused
func (h *hostsDNSHolder) isContain(upstream string) bool {
h.mutex.RLock()
defer h.mutex.RUnlock()
_, ok := h.unprotectedDNSList[upstream]
return ok
}
func (h *hostsDNSHolder) normalizeAddress(addr string) (string, error) {
a, err := netip.ParseAddr(addr)
if err != nil {
log.Errorf("invalid upstream IP address: %s, error: %s", addr, err)
return "", err
}
if a.Is4() {
return fmt.Sprintf("%s:53", addr), nil
} else {
return fmt.Sprintf("[%s]:53", addr), nil
}
}

View File

@@ -1,4 +1,4 @@
//go:build (linux && !android) || freebsd //go:build !android
package dns package dns

View File

@@ -1,4 +1,4 @@
//go:build (linux && !android) || freebsd //go:build !android
package dns package dns

View File

@@ -4,7 +4,6 @@ import (
"context" "context"
"fmt" "fmt"
"net/netip" "net/netip"
"runtime"
"strings" "strings"
"sync" "sync"
@@ -55,8 +54,9 @@ type DefaultServer struct {
currentConfig HostDNSConfig currentConfig HostDNSConfig
// permanent related properties // permanent related properties
permanent bool permanent bool
hostsDNSHolder *hostsDNSHolder hostsDnsList []string
hostsDnsListLock sync.Mutex
// make sense on mobile only // make sense on mobile only
searchDomainNotifier *notifier searchDomainNotifier *notifier
@@ -94,7 +94,7 @@ func NewDefaultServer(
var dnsService service var dnsService service
if wgInterface.IsUserspaceBind() { if wgInterface.IsUserspaceBind() {
dnsService = NewServiceViaMemory(wgInterface) dnsService = newServiceViaMemory(wgInterface)
} else { } else {
dnsService = newServiceViaListener(wgInterface, addrPort) dnsService = newServiceViaListener(wgInterface, addrPort)
} }
@@ -112,9 +112,9 @@ func NewDefaultServerPermanentUpstream(
statusRecorder *peer.Status, statusRecorder *peer.Status,
) *DefaultServer { ) *DefaultServer {
log.Debugf("host dns address list is: %v", hostsDnsList) log.Debugf("host dns address list is: %v", hostsDnsList)
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder) ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface), statusRecorder)
ds.hostsDNSHolder.set(hostsDnsList)
ds.permanent = true ds.permanent = true
ds.hostsDnsList = hostsDnsList
ds.addHostRootZone() ds.addHostRootZone()
ds.currentConfig = dnsConfigToHostDNSConfig(config, ds.service.RuntimeIP(), ds.service.RuntimePort()) ds.currentConfig = dnsConfigToHostDNSConfig(config, ds.service.RuntimeIP(), ds.service.RuntimePort())
ds.searchDomainNotifier = newNotifier(ds.SearchDomains()) ds.searchDomainNotifier = newNotifier(ds.SearchDomains())
@@ -130,7 +130,7 @@ func NewDefaultServerIos(
iosDnsManager IosDnsManager, iosDnsManager IosDnsManager,
statusRecorder *peer.Status, statusRecorder *peer.Status,
) *DefaultServer { ) *DefaultServer {
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder) ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface), statusRecorder)
ds.iosDnsManager = iosDnsManager ds.iosDnsManager = iosDnsManager
return ds return ds
} }
@@ -147,7 +147,6 @@ func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService servi
}, },
wgInterface: wgInterface, wgInterface: wgInterface,
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
hostsDNSHolder: newHostsDNSHolder(),
} }
return defaultServer return defaultServer
@@ -203,8 +202,10 @@ func (s *DefaultServer) Stop() {
// OnUpdatedHostDNSServer update the DNS servers addresses for root zones // OnUpdatedHostDNSServer update the DNS servers addresses for root zones
// It will be applied if the mgm server do not enforce DNS settings for root zone // It will be applied if the mgm server do not enforce DNS settings for root zone
func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []string) { func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []string) {
s.hostsDNSHolder.set(hostsDnsList) s.hostsDnsListLock.Lock()
defer s.hostsDnsListLock.Unlock()
s.hostsDnsList = hostsDnsList
_, ok := s.dnsMuxMap[nbdns.RootZone] _, ok := s.dnsMuxMap[nbdns.RootZone]
if ok { if ok {
log.Debugf("on new host DNS config but skip to apply it") log.Debugf("on new host DNS config but skip to apply it")
@@ -373,7 +374,6 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
s.wgInterface.Address().IP, s.wgInterface.Address().IP,
s.wgInterface.Address().Network, s.wgInterface.Address().Network,
s.statusRecorder, s.statusRecorder,
s.hostsDNSHolder,
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to create a new upstream resolver, error: %v", err) return nil, fmt.Errorf("unable to create a new upstream resolver, error: %v", err)
@@ -452,7 +452,9 @@ func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
_, found := muxUpdateMap[key] _, found := muxUpdateMap[key]
if !found { if !found {
if !isContainRootUpdate && key == nbdns.RootZone { if !isContainRootUpdate && key == nbdns.RootZone {
s.hostsDnsListLock.Lock()
s.addHostRootZone() s.addHostRootZone()
s.hostsDnsListLock.Unlock()
existingHandler.stop() existingHandler.stop()
} else { } else {
existingHandler.stop() existingHandler.stop()
@@ -510,7 +512,6 @@ func (s *DefaultServer) upstreamCallbacks(
if nsGroup.Primary { if nsGroup.Primary {
removeIndex[nbdns.RootZone] = -1 removeIndex[nbdns.RootZone] = -1
s.currentConfig.RouteAll = false s.currentConfig.RouteAll = false
s.service.DeregisterMux(nbdns.RootZone)
} }
for i, item := range s.currentConfig.Domains { for i, item := range s.currentConfig.Domains {
@@ -520,15 +521,10 @@ func (s *DefaultServer) upstreamCallbacks(
removeIndex[item.Domain] = i removeIndex[item.Domain] = i
} }
} }
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil { if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
l.Errorf("Failed to apply nameserver deactivation on the host: %v", err) l.Errorf("Failed to apply nameserver deactivation on the host: %v", err)
} }
if runtime.GOOS == "android" && nsGroup.Primary && len(s.hostsDNSHolder.get()) > 0 {
s.addHostRootZone()
}
s.updateNSState(nsGroup, err, false) s.updateNSState(nsGroup, err, false)
} }
@@ -549,7 +545,6 @@ func (s *DefaultServer) upstreamCallbacks(
if nsGroup.Primary { if nsGroup.Primary {
s.currentConfig.RouteAll = true s.currentConfig.RouteAll = true
s.service.RegisterMux(nbdns.RootZone, handler)
} }
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil { if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply") l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply")
@@ -567,16 +562,25 @@ func (s *DefaultServer) addHostRootZone() {
s.wgInterface.Address().IP, s.wgInterface.Address().IP,
s.wgInterface.Address().Network, s.wgInterface.Address().Network,
s.statusRecorder, s.statusRecorder,
s.hostsDNSHolder,
) )
if err != nil { if err != nil {
log.Errorf("unable to create a new upstream resolver, error: %v", err) log.Errorf("unable to create a new upstream resolver, error: %v", err)
return return
} }
handler.upstreamServers = make([]string, len(s.hostsDnsList))
for n, ua := range s.hostsDnsList {
a, err := netip.ParseAddr(ua)
if err != nil {
log.Errorf("invalid upstream IP address: %s, error: %s", ua, err)
continue
}
handler.upstreamServers = make([]string, 0) ipString := ua
for k := range s.hostsDNSHolder.get() { if !a.Is4() {
handler.upstreamServers = append(handler.upstreamServers, k) ipString = fmt.Sprintf("[%s]", ua)
}
handler.upstreamServers[n] = fmt.Sprintf("%s:53", ipString)
} }
handler.deactivate = func(error) {} handler.deactivate = func(error) {}
handler.reactivate = func() {} handler.reactivate = func() {}

View File

@@ -1,4 +1,4 @@
//go:build (linux && !android) || freebsd //go:build !android
package dns package dns

View File

@@ -39,10 +39,6 @@ func (w *mocWGIface) Address() iface.WGAddress {
} }
} }
func (w *mocWGIface) ToInterface() *net.Interface {
panic("implement me")
}
func (w *mocWGIface) GetFilter() iface.PacketFilter { func (w *mocWGIface) GetFilter() iface.PacketFilter {
return w.filter return w.filter
} }
@@ -265,7 +261,7 @@ func TestUpdateDNSServer(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), 33100, privKey.String(), iface.DefaultMTU, newNet, nil, nil) wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), 33100, privKey.String(), iface.DefaultMTU, newNet, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -343,7 +339,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
} }
privKey, _ := wgtypes.GeneratePrivateKey() privKey, _ := wgtypes.GeneratePrivateKey()
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.1/32", 33100, privKey.String(), iface.DefaultMTU, newNet, nil, nil) wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.1/32", 33100, privKey.String(), iface.DefaultMTU, newNet, nil)
if err != nil { if err != nil {
t.Errorf("build interface wireguard: %v", err) t.Errorf("build interface wireguard: %v", err)
return return
@@ -534,7 +530,7 @@ func TestDNSServerStartStop(t *testing.T) {
func TestDNSServerUpstreamDeactivateCallback(t *testing.T) { func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
hostManager := &mockHostConfigurator{} hostManager := &mockHostConfigurator{}
server := DefaultServer{ server := DefaultServer{
service: NewServiceViaMemory(&mocWGIface{}), service: newServiceViaMemory(&mocWGIface{}),
localResolver: &localResolver{ localResolver: &localResolver{
registeredMap: make(registrationMap), registeredMap: make(registrationMap),
}, },
@@ -801,7 +797,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
} }
privKey, _ := wgtypes.GeneratePrivateKey() privKey, _ := wgtypes.GeneratePrivateKey()
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.2/24", 33100, privKey.String(), iface.DefaultMTU, newNet, nil, nil) wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.2/24", 33100, privKey.String(), iface.DefaultMTU, newNet, nil)
if err != nil { if err != nil {
t.Fatalf("build interface wireguard: %v", err) t.Fatalf("build interface wireguard: %v", err)
return nil, err return nil, err

View File

@@ -128,9 +128,6 @@ func (s *serviceViaListener) RuntimeIP() string {
} }
func (s *serviceViaListener) setListenerStatus(running bool) { func (s *serviceViaListener) setListenerStatus(running bool) {
s.listenerFlagLock.Lock()
defer s.listenerFlagLock.Unlock()
s.listenerIsRunning = running s.listenerIsRunning = running
} }

View File

@@ -12,7 +12,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
type ServiceViaMemory struct { type serviceViaMemory struct {
wgInterface WGIface wgInterface WGIface
dnsMux *dns.ServeMux dnsMux *dns.ServeMux
runtimeIP string runtimeIP string
@@ -22,8 +22,8 @@ type ServiceViaMemory struct {
listenerFlagLock sync.Mutex listenerFlagLock sync.Mutex
} }
func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory { func newServiceViaMemory(wgIface WGIface) *serviceViaMemory {
s := &ServiceViaMemory{ s := &serviceViaMemory{
wgInterface: wgIface, wgInterface: wgIface,
dnsMux: dns.NewServeMux(), dnsMux: dns.NewServeMux(),
@@ -33,7 +33,7 @@ func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory {
return s return s
} }
func (s *ServiceViaMemory) Listen() error { func (s *serviceViaMemory) Listen() error {
s.listenerFlagLock.Lock() s.listenerFlagLock.Lock()
defer s.listenerFlagLock.Unlock() defer s.listenerFlagLock.Unlock()
@@ -52,7 +52,7 @@ func (s *ServiceViaMemory) Listen() error {
return nil return nil
} }
func (s *ServiceViaMemory) Stop() { func (s *serviceViaMemory) Stop() {
s.listenerFlagLock.Lock() s.listenerFlagLock.Lock()
defer s.listenerFlagLock.Unlock() defer s.listenerFlagLock.Unlock()
@@ -67,23 +67,23 @@ func (s *ServiceViaMemory) Stop() {
s.listenerIsRunning = false s.listenerIsRunning = false
} }
func (s *ServiceViaMemory) RegisterMux(pattern string, handler dns.Handler) { func (s *serviceViaMemory) RegisterMux(pattern string, handler dns.Handler) {
s.dnsMux.Handle(pattern, handler) s.dnsMux.Handle(pattern, handler)
} }
func (s *ServiceViaMemory) DeregisterMux(pattern string) { func (s *serviceViaMemory) DeregisterMux(pattern string) {
s.dnsMux.HandleRemove(pattern) s.dnsMux.HandleRemove(pattern)
} }
func (s *ServiceViaMemory) RuntimePort() int { func (s *serviceViaMemory) RuntimePort() int {
return s.runtimePort return s.runtimePort
} }
func (s *ServiceViaMemory) RuntimeIP() string { func (s *serviceViaMemory) RuntimeIP() string {
return s.runtimeIP return s.runtimeIP
} }
func (s *ServiceViaMemory) filterDNSTraffic() (string, error) { func (s *serviceViaMemory) filterDNSTraffic() (string, error) {
filter := s.wgInterface.GetFilter() filter := s.wgInterface.GetFilter()
if filter == nil { if filter == nil {
return "", fmt.Errorf("can't set DNS filter, filter not initialized") return "", fmt.Errorf("can't set DNS filter, filter not initialized")

View File

@@ -1,20 +0,0 @@
package dns
import (
"errors"
"fmt"
)
var errNotImplemented = errors.New("not implemented")
func newSystemdDbusConfigurator(wgInterface string) (hostManager, error) {
return nil, fmt.Errorf("systemd dns management: %w on freebsd", errNotImplemented)
}
func isSystemdResolvedRunning() bool {
return false
}
func isSystemdResolveConfMode() bool {
return false
}

View File

@@ -242,25 +242,3 @@ func getSystemdDbusProperty(property string, store any) error {
return v.Store(store) return v.Store(store)
} }
func isSystemdResolvedRunning() bool {
return isDbusListenerRunning(systemdResolvedDest, systemdDbusObjectNode)
}
func isSystemdResolveConfMode() bool {
if !isDbusListenerRunning(systemdResolvedDest, systemdDbusObjectNode) {
return false
}
var value string
if err := getSystemdDbusProperty(systemdDbusResolvConfModeProperty, &value); err != nil {
log.Errorf("got an error while checking systemd resolv conf mode, error: %s", err)
return false
}
if value == systemdDbusResolvConfModeForeign {
return true
}
return false
}

View File

@@ -1,4 +1,4 @@
//go:build (linux && !android) || freebsd //go:build !android
package dns package dns
@@ -14,6 +14,11 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
const (
fileUncleanShutdownResolvConfLocation = "/var/lib/netbird/resolv.conf"
fileUncleanShutdownManagerTypeLocation = "/var/lib/netbird/manager"
)
func CheckUncleanShutdown(wgIface string) error { func CheckUncleanShutdown(wgIface string) error {
if _, err := os.Stat(fileUncleanShutdownResolvConfLocation); err != nil { if _, err := os.Stat(fileUncleanShutdownResolvConfLocation); err != nil {
if errors.Is(err, fs.ErrNotExist) { if errors.Is(err, fs.ErrNotExist) {

View File

@@ -5,6 +5,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"net" "net"
"runtime"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
@@ -24,7 +25,7 @@ const (
probeTimeout = 2 * time.Second probeTimeout = 2 * time.Second
) )
const testRecord = "com." const testRecord = "."
type upstreamClient interface { type upstreamClient interface {
exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error) exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
@@ -42,7 +43,6 @@ type upstreamResolverBase struct {
upstreamServers []string upstreamServers []string
disabled bool disabled bool
failsCount atomic.Int32 failsCount atomic.Int32
successCount atomic.Int32
failsTillDeact int32 failsTillDeact int32
mutex sync.Mutex mutex sync.Mutex
reactivatePeriod time.Duration reactivatePeriod time.Duration
@@ -79,11 +79,6 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
}() }()
log.WithField("question", r.Question[0]).Trace("received an upstream question") log.WithField("question", r.Question[0]).Trace("received an upstream question")
// set the AuthenticatedData flag and the EDNS0 buffer size to 4096 bytes to support larger dns records
if r.Extra == nil {
r.SetEdns0(4096, false)
r.MsgHdr.AuthenticatedData = true
}
select { select {
case <-u.ctx.Done(): case <-u.ctx.Done():
@@ -125,7 +120,6 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return return
} }
u.successCount.Add(1)
log.Tracef("took %s to query the upstream %s", t, upstream) log.Tracef("took %s to query the upstream %s", t, upstream)
err = w.WriteMsg(rm) err = w.WriteMsg(rm)
@@ -174,11 +168,6 @@ func (u *upstreamResolverBase) probeAvailability() {
default: default:
} }
// avoid probe if upstreams could resolve at least one query and fails count is less than failsTillDeact
if u.successCount.Load() > 0 && u.failsCount.Load() < u.failsTillDeact {
return
}
var success bool var success bool
var mu sync.Mutex var mu sync.Mutex
var wg sync.WaitGroup var wg sync.WaitGroup
@@ -190,7 +179,7 @@ func (u *upstreamResolverBase) probeAvailability() {
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
err := u.testNameserver(upstream, 500*time.Millisecond) err := u.testNameserver(upstream)
if err != nil { if err != nil {
errors = multierror.Append(errors, err) errors = multierror.Append(errors, err)
log.Warnf("probing upstream nameserver %s: %s", upstream, err) log.Warnf("probing upstream nameserver %s: %s", upstream, err)
@@ -231,7 +220,7 @@ func (u *upstreamResolverBase) waitUntilResponse() {
} }
for _, upstream := range u.upstreamServers { for _, upstream := range u.upstreamServers {
if err := u.testNameserver(upstream, probeTimeout); err != nil { if err := u.testNameserver(upstream); err != nil {
log.Tracef("upstream check for %s: %s", upstream, err) log.Tracef("upstream check for %s: %s", upstream, err)
} else { } else {
// at least one upstream server is available, stop probing // at least one upstream server is available, stop probing
@@ -251,7 +240,6 @@ func (u *upstreamResolverBase) waitUntilResponse() {
log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServers) log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServers)
u.failsCount.Store(0) u.failsCount.Store(0)
u.successCount.Add(1)
u.reactivate() u.reactivate()
u.disabled = false u.disabled = false
} }
@@ -272,15 +260,17 @@ func (u *upstreamResolverBase) disable(err error) {
return return
} }
log.Warnf("Upstream resolving is Disabled for %v", reactivatePeriod) // todo test the deactivation logic, it seems to affect the client
u.successCount.Store(0) if runtime.GOOS != "ios" {
u.deactivate(err) log.Warnf("Upstream resolving is Disabled for %v", reactivatePeriod)
u.disabled = true u.deactivate(err)
go u.waitUntilResponse() u.disabled = true
go u.waitUntilResponse()
}
} }
func (u *upstreamResolverBase) testNameserver(server string, timeout time.Duration) error { func (u *upstreamResolverBase) testNameserver(server string) error {
ctx, cancel := context.WithTimeout(u.ctx, timeout) ctx, cancel := context.WithTimeout(u.ctx, probeTimeout)
defer cancel() defer cancel()
r := new(dns.Msg).SetQuestion(testRecord, dns.TypeSOA) r := new(dns.Msg).SetQuestion(testRecord, dns.TypeSOA)

View File

@@ -1,84 +0,0 @@
package dns
import (
"context"
"net"
"syscall"
"time"
"github.com/miekg/dns"
"github.com/netbirdio/netbird/client/internal/peer"
nbnet "github.com/netbirdio/netbird/util/net"
)
type upstreamResolver struct {
*upstreamResolverBase
hostsDNSHolder *hostsDNSHolder
}
// newUpstreamResolver in Android we need to distinguish the DNS servers to available through VPN or outside of VPN
// In case if the assigned DNS address is available only in the protected network then the resolver will time out at the
// first time, and we need to wait for a while to start to use again the proper DNS resolver.
func newUpstreamResolver(
ctx context.Context,
_ string,
_ net.IP,
_ *net.IPNet,
statusRecorder *peer.Status,
hostsDNSHolder *hostsDNSHolder,
) (*upstreamResolver, error) {
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder)
c := &upstreamResolver{
upstreamResolverBase: upstreamResolverBase,
hostsDNSHolder: hostsDNSHolder,
}
upstreamResolverBase.upstreamClient = c
return c, nil
}
// exchange in case of Android if the upstream is a local resolver then we do not need to mark the socket as protected.
// In other case the DNS resolvation goes through the VPN, so we need to force to use the
func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
if u.isLocalResolver(upstream) {
return u.exchangeWithoutVPN(ctx, upstream, r)
} else {
return u.exchangeWithinVPN(ctx, upstream, r)
}
}
func (u *upstreamResolver) exchangeWithinVPN(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
upstreamExchangeClient := &dns.Client{}
return upstreamExchangeClient.ExchangeContext(ctx, r, upstream)
}
// exchangeWithoutVPN protect the UDP socket by Android SDK to avoid to goes through the VPN
func (u *upstreamResolver) exchangeWithoutVPN(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
timeout := upstreamTimeout
if deadline, ok := ctx.Deadline(); ok {
timeout = time.Until(deadline)
}
dialTimeout := timeout
nbDialer := nbnet.NewDialer()
dialer := &net.Dialer{
Control: func(network, address string, c syscall.RawConn) error {
return nbDialer.Control(network, address, c)
},
Timeout: dialTimeout,
}
upstreamExchangeClient := &dns.Client{
Dialer: dialer,
}
return upstreamExchangeClient.Exchange(r, upstream)
}
func (u *upstreamResolver) isLocalResolver(upstream string) bool {
if u.hostsDNSHolder.isContain(upstream) {
return true
}
return false
}

View File

@@ -4,7 +4,6 @@ package dns
import ( import (
"context" "context"
"fmt"
"net" "net"
"syscall" "syscall"
"time" "time"
@@ -18,9 +17,9 @@ import (
type upstreamResolverIOS struct { type upstreamResolverIOS struct {
*upstreamResolverBase *upstreamResolverBase
lIP net.IP lIP net.IP
lNet *net.IPNet lNet *net.IPNet
interfaceName string iIndex int
} }
func newUpstreamResolver( func newUpstreamResolver(
@@ -29,15 +28,20 @@ func newUpstreamResolver(
ip net.IP, ip net.IP,
net *net.IPNet, net *net.IPNet,
statusRecorder *peer.Status, statusRecorder *peer.Status,
_ *hostsDNSHolder,
) (*upstreamResolverIOS, error) { ) (*upstreamResolverIOS, error) {
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder) upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder)
index, err := getInterfaceIndex(interfaceName)
if err != nil {
log.Debugf("unable to get interface index for %s: %s", interfaceName, err)
return nil, err
}
ios := &upstreamResolverIOS{ ios := &upstreamResolverIOS{
upstreamResolverBase: upstreamResolverBase, upstreamResolverBase: upstreamResolverBase,
lIP: ip, lIP: ip,
lNet: net, lNet: net,
interfaceName: interfaceName, iIndex: index,
} }
ios.upstreamClient = ios ios.upstreamClient = ios
@@ -48,7 +52,7 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
client := &dns.Client{} client := &dns.Client{}
upstreamHost, _, err := net.SplitHostPort(upstream) upstreamHost, _, err := net.SplitHostPort(upstream)
if err != nil { if err != nil {
return nil, 0, fmt.Errorf("error while parsing upstream host: %s", err) log.Errorf("error while parsing upstream host: %s", err)
} }
timeout := upstreamTimeout timeout := upstreamTimeout
@@ -60,35 +64,26 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
upstreamIP := net.ParseIP(upstreamHost) upstreamIP := net.ParseIP(upstreamHost)
if u.lNet.Contains(upstreamIP) || net.IP.IsPrivate(upstreamIP) { if u.lNet.Contains(upstreamIP) || net.IP.IsPrivate(upstreamIP) {
log.Debugf("using private client to query upstream: %s", upstream) log.Debugf("using private client to query upstream: %s", upstream)
client, err = GetClientPrivate(u.lIP, u.interfaceName, timeout) client = u.getClientPrivate(timeout)
if err != nil {
return nil, 0, fmt.Errorf("error while creating private client: %s", err)
}
} }
// Cannot use client.ExchangeContext because it overwrites our Dialer // Cannot use client.ExchangeContext because it overwrites our Dialer
return client.Exchange(r, upstream) return client.Exchange(r, upstream)
} }
// GetClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface // getClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface
// This method is needed for iOS // This method is needed for iOS
func GetClientPrivate(ip net.IP, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) { func (u *upstreamResolverIOS) getClientPrivate(dialTimeout time.Duration) *dns.Client {
index, err := getInterfaceIndex(interfaceName)
if err != nil {
log.Debugf("unable to get interface index for %s: %s", interfaceName, err)
return nil, err
}
dialer := &net.Dialer{ dialer := &net.Dialer{
LocalAddr: &net.UDPAddr{ LocalAddr: &net.UDPAddr{
IP: ip, IP: u.lIP,
Port: 0, // Let the OS pick a free port Port: 0, // Let the OS pick a free port
}, },
Timeout: dialTimeout, Timeout: dialTimeout,
Control: func(network, address string, c syscall.RawConn) error { Control: func(network, address string, c syscall.RawConn) error {
var operr error var operr error
fn := func(s uintptr) { fn := func(s uintptr) {
operr = unix.SetsockoptInt(int(s), unix.IPPROTO_IP, unix.IP_BOUND_IF, index) operr = unix.SetsockoptInt(int(s), unix.IPPROTO_IP, unix.IP_BOUND_IF, u.iIndex)
} }
if err := c.Control(fn); err != nil { if err := c.Control(fn); err != nil {
@@ -105,7 +100,7 @@ func GetClientPrivate(ip net.IP, interfaceName string, dialTimeout time.Duration
client := &dns.Client{ client := &dns.Client{
Dialer: dialer, Dialer: dialer,
} }
return client, nil return client
} }
func getInterfaceIndex(interfaceName string) (int, error) { func getInterfaceIndex(interfaceName string) (int, error) {

View File

@@ -1,4 +1,4 @@
//go:build !android && !ios //go:build !ios
package dns package dns
@@ -12,7 +12,7 @@ import (
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
) )
type upstreamResolver struct { type upstreamResolverNonIOS struct {
*upstreamResolverBase *upstreamResolverBase
} }
@@ -22,17 +22,16 @@ func newUpstreamResolver(
_ net.IP, _ net.IP,
_ *net.IPNet, _ *net.IPNet,
statusRecorder *peer.Status, statusRecorder *peer.Status,
_ *hostsDNSHolder, ) (*upstreamResolverNonIOS, error) {
) (*upstreamResolver, error) {
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder) upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder)
nonIOS := &upstreamResolver{ nonIOS := &upstreamResolverNonIOS{
upstreamResolverBase: upstreamResolverBase, upstreamResolverBase: upstreamResolverBase,
} }
upstreamResolverBase.upstreamClient = nonIOS upstreamResolverBase.upstreamClient = nonIOS
return nonIOS, nil return nonIOS, nil
} }
func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) { func (u *upstreamResolverNonIOS) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
upstreamExchangeClient := &dns.Client{} upstreamExchangeClient := &dns.Client{}
return upstreamExchangeClient.ExchangeContext(ctx, r, upstream) return upstreamExchangeClient.ExchangeContext(ctx, r, upstream)
} }

View File

@@ -58,7 +58,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
for _, testCase := range testCases { for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.TODO()) ctx, cancel := context.WithCancel(context.TODO())
resolver, _ := newUpstreamResolver(ctx, "", net.IP{}, &net.IPNet{}, nil, nil) resolver, _ := newUpstreamResolver(ctx, "", net.IP{}, &net.IPNet{}, nil)
resolver.upstreamServers = testCase.InputServers resolver.upstreamServers = testCase.InputServers
resolver.upstreamTimeout = testCase.timeout resolver.upstreamTimeout = testCase.timeout
if testCase.cancelCTX { if testCase.cancelCTX {

View File

@@ -2,17 +2,12 @@
package dns package dns
import ( import "github.com/netbirdio/netbird/iface"
"net"
"github.com/netbirdio/netbird/iface"
)
// WGIface defines subset methods of interface required for manager // WGIface defines subset methods of interface required for manager
type WGIface interface { type WGIface interface {
Name() string Name() string
Address() iface.WGAddress Address() iface.WGAddress
ToInterface() *net.Interface
IsUserspaceBind() bool IsUserspaceBind() bool
GetFilter() iface.PacketFilter GetFilter() iface.PacketFilter
GetDevice() *iface.DeviceWrapper GetDevice() *iface.DeviceWrapper

View File

@@ -2,18 +2,14 @@ package internal
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"maps"
"math/rand" "math/rand"
"net" "net"
"net/netip" "net/netip"
"reflect" "reflect"
"runtime" "runtime"
"slices"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/pion/ice/v3" "github.com/pion/ice/v3"
@@ -25,29 +21,21 @@ import (
"github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/acl" "github.com/netbirdio/netbird/client/internal/acl"
"github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/networkmonitor"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/relay" "github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/client/internal/rosenpass" "github.com/netbirdio/netbird/client/internal/rosenpass"
"github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/wgproxy" "github.com/netbirdio/netbird/client/internal/wgproxy"
nbssh "github.com/netbirdio/netbird/client/ssh" nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/iface/bind" "github.com/netbirdio/netbird/iface/bind"
mgm "github.com/netbirdio/netbird/management/client" mgm "github.com/netbirdio/netbird/management/client"
"github.com/netbirdio/netbird/management/domain"
mgmProto "github.com/netbirdio/netbird/management/proto" mgmProto "github.com/netbirdio/netbird/management/proto"
auth "github.com/netbirdio/netbird/relay/auth/hmac"
relayClient "github.com/netbirdio/netbird/relay/client"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
signal "github.com/netbirdio/netbird/signal/client" signal "github.com/netbirdio/netbird/signal/client"
sProto "github.com/netbirdio/netbird/signal/proto" sProto "github.com/netbirdio/netbird/signal/proto"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
nbnet "github.com/netbirdio/netbird/util/net"
) )
// PeerConnectionTimeoutMax is a timeout of an initial connection attempt to a remote peer. // PeerConnectionTimeoutMax is a timeout of an initial connection attempt to a remote peer.
@@ -72,9 +60,6 @@ type EngineConfig struct {
// WgPrivateKey is a Wireguard private key of our peer (it MUST never leave the machine) // WgPrivateKey is a Wireguard private key of our peer (it MUST never leave the machine)
WgPrivateKey wgtypes.Key WgPrivateKey wgtypes.Key
// NetworkMonitor is a flag to enable network monitoring
NetworkMonitor bool
// IFaceBlackList is a list of network interfaces to ignore when discovering connection candidates (ICE related) // IFaceBlackList is a list of network interfaces to ignore when discovering connection candidates (ICE related)
IFaceBlackList []string IFaceBlackList []string
DisableIPv6Discovery bool DisableIPv6Discovery bool
@@ -98,22 +83,19 @@ type EngineConfig struct {
RosenpassPermissive bool RosenpassPermissive bool
ServerSSHAllowed bool ServerSSHAllowed bool
DNSRouteInterval time.Duration
} }
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers. // Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
type Engine struct { type Engine struct {
// signal is a Signal Service client // signal is a Signal Service client
signal signal.Client signal signal.Client
signaler *peer.Signaler
// mgmClient is a Management Service client // mgmClient is a Management Service client
mgmClient mgm.Client mgmClient mgm.Client
// peerConns is a map that holds all the peers that are known to this peer // peerConns is a map that holds all the peers that are known to this peer
peerConns map[string]*peer.Conn peerConns map[string]*peer.Conn
beforePeerHook nbnet.AddHookFunc beforePeerHook peer.BeforeAddPeerHookFunc
afterPeerHook nbnet.RemoveHookFunc afterPeerHook peer.AfterRemovePeerHookFunc
// rpManager is a Rosenpass manager // rpManager is a Rosenpass manager
rpManager *rosenpass.Manager rpManager *rosenpass.Manager
@@ -127,20 +109,16 @@ type Engine struct {
// STUNs is a list of STUN servers used by ICE // STUNs is a list of STUN servers used by ICE
STUNs []*stun.URI STUNs []*stun.URI
// TURNs is a list of STUN servers used by ICE // TURNs is a list of STUN servers used by ICE
TURNs []*stun.URI TURNs []*stun.URI
stunTurn atomic.Value
// clientRoutes is the most recent list of clientRoutes received from the Management Service // clientRoutes is the most recent list of clientRoutes received from the Management Service
clientRoutes route.HAMap clientRoutes map[string][]*route.Route
clientRoutesMu sync.RWMutex
clientCtx context.Context
clientCancel context.CancelFunc
ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
wgInterface iface.IWGIface ctx context.Context
wgInterface *iface.WGIface
wgProxyFactory *wgproxy.Factory wgProxyFactory *wgproxy.Factory
udpMux *bind.UniversalUDPMuxDefault udpMux *bind.UniversalUDPMuxDefault
@@ -148,8 +126,6 @@ type Engine struct {
// networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service // networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service
networkSerial uint64 networkSerial uint64
networkMonitor *networkmonitor.NetworkMonitor
sshServerFunc func(hostKeyPEM []byte, addr string) (nbssh.Server, error) sshServerFunc func(hostKeyPEM []byte, addr string) (nbssh.Server, error)
sshServer nbssh.Server sshServer nbssh.Server
@@ -161,12 +137,10 @@ type Engine struct {
dnsServer dns.Server dnsServer dns.Server
probes *ProbeHolder mgmProbe *Probe
signalProbe *Probe
// checks are the client-applied posture checks that need to be evaluated on the client relayProbe *Probe
checks []*mgmProto.Checks wgProbe *Probe
relayManager *relayClient.Manager
} }
// Peer is an instance of the Connection Peer // Peer is an instance of the Connection Peer
@@ -177,50 +151,48 @@ type Peer struct {
// NewEngine creates a new Connection Engine // NewEngine creates a new Connection Engine
func NewEngine( func NewEngine(
clientCtx context.Context, ctx context.Context,
clientCancel context.CancelFunc, cancel context.CancelFunc,
signalClient signal.Client, signalClient signal.Client,
mgmClient mgm.Client, mgmClient mgm.Client,
relayManager *relayClient.Manager,
config *EngineConfig, config *EngineConfig,
mobileDep MobileDependency, mobileDep MobileDependency,
statusRecorder *peer.Status, statusRecorder *peer.Status,
checks []*mgmProto.Checks,
) *Engine { ) *Engine {
return NewEngineWithProbes( return NewEngineWithProbes(
clientCtx, ctx,
clientCancel, cancel,
signalClient, signalClient,
mgmClient, mgmClient,
relayManager,
config, config,
mobileDep, mobileDep,
statusRecorder, statusRecorder,
nil, nil,
checks, nil,
nil,
nil,
) )
} }
// NewEngineWithProbes creates a new Connection Engine with probes attached // NewEngineWithProbes creates a new Connection Engine with probes attached
func NewEngineWithProbes( func NewEngineWithProbes(
clientCtx context.Context, ctx context.Context,
clientCancel context.CancelFunc, cancel context.CancelFunc,
signalClient signal.Client, signalClient signal.Client,
mgmClient mgm.Client, mgmClient mgm.Client,
relayManager *relayClient.Manager,
config *EngineConfig, config *EngineConfig,
mobileDep MobileDependency, mobileDep MobileDependency,
statusRecorder *peer.Status, statusRecorder *peer.Status,
probes *ProbeHolder, mgmProbe *Probe,
checks []*mgmProto.Checks, signalProbe *Probe,
relayProbe *Probe,
wgProbe *Probe,
) *Engine { ) *Engine {
return &Engine{ return &Engine{
clientCtx: clientCtx, ctx: ctx,
clientCancel: clientCancel, cancel: cancel,
signal: signalClient, signal: signalClient,
signaler: peer.NewSignaler(signalClient, config.WgPrivateKey),
mgmClient: mgmClient, mgmClient: mgmClient,
relayManager: relayManager,
peerConns: make(map[string]*peer.Conn), peerConns: make(map[string]*peer.Conn),
syncMsgMux: &sync.Mutex{}, syncMsgMux: &sync.Mutex{},
config: config, config: config,
@@ -230,41 +202,27 @@ func NewEngineWithProbes(
networkSerial: 0, networkSerial: 0,
sshServerFunc: nbssh.DefaultSSHServer, sshServerFunc: nbssh.DefaultSSHServer,
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
probes: probes, wgProxyFactory: wgproxy.NewFactory(config.WgPort),
checks: checks, mgmProbe: mgmProbe,
signalProbe: signalProbe,
relayProbe: relayProbe,
wgProbe: wgProbe,
} }
} }
func (e *Engine) Stop() error { func (e *Engine) Stop() error {
if e == nil {
// this seems to be a very odd case but there was the possibility if the netbird down command comes before the engine is fully started
log.Debugf("tried stopping engine that is nil")
return nil
}
e.syncMsgMux.Lock() e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock() defer e.syncMsgMux.Unlock()
// stopping network monitor first to avoid starting the engine again
if e.networkMonitor != nil {
e.networkMonitor.Stop()
}
log.Info("Network monitor: stopped")
err := e.removeAllPeers() err := e.removeAllPeers()
if err != nil { if err != nil {
return fmt.Errorf("failed to remove all peers: %s", err) return err
} }
e.clientRoutesMu.Lock()
e.clientRoutes = nil e.clientRoutes = nil
e.clientRoutesMu.Unlock()
if e.cancel != nil {
e.cancel()
}
// very ugly but we want to remove peers from the WireGuard interface first before removing interface. // very ugly but we want to remove peers from the WireGuard interface first before removing interface.
// Removing peers happens in the conn.Close() asynchronously // Removing peers happens in the conn.CLose() asynchronously
time.Sleep(500 * time.Millisecond) time.Sleep(500 * time.Millisecond)
e.close() e.close()
@@ -279,11 +237,6 @@ func (e *Engine) Start() error {
e.syncMsgMux.Lock() e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock() defer e.syncMsgMux.Unlock()
if e.cancel != nil {
e.cancel()
}
e.ctx, e.cancel = context.WithCancel(e.clientCtx)
wgIface, err := e.newWgIface() wgIface, err := e.newWgIface()
if err != nil { if err != nil {
log.Errorf("failed creating wireguard interface instance %s: [%s]", e.config.WgIfaceName, err) log.Errorf("failed creating wireguard interface instance %s: [%s]", e.config.WgIfaceName, err)
@@ -291,9 +244,6 @@ func (e *Engine) Start() error {
} }
e.wgInterface = wgIface e.wgInterface = wgIface
userspace := e.wgInterface.IsUserspaceBind()
e.wgProxyFactory = wgproxy.NewFactory(e.ctx, userspace, e.config.WgPort)
if e.config.RosenpassEnabled { if e.config.RosenpassEnabled {
log.Infof("rosenpass is enabled") log.Infof("rosenpass is enabled")
if e.config.RosenpassPermissive { if e.config.RosenpassPermissive {
@@ -318,7 +268,7 @@ func (e *Engine) Start() error {
} }
e.dnsServer = dnsServer e.dnsServer = dnsServer
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.config.DNSRouteInterval, e.wgInterface, e.statusRecorder, e.relayManager, initialRoutes) e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder, initialRoutes)
beforePeerHook, afterPeerHook, err := e.routeManager.Init() beforePeerHook, afterPeerHook, err := e.routeManager.Init()
if err != nil { if err != nil {
log.Errorf("Failed to initialize route manager: %s", err) log.Errorf("Failed to initialize route manager: %s", err)
@@ -370,9 +320,6 @@ func (e *Engine) Start() error {
e.receiveManagementEvents() e.receiveManagementEvents()
e.receiveProbeEvents() e.receiveProbeEvents()
// starting network monitor at the very last to avoid disruptions
e.startNetworkMonitor()
return nil return nil
} }
@@ -467,52 +414,83 @@ func (e *Engine) removePeer(peerKey string) error {
conn, exists := e.peerConns[peerKey] conn, exists := e.peerConns[peerKey]
if exists { if exists {
delete(e.peerConns, peerKey) delete(e.peerConns, peerKey)
conn.Close() err := conn.Close()
if err != nil {
switch err.(type) {
case *peer.ConnectionAlreadyClosedError:
return nil
default:
return err
}
}
} }
return nil return nil
} }
func signalCandidate(candidate ice.Candidate, myKey wgtypes.Key, remoteKey wgtypes.Key, s signal.Client) error {
err := s.Send(&sProto.Message{
Key: myKey.PublicKey().String(),
RemoteKey: remoteKey.String(),
Body: &sProto.Body{
Type: sProto.Body_CANDIDATE,
Payload: candidate.Marshal(),
},
})
if err != nil {
return err
}
return nil
}
func sendSignal(message *sProto.Message, s signal.Client) error {
return s.Send(message)
}
// SignalOfferAnswer signals either an offer or an answer to remote peer
func SignalOfferAnswer(offerAnswer peer.OfferAnswer, myKey wgtypes.Key, remoteKey wgtypes.Key, s signal.Client,
isAnswer bool) error {
var t sProto.Body_Type
if isAnswer {
t = sProto.Body_ANSWER
} else {
t = sProto.Body_OFFER
}
msg, err := signal.MarshalCredential(myKey, offerAnswer.WgListenPort, remoteKey, &signal.Credential{
UFrag: offerAnswer.IceCredentials.UFrag,
Pwd: offerAnswer.IceCredentials.Pwd,
}, t, offerAnswer.RosenpassPubKey, offerAnswer.RosenpassAddr)
if err != nil {
return err
}
err = s.Send(msg)
if err != nil {
return err
}
return nil
}
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error { func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
e.syncMsgMux.Lock() e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock() defer e.syncMsgMux.Unlock()
if update.GetWiretrusteeConfig() != nil { if update.GetWiretrusteeConfig() != nil {
wCfg := update.GetWiretrusteeConfig() err := e.updateTURNs(update.GetWiretrusteeConfig().GetTurns())
err := e.updateTURNs(wCfg.GetTurns())
if err != nil { if err != nil {
return fmt.Errorf("update TURNs: %w", err) return err
} }
err = e.updateSTUNs(wCfg.GetStuns()) err = e.updateSTUNs(update.GetWiretrusteeConfig().GetStuns())
if err != nil { if err != nil {
return fmt.Errorf("update STUNs: %w", err) return err
} }
var stunTurn []*stun.URI
stunTurn = append(stunTurn, e.STUNs...)
stunTurn = append(stunTurn, e.TURNs...)
e.stunTurn.Store(stunTurn)
relayMsg := wCfg.GetRelay()
if relayMsg != nil {
c := &auth.Token{
Payload: relayMsg.GetTokenPayload(),
Signature: relayMsg.GetTokenSignature(),
}
if err := e.relayManager.UpdateToken(c); err != nil {
log.Errorf("failed to update relay token: %v", err)
return fmt.Errorf("update relay token: %w", err)
}
}
// todo update relay address in the relay manager
// todo update signal // todo update signal
} }
if err := e.updateChecksIfNew(update.Checks); err != nil {
return err
}
if update.GetNetworkMap() != nil { if update.GetNetworkMap() != nil {
// only apply new changes and ignore old ones // only apply new changes and ignore old ones
err := e.updateNetworkMap(update.GetNetworkMap()) err := e.updateNetworkMap(update.GetNetworkMap())
@@ -520,27 +498,7 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
return err return err
} }
} }
return nil
}
// updateChecksIfNew updates checks if there are changes and sync new meta with management
func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
// if checks are equal, we skip the update
if isChecksEqual(e.checks, checks) {
return nil
}
e.checks = checks
info, err := system.GetInfoWithChecks(e.ctx, checks)
if err != nil {
log.Warnf("failed to get system info with checks: %v", err)
info = system.GetInfo(e.ctx)
}
if err := e.mgmClient.SyncMeta(info); err != nil {
log.Errorf("could not sync meta: error %s", err)
return err
}
return nil return nil
} }
@@ -556,8 +514,8 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
} else { } else {
if sshConf.GetSshEnabled() { if sshConf.GetSshEnabled() {
if runtime.GOOS == "windows" || runtime.GOOS == "freebsd" { if runtime.GOOS == "windows" {
log.Warnf("running SSH server on %s is not supported", runtime.GOOS) log.Warnf("running SSH server on Windows is not supported")
return nil return nil
} }
// start SSH server if it wasn't running // start SSH server if it wasn't running
@@ -630,19 +588,12 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
// E.g. when a new peer has been registered and we are allowed to connect to it. // E.g. when a new peer has been registered and we are allowed to connect to it.
func (e *Engine) receiveManagementEvents() { func (e *Engine) receiveManagementEvents() {
go func() { go func() {
info, err := system.GetInfoWithChecks(e.ctx, e.checks) err := e.mgmClient.Sync(e.handleSync)
if err != nil {
log.Warnf("failed to get system info with checks: %v", err)
info = system.GetInfo(e.ctx)
}
// err = e.mgmClient.Sync(info, e.handleSync)
err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
if err != nil { if err != nil {
// happens if management is unavailable for a long time. // happens if management is unavailable for a long time.
// We want to cancel the operation of the whole client // We want to cancel the operation of the whole client
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection) _ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
e.clientCancel() e.cancel()
return return
} }
log.Debugf("stopped receiving updates from Management Service") log.Debugf("stopped receiving updates from Management Service")
@@ -704,20 +655,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
return nil return nil
} }
protoRoutes := networkMap.GetRoutes()
if protoRoutes == nil {
protoRoutes = []*mgmProto.Route{}
}
_, clientRoutes, err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes))
if err != nil {
log.Errorf("failed to update clientRoutes, err: %v", err)
}
e.clientRoutesMu.Lock()
e.clientRoutes = clientRoutes
e.clientRoutesMu.Unlock()
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers())) log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))
e.updateOfflinePeers(networkMap.GetOfflinePeers()) e.updateOfflinePeers(networkMap.GetOfflinePeers())
@@ -759,6 +696,17 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
} }
} }
} }
protoRoutes := networkMap.GetRoutes()
if protoRoutes == nil {
protoRoutes = []*mgmProto.Route{}
}
_, clientRoutes, err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes))
if err != nil {
log.Errorf("failed to update clientRoutes, err: %v", err)
}
e.clientRoutes = clientRoutes
protoDNSConfig := networkMap.GetDNSConfig() protoDNSConfig := networkMap.GetDNSConfig()
if protoDNSConfig == nil { if protoDNSConfig == nil {
@@ -786,24 +734,15 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route { func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
routes := make([]*route.Route, 0) routes := make([]*route.Route, 0)
for _, protoRoute := range protoRoutes { for _, protoRoute := range protoRoutes {
var prefix netip.Prefix _, prefix, _ := route.ParseNetwork(protoRoute.Network)
if len(protoRoute.Domains) == 0 {
var err error
if prefix, err = netip.ParsePrefix(protoRoute.Network); err != nil {
log.Errorf("Failed to parse prefix %s: %v", protoRoute.Network, err)
continue
}
}
convertedRoute := &route.Route{ convertedRoute := &route.Route{
ID: route.ID(protoRoute.ID), ID: protoRoute.ID,
Network: prefix, Network: prefix,
Domains: domain.FromPunycodeList(protoRoute.Domains), NetID: protoRoute.NetID,
NetID: route.NetID(protoRoute.NetID),
NetworkType: route.NetworkType(protoRoute.NetworkType), NetworkType: route.NetworkType(protoRoute.NetworkType),
Peer: protoRoute.Peer, Peer: protoRoute.Peer,
Metric: int(protoRoute.Metric), Metric: int(protoRoute.Metric),
Masquerade: protoRoute.Masquerade, Masquerade: protoRoute.Masquerade,
KeepRoute: protoRoute.KeepRoute,
} }
routes = append(routes, convertedRoute) routes = append(routes, convertedRoute)
} }
@@ -901,13 +840,60 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err) log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err)
} }
conn.Open() go e.connWorker(conn, peerKey)
} }
return nil return nil
} }
func (e *Engine) connWorker(conn *peer.Conn, peerKey string) {
for {
// randomize starting time a bit
min := 500
max := 2000
time.Sleep(time.Duration(rand.Intn(max-min)+min) * time.Millisecond)
// if peer has been removed -> give up
if !e.peerExists(peerKey) {
log.Debugf("peer %s doesn't exist anymore, won't retry connection", peerKey)
return
}
if !e.signal.Ready() {
log.Infof("signal client isn't ready, skipping connection attempt %s", peerKey)
continue
}
// we might have received new STUN and TURN servers meanwhile, so update them
e.syncMsgMux.Lock()
conn.UpdateStunTurn(append(e.STUNs, e.TURNs...))
e.syncMsgMux.Unlock()
err := conn.Open()
if err != nil {
log.Debugf("connection to peer %s failed: %v", peerKey, err)
switch err.(type) {
case *peer.ConnectionClosedError:
// conn has been forced to close, so we exit the loop
return
default:
}
}
}
}
func (e *Engine) peerExists(peerKey string) bool {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
_, ok := e.peerConns[peerKey]
return ok
}
func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, error) { func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, error) {
log.Debugf("creating peer connection %s", pubKey) log.Debugf("creating peer connection %s", pubKey)
var stunTurn []*stun.URI
stunTurn = append(stunTurn, e.STUNs...)
stunTurn = append(stunTurn, e.TURNs...)
wgConfig := peer.WgConfig{ wgConfig := peer.WgConfig{
RemoteKey: pubKey, RemoteKey: pubKey,
@@ -940,29 +926,53 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e
// randomize connection timeout // randomize connection timeout
timeout := time.Duration(rand.Intn(PeerConnectionTimeoutMax-PeerConnectionTimeoutMin)+PeerConnectionTimeoutMin) * time.Millisecond timeout := time.Duration(rand.Intn(PeerConnectionTimeoutMax-PeerConnectionTimeoutMin)+PeerConnectionTimeoutMin) * time.Millisecond
config := peer.ConnConfig{ config := peer.ConnConfig{
Key: pubKey, Key: pubKey,
LocalKey: e.config.WgPrivateKey.PublicKey().String(), LocalKey: e.config.WgPrivateKey.PublicKey().String(),
Timeout: timeout, StunTurn: stunTurn,
WgConfig: wgConfig, InterfaceBlackList: e.config.IFaceBlackList,
LocalWgPort: e.config.WgPort, DisableIPv6Discovery: e.config.DisableIPv6Discovery,
RosenpassPubKey: e.getRosenpassPubKey(), Timeout: timeout,
RosenpassAddr: e.getRosenpassAddr(), UDPMux: e.udpMux.UDPMuxDefault,
ICEConfig: peer.ICEConfig{ UDPMuxSrflx: e.udpMux,
StunTurn: &e.stunTurn, WgConfig: wgConfig,
InterfaceBlackList: e.config.IFaceBlackList, LocalWgPort: e.config.WgPort,
DisableIPv6Discovery: e.config.DisableIPv6Discovery, NATExternalIPs: e.parseNATExternalIPMappings(),
UDPMux: e.udpMux.UDPMuxDefault, UserspaceBind: e.wgInterface.IsUserspaceBind(),
UDPMuxSrflx: e.udpMux, RosenpassPubKey: e.getRosenpassPubKey(),
NATExternalIPs: e.parseNATExternalIPMappings(), RosenpassAddr: e.getRosenpassAddr(),
},
} }
peerConn, err := peer.NewConn(e.ctx, config, e.statusRecorder, e.wgProxyFactory, e.signaler, e.mobileDep.IFaceDiscover, e.relayManager) peerConn, err := peer.NewConn(config, e.statusRecorder, e.wgProxyFactory, e.mobileDep.TunAdapter, e.mobileDep.IFaceDiscover)
if err != nil { if err != nil {
return nil, err return nil, err
} }
wgPubKey, err := wgtypes.ParseKey(pubKey)
if err != nil {
return nil, err
}
signalOffer := func(offerAnswer peer.OfferAnswer) error {
return SignalOfferAnswer(offerAnswer, e.config.WgPrivateKey, wgPubKey, e.signal, false)
}
signalCandidate := func(candidate ice.Candidate) error {
return signalCandidate(candidate, e.config.WgPrivateKey, wgPubKey, e.signal)
}
signalAnswer := func(offerAnswer peer.OfferAnswer) error {
return SignalOfferAnswer(offerAnswer, e.config.WgPrivateKey, wgPubKey, e.signal, true)
}
peerConn.SetSignalCandidate(signalCandidate)
peerConn.SetSignalOffer(signalOffer)
peerConn.SetSignalAnswer(signalAnswer)
peerConn.SetSendSignalMessage(func(message *sProto.Message) error {
return sendSignal(message, e.signal)
})
if e.rpManager != nil { if e.rpManager != nil {
peerConn.SetOnConnected(e.rpManager.OnConnected) peerConn.SetOnConnected(e.rpManager.OnConnected)
peerConn.SetOnDisconnected(e.rpManager.OnDisconnected) peerConn.SetOnDisconnected(e.rpManager.OnDisconnected)
} }
@@ -974,7 +984,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e
func (e *Engine) receiveSignalEvents() { func (e *Engine) receiveSignalEvents() {
go func() { go func() {
// connect to a stream of messages coming from the signal server // connect to a stream of messages coming from the signal server
err := e.signal.Receive(e.ctx, func(msg *sProto.Message) error { err := e.signal.Receive(func(msg *sProto.Message) error {
e.syncMsgMux.Lock() e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock() defer e.syncMsgMux.Unlock()
@@ -990,6 +1000,8 @@ func (e *Engine) receiveSignalEvents() {
return err return err
} }
conn.RegisterProtoSupportMeta(msg.Body.GetFeaturesSupported())
var rosenpassPubKey []byte var rosenpassPubKey []byte
rosenpassAddr := "" rosenpassAddr := ""
if msg.GetBody().GetRosenpassConfig() != nil { if msg.GetBody().GetRosenpassConfig() != nil {
@@ -1005,7 +1017,6 @@ func (e *Engine) receiveSignalEvents() {
Version: msg.GetBody().GetNetBirdVersion(), Version: msg.GetBody().GetNetBirdVersion(),
RosenpassPubKey: rosenpassPubKey, RosenpassPubKey: rosenpassPubKey,
RosenpassAddr: rosenpassAddr, RosenpassAddr: rosenpassAddr,
RelaySrvAddress: msg.GetBody().GetRelayServerAddress(),
}) })
case sProto.Body_ANSWER: case sProto.Body_ANSWER:
remoteCred, err := signal.UnMarshalCredential(msg) remoteCred, err := signal.UnMarshalCredential(msg)
@@ -1013,6 +1024,8 @@ func (e *Engine) receiveSignalEvents() {
return err return err
} }
conn.RegisterProtoSupportMeta(msg.GetBody().GetFeaturesSupported())
var rosenpassPubKey []byte var rosenpassPubKey []byte
rosenpassAddr := "" rosenpassAddr := ""
if msg.GetBody().GetRosenpassConfig() != nil { if msg.GetBody().GetRosenpassConfig() != nil {
@@ -1028,7 +1041,6 @@ func (e *Engine) receiveSignalEvents() {
Version: msg.GetBody().GetNetBirdVersion(), Version: msg.GetBody().GetNetBirdVersion(),
RosenpassPubKey: rosenpassPubKey, RosenpassPubKey: rosenpassPubKey,
RosenpassAddr: rosenpassAddr, RosenpassAddr: rosenpassAddr,
RelaySrvAddress: msg.GetBody().GetRelayServerAddress(),
}) })
case sProto.Body_CANDIDATE: case sProto.Body_CANDIDATE:
candidate, err := ice.UnmarshalCandidate(msg.GetBody().Payload) candidate, err := ice.UnmarshalCandidate(msg.GetBody().Payload)
@@ -1036,8 +1048,7 @@ func (e *Engine) receiveSignalEvents() {
log.Errorf("failed on parsing remote candidate %s -> %s", candidate, err) log.Errorf("failed on parsing remote candidate %s -> %s", candidate, err)
return err return err
} }
conn.OnRemoteCandidate(candidate)
go conn.OnRemoteCandidate(candidate, e.GetClientRoutes())
case sProto.Body_MODE: case sProto.Body_MODE:
} }
@@ -1047,7 +1058,7 @@ func (e *Engine) receiveSignalEvents() {
// happens if signal is unavailable for a long time. // happens if signal is unavailable for a long time.
// We want to cancel the operation of the whole client // We want to cancel the operation of the whole client
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection) _ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
e.clientCancel() e.cancel()
return return
} }
}() }()
@@ -1108,14 +1119,14 @@ func (e *Engine) parseNATExternalIPMappings() []string {
} }
func (e *Engine) close() { func (e *Engine) close() {
if e.wgProxyFactory != nil { if err := e.wgProxyFactory.Free(); err != nil {
if err := e.wgProxyFactory.Free(); err != nil { log.Errorf("failed closing ebpf proxy: %s", err)
log.Errorf("failed closing ebpf proxy: %s", err)
}
} }
// stop/restore DNS first so dbus and friends don't complain because of a missing interface // stop/restore DNS first so dbus and friends don't complain because of a missing interface
e.stopDNSServer() if e.dnsServer != nil {
e.dnsServer.Stop()
}
if e.routeManager != nil { if e.routeManager != nil {
e.routeManager.Stop() e.routeManager.Stop()
@@ -1148,8 +1159,7 @@ func (e *Engine) close() {
} }
func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) { func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) {
info := system.GetInfo(e.ctx) netMap, err := e.mgmClient.GetNetworkMap()
netMap, err := e.mgmClient.GetNetworkMap(info)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@@ -1178,7 +1188,7 @@ func (e *Engine) newWgIface() (*iface.WGIface, error) {
default: default:
} }
return iface.NewWGIFace(e.config.WgIfaceName, e.config.WgAddr, e.config.WgPort, e.config.WgPrivateKey.String(), iface.DefaultMTU, transportNet, mArgs, e.addrViaRoutes) return iface.NewWGIFace(e.config.WgIfaceName, e.config.WgAddr, e.config.WgPort, e.config.WgPrivateKey.String(), iface.DefaultMTU, transportNet, mArgs)
} }
func (e *Engine) wgInterfaceCreate() (err error) { func (e *Engine) wgInterfaceCreate() (err error) {
@@ -1228,21 +1238,18 @@ func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) {
} }
// GetClientRoutes returns the current routes from the route map // GetClientRoutes returns the current routes from the route map
func (e *Engine) GetClientRoutes() route.HAMap { func (e *Engine) GetClientRoutes() map[string][]*route.Route {
e.clientRoutesMu.RLock() return e.clientRoutes
defer e.clientRoutesMu.RUnlock()
return maps.Clone(e.clientRoutes)
} }
// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only // GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only
func (e *Engine) GetClientRoutesWithNetID() map[route.NetID][]*route.Route { func (e *Engine) GetClientRoutesWithNetID() map[string][]*route.Route {
e.clientRoutesMu.RLock() routes := make(map[string][]*route.Route, len(e.clientRoutes))
defer e.clientRoutesMu.RUnlock()
routes := make(map[route.NetID][]*route.Route, len(e.clientRoutes))
for id, v := range e.clientRoutes { for id, v := range e.clientRoutes {
routes[id.NetID()] = v if i := strings.LastIndex(id, "-"); i != -1 {
id = id[:i]
}
routes[id] = v
} }
return routes return routes
} }
@@ -1288,27 +1295,24 @@ func (e *Engine) getRosenpassAddr() string {
} }
func (e *Engine) receiveProbeEvents() { func (e *Engine) receiveProbeEvents() {
if e.probes == nil { if e.signalProbe != nil {
return go e.signalProbe.Receive(e.ctx, func() bool {
}
if e.probes.SignalProbe != nil {
go e.probes.SignalProbe.Receive(e.ctx, func() bool {
healthy := e.signal.IsHealthy() healthy := e.signal.IsHealthy()
log.Debugf("received signal probe request, healthy: %t", healthy) log.Debugf("received signal probe request, healthy: %t", healthy)
return healthy return healthy
}) })
} }
if e.probes.MgmProbe != nil { if e.mgmProbe != nil {
go e.probes.MgmProbe.Receive(e.ctx, func() bool { go e.mgmProbe.Receive(e.ctx, func() bool {
healthy := e.mgmClient.IsHealthy() healthy := e.mgmClient.IsHealthy()
log.Debugf("received management probe request, healthy: %t", healthy) log.Debugf("received management probe request, healthy: %t", healthy)
return healthy return healthy
}) })
} }
if e.probes.RelayProbe != nil { if e.relayProbe != nil {
go e.probes.RelayProbe.Receive(e.ctx, func() bool { go e.relayProbe.Receive(e.ctx, func() bool {
healthy := true healthy := true
results := append(e.probeSTUNs(), e.probeTURNs()...) results := append(e.probeSTUNs(), e.probeTURNs()...)
@@ -1327,13 +1331,13 @@ func (e *Engine) receiveProbeEvents() {
}) })
} }
if e.probes.WgProbe != nil { if e.wgProbe != nil {
go e.probes.WgProbe.Receive(e.ctx, func() bool { go e.wgProbe.Receive(e.ctx, func() bool {
log.Debug("received wg probe request") log.Debug("received wg probe request")
for _, peer := range e.peerConns { for _, peer := range e.peerConns {
key := peer.GetKey() key := peer.GetKey()
wgStats, err := peer.WgConfig().WgInterface.GetStats(key) wgStats, err := peer.GetConf().WgConfig.WgInterface.GetStats(key)
if err != nil { if err != nil {
log.Debugf("failed to get wg stats for peer %s: %s", key, err) log.Debugf("failed to get wg stats for peer %s: %s", key, err)
} }
@@ -1355,91 +1359,3 @@ func (e *Engine) probeSTUNs() []relay.ProbeResult {
func (e *Engine) probeTURNs() []relay.ProbeResult { func (e *Engine) probeTURNs() []relay.ProbeResult {
return relay.ProbeAll(e.ctx, relay.ProbeTURN, e.TURNs) return relay.ProbeAll(e.ctx, relay.ProbeTURN, e.TURNs)
} }
func (e *Engine) restartEngine() {
log.Info("restarting engine")
CtxGetState(e.ctx).Set(StatusConnecting)
if err := e.Stop(); err != nil {
log.Errorf("Failed to stop engine: %v", err)
}
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
log.Infof("cancelling client, engine will be recreated")
e.clientCancel()
}
func (e *Engine) startNetworkMonitor() {
if !e.config.NetworkMonitor {
log.Infof("Network monitor is disabled, not starting")
return
}
e.networkMonitor = networkmonitor.New()
go func() {
var mu sync.Mutex
var debounceTimer *time.Timer
// Start the network monitor with a callback, Start will block until the monitor is stopped,
// a network change is detected, or an error occurs on start up
err := e.networkMonitor.Start(e.ctx, func() {
// This function is called when a network change is detected
mu.Lock()
defer mu.Unlock()
if debounceTimer != nil {
log.Infof("Network monitor: detected network change, reset debounceTimer")
debounceTimer.Stop()
}
// Set a new timer to debounce rapid network changes
debounceTimer = time.AfterFunc(1*time.Second, func() {
// This function is called after the debounce period
mu.Lock()
defer mu.Unlock()
log.Infof("Network monitor: detected network change, restarting engine")
e.restartEngine()
})
})
if err != nil && !errors.Is(err, networkmonitor.ErrStopped) {
log.Errorf("Network monitor: %v", err)
}
}()
}
func (e *Engine) addrViaRoutes(addr netip.Addr) (bool, netip.Prefix, error) {
var vpnRoutes []netip.Prefix
for _, routes := range e.GetClientRoutes() {
if len(routes) > 0 && routes[0] != nil {
vpnRoutes = append(vpnRoutes, routes[0].Network)
}
}
if isVpn, prefix := systemops.IsAddrRouted(addr, vpnRoutes); isVpn {
return true, prefix, nil
}
return false, netip.Prefix{}, nil
}
func (e *Engine) stopDNSServer() {
err := fmt.Errorf("DNS server stopped")
nsGroupStates := e.statusRecorder.GetDNSStates()
for i := range nsGroupStates {
nsGroupStates[i].Enabled = false
nsGroupStates[i].Error = err
}
e.statusRecorder.UpdateDNSStates(nsGroupStates)
if e.dnsServer != nil {
e.dnsServer.Stop()
e.dnsServer = nil
}
}
// isChecksEqual checks if two slices of checks are equal.
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
return slices.EqualFunc(checks, oChecks, func(checks, oChecks *mgmProto.Checks) bool {
return slices.Equal(checks.Files, oChecks.Files)
})
}

View File

@@ -13,12 +13,10 @@ import (
"testing" "testing"
"time" "time"
"github.com/google/uuid"
"github.com/pion/transport/v3/stdnet" "github.com/pion/transport/v3/stdnet"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
@@ -37,8 +35,6 @@ import (
mgmtProto "github.com/netbirdio/netbird/management/proto" mgmtProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/telemetry"
relayClient "github.com/netbirdio/netbird/relay/client"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
signal "github.com/netbirdio/netbird/signal/client" signal "github.com/netbirdio/netbird/signal/client"
"github.com/netbirdio/netbird/signal/proto" "github.com/netbirdio/netbird/signal/proto"
@@ -60,16 +56,10 @@ var (
} }
) )
func TestMain(m *testing.M) {
_ = util.InitLog("debug", "console")
code := m.Run()
os.Exit(code)
}
func TestEngine_SSH(t *testing.T) { func TestEngine_SSH(t *testing.T) {
// todo resolve test execution on freebsd
if runtime.GOOS == "windows" || runtime.GOOS == "freebsd" { if runtime.GOOS == "windows" {
t.Skip("skipping TestEngine_SSH") t.Skip("skipping TestEngine_SSH on Windows")
} }
key, err := wgtypes.GeneratePrivateKey() key, err := wgtypes.GeneratePrivateKey()
@@ -81,23 +71,13 @@ func TestEngine_SSH(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String()) engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, &EngineConfig{
engine := NewEngine( WgIfaceName: "utun101",
ctx, cancel, WgAddr: "100.64.0.1/24",
&signal.MockClient{}, WgPrivateKey: key,
&mgmt.MockClient{}, WgPort: 33100,
relayMgr, ServerSSHAllowed: true,
&EngineConfig{ }, MobileDependency{}, peer.NewRecorder("https://mgm"))
WgIfaceName: "utun101",
WgAddr: "100.64.0.1/24",
WgPrivateKey: key,
WgPort: 33100,
ServerSSHAllowed: true,
},
MobileDependency{},
peer.NewRecorder("https://mgm"),
nil,
)
engine.dnsServer = &dns.MockServer{ engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil }, UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
@@ -193,7 +173,7 @@ func TestEngine_SSH(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
// time.Sleep(250 * time.Millisecond) //time.Sleep(250 * time.Millisecond)
assert.NotNil(t, engine.sshServer) assert.NotNil(t, engine.sshServer)
assert.Contains(t, sshPeersRemoved, "MNHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=") assert.Contains(t, sshPeersRemoved, "MNHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=")
@@ -226,29 +206,21 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String()) engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, &EngineConfig{
engine := NewEngine( WgIfaceName: "utun102",
ctx, cancel, WgAddr: "100.64.0.1/24",
&signal.MockClient{}, WgPrivateKey: key,
&mgmt.MockClient{}, WgPort: 33100,
relayMgr, }, MobileDependency{}, peer.NewRecorder("https://mgm"))
&EngineConfig{ newNet, err := stdnet.NewNet()
WgIfaceName: "utun102", if err != nil {
WgAddr: "100.64.0.1/24", t.Fatal(err)
WgPrivateKey: key,
WgPort: 33100,
},
MobileDependency{},
peer.NewRecorder("https://mgm"),
nil)
wgIface := &iface.MockWGIface{
RemovePeerFunc: func(peerKey string) error {
return nil
},
} }
engine.wgInterface = wgIface engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", engine.config.WgPort, key.String(), iface.DefaultMTU, newNet, nil)
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, relayMgr, nil) if err != nil {
t.Fatal(err)
}
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), engine.wgInterface, engine.statusRecorder, nil)
engine.dnsServer = &dns.MockServer{ engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil }, UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
} }
@@ -257,7 +229,6 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn}) engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn})
engine.ctx = ctx
type testCase struct { type testCase struct {
name string name string
@@ -421,7 +392,7 @@ func TestEngine_Sync(t *testing.T) {
// feed updates to Engine via mocked Management client // feed updates to Engine via mocked Management client
updates := make(chan *mgmtProto.SyncResponse) updates := make(chan *mgmtProto.SyncResponse)
defer close(updates) defer close(updates)
syncFunc := func(ctx context.Context, info *system.Info, msgHandler func(msg *mgmtProto.SyncResponse) error) error { syncFunc := func(msgHandler func(msg *mgmtProto.SyncResponse) error) error {
for msg := range updates { for msg := range updates {
err := msgHandler(msg) err := msgHandler(msg)
if err != nil { if err != nil {
@@ -430,14 +401,13 @@ func TestEngine_Sync(t *testing.T) {
} }
return nil return nil
} }
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String())
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{SyncFunc: syncFunc}, relayMgr, &EngineConfig{ engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{SyncFunc: syncFunc}, &EngineConfig{
WgIfaceName: "utun103", WgIfaceName: "utun103",
WgAddr: "100.64.0.1/24", WgAddr: "100.64.0.1/24",
WgPrivateKey: key, WgPrivateKey: key,
WgPort: 33100, WgPort: 33100,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil) }, MobileDependency{}, peer.NewRecorder("https://mgm"))
engine.ctx = ctx
engine.dnsServer = &dns.MockServer{ engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil }, UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
@@ -590,19 +560,17 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
wgIfaceName := fmt.Sprintf("utun%d", 104+n) wgIfaceName := fmt.Sprintf("utun%d", 104+n)
wgAddr := fmt.Sprintf("100.66.%d.1/24", n) wgAddr := fmt.Sprintf("100.66.%d.1/24", n)
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String()) engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, &EngineConfig{
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{
WgIfaceName: wgIfaceName, WgIfaceName: wgIfaceName,
WgAddr: wgAddr, WgAddr: wgAddr,
WgPrivateKey: key, WgPrivateKey: key,
WgPort: 33100, WgPort: 33100,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil) }, MobileDependency{}, peer.NewRecorder("https://mgm"))
engine.ctx = ctx
newNet, err := stdnet.NewNet() newNet, err := stdnet.NewNet()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, engine.config.WgPort, key.String(), iface.DefaultMTU, newNet, nil, nil) engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, engine.config.WgPort, key.String(), iface.DefaultMTU, newNet, nil)
assert.NoError(t, err, "shouldn't return error") assert.NoError(t, err, "shouldn't return error")
input := struct { input := struct {
inputSerial uint64 inputSerial uint64
@@ -610,7 +578,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
}{} }{}
mockRouteManager := &routemanager.MockManager{ mockRouteManager := &routemanager.MockManager{
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) { UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error) {
input.inputSerial = updateSerial input.inputSerial = updateSerial
input.inputRoutes = newRoutes input.inputRoutes = newRoutes
return nil, nil, testCase.inputErr return nil, nil, testCase.inputErr
@@ -761,24 +729,21 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
wgIfaceName := fmt.Sprintf("utun%d", 104+n) wgIfaceName := fmt.Sprintf("utun%d", 104+n)
wgAddr := fmt.Sprintf("100.66.%d.1/24", n) wgAddr := fmt.Sprintf("100.66.%d.1/24", n)
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String()) engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, &EngineConfig{
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{
WgIfaceName: wgIfaceName, WgIfaceName: wgIfaceName,
WgAddr: wgAddr, WgAddr: wgAddr,
WgPrivateKey: key, WgPrivateKey: key,
WgPort: 33100, WgPort: 33100,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil) }, MobileDependency{}, peer.NewRecorder("https://mgm"))
engine.ctx = ctx
newNet, err := stdnet.NewNet() newNet, err := stdnet.NewNet()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, 33100, key.String(), iface.DefaultMTU, newNet, nil, nil) engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, 33100, key.String(), iface.DefaultMTU, newNet, nil)
assert.NoError(t, err, "shouldn't return error") assert.NoError(t, err, "shouldn't return error")
mockRouteManager := &routemanager.MockManager{ mockRouteManager := &routemanager.MockManager{
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) { UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error) {
return nil, nil, nil return nil, nil, nil
}, },
} }
@@ -840,13 +805,13 @@ func TestEngine_MultiplePeers(t *testing.T) {
ctx, cancel := context.WithCancel(CtxInitState(context.Background())) ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
defer cancel() defer cancel()
sigServer, signalAddr, err := startSignal(t) sigServer, signalAddr, err := startSignal()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return
} }
defer sigServer.Stop() defer sigServer.Stop()
mgmtServer, mgmtAddr, err := startManagement(t, dir) mgmtServer, mgmtAddr, err := startManagement(dir)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return
@@ -873,8 +838,6 @@ func TestEngine_MultiplePeers(t *testing.T) {
engine.dnsServer = &dns.MockServer{} engine.dnsServer = &dns.MockServer{}
mu.Lock() mu.Lock()
defer mu.Unlock() defer mu.Unlock()
guid := fmt.Sprintf("{%s}", uuid.New().String())
iface.CustomWindowsGUIDString = strings.ToLower(guid)
err = engine.Start() err = engine.Start()
if err != nil { if err != nil {
t.Errorf("unable to start engine for peer %d with error %v", j, err) t.Errorf("unable to start engine for peer %d with error %v", j, err)
@@ -1040,15 +1003,10 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
WgPort: wgPort, WgPort: wgPort,
} }
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String()) return NewEngine(ctx, cancel, signalClient, mgmtClient, conf, MobileDependency{}, peer.NewRecorder("https://mgm")), nil
e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil), nil
e.ctx = ctx
return e, err
} }
func startSignal(t *testing.T) (*grpc.Server, string, error) { func startSignal() (*grpc.Server, string, error) {
t.Helper()
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp)) s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
lis, err := net.Listen("tcp", "localhost:0") lis, err := net.Listen("tcp", "localhost:0")
@@ -1056,9 +1014,7 @@ func startSignal(t *testing.T) (*grpc.Server, string, error) {
log.Fatalf("failed to listen: %v", err) log.Fatalf("failed to listen: %v", err)
} }
srv, err := signalServer.NewServer(otel.Meter("")) proto.RegisterSignalExchangeServer(s, signalServer.NewServer())
require.NoError(t, err)
proto.RegisterSignalExchangeServer(s, srv)
go func() { go func() {
if err = s.Serve(lis); err != nil { if err = s.Serve(lis); err != nil {
@@ -1069,17 +1025,10 @@ func startSignal(t *testing.T) (*grpc.Server, string, error) {
return s, lis.Addr().String(), nil return s, lis.Addr().String(), nil
} }
func startManagement(t *testing.T, dataDir string) (*grpc.Server, string, error) { func startManagement(dataDir string) (*grpc.Server, string, error) {
t.Helper()
config := &server.Config{ config := &server.Config{
Stuns: []*server.Host{}, Stuns: []*server.Host{},
TURNConfig: &server.TURNConfig{}, TURNConfig: &server.TURNConfig{},
Relay: &server.Relay{
Addresses: []string{"127.0.0.1:1234"},
CredentialsTTL: util.Duration{Duration: time.Hour},
Secret: "222222222222222222",
},
Signal: &server.Host{ Signal: &server.Host{
Proto: "http", Proto: "http",
URI: "localhost:10000", URI: "localhost:10000",
@@ -1093,30 +1042,23 @@ func startManagement(t *testing.T, dataDir string) (*grpc.Server, string, error)
return nil, "", err return nil, "", err
} }
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp)) s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
store, err := server.NewStoreFromJson(config.Datadir, nil)
store, cleanUp, err := server.NewTestStoreFromJson(context.Background(), config.Datadir)
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }
t.Cleanup(cleanUp)
peersUpdateManager := server.NewPeersUpdateManager(nil) peersUpdateManager := server.NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{} eventStore := &activity.InMemoryEventStore{}
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore) ia, _ := integrations.NewIntegratedValidator(eventStore)
accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics)
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay) mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, secretsManager, nil, nil)
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }

View File

@@ -0,0 +1 @@
package internal

View File

@@ -68,7 +68,7 @@ func Login(ctx context.Context, config *Config, setupKey string, jwtToken string
} }
serverKey, err := doMgmLogin(ctx, mgmClient, pubSSHKey) serverKey, err := doMgmLogin(ctx, mgmClient, pubSSHKey)
if serverKey != nil && isRegistrationNeeded(err) { if isRegistrationNeeded(err) {
log.Debugf("peer registration required") log.Debugf("peer registration required")
_, err = registerPeer(ctx, *serverKey, mgmClient, setupKey, jwtToken, pubSSHKey) _, err = registerPeer(ctx, *serverKey, mgmClient, setupKey, jwtToken, pubSSHKey)
return err return err

View File

@@ -1,21 +0,0 @@
package networkmonitor
import (
"context"
"errors"
"sync"
)
var ErrStopped = errors.New("monitor has been stopped")
// NetworkMonitor watches for changes in network configuration.
type NetworkMonitor struct {
cancel context.CancelFunc
wg sync.WaitGroup
mu sync.Mutex
}
// New creates a new network monitor.
func New() *NetworkMonitor {
return &NetworkMonitor{}
}

View File

@@ -1,107 +0,0 @@
//go:build (darwin && !ios) || dragonfly || freebsd || netbsd || openbsd
package networkmonitor
import (
"context"
"errors"
"fmt"
"syscall"
"unsafe"
log "github.com/sirupsen/logrus"
"golang.org/x/net/route"
"golang.org/x/sys/unix"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
)
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error {
fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
if err != nil {
return fmt.Errorf("failed to open routing socket: %v", err)
}
defer func() {
err := unix.Close(fd)
if err != nil && !errors.Is(err, unix.EBADF) {
log.Errorf("Network monitor: failed to close routing socket: %v", err)
}
}()
go func() {
<-ctx.Done()
err := unix.Close(fd)
if err != nil && !errors.Is(err, unix.EBADF) {
log.Debugf("Network monitor: closed routing socket")
}
}()
for {
select {
case <-ctx.Done():
return ErrStopped
default:
buf := make([]byte, 2048)
n, err := unix.Read(fd, buf)
if err != nil {
if !errors.Is(err, unix.EBADF) && !errors.Is(err, unix.EINVAL) {
log.Errorf("Network monitor: failed to read from routing socket: %v", err)
}
continue
}
if n < unix.SizeofRtMsghdr {
log.Errorf("Network monitor: read from routing socket returned less than expected: %d bytes", n)
continue
}
msg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0]))
switch msg.Type {
// handle route changes
case unix.RTM_ADD, syscall.RTM_DELETE:
route, err := parseRouteMessage(buf[:n])
if err != nil {
log.Errorf("Network monitor: error parsing routing message: %v", err)
continue
}
if route.Dst.Bits() != 0 {
continue
}
intf := "<nil>"
if route.Interface != nil {
intf = route.Interface.Name
}
switch msg.Type {
case unix.RTM_ADD:
log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf)
go callback()
case unix.RTM_DELETE:
if nexthopv4.Intf != nil && route.Gw.Compare(nexthopv4.IP) == 0 || nexthopv6.Intf != nil && route.Gw.Compare(nexthopv6.IP) == 0 {
log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf)
go callback()
}
}
}
}
}
}
func parseRouteMessage(buf []byte) (*systemops.Route, error) {
msgs, err := route.ParseRIB(route.RIBTypeRoute, buf)
if err != nil {
return nil, fmt.Errorf("parse RIB: %v", err)
}
if len(msgs) != 1 {
return nil, fmt.Errorf("unexpected RIB message msgs: %v", msgs)
}
msg, ok := msgs[0].(*route.RouteMessage)
if !ok {
return nil, fmt.Errorf("unexpected RIB message type: %T", msgs[0])
}
return systemops.MsgToRoute(msg)
}

View File

@@ -1,82 +0,0 @@
//go:build !ios && !android
package networkmonitor
import (
"context"
"errors"
"fmt"
"net/netip"
"runtime/debug"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
)
// Start begins monitoring network changes. When a change is detected, it calls the callback asynchronously and returns.
func (nw *NetworkMonitor) Start(ctx context.Context, callback func()) (err error) {
if ctx.Err() != nil {
return ctx.Err()
}
nw.mu.Lock()
ctx, nw.cancel = context.WithCancel(ctx)
nw.mu.Unlock()
nw.wg.Add(1)
defer nw.wg.Done()
var nexthop4, nexthop6 systemops.Nexthop
operation := func() error {
var errv4, errv6 error
nexthop4, errv4 = systemops.GetNextHop(netip.IPv4Unspecified())
nexthop6, errv6 = systemops.GetNextHop(netip.IPv6Unspecified())
if errv4 != nil && errv6 != nil {
return errors.New("failed to get default next hops")
}
if errv4 == nil {
log.Debugf("Network monitor: IPv4 default route: %s, interface: %s", nexthop4.IP, nexthop4.Intf.Name)
}
if errv6 == nil {
log.Debugf("Network monitor: IPv6 default route: %s, interface: %s", nexthop6.IP, nexthop6.Intf.Name)
}
// continue if either route was found
return nil
}
expBackOff := backoff.WithContext(backoff.NewExponentialBackOff(), ctx)
if err := backoff.Retry(operation, expBackOff); err != nil {
return fmt.Errorf("failed to get default next hops: %w", err)
}
// recover in case sys ops panic
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic occurred: %v, stack trace: %s", r, debug.Stack())
}
}()
if err := checkChange(ctx, nexthop4, nexthop6, callback); err != nil {
return fmt.Errorf("check change: %w", err)
}
return nil
}
// Stop stops the network monitor.
func (nw *NetworkMonitor) Stop() {
nw.mu.Lock()
defer nw.mu.Unlock()
if nw.cancel != nil {
nw.cancel()
nw.wg.Wait()
}
}

View File

@@ -1,57 +0,0 @@
//go:build !android
package networkmonitor
import (
"context"
"errors"
"fmt"
"syscall"
log "github.com/sirupsen/logrus"
"github.com/vishvananda/netlink"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
)
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error {
if nexthopv4.Intf == nil && nexthopv6.Intf == nil {
return errors.New("no interfaces available")
}
done := make(chan struct{})
defer close(done)
routeChan := make(chan netlink.RouteUpdate)
if err := netlink.RouteSubscribe(routeChan, done); err != nil {
return fmt.Errorf("subscribe to route updates: %v", err)
}
log.Info("Network monitor: started")
for {
select {
case <-ctx.Done():
return ErrStopped
// handle route changes
case route := <-routeChan:
// default route and main table
if route.Dst != nil || route.Table != syscall.RT_TABLE_MAIN {
continue
}
switch route.Type {
// triggered on added/replaced routes
case syscall.RTM_NEWROUTE:
log.Infof("Network monitor: default route changed: via %s, interface %d", route.Gw, route.LinkIndex)
go callback()
return nil
case syscall.RTM_DELROUTE:
if nexthopv4.Intf != nil && route.Gw.Equal(nexthopv4.IP.AsSlice()) || nexthopv6.Intf != nil && route.Gw.Equal(nexthopv6.IP.AsSlice()) {
log.Infof("Network monitor: default route removed: via %s, interface %d", route.Gw, route.LinkIndex)
go callback()
return nil
}
}
}
}
}

View File

@@ -1,12 +0,0 @@
//go:build ios || android
package networkmonitor
import "context"
func (nw *NetworkMonitor) Start(context.Context, func()) error {
return nil
}
func (nw *NetworkMonitor) Stop() {
}

View File

@@ -1,75 +0,0 @@
package networkmonitor
import (
"context"
"fmt"
"strings"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
)
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error {
routeMonitor, err := systemops.NewRouteMonitor(ctx)
if err != nil {
return fmt.Errorf("failed to create route monitor: %w", err)
}
defer func() {
if err := routeMonitor.Stop(); err != nil {
log.Errorf("Network monitor: failed to stop route monitor: %v", err)
}
}()
for {
select {
case <-ctx.Done():
return ErrStopped
case route := <-routeMonitor.RouteUpdates():
if route.Destination.Bits() != 0 {
continue
}
if routeChanged(route, nexthopv4, nexthopv6, callback) {
break
}
}
}
}
func routeChanged(route systemops.RouteUpdate, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) bool {
intf := "<nil>"
if route.Interface != nil {
intf = route.Interface.Name
if isSoftInterface(intf) {
log.Debugf("Network monitor: ignoring default route change for soft interface %s", intf)
return false
}
}
switch route.Type {
case systemops.RouteModified:
// TODO: get routing table to figure out if our route is affected for modified routes
log.Infof("Network monitor: default route changed: via %s, interface %s", route.NextHop, intf)
go callback()
return true
case systemops.RouteAdded:
if route.NextHop.Is4() && route.NextHop != nexthopv4.IP || route.NextHop.Is6() && route.NextHop != nexthopv6.IP {
log.Infof("Network monitor: default route added: via %s, interface %s", route.NextHop, intf)
go callback()
return true
}
case systemops.RouteDeleted:
if nexthopv4.Intf != nil && route.NextHop == nexthopv4.IP || nexthopv6.Intf != nil && route.NextHop == nexthopv6.IP {
log.Infof("Network monitor: default route removed: via %s, interface %s", route.NextHop, intf)
go callback()
return true
}
}
return false
}
func isSoftInterface(name string) bool {
return strings.Contains(strings.ToLower(name), "isatap") || strings.Contains(strings.ToLower(name), "teredo")
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,10 +1,6 @@
package peer package peer
import ( import log "github.com/sirupsen/logrus"
"sync/atomic"
log "github.com/sirupsen/logrus"
)
const ( const (
// StatusConnected indicate the peer is in connected state // StatusConnected indicate the peer is in connected state
@@ -16,34 +12,7 @@ const (
) )
// ConnStatus describe the status of a peer's connection // ConnStatus describe the status of a peer's connection
type ConnStatus int32 type ConnStatus int
// AtomicConnStatus is a thread-safe wrapper for ConnStatus
type AtomicConnStatus struct {
status atomic.Int32
}
// NewAtomicConnStatus creates a new AtomicConnStatus with the given initial status
func NewAtomicConnStatus() *AtomicConnStatus {
acs := &AtomicConnStatus{}
acs.Set(StatusDisconnected)
return acs
}
// Get returns the current connection status
func (acs *AtomicConnStatus) Get() ConnStatus {
return ConnStatus(acs.status.Load())
}
// Set updates the connection status
func (acs *AtomicConnStatus) Set(status ConnStatus) {
acs.status.Store(int32(status))
}
// String returns the string representation of the current status
func (acs *AtomicConnStatus) String() string {
return acs.Get().String()
}
func (s ConnStatus) String() string { func (s ConnStatus) String() string {
switch s { switch s {

View File

@@ -1,34 +1,25 @@
package peer package peer
import ( import (
"context"
"os"
"sync" "sync"
"testing" "testing"
"time" "time"
"github.com/magiconair/properties/assert" "github.com/magiconair/properties/assert"
"github.com/pion/stun/v2"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/internal/wgproxy" "github.com/netbirdio/netbird/client/internal/wgproxy"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/util"
) )
var connConf = ConnConfig{ var connConf = ConnConfig{
Key: "LLHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=", Key: "LLHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
LocalKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=", LocalKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
Timeout: time.Second, StunTurn: []*stun.URI{},
LocalWgPort: 51820, InterfaceBlackList: nil,
ICEConfig: ICEConfig{ Timeout: time.Second,
InterfaceBlackList: nil, LocalWgPort: 51820,
},
}
func TestMain(m *testing.M) {
_ = util.InitLog("trace", "console")
code := m.Run()
os.Exit(code)
} }
func TestNewConn_interfaceFilter(t *testing.T) { func TestNewConn_interfaceFilter(t *testing.T) {
@@ -44,11 +35,11 @@ func TestNewConn_interfaceFilter(t *testing.T) {
} }
func TestConn_GetKey(t *testing.T) { func TestConn_GetKey(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort) wgProxyFactory := wgproxy.NewFactory(connConf.LocalWgPort)
defer func() { defer func() {
_ = wgProxyFactory.Free() _ = wgProxyFactory.Free()
}() }()
conn, err := NewConn(context.Background(), connConf, nil, wgProxyFactory, nil, nil, nil) conn, err := NewConn(connConf, nil, wgProxyFactory, nil, nil)
if err != nil { if err != nil {
return return
} }
@@ -59,11 +50,11 @@ func TestConn_GetKey(t *testing.T) {
} }
func TestConn_OnRemoteOffer(t *testing.T) { func TestConn_OnRemoteOffer(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort) wgProxyFactory := wgproxy.NewFactory(connConf.LocalWgPort)
defer func() { defer func() {
_ = wgProxyFactory.Free() _ = wgProxyFactory.Free()
}() }()
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil, nil) conn, err := NewConn(connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil)
if err != nil { if err != nil {
return return
} }
@@ -71,7 +62,7 @@ func TestConn_OnRemoteOffer(t *testing.T) {
wg := sync.WaitGroup{} wg := sync.WaitGroup{}
wg.Add(2) wg.Add(2)
go func() { go func() {
<-conn.handshaker.remoteOffersCh <-conn.remoteOffersCh
wg.Done() wg.Done()
}() }()
@@ -96,11 +87,11 @@ func TestConn_OnRemoteOffer(t *testing.T) {
} }
func TestConn_OnRemoteAnswer(t *testing.T) { func TestConn_OnRemoteAnswer(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort) wgProxyFactory := wgproxy.NewFactory(connConf.LocalWgPort)
defer func() { defer func() {
_ = wgProxyFactory.Free() _ = wgProxyFactory.Free()
}() }()
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil, nil) conn, err := NewConn(connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil)
if err != nil { if err != nil {
return return
} }
@@ -108,7 +99,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
wg := sync.WaitGroup{} wg := sync.WaitGroup{}
wg.Add(2) wg.Add(2)
go func() { go func() {
<-conn.handshaker.remoteAnswerCh <-conn.remoteAnswerCh
wg.Done() wg.Done()
}() }()
@@ -132,42 +123,62 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
wg.Wait() wg.Wait()
} }
func TestConn_Status(t *testing.T) { func TestConn_Status(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort) wgProxyFactory := wgproxy.NewFactory(connConf.LocalWgPort)
defer func() { defer func() {
_ = wgProxyFactory.Free() _ = wgProxyFactory.Free()
}() }()
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil, nil) conn, err := NewConn(connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil)
if err != nil { if err != nil {
return return
} }
tables := []struct { tables := []struct {
name string name string
statusIce ConnStatus status ConnStatus
statusRelay ConnStatus want ConnStatus
want ConnStatus
}{ }{
{"StatusConnected", StatusConnected, StatusConnected, StatusConnected}, {"StatusConnected", StatusConnected, StatusConnected},
{"StatusDisconnected", StatusDisconnected, StatusDisconnected, StatusDisconnected}, {"StatusDisconnected", StatusDisconnected, StatusDisconnected},
{"StatusConnecting", StatusConnecting, StatusConnecting, StatusConnecting}, {"StatusConnecting", StatusConnecting, StatusConnecting},
{"StatusConnectingIce", StatusConnecting, StatusDisconnected, StatusConnecting},
{"StatusConnectingIceAlternative", StatusConnecting, StatusConnected, StatusConnected},
{"StatusConnectingRelay", StatusDisconnected, StatusConnecting, StatusConnecting},
{"StatusConnectingRelayAlternative", StatusConnected, StatusConnecting, StatusConnected},
} }
for _, table := range tables { for _, table := range tables {
t.Run(table.name, func(t *testing.T) { t.Run(table.name, func(t *testing.T) {
si := NewAtomicConnStatus() conn.status = table.status
si.Set(table.statusIce)
conn.statusICE = si
sr := NewAtomicConnStatus()
sr.Set(table.statusRelay)
conn.statusRelay = sr
got := conn.Status() got := conn.Status()
assert.Equal(t, got, table.want, "they should be equal") assert.Equal(t, got, table.want, "they should be equal")
}) })
} }
} }
func TestConn_Close(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(connConf.LocalWgPort)
defer func() {
_ = wgProxyFactory.Free()
}()
conn, err := NewConn(connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil)
if err != nil {
return
}
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
<-conn.closeCh
wg.Done()
}()
go func() {
for {
err := conn.Close()
if err != nil {
continue
} else {
return
}
}
}()
wg.Wait()
}

View File

@@ -1,192 +0,0 @@
package peer
import (
"context"
"errors"
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/version"
)
var (
ErrSignalIsNotReady = errors.New("signal is not ready")
)
// IceCredentials ICE protocol credentials struct
type IceCredentials struct {
UFrag string
Pwd string
}
// OfferAnswer represents a session establishment offer or answer
type OfferAnswer struct {
IceCredentials IceCredentials
// WgListenPort is a remote WireGuard listen port.
// This field is used when establishing a direct WireGuard connection without any proxy.
// We can set the remote peer's endpoint with this port.
WgListenPort int
// Version of NetBird Agent
Version string
// RosenpassPubKey is the Rosenpass public key of the remote peer when receiving this message
// This value is the local Rosenpass server public key when sending the message
RosenpassPubKey []byte
// RosenpassAddr is the Rosenpass server address (IP:port) of the remote peer when receiving this message
// This value is the local Rosenpass server address when sending the message
RosenpassAddr string
// relay server address
RelaySrvAddress string
}
type Handshaker struct {
mu sync.Mutex
ctx context.Context
log *log.Entry
config ConnConfig
signaler *Signaler
ice *WorkerICE
relay *WorkerRelay
onNewOfferListeners []func(*OfferAnswer)
// remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection
remoteOffersCh chan OfferAnswer
// remoteAnswerCh is a channel used to wait for remote credentials answer (confirmation of our offer) to proceed with the connection
remoteAnswerCh chan OfferAnswer
}
func NewHandshaker(ctx context.Context, log *log.Entry, config ConnConfig, signaler *Signaler, ice *WorkerICE, relay *WorkerRelay) *Handshaker {
return &Handshaker{
ctx: ctx,
log: log,
config: config,
signaler: signaler,
ice: ice,
relay: relay,
remoteOffersCh: make(chan OfferAnswer),
remoteAnswerCh: make(chan OfferAnswer),
}
}
func (h *Handshaker) AddOnNewOfferListener(offer func(remoteOfferAnswer *OfferAnswer)) {
h.onNewOfferListeners = append(h.onNewOfferListeners, offer)
}
func (h *Handshaker) Listen() {
for {
h.log.Debugf("wait for remote offer confirmation")
remoteOfferAnswer, err := h.waitForRemoteOfferConfirmation()
if err != nil {
var connectionClosedError *ConnectionClosedError
if errors.As(err, &connectionClosedError) {
h.log.Tracef("stop handshaker")
return
}
h.log.Errorf("failed to received remote offer confirmation: %s", err)
continue
}
h.log.Debugf("received connection confirmation, running version %s and with remote WireGuard listen port %d", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort)
for _, listener := range h.onNewOfferListeners {
go listener(remoteOfferAnswer)
}
}
}
func (h *Handshaker) SendOffer() error {
h.mu.Lock()
defer h.mu.Unlock()
return h.sendOffer()
}
// OnRemoteOffer handles an offer from the remote peer and returns true if the message was accepted, false otherwise
// doesn't block, discards the message if connection wasn't ready
func (h *Handshaker) OnRemoteOffer(offer OfferAnswer) bool {
select {
case h.remoteOffersCh <- offer:
return true
default:
h.log.Debugf("OnRemoteOffer skipping message because is not ready")
// connection might not be ready yet to receive so we ignore the message
return false
}
}
// OnRemoteAnswer handles an offer from the remote peer and returns true if the message was accepted, false otherwise
// doesn't block, discards the message if connection wasn't ready
func (h *Handshaker) OnRemoteAnswer(answer OfferAnswer) bool {
select {
case h.remoteAnswerCh <- answer:
return true
default:
// connection might not be ready yet to receive so we ignore the message
h.log.Debugf("OnRemoteAnswer skipping message because is not ready")
return false
}
}
func (h *Handshaker) waitForRemoteOfferConfirmation() (*OfferAnswer, error) {
select {
case remoteOfferAnswer := <-h.remoteOffersCh:
// received confirmation from the remote peer -> ready to proceed
err := h.sendAnswer()
if err != nil {
return nil, err
}
return &remoteOfferAnswer, nil
case remoteOfferAnswer := <-h.remoteAnswerCh:
return &remoteOfferAnswer, nil
case <-h.ctx.Done():
// closed externally
return nil, NewConnectionClosedError(h.config.Key)
}
}
// sendOffer prepares local user credentials and signals them to the remote peer
func (h *Handshaker) sendOffer() error {
if !h.signaler.Ready() {
return ErrSignalIsNotReady
}
iceUFrag, icePwd := h.ice.GetLocalUserCredentials()
offer := OfferAnswer{
IceCredentials: IceCredentials{iceUFrag, icePwd},
WgListenPort: h.config.LocalWgPort,
Version: version.NetbirdVersion(),
RosenpassPubKey: h.config.RosenpassPubKey,
RosenpassAddr: h.config.RosenpassAddr,
}
addr, err := h.relay.RelayInstanceAddress()
if err == nil {
offer.RelaySrvAddress = addr
}
return h.signaler.SignalOffer(offer, h.config.Key)
}
func (h *Handshaker) sendAnswer() error {
h.log.Debugf("sending answer")
uFrag, pwd := h.ice.GetLocalUserCredentials()
answer := OfferAnswer{
IceCredentials: IceCredentials{uFrag, pwd},
WgListenPort: h.config.LocalWgPort,
Version: version.NetbirdVersion(),
RosenpassPubKey: h.config.RosenpassPubKey,
RosenpassAddr: h.config.RosenpassAddr,
}
addr, err := h.relay.RelayInstanceAddress()
if err == nil {
answer.RelaySrvAddress = addr
}
err = h.signaler.SignalAnswer(answer, h.config.Key)
if err != nil {
return err
}
return nil
}

View File

@@ -1,70 +0,0 @@
package peer
import (
"github.com/pion/ice/v3"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
signal "github.com/netbirdio/netbird/signal/client"
sProto "github.com/netbirdio/netbird/signal/proto"
)
type Signaler struct {
signal signal.Client
wgPrivateKey wgtypes.Key
}
func NewSignaler(signal signal.Client, wgPrivateKey wgtypes.Key) *Signaler {
return &Signaler{
signal: signal,
wgPrivateKey: wgPrivateKey,
}
}
func (s *Signaler) SignalOffer(offer OfferAnswer, remoteKey string) error {
return s.signalOfferAnswer(offer, remoteKey, sProto.Body_OFFER)
}
func (s *Signaler) SignalAnswer(offer OfferAnswer, remoteKey string) error {
return s.signalOfferAnswer(offer, remoteKey, sProto.Body_ANSWER)
}
func (s *Signaler) SignalICECandidate(candidate ice.Candidate, remoteKey string) error {
return s.signal.Send(&sProto.Message{
Key: s.wgPrivateKey.PublicKey().String(),
RemoteKey: remoteKey,
Body: &sProto.Body{
Type: sProto.Body_CANDIDATE,
Payload: candidate.Marshal(),
},
})
}
func (s *Signaler) Ready() bool {
return s.signal.Ready()
}
// SignalOfferAnswer signals either an offer or an answer to remote peer
func (s *Signaler) signalOfferAnswer(offerAnswer OfferAnswer, remoteKey string, bodyType sProto.Body_Type) error {
msg, err := signal.MarshalCredential(
s.wgPrivateKey,
offerAnswer.WgListenPort,
remoteKey,
&signal.Credential{
UFrag: offerAnswer.IceCredentials.UFrag,
Pwd: offerAnswer.IceCredentials.Pwd,
},
bodyType,
offerAnswer.RosenpassPubKey,
offerAnswer.RosenpassAddr,
offerAnswer.RelaySrvAddress)
if err != nil {
return err
}
err = s.signal.Send(msg)
if err != nil {
return err
}
return nil
}

View File

@@ -2,19 +2,14 @@ package peer
import ( import (
"errors" "errors"
"net/netip"
"slices"
"sync" "sync"
"time" "time"
"golang.org/x/exp/maps"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status" gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal/relay" "github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/management/domain"
relayClient "github.com/netbirdio/netbird/relay/client"
) )
// State contains the latest state of a peer // State contains the latest state of a peer
@@ -26,11 +21,11 @@ type State struct {
ConnStatus ConnStatus ConnStatus ConnStatus
ConnStatusUpdate time.Time ConnStatusUpdate time.Time
Relayed bool Relayed bool
Direct bool
LocalIceCandidateType string LocalIceCandidateType string
RemoteIceCandidateType string RemoteIceCandidateType string
LocalIceCandidateEndpoint string LocalIceCandidateEndpoint string
RemoteIceCandidateEndpoint string RemoteIceCandidateEndpoint string
RelayServerAddress string
LastWireguardHandshake time.Time LastWireguardHandshake time.Time
BytesTx int64 BytesTx int64
BytesRx int64 BytesRx int64
@@ -42,25 +37,25 @@ type State struct {
// AddRoute add a single route to routes map // AddRoute add a single route to routes map
func (s *State) AddRoute(network string) { func (s *State) AddRoute(network string) {
s.Mux.Lock() s.Mux.Lock()
defer s.Mux.Unlock()
if s.routes == nil { if s.routes == nil {
s.routes = make(map[string]struct{}) s.routes = make(map[string]struct{})
} }
s.routes[network] = struct{}{} s.routes[network] = struct{}{}
s.Mux.Unlock()
} }
// SetRoutes set state routes // SetRoutes set state routes
func (s *State) SetRoutes(routes map[string]struct{}) { func (s *State) SetRoutes(routes map[string]struct{}) {
s.Mux.Lock() s.Mux.Lock()
defer s.Mux.Unlock()
s.routes = routes s.routes = routes
s.Mux.Unlock()
} }
// DeleteRoute removes a route from the network amp // DeleteRoute removes a route from the network amp
func (s *State) DeleteRoute(network string) { func (s *State) DeleteRoute(network string) {
s.Mux.Lock() s.Mux.Lock()
defer s.Mux.Unlock()
delete(s.routes, network) delete(s.routes, network)
s.Mux.Unlock()
} }
// GetRoutes return routes map // GetRoutes return routes map
@@ -122,50 +117,40 @@ type FullStatus struct {
// Status holds a state of peers, signal, management connections and relays // Status holds a state of peers, signal, management connections and relays
type Status struct { type Status struct {
mux sync.Mutex mux sync.Mutex
peers map[string]State peers map[string]State
changeNotify map[string]chan struct{} changeNotify map[string]chan struct{}
signalState bool signalState bool
signalError error signalError error
managementState bool managementState bool
managementError error managementError error
relayStates []relay.ProbeResult relayStates []relay.ProbeResult
localPeer LocalPeerState localPeer LocalPeerState
offlinePeers []State offlinePeers []State
mgmAddress string mgmAddress string
signalAddress string signalAddress string
notifier *notifier notifier *notifier
rosenpassEnabled bool rosenpassEnabled bool
rosenpassPermissive bool rosenpassPermissive bool
nsGroupStates []NSGroupState nsGroupStates []NSGroupState
resolvedDomainsStates map[domain.Domain][]netip.Prefix
// To reduce the number of notification invocation this bool will be true when need to call the notification // To reduce the number of notification invocation this bool will be true when need to call the notification
// Some Peer actions mostly used by in a batch when the network map has been synchronized. In these type of events // Some Peer actions mostly used by in a batch when the network map has been synchronized. In these type of events
// set to true this variable and at the end of the processing we will reset it by the FinishPeerListModifications() // set to true this variable and at the end of the processing we will reset it by the FinishPeerListModifications()
peerListChangedForNotification bool peerListChangedForNotification bool
relayMgr *relayClient.Manager
} }
// NewRecorder returns a new Status instance // NewRecorder returns a new Status instance
func NewRecorder(mgmAddress string) *Status { func NewRecorder(mgmAddress string) *Status {
return &Status{ return &Status{
peers: make(map[string]State), peers: make(map[string]State),
changeNotify: make(map[string]chan struct{}), changeNotify: make(map[string]chan struct{}),
offlinePeers: make([]State, 0), offlinePeers: make([]State, 0),
notifier: newNotifier(), notifier: newNotifier(),
mgmAddress: mgmAddress, mgmAddress: mgmAddress,
resolvedDomainsStates: make(map[domain.Domain][]netip.Prefix),
} }
} }
func (d *Status) SetRelayMgr(manager *relayClient.Manager) {
d.mux.Lock()
defer d.mux.Unlock()
d.relayMgr = manager
}
// ReplaceOfflinePeers replaces // ReplaceOfflinePeers replaces
func (d *Status) ReplaceOfflinePeers(replacement []State) { func (d *Status) ReplaceOfflinePeers(replacement []State) {
d.mux.Lock() d.mux.Lock()
@@ -203,7 +188,7 @@ func (d *Status) GetPeer(peerPubKey string) (State, error) {
state, ok := d.peers[peerPubKey] state, ok := d.peers[peerPubKey]
if !ok { if !ok {
return State{}, iface.ErrPeerNotFound return State{}, errors.New("peer not found")
} }
return state, nil return state, nil
} }
@@ -241,17 +226,17 @@ func (d *Status) UpdatePeerState(receivedState State) error {
peerState.SetRoutes(receivedState.GetRoutes()) peerState.SetRoutes(receivedState.GetRoutes())
} }
skipNotification := shouldSkipNotify(receivedState.ConnStatus, peerState) skipNotification := shouldSkipNotify(receivedState, peerState)
if receivedState.ConnStatus != peerState.ConnStatus { if receivedState.ConnStatus != peerState.ConnStatus {
peerState.ConnStatus = receivedState.ConnStatus peerState.ConnStatus = receivedState.ConnStatus
peerState.ConnStatusUpdate = receivedState.ConnStatusUpdate peerState.ConnStatusUpdate = receivedState.ConnStatusUpdate
peerState.Direct = receivedState.Direct
peerState.Relayed = receivedState.Relayed peerState.Relayed = receivedState.Relayed
peerState.LocalIceCandidateType = receivedState.LocalIceCandidateType peerState.LocalIceCandidateType = receivedState.LocalIceCandidateType
peerState.RemoteIceCandidateType = receivedState.RemoteIceCandidateType peerState.RemoteIceCandidateType = receivedState.RemoteIceCandidateType
peerState.LocalIceCandidateEndpoint = receivedState.LocalIceCandidateEndpoint peerState.LocalIceCandidateEndpoint = receivedState.LocalIceCandidateEndpoint
peerState.RemoteIceCandidateEndpoint = receivedState.RemoteIceCandidateEndpoint peerState.RemoteIceCandidateEndpoint = receivedState.RemoteIceCandidateEndpoint
peerState.RelayServerAddress = receivedState.RelayServerAddress
peerState.RosenpassEnabled = receivedState.RosenpassEnabled peerState.RosenpassEnabled = receivedState.RosenpassEnabled
} }
@@ -271,146 +256,6 @@ func (d *Status) UpdatePeerState(receivedState State) error {
return nil return nil
} }
func (d *Status) UpdatePeerICEState(receivedState State) error {
d.mux.Lock()
defer d.mux.Unlock()
peerState, ok := d.peers[receivedState.PubKey]
if !ok {
return errors.New("peer doesn't exist")
}
if receivedState.IP != "" {
peerState.IP = receivedState.IP
}
skipNotification := shouldSkipNotify(receivedState.ConnStatus, peerState)
peerState.ConnStatus = receivedState.ConnStatus
peerState.ConnStatusUpdate = receivedState.ConnStatusUpdate
peerState.Relayed = receivedState.Relayed
peerState.LocalIceCandidateType = receivedState.LocalIceCandidateType
peerState.RemoteIceCandidateType = receivedState.RemoteIceCandidateType
peerState.LocalIceCandidateEndpoint = receivedState.LocalIceCandidateEndpoint
peerState.RemoteIceCandidateEndpoint = receivedState.RemoteIceCandidateEndpoint
peerState.RosenpassEnabled = receivedState.RosenpassEnabled
d.peers[receivedState.PubKey] = peerState
if skipNotification {
return nil
}
ch, found := d.changeNotify[receivedState.PubKey]
if found && ch != nil {
close(ch)
d.changeNotify[receivedState.PubKey] = nil
}
d.notifyPeerListChanged()
return nil
}
func (d *Status) UpdatePeerRelayedState(receivedState State) error {
d.mux.Lock()
defer d.mux.Unlock()
peerState, ok := d.peers[receivedState.PubKey]
if !ok {
return errors.New("peer doesn't exist")
}
skipNotification := shouldSkipNotify(receivedState.ConnStatus, peerState)
peerState.ConnStatus = receivedState.ConnStatus
peerState.ConnStatusUpdate = receivedState.ConnStatusUpdate
peerState.Relayed = receivedState.Relayed
peerState.RelayServerAddress = receivedState.RelayServerAddress
peerState.RosenpassEnabled = receivedState.RosenpassEnabled
d.peers[receivedState.PubKey] = peerState
if skipNotification {
return nil
}
ch, found := d.changeNotify[receivedState.PubKey]
if found && ch != nil {
close(ch)
d.changeNotify[receivedState.PubKey] = nil
}
d.notifyPeerListChanged()
return nil
}
func (d *Status) UpdatePeerRelayedStateToDisconnected(receivedState State) error {
d.mux.Lock()
defer d.mux.Unlock()
peerState, ok := d.peers[receivedState.PubKey]
if !ok {
return errors.New("peer doesn't exist")
}
skipNotification := shouldSkipNotify(receivedState.ConnStatus, peerState)
peerState.ConnStatus = receivedState.ConnStatus
peerState.Relayed = receivedState.Relayed
peerState.ConnStatusUpdate = receivedState.ConnStatusUpdate
peerState.RelayServerAddress = ""
d.peers[receivedState.PubKey] = peerState
if skipNotification {
return nil
}
ch, found := d.changeNotify[receivedState.PubKey]
if found && ch != nil {
close(ch)
d.changeNotify[receivedState.PubKey] = nil
}
d.notifyPeerListChanged()
return nil
}
func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error {
d.mux.Lock()
defer d.mux.Unlock()
peerState, ok := d.peers[receivedState.PubKey]
if !ok {
return errors.New("peer doesn't exist")
}
skipNotification := shouldSkipNotify(receivedState.ConnStatus, peerState)
peerState.ConnStatus = receivedState.ConnStatus
peerState.Relayed = receivedState.Relayed
peerState.ConnStatusUpdate = receivedState.ConnStatusUpdate
peerState.LocalIceCandidateType = receivedState.LocalIceCandidateType
peerState.RemoteIceCandidateType = receivedState.RemoteIceCandidateType
peerState.LocalIceCandidateEndpoint = receivedState.LocalIceCandidateEndpoint
peerState.RemoteIceCandidateEndpoint = receivedState.RemoteIceCandidateEndpoint
d.peers[receivedState.PubKey] = peerState
if skipNotification {
return nil
}
ch, found := d.changeNotify[receivedState.PubKey]
if found && ch != nil {
close(ch)
d.changeNotify[receivedState.PubKey] = nil
}
d.notifyPeerListChanged()
return nil
}
// UpdateWireGuardPeerState updates the WireGuard bits of the peer state // UpdateWireGuardPeerState updates the WireGuard bits of the peer state
func (d *Status) UpdateWireGuardPeerState(pubKey string, wgStats iface.WGStats) error { func (d *Status) UpdateWireGuardPeerState(pubKey string, wgStats iface.WGStats) error {
d.mux.Lock() d.mux.Lock()
@@ -430,13 +275,13 @@ func (d *Status) UpdateWireGuardPeerState(pubKey string, wgStats iface.WGStats)
return nil return nil
} }
func shouldSkipNotify(receivedConnStatus ConnStatus, curr State) bool { func shouldSkipNotify(received, curr State) bool {
switch { switch {
case receivedConnStatus == StatusConnecting: case received.ConnStatus == StatusConnecting:
return true return true
case receivedConnStatus == StatusDisconnected && curr.ConnStatus == StatusConnecting: case received.ConnStatus == StatusDisconnected && curr.ConnStatus == StatusConnecting:
return true return true
case receivedConnStatus == StatusDisconnected && curr.ConnStatus == StatusDisconnected: case received.ConnStatus == StatusDisconnected && curr.ConnStatus == StatusDisconnected:
return curr.IP != "" return curr.IP != ""
default: default:
return false return false
@@ -584,21 +429,7 @@ func (d *Status) UpdateDNSStates(dnsStates []NSGroupState) {
d.nsGroupStates = dnsStates d.nsGroupStates = dnsStates
} }
func (d *Status) UpdateResolvedDomainsStates(domain domain.Domain, prefixes []netip.Prefix) {
d.mux.Lock()
defer d.mux.Unlock()
d.resolvedDomainsStates[domain] = prefixes
}
func (d *Status) DeleteResolvedDomainsStates(domain domain.Domain) {
d.mux.Lock()
defer d.mux.Unlock()
delete(d.resolvedDomainsStates, domain)
}
func (d *Status) GetRosenpassState() RosenpassState { func (d *Status) GetRosenpassState() RosenpassState {
d.mux.Lock()
defer d.mux.Unlock()
return RosenpassState{ return RosenpassState{
d.rosenpassEnabled, d.rosenpassEnabled,
d.rosenpassPermissive, d.rosenpassPermissive,
@@ -606,8 +437,6 @@ func (d *Status) GetRosenpassState() RosenpassState {
} }
func (d *Status) GetManagementState() ManagementState { func (d *Status) GetManagementState() ManagementState {
d.mux.Lock()
defer d.mux.Unlock()
return ManagementState{ return ManagementState{
d.mgmAddress, d.mgmAddress,
d.managementState, d.managementState,
@@ -649,8 +478,6 @@ func (d *Status) IsLoginRequired() bool {
} }
func (d *Status) GetSignalState() SignalState { func (d *Status) GetSignalState() SignalState {
d.mux.Lock()
defer d.mux.Unlock()
return SignalState{ return SignalState{
d.signalAddress, d.signalAddress,
d.signalState, d.signalState,
@@ -658,71 +485,34 @@ func (d *Status) GetSignalState() SignalState {
} }
} }
// GetRelayStates returns the stun/turn/permanent relay states
func (d *Status) GetRelayStates() []relay.ProbeResult { func (d *Status) GetRelayStates() []relay.ProbeResult {
d.mux.Lock() return d.relayStates
defer d.mux.Unlock()
if d.relayMgr == nil {
return d.relayStates
}
// extend the list of stun, turn servers with relay address
relayStates := slices.Clone(d.relayStates)
var relayState relay.ProbeResult
// if the server connection is not established then we will use the general address
// in case of connection we will use the instance specific address
instanceAddr, err := d.relayMgr.RelayInstanceAddress()
if err != nil {
// TODO add their status
if errors.Is(err, relayClient.ErrRelayClientNotConnected) {
for _, r := range d.relayMgr.ServerURLs() {
relayStates = append(relayStates, relay.ProbeResult{
URI: r,
})
}
return relayStates
}
relayState.Err = err
}
relayState.URI = instanceAddr
return append(relayStates, relayState)
} }
func (d *Status) GetDNSStates() []NSGroupState { func (d *Status) GetDNSStates() []NSGroupState {
d.mux.Lock()
defer d.mux.Unlock()
return d.nsGroupStates return d.nsGroupStates
} }
func (d *Status) GetResolvedDomainsStates() map[domain.Domain][]netip.Prefix {
d.mux.Lock()
defer d.mux.Unlock()
return maps.Clone(d.resolvedDomainsStates)
}
// GetFullStatus gets full status // GetFullStatus gets full status
func (d *Status) GetFullStatus() FullStatus { func (d *Status) GetFullStatus() FullStatus {
d.mux.Lock()
defer d.mux.Unlock()
fullStatus := FullStatus{ fullStatus := FullStatus{
ManagementState: d.GetManagementState(), ManagementState: d.GetManagementState(),
SignalState: d.GetSignalState(), SignalState: d.GetSignalState(),
LocalPeerState: d.localPeer,
Relays: d.GetRelayStates(), Relays: d.GetRelayStates(),
RosenpassState: d.GetRosenpassState(), RosenpassState: d.GetRosenpassState(),
NSGroupStates: d.GetDNSStates(), NSGroupStates: d.GetDNSStates(),
} }
d.mux.Lock()
defer d.mux.Unlock()
fullStatus.LocalPeerState = d.localPeer
for _, status := range d.peers { for _, status := range d.peers {
fullStatus.Peers = append(fullStatus.Peers, status) fullStatus.Peers = append(fullStatus.Peers, status)
} }
fullStatus.Peers = append(fullStatus.Peers, d.offlinePeers...) fullStatus.Peers = append(fullStatus.Peers, d.offlinePeers...)
return fullStatus return fullStatus
} }

View File

@@ -2,8 +2,8 @@ package peer
import ( import (
"errors" "errors"
"sync"
"testing" "testing"
"sync"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@@ -43,7 +43,7 @@ func TestUpdatePeerState(t *testing.T) {
status := NewRecorder("https://mgm") status := NewRecorder("https://mgm")
peerState := State{ peerState := State{
PubKey: key, PubKey: key,
Mux: new(sync.RWMutex), Mux: new(sync.RWMutex),
} }
status.peers[key] = peerState status.peers[key] = peerState
@@ -64,7 +64,7 @@ func TestStatus_UpdatePeerFQDN(t *testing.T) {
status := NewRecorder("https://mgm") status := NewRecorder("https://mgm")
peerState := State{ peerState := State{
PubKey: key, PubKey: key,
Mux: new(sync.RWMutex), Mux: new(sync.RWMutex),
} }
status.peers[key] = peerState status.peers[key] = peerState
@@ -83,7 +83,7 @@ func TestGetPeerStateChangeNotifierLogic(t *testing.T) {
status := NewRecorder("https://mgm") status := NewRecorder("https://mgm")
peerState := State{ peerState := State{
PubKey: key, PubKey: key,
Mux: new(sync.RWMutex), Mux: new(sync.RWMutex),
} }
status.peers[key] = peerState status.peers[key] = peerState
@@ -108,7 +108,7 @@ func TestRemovePeer(t *testing.T) {
status := NewRecorder("https://mgm") status := NewRecorder("https://mgm")
peerState := State{ peerState := State{
PubKey: key, PubKey: key,
Mux: new(sync.RWMutex), Mux: new(sync.RWMutex),
} }
status.peers[key] = peerState status.peers[key] = peerState

View File

@@ -6,6 +6,6 @@ import (
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
) )
func (w *WorkerICE) newStdNet() (*stdnet.Net, error) { func (conn *Conn) newStdNet() (*stdnet.Net, error) {
return stdnet.NewNet(w.config.ICEConfig.InterfaceBlackList) return stdnet.NewNet(conn.config.InterfaceBlackList)
} }

View File

@@ -2,6 +2,6 @@ package peer
import "github.com/netbirdio/netbird/client/internal/stdnet" import "github.com/netbirdio/netbird/client/internal/stdnet"
func (w *WorkerICE) newStdNet() (*stdnet.Net, error) { func (conn *Conn) newStdNet() (*stdnet.Net, error) {
return stdnet.NewNetWithDiscover(w.iFaceDiscover, w.config.ICEConfig.InterfaceBlackList) return stdnet.NewNetWithDiscover(conn.iFaceDiscover, conn.config.InterfaceBlackList)
} }

View File

@@ -1,470 +0,0 @@
package peer
import (
"context"
"fmt"
"net"
"net/netip"
"runtime"
"sync"
"sync/atomic"
"time"
"github.com/pion/ice/v3"
"github.com/pion/randutil"
"github.com/pion/stun/v2"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/iface/bind"
"github.com/netbirdio/netbird/route"
)
const (
iceKeepAliveDefault = 4 * time.Second
iceDisconnectedTimeoutDefault = 6 * time.Second
// iceRelayAcceptanceMinWaitDefault is the same as in the Pion ICE package
iceRelayAcceptanceMinWaitDefault = 2 * time.Second
lenUFrag = 16
lenPwd = 32
runesAlpha = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
)
var (
failedTimeout = 6 * time.Second
)
type ICEConfig struct {
// StunTurn is a list of STUN and TURN URLs
StunTurn *atomic.Value // []*stun.URI
// InterfaceBlackList is a list of machine interfaces that should be filtered out by ICE Candidate gathering
// (e.g. if eth0 is in the list, host candidate of this interface won't be used)
InterfaceBlackList []string
DisableIPv6Discovery bool
UDPMux ice.UDPMux
UDPMuxSrflx ice.UniversalUDPMux
NATExternalIPs []string
}
type ICEConnInfo struct {
RemoteConn net.Conn
RosenpassPubKey []byte
RosenpassAddr string
LocalIceCandidateType string
RemoteIceCandidateType string
RemoteIceCandidateEndpoint string
LocalIceCandidateEndpoint string
Relayed bool
RelayedOnLocal bool
}
type WorkerICECallbacks struct {
OnConnReady func(ConnPriority, ICEConnInfo)
OnStatusChanged func(ConnStatus)
}
type WorkerICE struct {
ctx context.Context
log *log.Entry
config ConnConfig
signaler *Signaler
iFaceDiscover stdnet.ExternalIFaceDiscover
statusRecorder *Status
hasRelayOnLocally bool
conn WorkerICECallbacks
selectedPriority ConnPriority
agent *ice.Agent
muxAgent sync.Mutex
StunTurn []*stun.URI
sentExtraSrflx bool
localUfrag string
localPwd string
}
func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool, callBacks WorkerICECallbacks) (*WorkerICE, error) {
w := &WorkerICE{
ctx: ctx,
log: log,
config: config,
signaler: signaler,
iFaceDiscover: ifaceDiscover,
statusRecorder: statusRecorder,
hasRelayOnLocally: hasRelayOnLocally,
conn: callBacks,
}
localUfrag, localPwd, err := generateICECredentials()
if err != nil {
return nil, err
}
w.localUfrag = localUfrag
w.localPwd = localPwd
return w, nil
}
func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
w.log.Debugf("OnNewOffer for ICE")
w.muxAgent.Lock()
if w.agent != nil {
w.log.Debugf("agent already exists, skipping the offer")
w.muxAgent.Unlock()
return
}
var preferredCandidateTypes []ice.CandidateType
if w.hasRelayOnLocally && remoteOfferAnswer.RelaySrvAddress != "" {
w.selectedPriority = connPriorityICEP2P
preferredCandidateTypes = candidateTypesP2P()
} else {
w.selectedPriority = connPriorityICETurn
preferredCandidateTypes = candidateTypes()
}
w.log.Debugf("recreate ICE agent")
agentCtx, agentCancel := context.WithCancel(w.ctx)
agent, err := w.reCreateAgent(agentCancel, preferredCandidateTypes)
if err != nil {
w.log.Errorf("failed to recreate ICE Agent: %s", err)
w.muxAgent.Unlock()
return
}
w.agent = agent
w.muxAgent.Unlock()
w.log.Debugf("gather candidates")
err = w.agent.GatherCandidates()
if err != nil {
w.log.Debugf("failed to gather candidates: %s", err)
return
}
// will block until connection succeeded
// but it won't release if ICE Agent went into Disconnected or Failed state,
// so we have to cancel it with the provided context once agent detected a broken connection
w.log.Debugf("turn agent dial")
remoteConn, err := w.turnAgentDial(agentCtx, remoteOfferAnswer)
if err != nil {
w.log.Debugf("failed to dial the remote peer: %s", err)
return
}
w.log.Debugf("agent dial succeeded")
pair, err := w.agent.GetSelectedCandidatePair()
if err != nil {
return
}
if !isRelayCandidate(pair.Local) {
// dynamically set remote WireGuard port if other side specified a different one from the default one
remoteWgPort := iface.DefaultWgPort
if remoteOfferAnswer.WgListenPort != 0 {
remoteWgPort = remoteOfferAnswer.WgListenPort
}
// To support old version's with direct mode we attempt to punch an additional role with the remote WireGuard port
go w.punchRemoteWGPort(pair, remoteWgPort)
}
ci := ICEConnInfo{
RemoteConn: remoteConn,
RosenpassPubKey: remoteOfferAnswer.RosenpassPubKey,
RosenpassAddr: remoteOfferAnswer.RosenpassAddr,
LocalIceCandidateType: pair.Local.Type().String(),
RemoteIceCandidateType: pair.Remote.Type().String(),
LocalIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Local.Address(), pair.Local.Port()),
RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Remote.Port()),
Relayed: isRelayed(pair),
RelayedOnLocal: isRelayCandidate(pair.Local),
}
w.log.Debugf("on ICE conn read to use ready")
go w.conn.OnConnReady(w.selectedPriority, ci)
}
// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer.
func (w *WorkerICE) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMap) {
w.muxAgent.Lock()
defer w.muxAgent.Unlock()
w.log.Debugf("OnRemoteCandidate from peer %s -> %s", w.config.Key, candidate.String())
if w.agent == nil {
w.log.Warnf("ICE Agent is not initialized yet")
return
}
if candidateViaRoutes(candidate, haRoutes) {
return
}
err := w.agent.AddRemoteCandidate(candidate)
if err != nil {
w.log.Errorf("error while handling remote candidate")
return
}
}
func (w *WorkerICE) GetLocalUserCredentials() (frag string, pwd string) {
w.muxAgent.Lock()
defer w.muxAgent.Unlock()
return w.localUfrag, w.localPwd
}
func (w *WorkerICE) Close() {
w.muxAgent.Lock()
defer w.muxAgent.Unlock()
if w.agent == nil {
return
}
err := w.agent.Close()
if err != nil {
w.log.Warnf("failed to close ICE agent: %s", err)
}
}
func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, relaySupport []ice.CandidateType) (*ice.Agent, error) {
transportNet, err := w.newStdNet()
if err != nil {
w.log.Errorf("failed to create pion's stdnet: %s", err)
}
iceKeepAlive := iceKeepAlive()
iceDisconnectedTimeout := iceDisconnectedTimeout()
iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait()
agentConfig := &ice.AgentConfig{
MulticastDNSMode: ice.MulticastDNSModeDisabled,
NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6},
Urls: w.config.ICEConfig.StunTurn.Load().([]*stun.URI),
CandidateTypes: relaySupport,
InterfaceFilter: stdnet.InterfaceFilter(w.config.ICEConfig.InterfaceBlackList),
UDPMux: w.config.ICEConfig.UDPMux,
UDPMuxSrflx: w.config.ICEConfig.UDPMuxSrflx,
NAT1To1IPs: w.config.ICEConfig.NATExternalIPs,
Net: transportNet,
FailedTimeout: &failedTimeout,
DisconnectedTimeout: &iceDisconnectedTimeout,
KeepaliveInterval: &iceKeepAlive,
RelayAcceptanceMinWait: &iceRelayAcceptanceMinWait,
LocalUfrag: w.localUfrag,
LocalPwd: w.localPwd,
}
if w.config.ICEConfig.DisableIPv6Discovery {
agentConfig.NetworkTypes = []ice.NetworkType{ice.NetworkTypeUDP4}
}
w.sentExtraSrflx = false
agent, err := ice.NewAgent(agentConfig)
if err != nil {
return nil, err
}
err = agent.OnCandidate(w.onICECandidate)
if err != nil {
return nil, err
}
err = agent.OnConnectionStateChange(func(state ice.ConnectionState) {
w.log.Debugf("ICE ConnectionState has changed to %s", state.String())
if state == ice.ConnectionStateFailed || state == ice.ConnectionStateDisconnected {
w.conn.OnStatusChanged(StatusDisconnected)
w.muxAgent.Lock()
agentCancel()
_ = agent.Close()
w.agent = nil
w.muxAgent.Unlock()
}
})
if err != nil {
return nil, err
}
err = agent.OnSelectedCandidatePairChange(w.onICESelectedCandidatePair)
if err != nil {
return nil, err
}
err = agent.OnSuccessfulSelectedPairBindingResponse(func(p *ice.CandidatePair) {
err := w.statusRecorder.UpdateLatency(w.config.Key, p.Latency())
if err != nil {
w.log.Debugf("failed to update latency for peer: %s", err)
return
}
})
if err != nil {
return nil, fmt.Errorf("failed setting binding response callback: %w", err)
}
return agent, nil
}
func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) {
// wait local endpoint configuration
time.Sleep(time.Second)
addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", pair.Remote.Address(), remoteWgPort))
if err != nil {
w.log.Warnf("got an error while resolving the udp address, err: %s", err)
return
}
mux, ok := w.config.ICEConfig.UDPMuxSrflx.(*bind.UniversalUDPMuxDefault)
if !ok {
w.log.Warn("invalid udp mux conversion")
return
}
_, err = mux.GetSharedConn().WriteTo([]byte{0x6e, 0x62}, addr)
if err != nil {
w.log.Warnf("got an error while sending the punch packet, err: %s", err)
}
}
// onICECandidate is a callback attached to an ICE Agent to receive new local connection candidates
// and then signals them to the remote peer
func (w *WorkerICE) onICECandidate(candidate ice.Candidate) {
// nil means candidate gathering has been ended
if candidate == nil {
return
}
// TODO: reported port is incorrect for CandidateTypeHost, makes understanding ICE use via logs confusing as port is ignored
w.log.Debugf("discovered local candidate %s", candidate.String())
go func() {
err := w.signaler.SignalICECandidate(candidate, w.config.Key)
if err != nil {
w.log.Errorf("failed signaling candidate to the remote peer %s %s", w.config.Key, err)
}
}()
if !w.shouldSendExtraSrflxCandidate(candidate) {
return
}
// sends an extra server reflexive candidate to the remote peer with our related port (usually the wireguard port)
// this is useful when network has an existing port forwarding rule for the wireguard port and this peer
extraSrflx, err := extraSrflxCandidate(candidate)
if err != nil {
w.log.Errorf("failed creating extra server reflexive candidate %s", err)
return
}
w.sentExtraSrflx = true
go func() {
err = w.signaler.SignalICECandidate(extraSrflx, w.config.Key)
if err != nil {
w.log.Errorf("failed signaling the extra server reflexive candidate: %s", err)
}
}()
}
func (w *WorkerICE) onICESelectedCandidatePair(c1 ice.Candidate, c2 ice.Candidate) {
w.log.Debugf("selected candidate pair [local <-> remote] -> [%s <-> %s], peer %s", c1.String(), c2.String(),
w.config.Key)
}
func (w *WorkerICE) shouldSendExtraSrflxCandidate(candidate ice.Candidate) bool {
if !w.sentExtraSrflx && candidate.Type() == ice.CandidateTypeServerReflexive && candidate.Port() != candidate.RelatedAddress().Port {
return true
}
return false
}
func (w *WorkerICE) turnAgentDial(ctx context.Context, remoteOfferAnswer *OfferAnswer) (*ice.Conn, error) {
isControlling := w.config.LocalKey > w.config.Key
if isControlling {
return w.agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
} else {
return w.agent.Accept(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
}
}
func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive, error) {
relatedAdd := candidate.RelatedAddress()
return ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{
Network: candidate.NetworkType().String(),
Address: candidate.Address(),
Port: relatedAdd.Port,
Component: candidate.Component(),
RelAddr: relatedAdd.Address,
RelPort: relatedAdd.Port,
})
}
func candidateViaRoutes(candidate ice.Candidate, clientRoutes route.HAMap) bool {
var routePrefixes []netip.Prefix
for _, routes := range clientRoutes {
if len(routes) > 0 && routes[0] != nil {
routePrefixes = append(routePrefixes, routes[0].Network)
}
}
addr, err := netip.ParseAddr(candidate.Address())
if err != nil {
log.Errorf("Failed to parse IP address %s: %v", candidate.Address(), err)
return false
}
for _, prefix := range routePrefixes {
// default route is
if prefix.Bits() == 0 {
continue
}
if prefix.Contains(addr) {
log.Debugf("Ignoring candidate [%s], its address is part of routed network %s", candidate.String(), prefix)
return true
}
}
return false
}
func candidateTypes() []ice.CandidateType {
if hasICEForceRelayConn() {
return []ice.CandidateType{ice.CandidateTypeRelay}
}
// TODO: remove this once we have refactored userspace proxy into the bind package
if runtime.GOOS == "ios" {
return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive}
}
return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive, ice.CandidateTypeRelay}
}
func candidateTypesP2P() []ice.CandidateType {
return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive}
}
func isRelayCandidate(candidate ice.Candidate) bool {
return candidate.Type() == ice.CandidateTypeRelay
}
func isRelayed(pair *ice.CandidatePair) bool {
if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay {
return true
}
return false
}
func generateICECredentials() (string, string, error) {
ufrag, err := randutil.GenerateCryptoRandomString(lenUFrag, runesAlpha)
if err != nil {
return "", "", err
}
pwd, err := randutil.GenerateCryptoRandomString(lenPwd, runesAlpha)
if err != nil {
return "", "", err
}
return ufrag, pwd, nil
}

Some files were not shown because too many files have changed in this diff Show More