Compare commits

..

1 Commits

Author SHA1 Message Date
Maycon Santos
840b07c784 add todos 2024-07-05 11:15:28 +02:00
630 changed files with 20108 additions and 63067 deletions

View File

@@ -1,8 +0,0 @@
root = true
[*]
end_of_line = lf
insert_final_newline = true
[*.go]
indent_style = tab

3
.github/FUNDING.yml vendored
View File

@@ -1,3 +0,0 @@
# These are supported funding model platforms
github: [netbirdio]

View File

@@ -31,14 +31,9 @@ Please specify whether you use NetBird Cloud or self-host NetBird's control plan
`netbird version` `netbird version`
**NetBird status -dA output:** **NetBird status -d output:**
If applicable, add the `netbird status -dA' command output. If applicable, add the `netbird status -d' command output.
**Do you face any (non-mobile) client issues?**
Please provide the file created by `netbird debug for 1m -AS`.
We advise reviewing the anonymized files for any remaining PII.
**Screenshots** **Screenshots**

View File

@@ -18,20 +18,18 @@ jobs:
runs-on: macos-latest runs-on: macos-latest
steps: steps:
- name: Install Go - name: Install Go
uses: actions/setup-go@v5 uses: actions/setup-go@v4
with: with:
go-version: "1.23.x" go-version: "1.21.x"
cache: false
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v3
- name: Cache Go modules - name: Cache Go modules
uses: actions/cache@v4 uses: actions/cache@v3
with: with:
path: ~/go/pkg/mod path: ~/go/pkg/mod
key: macos-gotest-${{ hashFiles('**/go.sum') }} key: macos-go-${{ hashFiles('**/go.sum') }}
restore-keys: | restore-keys: |
macos-gotest-
macos-go- macos-go-
- name: Install libpcap - name: Install libpcap
@@ -44,4 +42,4 @@ jobs:
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
- name: Test - name: Test
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v /management) run: NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 ./...

View File

@@ -13,7 +13,7 @@ concurrency:
jobs: jobs:
test: test:
runs-on: ubuntu-22.04 runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- name: Test in FreeBSD - name: Test in FreeBSD
@@ -21,26 +21,19 @@ jobs:
uses: vmactions/freebsd-vm@v1 uses: vmactions/freebsd-vm@v1
with: with:
usesh: true usesh: true
copyback: false
release: "14.1"
prepare: | prepare: |
pkg install -y go pkg install -y curl
pkg install -y git
# -x - to print all executed commands
# -e - to faile on first error
run: | run: |
set -e -x set -x
time go build -o netbird client/main.go curl -o go.tar.gz https://go.dev/dl/go1.21.11.freebsd-amd64.tar.gz -L
# check all component except management, since we do not support management server on freebsd tar zxf go.tar.gz
time go test -timeout 1m -failfast ./base62/... mv go /usr/local/go
# NOTE: without -p1 `client/internal/dns` will fail becasue of `listen udp4 :33100: bind: address already in use` ln -s /usr/local/go/bin/go /usr/local/bin/go
time go test -timeout 8m -failfast -p 1 ./client/... go mod tidy
time go test -timeout 1m -failfast ./dns/... go test -timeout 5m -p 1 ./iface/...
time go test -timeout 1m -failfast ./encryption/... go test -timeout 5m -p 1 ./client/...
time go test -timeout 1m -failfast ./formatter/... cd client
time go test -timeout 1m -failfast ./client/iface/... go build .
time go test -timeout 1m -failfast ./route/... cd ..
time go test -timeout 1m -failfast ./sharedsock/...
time go test -timeout 1m -failfast ./signal/...
time go test -timeout 1m -failfast ./util/...
time go test -timeout 1m -failfast ./version/...

View File

@@ -11,163 +11,29 @@ concurrency:
cancel-in-progress: true cancel-in-progress: true
jobs: jobs:
build-cache:
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache@v4
id: cache
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
- name: Install dependencies
if: steps.cache.outputs.cache-hit != 'true'
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: steps.cache.outputs.cache-hit != 'true'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Build client
if: steps.cache.outputs.cache-hit != 'true'
working-directory: client
run: CGO_ENABLED=1 go build .
- name: Build client 386
if: steps.cache.outputs.cache-hit != 'true'
working-directory: client
run: CGO_ENABLED=1 GOARCH=386 go build -o client-386 .
- name: Build management
if: steps.cache.outputs.cache-hit != 'true'
working-directory: management
run: CGO_ENABLED=1 go build .
- name: Build management 386
if: steps.cache.outputs.cache-hit != 'true'
working-directory: management
run: CGO_ENABLED=1 GOARCH=386 go build -o management-386 .
- name: Build signal
if: steps.cache.outputs.cache-hit != 'true'
working-directory: signal
run: CGO_ENABLED=1 go build .
- name: Build signal 386
if: steps.cache.outputs.cache-hit != 'true'
working-directory: signal
run: CGO_ENABLED=1 GOARCH=386 go build -o signal-386 .
- name: Build relay
if: steps.cache.outputs.cache-hit != 'true'
working-directory: relay
run: CGO_ENABLED=1 go build .
- name: Build relay 386
if: steps.cache.outputs.cache-hit != 'true'
working-directory: relay
run: CGO_ENABLED=1 GOARCH=386 go build -o relay-386 .
test: test:
needs: [build-cache]
strategy: strategy:
fail-fast: false
matrix:
arch: [ '386','amd64' ]
runs-on: ubuntu-22.04
steps:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v /management)
test_management:
needs: [ build-cache ]
strategy:
fail-fast: false
matrix: matrix:
arch: [ '386','amd64' ] arch: [ '386','amd64' ]
store: [ 'sqlite', 'postgres'] store: [ 'sqlite', 'postgres']
runs-on: ubuntu-22.04 runs-on: ubuntu-latest
steps: steps:
- name: Install Go - name: Install Go
uses: actions/setup-go@v5 uses: actions/setup-go@v4
with: with:
go-version: "1.23.x" go-version: "1.21.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules - name: Cache Go modules
uses: actions/cache/restore@v4 uses: actions/cache@v3
with: with:
path: | path: ~/go/pkg/mod
${{ env.cache }} key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: | restore-keys: |
${{ runner.os }}-gotest-cache- ${{ runner.os }}-go-
- name: Checkout code
uses: actions/checkout@v3
- name: Install dependencies - name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
@@ -183,84 +49,26 @@ jobs:
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
- name: Test - name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -p 1 -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m $(go list ./... | grep /management) run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 ./...
benchmark:
needs: [ build-cache ]
strategy:
fail-fast: false
matrix:
arch: [ '386','amd64' ]
store: [ 'sqlite', 'postgres' ]
runs-on: ubuntu-22.04
steps:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -run=^$ -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m -p 1 ./...
test_client_on_docker: test_client_on_docker:
needs: [ build-cache ]
runs-on: ubuntu-20.04 runs-on: ubuntu-20.04
steps: steps:
- name: Install Go - name: Install Go
uses: actions/setup-go@v5 uses: actions/setup-go@v4
with: with:
go-version: "1.23.x" go-version: "1.21.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules - name: Cache Go modules
uses: actions/cache/restore@v4 uses: actions/cache@v3
with: with:
path: | path: ~/go/pkg/mod
${{ env.cache }} key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: | restore-keys: |
${{ runner.os }}-gotest-cache- ${{ runner.os }}-go-
- name: Checkout code
uses: actions/checkout@v3
- name: Install dependencies - name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
@@ -271,6 +79,9 @@ jobs:
- name: check git status - name: check git status
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
- name: Generate Iface Test bin
run: CGO_ENABLED=0 go test -c -o iface-testing.bin ./iface/
- name: Generate Shared Sock Test bin - name: Generate Shared Sock Test bin
run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock
@@ -287,7 +98,7 @@ jobs:
run: CGO_ENABLED=1 go test -c -o engine-testing.bin ./client/internal run: CGO_ENABLED=1 go test -c -o engine-testing.bin ./client/internal
- name: Generate Peer Test bin - name: Generate Peer Test bin
run: CGO_ENABLED=0 go test -c -o peer-testing.bin ./client/internal/peer/ run: CGO_ENABLED=0 go test -c -o peer-testing.bin ./client/internal/peer/...
- run: chmod +x *testing.bin - run: chmod +x *testing.bin
@@ -295,7 +106,7 @@ jobs:
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/sharedsock --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/sharedsock-testing.bin -test.timeout 5m -test.parallel 1 run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/sharedsock --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/sharedsock-testing.bin -test.timeout 5m -test.parallel 1
- name: Run Iface tests in docker - name: Run Iface tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/netbird -v /tmp/cache:/tmp/cache -v /tmp/modcache:/tmp/modcache -w /netbird -e GOCACHE=/tmp/cache -e GOMODCACHE=/tmp/modcache -e CGO_ENABLED=0 golang:1.23-alpine go test -test.timeout 5m -test.parallel 1 ./client/iface/... run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/iface --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/iface-testing.bin -test.timeout 5m -test.parallel 1
- name: Run RouteManager tests in docker - name: Run RouteManager tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin -test.timeout 5m -test.parallel 1 run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin -test.timeout 5m -test.parallel 1

View File

@@ -17,30 +17,13 @@ jobs:
runs-on: windows-latest runs-on: windows-latest
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v3
- name: Install Go - name: Install Go
uses: actions/setup-go@v5 uses: actions/setup-go@v4
id: go id: go
with: with:
go-version: "1.23.x" go-version: "1.21.x"
cache: false
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $env:GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $env:GITHUB_ENV
- name: Cache Go modules
uses: actions/cache@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-
${{ runner.os }}-go-
- name: Download wintun - name: Download wintun
uses: carlosperate/download-file-action@v2 uses: carlosperate/download-file-action@v2
@@ -59,13 +42,11 @@ jobs:
- run: choco install -y sysinternals --ignore-checksums - run: choco install -y sysinternals --ignore-checksums
- run: choco install -y mingw - run: choco install -y mingw
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }} - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=C:\Users\runneradmin\go\pkg\mod
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }} - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=C:\Users\runneradmin\AppData\Local\go-build
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' })" >> $env:GITHUB_ENV
- name: test - name: test
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1" run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 10m -p 1 ./... > test-out.txt 2>&1"
- name: test output - name: test output
if: ${{ always() }} if: ${{ always() }}
run: Get-Content test-out.txt run: Get-Content test-out.txt

View File

@@ -15,11 +15,11 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v3
- name: codespell - name: codespell
uses: codespell-project/actions-codespell@v2 uses: codespell-project/actions-codespell@v2
with: with:
ignore_words_list: erro,clienta,hastable,iif,groupd,testIn ignore_words_list: erro,clienta,hastable,
skip: go.mod,go.sum skip: go.mod,go.sum
only_warn: 1 only_warn: 1
golangci: golangci:
@@ -32,21 +32,21 @@ jobs:
timeout-minutes: 15 timeout-minutes: 15
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v3
- name: Check for duplicate constants - name: Check for duplicate constants
if: matrix.os == 'ubuntu-latest' if: matrix.os == 'ubuntu-latest'
run: | run: |
! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep . ! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep .
- name: Install Go - name: Install Go
uses: actions/setup-go@v5 uses: actions/setup-go@v4
with: with:
go-version: "1.23.x" go-version: "1.21.x"
cache: false cache: false
- name: Install dependencies - name: Install dependencies
if: matrix.os == 'ubuntu-latest' if: matrix.os == 'ubuntu-latest'
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
- name: golangci-lint - name: golangci-lint
uses: golangci/golangci-lint-action@v4 uses: golangci/golangci-lint-action@v3
with: with:
version: latest version: latest
args: --timeout=12m --out-format colored-line-number args: --timeout=12m

View File

@@ -13,7 +13,6 @@ concurrency:
jobs: jobs:
test-install-script: test-install-script:
strategy: strategy:
fail-fast: false
max-parallel: 2 max-parallel: 2
matrix: matrix:
os: [ubuntu-latest, macos-latest] os: [ubuntu-latest, macos-latest]
@@ -22,7 +21,7 @@ jobs:
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v3
- name: run install script - name: run install script
env: env:

View File

@@ -15,23 +15,23 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v4 uses: actions/checkout@v3
- name: Install Go - name: Install Go
uses: actions/setup-go@v5 uses: actions/setup-go@v4
with: with:
go-version: "1.23.x" go-version: "1.21.x"
- name: Setup Android SDK - name: Setup Android SDK
uses: android-actions/setup-android@v3 uses: android-actions/setup-android@v3
with: with:
cmdline-tools-version: 8512546 cmdline-tools-version: 8512546
- name: Setup Java - name: Setup Java
uses: actions/setup-java@v4 uses: actions/setup-java@v3
with: with:
java-version: "11" java-version: "11"
distribution: "adopt" distribution: "adopt"
- name: NDK Cache - name: NDK Cache
id: ndk-cache id: ndk-cache
uses: actions/cache@v4 uses: actions/cache@v3
with: with:
path: /usr/local/lib/android/sdk/ndk path: /usr/local/lib/android/sdk/ndk
key: ndk-cache-23.1.7779620 key: ndk-cache-23.1.7779620
@@ -50,11 +50,11 @@ jobs:
runs-on: macos-latest runs-on: macos-latest
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v4 uses: actions/checkout@v3
- name: Install Go - name: Install Go
uses: actions/setup-go@v5 uses: actions/setup-go@v4
with: with:
go-version: "1.23.x" go-version: "1.21.x"
- name: install gomobile - name: install gomobile
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed
- name: gomobile init - name: gomobile init

View File

@@ -3,16 +3,15 @@ name: Release
on: on:
push: push:
tags: tags:
- "v*" - 'v*'
branches: branches:
- main - main
pull_request: pull_request:
env: env:
SIGN_PIPE_VER: "v0.0.17" SIGN_PIPE_VER: "v0.0.11"
GORELEASER_VER: "v2.3.2" GORELEASER_VER: "v1.14.1"
PRODUCT_NAME: "NetBird"
COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)"
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }} group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
@@ -20,30 +19,26 @@ concurrency:
jobs: jobs:
release: release:
runs-on: ubuntu-22.04 runs-on: ubuntu-latest
env: env:
flags: "" flags: ""
steps: steps:
- name: Parse semver string
id: semver_parser
uses: booxmedialtd/ws-action-parse-semver@v1
with:
input_string: ${{ (startsWith(github.ref, 'refs/tags/v') && github.ref) || 'refs/tags/v0.0.0' }}
version_extractor_regex: '\/v(.*)$'
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }} - if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV run: echo "flags=--snapshot" >> $GITHUB_ENV
- name: Checkout -
uses: actions/checkout@v4 name: Checkout
uses: actions/checkout@v3
with: with:
fetch-depth: 0 # It is required for GoReleaser to work properly fetch-depth: 0 # It is required for GoReleaser to work properly
- name: Set up Go -
uses: actions/setup-go@v5 name: Set up Go
uses: actions/setup-go@v4
with: with:
go-version: "1.23" go-version: "1.21"
cache: false cache: false
- name: Cache Go modules -
uses: actions/cache@v4 name: Cache Go modules
uses: actions/cache@v3
with: with:
path: | path: |
~/go/pkg/mod ~/go/pkg/mod
@@ -51,57 +46,73 @@ jobs:
key: ${{ runner.os }}-go-releaser-${{ hashFiles('**/go.sum') }} key: ${{ runner.os }}-go-releaser-${{ hashFiles('**/go.sum') }}
restore-keys: | restore-keys: |
${{ runner.os }}-go-releaser- ${{ runner.os }}-go-releaser-
- name: Install modules -
name: Install modules
run: go mod tidy run: go mod tidy
- name: check git status -
name: check git status
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
- name: Set up QEMU -
name: Set up QEMU
uses: docker/setup-qemu-action@v2 uses: docker/setup-qemu-action@v2
- name: Set up Docker Buildx -
name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2 uses: docker/setup-buildx-action@v2
- name: Login to Docker hub -
name: Login to Docker hub
if: github.event_name != 'pull_request' if: github.event_name != 'pull_request'
uses: docker/login-action@v1 uses: docker/login-action@v1
with: with:
username: ${{ secrets.DOCKER_USER }} username: netbirdio
password: ${{ secrets.DOCKER_TOKEN }} password: ${{ secrets.DOCKER_TOKEN }}
- name: Install OS build dependencies - name: Install OS build dependencies
run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu
- name: Install goversioninfo - name: Install rsrc
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e run: go install github.com/akavel/rsrc@v0.10.2
- name: Generate windows syso amd64 - name: Generate windows rsrc amd64
run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_amd64.syso run: rsrc -arch amd64 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_amd64.syso
- name: Run GoReleaser - name: Generate windows rsrc arm64
run: rsrc -arch arm64 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_arm64.syso
- name: Generate windows rsrc arm
run: rsrc -arch arm -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_arm.syso
- name: Generate windows rsrc 386
run: rsrc -arch 386 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_386.syso
-
name: Run GoReleaser
uses: goreleaser/goreleaser-action@v4 uses: goreleaser/goreleaser-action@v4
with: with:
version: ${{ env.GORELEASER_VER }} version: ${{ env.GORELEASER_VER }}
args: release --clean ${{ env.flags }} args: release --rm-dist ${{ env.flags }}
env: env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }} HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }} UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }} UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
- name: upload non tags for debug purposes -
uses: actions/upload-artifact@v4 name: upload non tags for debug purposes
uses: actions/upload-artifact@v3
with: with:
name: release name: release
path: dist/ path: dist/
retention-days: 3 retention-days: 3
- name: upload linux packages -
uses: actions/upload-artifact@v4 name: upload linux packages
uses: actions/upload-artifact@v3
with: with:
name: linux-packages name: linux-packages
path: dist/netbird_linux** path: dist/netbird_linux**
retention-days: 3 retention-days: 3
- name: upload windows packages -
uses: actions/upload-artifact@v4 name: upload windows packages
uses: actions/upload-artifact@v3
with: with:
name: windows-packages name: windows-packages
path: dist/netbird_windows** path: dist/netbird_windows**
retention-days: 3 retention-days: 3
- name: upload macos packages -
uses: actions/upload-artifact@v4 name: upload macos packages
uses: actions/upload-artifact@v3
with: with:
name: macos-packages name: macos-packages
path: dist/netbird_darwin** path: dist/netbird_darwin**
@@ -110,27 +121,20 @@ jobs:
release_ui: release_ui:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Parse semver string
id: semver_parser
uses: booxmedialtd/ws-action-parse-semver@v1
with:
input_string: ${{ (startsWith(github.ref, 'refs/tags/v') && github.ref) || 'refs/tags/v0.0.0' }}
version_extractor_regex: '\/v(.*)$'
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }} - if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV run: echo "flags=--snapshot" >> $GITHUB_ENV
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v3
with: with:
fetch-depth: 0 # It is required for GoReleaser to work properly fetch-depth: 0 # It is required for GoReleaser to work properly
- name: Set up Go - name: Set up Go
uses: actions/setup-go@v5 uses: actions/setup-go@v4
with: with:
go-version: "1.23" go-version: "1.21"
cache: false cache: false
- name: Cache Go modules - name: Cache Go modules
uses: actions/cache@v4 uses: actions/cache@v3
with: with:
path: | path: |
~/go/pkg/mod ~/go/pkg/mod
@@ -147,23 +151,22 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: sudo apt update && sudo apt install -y -q libappindicator3-dev gir1.2-appindicator3-0.1 libxxf86vm-dev gcc-mingw-w64-x86-64 run: sudo apt update && sudo apt install -y -q libappindicator3-dev gir1.2-appindicator3-0.1 libxxf86vm-dev gcc-mingw-w64-x86-64
- name: Install goversioninfo - name: Install rsrc
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e run: go install github.com/akavel/rsrc@v0.10.2
- name: Generate windows syso amd64 - name: Generate windows rsrc
run: goversioninfo -64 -icon client/ui/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_amd64.syso run: rsrc -arch amd64 -ico client/ui/netbird.ico -manifest client/ui/manifest.xml -o client/ui/resources_windows_amd64.syso
- name: Run GoReleaser - name: Run GoReleaser
uses: goreleaser/goreleaser-action@v4 uses: goreleaser/goreleaser-action@v4
with: with:
version: ${{ env.GORELEASER_VER }} version: ${{ env.GORELEASER_VER }}
args: release --config .goreleaser_ui.yaml --clean ${{ env.flags }} args: release --config .goreleaser_ui.yaml --rm-dist ${{ env.flags }}
env: env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }} HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }} UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }} UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
- name: upload non tags for debug purposes - name: upload non tags for debug purposes
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v3
with: with:
name: release-ui name: release-ui
path: dist/ path: dist/
@@ -174,17 +177,20 @@ jobs:
steps: steps:
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }} - if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV run: echo "flags=--snapshot" >> $GITHUB_ENV
- name: Checkout -
uses: actions/checkout@v4 name: Checkout
uses: actions/checkout@v3
with: with:
fetch-depth: 0 # It is required for GoReleaser to work properly fetch-depth: 0 # It is required for GoReleaser to work properly
- name: Set up Go -
uses: actions/setup-go@v5 name: Set up Go
uses: actions/setup-go@v4
with: with:
go-version: "1.23" go-version: "1.21"
cache: false cache: false
- name: Cache Go modules -
uses: actions/cache@v4 name: Cache Go modules
uses: actions/cache@v3
with: with:
path: | path: |
~/go/pkg/mod ~/go/pkg/mod
@@ -192,35 +198,53 @@ jobs:
key: ${{ runner.os }}-ui-go-releaser-darwin-${{ hashFiles('**/go.sum') }} key: ${{ runner.os }}-ui-go-releaser-darwin-${{ hashFiles('**/go.sum') }}
restore-keys: | restore-keys: |
${{ runner.os }}-ui-go-releaser-darwin- ${{ runner.os }}-ui-go-releaser-darwin-
- name: Install modules -
name: Install modules
run: go mod tidy run: go mod tidy
- name: check git status -
name: check git status
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
- name: Run GoReleaser -
name: Run GoReleaser
id: goreleaser id: goreleaser
uses: goreleaser/goreleaser-action@v4 uses: goreleaser/goreleaser-action@v4
with: with:
version: ${{ env.GORELEASER_VER }} version: ${{ env.GORELEASER_VER }}
args: release --config .goreleaser_ui_darwin.yaml --clean ${{ env.flags }} args: release --config .goreleaser_ui_darwin.yaml --rm-dist ${{ env.flags }}
env: env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: upload non tags for debug purposes -
uses: actions/upload-artifact@v4 name: upload non tags for debug purposes
uses: actions/upload-artifact@v3
with: with:
name: release-ui-darwin name: release-ui-darwin
path: dist/ path: dist/
retention-days: 3 retention-days: 3
trigger_signer: trigger_windows_signer:
runs-on: ubuntu-latest runs-on: ubuntu-latest
needs: [release, release_ui, release_ui_darwin] needs: [release,release_ui]
if: startsWith(github.ref, 'refs/tags/') if: startsWith(github.ref, 'refs/tags/')
steps: steps:
- name: Trigger binaries sign pipelines - name: Trigger Windows binaries sign pipeline
uses: benc-uk/workflow-dispatch@v1 uses: benc-uk/workflow-dispatch@v1
with: with:
workflow: Sign bin and installer workflow: Sign windows bin and installer
repo: netbirdio/sign-pipelines repo: netbirdio/sign-pipelines
ref: ${{ env.SIGN_PIPE_VER }} ref: ${{ env.SIGN_PIPE_VER }}
token: ${{ secrets.SIGN_GITHUB_TOKEN }} token: ${{ secrets.SIGN_GITHUB_TOKEN }}
inputs: '{ "tag": "${{ github.ref }}", "skipRelease": false }' inputs: '{ "tag": "${{ github.ref }}" }'
trigger_darwin_signer:
runs-on: ubuntu-latest
needs: [release,release_ui_darwin]
if: startsWith(github.ref, 'refs/tags/')
steps:
- name: Trigger Darwin App binaries sign pipeline
uses: benc-uk/workflow-dispatch@v1
with:
workflow: Sign darwin ui app with dispatch
repo: netbirdio/sign-pipelines
ref: ${{ env.SIGN_PIPE_VER }}
token: ${{ secrets.SIGN_GITHUB_TOKEN }}
inputs: '{ "tag": "${{ github.ref }}" }'

View File

@@ -18,31 +18,7 @@ concurrency:
jobs: jobs:
test-docker-compose: test-docker-compose:
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy:
matrix:
store: [ 'sqlite', 'postgres' ]
services:
postgres:
image: ${{ (matrix.store == 'postgres') && 'postgres' || '' }}
env:
POSTGRES_USER: netbird
POSTGRES_PASSWORD: postgres
POSTGRES_DB: netbird
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
ports:
- 5432:5432
steps: steps:
- name: Set Database Connection String
run: |
if [ "${{ matrix.store }}" == "postgres" ]; then
echo "NETBIRD_STORE_ENGINE_POSTGRES_DSN=host=$(hostname -I | awk '{print $1}') user=netbird password=postgres dbname=netbird port=5432" >> $GITHUB_ENV
else
echo "NETBIRD_STORE_ENGINE_POSTGRES_DSN==" >> $GITHUB_ENV
fi
- name: Install jq - name: Install jq
run: sudo apt-get install -y jq run: sudo apt-get install -y jq
@@ -50,12 +26,12 @@ jobs:
run: sudo apt-get install -y curl run: sudo apt-get install -y curl
- name: Install Go - name: Install Go
uses: actions/setup-go@v5 uses: actions/setup-go@v4
with: with:
go-version: "1.23.x" go-version: "1.21.x"
- name: Cache Go modules - name: Cache Go modules
uses: actions/cache@v4 uses: actions/cache@v3
with: with:
path: ~/go/pkg/mod path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
@@ -63,7 +39,7 @@ jobs:
${{ runner.os }}-go- ${{ runner.os }}-go-
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v3
- name: cp setup.env - name: cp setup.env
run: cp infrastructure_files/tests/setup.env infrastructure_files/ run: cp infrastructure_files/tests/setup.env infrastructure_files/
@@ -82,8 +58,7 @@ jobs:
CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id
CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret
CI_NETBIRD_AUTH_SUPPORTED_SCOPES: "openid profile email offline_access api email_verified" CI_NETBIRD_AUTH_SUPPORTED_SCOPES: "openid profile email offline_access api email_verified"
CI_NETBIRD_STORE_CONFIG_ENGINE: ${{ matrix.store }} CI_NETBIRD_STORE_CONFIG_ENGINE: "sqlite"
NETBIRD_STORE_ENGINE_POSTGRES_DSN: ${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
- name: check values - name: check values
@@ -110,8 +85,7 @@ jobs:
CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id
CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret
CI_NETBIRD_SIGNAL_PORT: 12345 CI_NETBIRD_SIGNAL_PORT: 12345
CI_NETBIRD_STORE_CONFIG_ENGINE: ${{ matrix.store }} CI_NETBIRD_STORE_CONFIG_ENGINE: "sqlite"
NETBIRD_STORE_ENGINE_POSTGRES_DSN: '${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}$'
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
CI_NETBIRD_TURN_EXTERNAL_IP: "1.2.3.4" CI_NETBIRD_TURN_EXTERNAL_IP: "1.2.3.4"
@@ -149,14 +123,6 @@ jobs:
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep Scope | grep "$CI_NETBIRD_AUTH_SUPPORTED_SCOPES" grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep Scope | grep "$CI_NETBIRD_AUTH_SUPPORTED_SCOPES"
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep -A 3 RedirectURLs | grep "http://localhost:53000" grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep -A 3 RedirectURLs | grep "http://localhost:53000"
grep "external-ip" turnserver.conf | grep $CI_NETBIRD_TURN_EXTERNAL_IP grep "external-ip" turnserver.conf | grep $CI_NETBIRD_TURN_EXTERNAL_IP
grep NETBIRD_STORE_ENGINE_POSTGRES_DSN docker-compose.yml | egrep "$NETBIRD_STORE_ENGINE_POSTGRES_DSN"
# check relay values
grep "NB_EXPOSED_ADDRESS=$CI_NETBIRD_DOMAIN:33445" docker-compose.yml
grep "NB_LISTEN_ADDRESS=:33445" docker-compose.yml
grep '33445:33445' docker-compose.yml
grep -A 10 'relay:' docker-compose.yml | egrep 'NB_AUTH_SECRET=.+$'
grep -A 7 Relay management.json | grep "rel://$CI_NETBIRD_DOMAIN:33445"
grep -A 7 Relay management.json | egrep '"Secret": ".+"'
- name: Install modules - name: Install modules
run: go mod tidy run: go mod tidy
@@ -182,35 +148,26 @@ jobs:
run: | run: |
docker build -t netbirdio/signal:latest . docker build -t netbirdio/signal:latest .
- name: Build relay binary
working-directory: relay
run: CGO_ENABLED=0 go build -o netbird-relay main.go
- name: Build relay docker image
working-directory: relay
run: |
docker build -t netbirdio/relay:latest .
- name: run docker compose up - name: run docker compose up
working-directory: infrastructure_files/artifacts working-directory: infrastructure_files/artifacts
run: | run: |
docker compose up -d docker-compose up -d
sleep 5 sleep 5
docker compose ps docker-compose ps
docker compose logs --tail=20 docker-compose logs --tail=20
- name: test running containers - name: test running containers
run: | run: |
count=$(docker compose ps --format json | jq '. | select(.Name | contains("artifacts")) | .State' | grep -c running) count=$(docker compose ps --format json | jq '. | select(.Name | contains("artifacts")) | .State' | grep -c running)
test $count -eq 5 || docker compose logs test $count -eq 4
working-directory: infrastructure_files/artifacts working-directory: infrastructure_files/artifacts
- name: test geolocation databases - name: test geolocation databases
working-directory: infrastructure_files/artifacts working-directory: infrastructure_files/artifacts
run: | run: |
sleep 30 sleep 30
docker compose exec management ls -l /var/lib/netbird/ | grep -i GeoLite2-City_[0-9]*.mmdb docker compose exec management ls -l /var/lib/netbird/ | grep -i GeoLite2-City.mmdb
docker compose exec management ls -l /var/lib/netbird/ | grep -i geonames_[0-9]*.db docker compose exec management ls -l /var/lib/netbird/ | grep -i geonames.db
test-getting-started-script: test-getting-started-script:
runs-on: ubuntu-latest runs-on: ubuntu-latest
@@ -219,7 +176,7 @@ jobs:
run: sudo apt-get install -y jq run: sudo apt-get install -y jq
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v3
- name: run script with Zitadel PostgreSQL - name: run script with Zitadel PostgreSQL
run: NETBIRD_DOMAIN=use-ip bash -x infrastructure_files/getting-started-with-zitadel.sh run: NETBIRD_DOMAIN=use-ip bash -x infrastructure_files/getting-started-with-zitadel.sh
@@ -245,15 +202,12 @@ jobs:
- name: test dashboard.env file gen postgres - name: test dashboard.env file gen postgres
run: test -f dashboard.env run: test -f dashboard.env
- name: test relay.env file gen postgres
run: test -f relay.env
- name: test zdb.env file gen postgres - name: test zdb.env file gen postgres
run: test -f zdb.env run: test -f zdb.env
- name: Postgres run cleanup - name: Postgres run cleanup
run: | run: |
docker compose down --volumes --rmi all docker-compose down --volumes --rmi all
rm -rf docker-compose.yml Caddyfile zitadel.env dashboard.env machinekey/zitadel-admin-sa.token turnserver.conf management.json zdb.env rm -rf docker-compose.yml Caddyfile zitadel.env dashboard.env machinekey/zitadel-admin-sa.token turnserver.conf management.json zdb.env
- name: run script with Zitadel CockroachDB - name: run script with Zitadel CockroachDB
@@ -283,5 +237,20 @@ jobs:
- name: test dashboard.env file gen CockroachDB - name: test dashboard.env file gen CockroachDB
run: test -f dashboard.env run: test -f dashboard.env
- name: test relay.env file gen CockroachDB test-download-geolite2-script:
run: test -f relay.env runs-on: ubuntu-latest
steps:
- name: Install jq
run: sudo apt-get update && sudo apt-get install -y unzip sqlite3
- name: Checkout code
uses: actions/checkout@v3
- name: test script
run: bash -x infrastructure_files/download-geolite2.sh
- name: test mmdb file exists
run: test -f GeoLite2-City.mmdb
- name: test geonames file exists
run: test -f geonames.db

1
.gitignore vendored
View File

@@ -29,3 +29,4 @@ infrastructure_files/setup.env
infrastructure_files/setup-*.env infrastructure_files/setup-*.env
.vscode .vscode
.DS_Store .DS_Store
GeoLite2-City*

View File

@@ -1,5 +1,3 @@
version: 2
project_name: netbird project_name: netbird
builds: builds:
- id: netbird - id: netbird
@@ -24,7 +22,7 @@ builds:
goarch: 386 goarch: 386
ldflags: ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}" mod_timestamp: '{{ .CommitTimestamp }}'
tags: tags:
- load_wgnt_from_rsrc - load_wgnt_from_rsrc
@@ -44,7 +42,7 @@ builds:
- softfloat - softfloat
ldflags: ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}" mod_timestamp: '{{ .CommitTimestamp }}'
tags: tags:
- load_wgnt_from_rsrc - load_wgnt_from_rsrc
@@ -66,7 +64,7 @@ builds:
- arm - arm
ldflags: ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}" mod_timestamp: '{{ .CommitTimestamp }}'
- id: netbird-signal - id: netbird-signal
dir: signal dir: signal
@@ -80,24 +78,7 @@ builds:
- arm - arm
ldflags: ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}" mod_timestamp: '{{ .CommitTimestamp }}'
- id: netbird-relay
dir: relay
env: [CGO_ENABLED=0]
binary: netbird-relay
goos:
- linux
goarch:
- amd64
- arm64
- arm
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
universal_binaries:
- id: netbird
archives: archives:
- builds: - builds:
@@ -105,6 +86,7 @@ archives:
- netbird-static - netbird-static
nfpms: nfpms:
- maintainer: Netbird <dev@netbird.io> - maintainer: Netbird <dev@netbird.io>
description: Netbird client. description: Netbird client.
homepage: https://netbird.io/ homepage: https://netbird.io/
@@ -179,52 +161,6 @@ dockers:
- "--label=org.opencontainers.image.revision={{.FullCommit}}" - "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.version={{.Version}}" - "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io" - "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/relay:{{ .Version }}-amd64
ids:
- netbird-relay
goarch: amd64
use: buildx
dockerfile: relay/Dockerfile
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/relay:{{ .Version }}-arm64v8
ids:
- netbird-relay
goarch: arm64
use: buildx
dockerfile: relay/Dockerfile
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/relay:{{ .Version }}-arm
ids:
- netbird-relay
goarch: arm
goarm: 6
use: buildx
dockerfile: relay/Dockerfile
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
- image_templates: - image_templates:
- netbirdio/signal:{{ .Version }}-amd64 - netbirdio/signal:{{ .Version }}-amd64
ids: ids:
@@ -377,18 +313,6 @@ docker_manifests:
- netbirdio/netbird:{{ .Version }}-arm - netbirdio/netbird:{{ .Version }}-arm
- netbirdio/netbird:{{ .Version }}-amd64 - netbirdio/netbird:{{ .Version }}-amd64
- name_template: netbirdio/relay:{{ .Version }}
image_templates:
- netbirdio/relay:{{ .Version }}-arm64v8
- netbirdio/relay:{{ .Version }}-arm
- netbirdio/relay:{{ .Version }}-amd64
- name_template: netbirdio/relay:latest
image_templates:
- netbirdio/relay:{{ .Version }}-arm64v8
- netbirdio/relay:{{ .Version }}-arm
- netbirdio/relay:{{ .Version }}-amd64
- name_template: netbirdio/signal:{{ .Version }} - name_template: netbirdio/signal:{{ .Version }}
image_templates: image_templates:
- netbirdio/signal:{{ .Version }}-arm64v8 - netbirdio/signal:{{ .Version }}-arm64v8
@@ -420,9 +344,10 @@ docker_manifests:
- netbirdio/management:{{ .Version }}-debug-amd64 - netbirdio/management:{{ .Version }}-debug-amd64
brews: brews:
- ids: -
ids:
- default - default
repository: tap:
owner: netbirdio owner: netbirdio
name: homebrew-tap name: homebrew-tap
token: "{{ .Env.HOMEBREW_TAP_GITHUB_TOKEN }}" token: "{{ .Env.HOMEBREW_TAP_GITHUB_TOKEN }}"

View File

@@ -1,5 +1,3 @@
version: 2
project_name: netbird-ui project_name: netbird-ui
builds: builds:
- id: netbird-ui - id: netbird-ui
@@ -13,7 +11,9 @@ builds:
- amd64 - amd64
ldflags: ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}" tags:
- legacy_appindicator
mod_timestamp: '{{ .CommitTimestamp }}'
- id: netbird-ui-windows - id: netbird-ui-windows
dir: client/ui dir: client/ui
@@ -28,7 +28,7 @@ builds:
ldflags: ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
- -H windowsgui - -H windowsgui
mod_timestamp: "{{ .CommitTimestamp }}" mod_timestamp: '{{ .CommitTimestamp }}'
archives: archives:
- id: linux-arch - id: linux-arch
@@ -41,6 +41,7 @@ archives:
- netbird-ui-windows - netbird-ui-windows
nfpms: nfpms:
- maintainer: Netbird <dev@netbird.io> - maintainer: Netbird <dev@netbird.io>
description: Netbird client UI. description: Netbird client UI.
homepage: https://netbird.io/ homepage: https://netbird.io/

View File

@@ -1,5 +1,3 @@
version: 2
project_name: netbird-ui project_name: netbird-ui
builds: builds:
- id: netbird-ui-darwin - id: netbird-ui-darwin
@@ -19,13 +17,10 @@ builds:
- softfloat - softfloat
ldflags: ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}" mod_timestamp: '{{ .CommitTimestamp }}'
tags: tags:
- load_wgnt_from_rsrc - load_wgnt_from_rsrc
universal_binaries:
- id: netbird-ui-darwin
archives: archives:
- builds: - builds:
- netbird-ui-darwin - netbird-ui-darwin
@@ -33,4 +28,4 @@ archives:
checksum: checksum:
name_template: "{{ .ProjectName }}_darwin_checksums.txt" name_template: "{{ .ProjectName }}_darwin_checksums.txt"
changelog: changelog:
disable: true skip: true

View File

@@ -96,7 +96,7 @@ They can be executed from the repository root before every push or PR:
**Goreleaser** **Goreleaser**
```shell ```shell
goreleaser build --snapshot --clean goreleaser --snapshot --rm-dist
``` ```
**golangci-lint** **golangci-lint**
```shell ```shell

View File

@@ -10,20 +10,14 @@
<img width="234" src="docs/media/logo-full.png"/> <img width="234" src="docs/media/logo-full.png"/>
</p> </p>
<p> <p>
<a href="https://img.shields.io/badge/license-BSD--3-blue)">
<img src="https://sonarcloud.io/api/project_badges/measure?project=netbirdio_netbird&metric=alert_status" />
</a>
<a href="https://github.com/netbirdio/netbird/blob/main/LICENSE"> <a href="https://github.com/netbirdio/netbird/blob/main/LICENSE">
<img src="https://img.shields.io/badge/license-BSD--3-blue" /> <img src="https://img.shields.io/badge/license-BSD--3-blue" />
</a> </a>
<a href="https://www.codacy.com/gh/netbirdio/netbird/dashboard?utm_source=github.com&amp;utm_medium=referral&amp;utm_content=netbirdio/netbird&amp;utm_campaign=Badge_Grade"><img src="https://app.codacy.com/project/badge/Grade/e3013d046aec44cdb7462c8673b00976"/></a>
<br> <br>
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2utg2ncdz-W7LEB6toRBLE1Jca37dYpg"> <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-vrahf41g-ik1v7fV8du6t0RwxSrJ96A">
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/> <img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
</a> </a>
<br>
<a href="https://gurubase.io/g/netbird">
<img src="https://img.shields.io/badge/Gurubase-Ask%20NetBird%20Guru-006BFF"/>
</a>
</p> </p>
</div> </div>
@@ -34,7 +28,7 @@
<br/> <br/>
See <a href="https://netbird.io/docs/">Documentation</a> See <a href="https://netbird.io/docs/">Documentation</a>
<br/> <br/>
Join our <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2utg2ncdz-W7LEB6toRBLE1Jca37dYpg">Slack channel</a> Join our <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-vrahf41g-ik1v7fV8du6t0RwxSrJ96A">Slack channel</a>
<br/> <br/>
</strong> </strong>
@@ -53,8 +47,6 @@
![netbird_2](https://github.com/netbirdio/netbird/assets/700848/46bc3b73-508d-4a0e-bb9a-f465d68646ab) ![netbird_2](https://github.com/netbirdio/netbird/assets/700848/46bc3b73-508d-4a0e-bb9a-f465d68646ab)
### NetBird on Lawrence Systems (Video)
[![Watch the video](https://img.youtube.com/vi/Kwrff6h0rEw/0.jpg)](https://www.youtube.com/watch?v=Kwrff6h0rEw)
### Key features ### Key features
@@ -68,7 +60,6 @@
| | | <ul><li> - \[x] [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn) </ul></li> | | <ul><li> - \[x] OpenWRT </ul></li> | | | | <ul><li> - \[x] [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn) </ul></li> | | <ul><li> - \[x] OpenWRT </ul></li> |
| | | <ui><li> - \[x] [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication)</ul></li> | | <ul><li> - \[x] [Serverless](https://docs.netbird.io/how-to/netbird-on-faas) </ul></li> | | | | <ui><li> - \[x] [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication)</ul></li> | | <ul><li> - \[x] [Serverless](https://docs.netbird.io/how-to/netbird-on-faas) </ul></li> |
| | | | | <ul><li> - \[x] Docker </ul></li> | | | | | | <ul><li> - \[x] Docker </ul></li> |
### Quickstart with NetBird Cloud ### Quickstart with NetBird Cloud
- Download and install NetBird at [https://app.netbird.io/install](https://app.netbird.io/install) - Download and install NetBird at [https://app.netbird.io/install](https://app.netbird.io/install)

View File

@@ -1,4 +1,4 @@
FROM alpine:3.21.0 FROM alpine:3.19
RUN apk add --no-cache ca-certificates iptables ip6tables RUN apk add --no-cache ca-certificates iptables ip6tables
ENV NB_FOREGROUND_MODE=true ENV NB_FOREGROUND_MODE=true
ENTRYPOINT [ "/usr/local/bin/netbird","up"] ENTRYPOINT [ "/usr/local/bin/netbird","up"]

View File

@@ -8,7 +8,6 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/listener"
@@ -16,6 +15,7 @@ import (
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/formatter" "github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/util/net" "github.com/netbirdio/netbird/util/net"
) )
@@ -26,7 +26,7 @@ type ConnectionListener interface {
// TunAdapter export internal TunAdapter for mobile // TunAdapter export internal TunAdapter for mobile
type TunAdapter interface { type TunAdapter interface {
device.TunAdapter iface.TunAdapter
} }
// IFaceDiscover export internal IFaceDiscover for mobile // IFaceDiscover export internal IFaceDiscover for mobile
@@ -51,7 +51,7 @@ func init() {
// Client struct manage the life circle of background service // Client struct manage the life circle of background service
type Client struct { type Client struct {
cfgFile string cfgFile string
tunAdapter device.TunAdapter tunAdapter iface.TunAdapter
iFaceDiscover IFaceDiscover iFaceDiscover IFaceDiscover
recorder *peer.Status recorder *peer.Status
ctxCancel context.CancelFunc ctxCancel context.CancelFunc

View File

@@ -84,7 +84,7 @@ func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) {
func (a *Auth) saveConfigIfSSOSupported() (bool, error) { func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
supportsSSO := true supportsSSO := true
err := a.withBackOff(a.ctx, func() (err error) { err := a.withBackOff(a.ctx, func() (err error) {
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil) _, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) { if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL) _, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
s, ok := gstatus.FromError(err) s, ok := gstatus.FromError(err)

View File

@@ -12,8 +12,6 @@ import (
"strings" "strings"
) )
const anonTLD = ".domain"
type Anonymizer struct { type Anonymizer struct {
ipAnonymizer map[netip.Addr]netip.Addr ipAnonymizer map[netip.Addr]netip.Addr
domainAnonymizer map[string]string domainAnonymizer map[string]string
@@ -21,8 +19,6 @@ type Anonymizer struct {
currentAnonIPv6 netip.Addr currentAnonIPv6 netip.Addr
startAnonIPv4 netip.Addr startAnonIPv4 netip.Addr
startAnonIPv6 netip.Addr startAnonIPv6 netip.Addr
domainKeyRegex *regexp.Regexp
} }
func DefaultAddresses() (netip.Addr, netip.Addr) { func DefaultAddresses() (netip.Addr, netip.Addr) {
@@ -38,8 +34,6 @@ func NewAnonymizer(startIPv4, startIPv6 netip.Addr) *Anonymizer {
currentAnonIPv6: startIPv6, currentAnonIPv6: startIPv6,
startAnonIPv4: startIPv4, startAnonIPv4: startIPv4,
startAnonIPv6: startIPv6, startAnonIPv6: startIPv6,
domainKeyRegex: regexp.MustCompile(`\bdomain=([^\s,:"]+)`),
} }
} }
@@ -89,39 +83,29 @@ func (a *Anonymizer) AnonymizeIPString(ip string) string {
} }
func (a *Anonymizer) AnonymizeDomain(domain string) string { func (a *Anonymizer) AnonymizeDomain(domain string) string {
baseDomain := domain if strings.HasSuffix(domain, "netbird.io") ||
hasDot := strings.HasSuffix(domain, ".") strings.HasSuffix(domain, "netbird.selfhosted") ||
if hasDot { strings.HasSuffix(domain, "netbird.cloud") ||
baseDomain = domain[:len(domain)-1] strings.HasSuffix(domain, "netbird.stage") ||
} strings.HasSuffix(domain, ".domain") {
if strings.HasSuffix(baseDomain, "netbird.io") ||
strings.HasSuffix(baseDomain, "netbird.selfhosted") ||
strings.HasSuffix(baseDomain, "netbird.cloud") ||
strings.HasSuffix(baseDomain, "netbird.stage") ||
strings.HasSuffix(baseDomain, anonTLD) {
return domain return domain
} }
parts := strings.Split(baseDomain, ".") parts := strings.Split(domain, ".")
if len(parts) < 2 { if len(parts) < 2 {
return domain return domain
} }
baseForLookup := parts[len(parts)-2] + "." + parts[len(parts)-1] baseDomain := parts[len(parts)-2] + "." + parts[len(parts)-1]
anonymized, ok := a.domainAnonymizer[baseForLookup] anonymized, ok := a.domainAnonymizer[baseDomain]
if !ok { if !ok {
anonymizedBase := "anon-" + generateRandomString(5) + anonTLD anonymizedBase := "anon-" + generateRandomString(5) + ".domain"
a.domainAnonymizer[baseForLookup] = anonymizedBase a.domainAnonymizer[baseDomain] = anonymizedBase
anonymized = anonymizedBase anonymized = anonymizedBase
} }
result := strings.Replace(baseDomain, baseForLookup, anonymized, 1) return strings.Replace(domain, baseDomain, anonymized, 1)
if hasDot {
result += "."
}
return result
} }
func (a *Anonymizer) AnonymizeURI(uri string) string { func (a *Anonymizer) AnonymizeURI(uri string) string {
@@ -168,42 +152,32 @@ func (a *Anonymizer) AnonymizeString(str string) string {
return str return str
} }
// AnonymizeSchemeURI finds and anonymizes URIs with ws, wss, rel, rels, stun, stuns, turn, and turns schemes. // AnonymizeSchemeURI finds and anonymizes URIs with stun, stuns, turn, and turns schemes.
func (a *Anonymizer) AnonymizeSchemeURI(text string) string { func (a *Anonymizer) AnonymizeSchemeURI(text string) string {
re := regexp.MustCompile(`(?i)\b(wss?://|rels?://|stuns?:|turns?:|https?://)\S+\b`) re := regexp.MustCompile(`(?i)\b(stuns?:|turns?:|https?://)\S+\b`)
return re.ReplaceAllStringFunc(text, a.AnonymizeURI) return re.ReplaceAllStringFunc(text, a.AnonymizeURI)
} }
// AnonymizeDNSLogLine anonymizes domain names in DNS log entries by replacing them with a random string.
func (a *Anonymizer) AnonymizeDNSLogLine(logEntry string) string { func (a *Anonymizer) AnonymizeDNSLogLine(logEntry string) string {
return a.domainKeyRegex.ReplaceAllStringFunc(logEntry, func(match string) string { domainPattern := `dns\.Question{Name:"([^"]+)",`
parts := strings.SplitN(match, "=", 2) domainRegex := regexp.MustCompile(domainPattern)
return domainRegex.ReplaceAllStringFunc(logEntry, func(match string) string {
parts := strings.Split(match, `"`)
if len(parts) >= 2 { if len(parts) >= 2 {
domain := parts[1] domain := parts[1]
if strings.HasSuffix(domain, anonTLD) { if strings.HasSuffix(domain, ".domain") {
return match return match
} }
return "domain=" + a.AnonymizeDomain(domain) randomDomain := generateRandomString(10) + ".domain"
return strings.Replace(match, domain, randomDomain, 1)
} }
return match return match
}) })
} }
// AnonymizeRoute anonymizes a route string by replacing IP addresses with anonymized versions and
// domain names with random strings.
func (a *Anonymizer) AnonymizeRoute(route string) string {
prefix, err := netip.ParsePrefix(route)
if err == nil {
ip := a.AnonymizeIPString(prefix.Addr().String())
return fmt.Sprintf("%s/%d", ip, prefix.Bits())
}
domains := strings.Split(route, ", ")
for i, domain := range domains {
domains[i] = a.AnonymizeDomain(domain)
}
return strings.Join(domains, ", ")
}
func isWellKnown(addr netip.Addr) bool { func isWellKnown(addr netip.Addr) bool {
wellKnown := []string{ wellKnown := []string{
"8.8.8.8", "8.8.4.4", // Google DNS IPv4 "8.8.8.8", "8.8.4.4", // Google DNS IPv4
@@ -212,8 +186,6 @@ func isWellKnown(addr netip.Addr) bool {
"2606:4700:4700::1111", "2606:4700:4700::1001", // Cloudflare DNS IPv6 "2606:4700:4700::1111", "2606:4700:4700::1001", // Cloudflare DNS IPv6
"9.9.9.9", "149.112.112.112", // Quad9 DNS IPv4 "9.9.9.9", "149.112.112.112", // Quad9 DNS IPv4
"2620:fe::fe", "2620:fe::9", // Quad9 DNS IPv6 "2620:fe::fe", "2620:fe::9", // Quad9 DNS IPv6
"128.0.0.0", "8000::", // 2nd split subnet for default routes
} }
if slices.Contains(wellKnown, addr.String()) { if slices.Contains(wellKnown, addr.String()) {

View File

@@ -46,59 +46,11 @@ func TestAnonymizeIP(t *testing.T) {
func TestAnonymizeDNSLogLine(t *testing.T) { func TestAnonymizeDNSLogLine(t *testing.T) {
anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{}) anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{})
tests := []struct { testLog := `2024-04-23T20:01:11+02:00 TRAC client/internal/dns/local.go:25: received question: dns.Question{Name:"example.com", Qtype:0x1c, Qclass:0x1}`
name string
input string
original string
expect string
}{
{
name: "Basic domain with trailing content",
input: "received DNS request for DNS forwarder: domain=example.com: something happened with code=123",
original: "example.com",
expect: `received DNS request for DNS forwarder: domain=anon-[a-zA-Z0-9]+\.domain: something happened with code=123`,
},
{
name: "Domain with trailing dot",
input: "domain=example.com. processing request with status=pending",
original: "example.com",
expect: `domain=anon-[a-zA-Z0-9]+\.domain\. processing request with status=pending`,
},
{
name: "Multiple domains in log",
input: "forward domain=first.com status=ok, redirect to domain=second.com port=443",
original: "first.com", // testing just one is sufficient as AnonymizeDomain is tested separately
expect: `forward domain=anon-[a-zA-Z0-9]+\.domain status=ok, redirect to domain=anon-[a-zA-Z0-9]+\.domain port=443`,
},
{
name: "Already anonymized domain",
input: "got request domain=anon-xyz123.domain from=client1 to=server2",
original: "", // nothing should be anonymized
expect: `got request domain=anon-xyz123\.domain from=client1 to=server2`,
},
{
name: "Subdomain with trailing dot",
input: "domain=sub.example.com. next_hop=10.0.0.1 proto=udp",
original: "example.com",
expect: `domain=sub\.anon-[a-zA-Z0-9]+\.domain\. next_hop=10\.0\.0\.1 proto=udp`,
},
{
name: "Handler chain pattern log",
input: "pattern: domain=example.com. original: domain=*.example.com. wildcard=true priority=100",
original: "example.com",
expect: `pattern: domain=anon-[a-zA-Z0-9]+\.domain\. original: domain=\*\.anon-[a-zA-Z0-9]+\.domain\. wildcard=true priority=100`,
},
}
for _, tc := range tests { result := anonymizer.AnonymizeDNSLogLine(testLog)
t.Run(tc.name, func(t *testing.T) { require.NotEqual(t, testLog, result)
result := anonymizer.AnonymizeDNSLogLine(tc.input) assert.NotContains(t, result, "example.com")
if tc.original != "" {
assert.NotContains(t, result, tc.original)
}
assert.Regexp(t, tc.expect, result)
})
}
} }
func TestAnonymizeDomain(t *testing.T) { func TestAnonymizeDomain(t *testing.T) {
@@ -115,36 +67,18 @@ func TestAnonymizeDomain(t *testing.T) {
`^anon-[a-zA-Z0-9]+\.domain$`, `^anon-[a-zA-Z0-9]+\.domain$`,
true, true,
}, },
{
"Domain with Trailing Dot",
"example.com.",
`^anon-[a-zA-Z0-9]+\.domain.$`,
true,
},
{ {
"Subdomain", "Subdomain",
"sub.example.com", "sub.example.com",
`^sub\.anon-[a-zA-Z0-9]+\.domain$`, `^sub\.anon-[a-zA-Z0-9]+\.domain$`,
true, true,
}, },
{
"Subdomain with Trailing Dot",
"sub.example.com.",
`^sub\.anon-[a-zA-Z0-9]+\.domain.$`,
true,
},
{ {
"Protected Domain", "Protected Domain",
"netbird.io", "netbird.io",
`^netbird\.io$`, `^netbird\.io$`,
false, false,
}, },
{
"Protected Domain with Trailing Dot",
"netbird.io.",
`^netbird\.io.$`,
false,
},
} }
for _, tc := range tests { for _, tc := range tests {
@@ -206,16 +140,8 @@ func TestAnonymizeSchemeURI(t *testing.T) {
expect string expect string
}{ }{
{"STUN URI in text", "Connection made via stun:example.com", `Connection made via stun:anon-[a-zA-Z0-9]+\.domain`}, {"STUN URI in text", "Connection made via stun:example.com", `Connection made via stun:anon-[a-zA-Z0-9]+\.domain`},
{"STUNS URI in message", "Secure connection to stuns:example.com:443", `Secure connection to stuns:anon-[a-zA-Z0-9]+\.domain:443`},
{"TURN URI in log", "Failed attempt turn:some.example.com:3478?transport=tcp: retrying", `Failed attempt turn:some.anon-[a-zA-Z0-9]+\.domain:3478\?transport=tcp: retrying`}, {"TURN URI in log", "Failed attempt turn:some.example.com:3478?transport=tcp: retrying", `Failed attempt turn:some.anon-[a-zA-Z0-9]+\.domain:3478\?transport=tcp: retrying`},
{"TURNS URI in message", "Secure connection to turns:example.com:5349", `Secure connection to turns:anon-[a-zA-Z0-9]+\.domain:5349`},
{"HTTP URI in text", "Visit http://example.com for more", `Visit http://anon-[a-zA-Z0-9]+\.domain for more`},
{"HTTPS URI in CAPS", "Visit HTTPS://example.com for more", `Visit https://anon-[a-zA-Z0-9]+\.domain for more`},
{"HTTPS URI in message", "Visit https://example.com for more", `Visit https://anon-[a-zA-Z0-9]+\.domain for more`}, {"HTTPS URI in message", "Visit https://example.com for more", `Visit https://anon-[a-zA-Z0-9]+\.domain for more`},
{"WS URI in log", "Connection established to ws://example.com:8080", `Connection established to ws://anon-[a-zA-Z0-9]+\.domain:8080`},
{"WSS URI in message", "Secure connection to wss://example.com", `Secure connection to wss://anon-[a-zA-Z0-9]+\.domain`},
{"Rel URI in text", "Relaying to rel://example.com", `Relaying to rel://anon-[a-zA-Z0-9]+\.domain`},
{"Rels URI in message", "Relaying to rels://example.com", `Relaying to rels://anon-[a-zA-Z0-9]+\.domain`},
} }
for _, tc := range tests { for _, tc := range tests {

View File

@@ -3,10 +3,8 @@ package cmd
import ( import (
"context" "context"
"fmt" "fmt"
"strings"
"time" "time"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
@@ -15,8 +13,6 @@ import (
"github.com/netbirdio/netbird/client/server" "github.com/netbirdio/netbird/client/server"
) )
const errCloseConnection = "Failed to close connection: %v"
var debugCmd = &cobra.Command{ var debugCmd = &cobra.Command{
Use: "debug", Use: "debug",
Short: "Debugging commands", Short: "Debugging commands",
@@ -62,31 +58,17 @@ var forCmd = &cobra.Command{
RunE: runForDuration, RunE: runForDuration,
} }
var persistenceCmd = &cobra.Command{
Use: "persistence [on|off]",
Short: "Set network map memory persistence",
Long: `Configure whether the latest network map should persist in memory. When enabled, the last known network map will be kept in memory.`,
Example: " netbird debug persistence on",
Args: cobra.ExactArgs(1),
RunE: setNetworkMapPersistence,
}
func debugBundle(cmd *cobra.Command, _ []string) error { func debugBundle(cmd *cobra.Command, _ []string) error {
conn, err := getClient(cmd) conn, err := getClient(cmd)
if err != nil { if err != nil {
return err return err
} }
defer func() { defer conn.Close()
if err := conn.Close(); err != nil {
log.Errorf(errCloseConnection, err)
}
}()
client := proto.NewDaemonServiceClient(conn) client := proto.NewDaemonServiceClient(conn)
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{ resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
Anonymize: anonymizeFlag, Anonymize: anonymizeFlag,
Status: getStatusOutput(cmd), Status: getStatusOutput(cmd),
SystemInfo: debugSystemInfoFlag,
}) })
if err != nil { if err != nil {
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message()) return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
@@ -102,11 +84,7 @@ func setLogLevel(cmd *cobra.Command, args []string) error {
if err != nil { if err != nil {
return err return err
} }
defer func() { defer conn.Close()
if err := conn.Close(); err != nil {
log.Errorf(errCloseConnection, err)
}
}()
client := proto.NewDaemonServiceClient(conn) client := proto.NewDaemonServiceClient(conn)
level := server.ParseLogLevel(args[0]) level := server.ParseLogLevel(args[0])
@@ -135,11 +113,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
if err != nil { if err != nil {
return err return err
} }
defer func() { defer conn.Close()
if err := conn.Close(); err != nil {
log.Errorf(errCloseConnection, err)
}
}()
client := proto.NewDaemonServiceClient(conn) client := proto.NewDaemonServiceClient(conn)
@@ -148,20 +122,17 @@ func runForDuration(cmd *cobra.Command, args []string) error {
return fmt.Errorf("failed to get status: %v", status.Convert(err).Message()) return fmt.Errorf("failed to get status: %v", status.Convert(err).Message())
} }
stateWasDown := stat.Status != string(internal.StatusConnected) && stat.Status != string(internal.StatusConnecting) restoreUp := stat.Status == string(internal.StatusConnected) || stat.Status == string(internal.StatusConnecting)
initialLogLevel, err := client.GetLogLevel(cmd.Context(), &proto.GetLogLevelRequest{}) initialLogLevel, err := client.GetLogLevel(cmd.Context(), &proto.GetLogLevelRequest{})
if err != nil { if err != nil {
return fmt.Errorf("failed to get log level: %v", status.Convert(err).Message()) return fmt.Errorf("failed to get log level: %v", status.Convert(err).Message())
} }
if stateWasDown { if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil { return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
}
cmd.Println("Netbird up")
time.Sleep(time.Second * 10)
} }
cmd.Println("Netbird down")
initialLevelTrace := initialLogLevel.GetLevel() >= proto.LogLevel_TRACE initialLevelTrace := initialLogLevel.GetLevel() >= proto.LogLevel_TRACE
if !initialLevelTrace { if !initialLevelTrace {
@@ -174,20 +145,8 @@ func runForDuration(cmd *cobra.Command, args []string) error {
cmd.Println("Log level set to trace.") cmd.Println("Log level set to trace.")
} }
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
}
cmd.Println("Netbird down")
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
// Enable network map persistence before bringing the service up
if _, err := client.SetNetworkMapPersistence(cmd.Context(), &proto.SetNetworkMapPersistenceRequest{
Enabled: true,
}); err != nil {
return fmt.Errorf("failed to enable network map persistence: %v", status.Convert(err).Message())
}
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil { if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
return fmt.Errorf("failed to up: %v", status.Convert(err).Message()) return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
} }
@@ -203,32 +162,21 @@ func runForDuration(cmd *cobra.Command, args []string) error {
} }
cmd.Println("\nDuration completed") cmd.Println("\nDuration completed")
cmd.Println("Creating debug bundle...")
headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration) headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd)) statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd))
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
Anonymize: anonymizeFlag,
Status: statusOutput,
SystemInfo: debugSystemInfoFlag,
})
if err != nil {
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
}
// Disable network map persistence after creating the debug bundle
if _, err := client.SetNetworkMapPersistence(cmd.Context(), &proto.SetNetworkMapPersistenceRequest{
Enabled: false,
}); err != nil {
return fmt.Errorf("failed to disable network map persistence: %v", status.Convert(err).Message())
}
if stateWasDown {
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil { if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
return fmt.Errorf("failed to down: %v", status.Convert(err).Message()) return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
} }
cmd.Println("Netbird down") cmd.Println("Netbird down")
time.Sleep(1 * time.Second)
if restoreUp {
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
}
cmd.Println("Netbird up")
} }
if !initialLevelTrace { if !initialLevelTrace {
@@ -238,36 +186,18 @@ func runForDuration(cmd *cobra.Command, args []string) error {
cmd.Println("Log level restored to", initialLogLevel.GetLevel()) cmd.Println("Log level restored to", initialLogLevel.GetLevel())
} }
cmd.Println(resp.GetPath()) cmd.Println("Creating debug bundle...")
return nil resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
} Anonymize: anonymizeFlag,
Status: statusOutput,
func setNetworkMapPersistence(cmd *cobra.Command, args []string) error {
conn, err := getClient(cmd)
if err != nil {
return err
}
defer func() {
if err := conn.Close(); err != nil {
log.Errorf(errCloseConnection, err)
}
}()
persistence := strings.ToLower(args[0])
if persistence != "on" && persistence != "off" {
return fmt.Errorf("invalid persistence value: %s. Use 'on' or 'off'", args[0])
}
client := proto.NewDaemonServiceClient(conn)
_, err = client.SetNetworkMapPersistence(cmd.Context(), &proto.SetNetworkMapPersistenceRequest{
Enabled: persistence == "on",
}) })
if err != nil { if err != nil {
return fmt.Errorf("failed to set network map persistence: %v", status.Convert(err).Message()) return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
} }
cmd.Printf("Network map persistence set to: %s\n", persistence) cmd.Println(resp.GetPath())
return nil return nil
} }

View File

@@ -26,7 +26,7 @@ var downCmd = &cobra.Command{
return err return err
} }
ctx, cancel := context.WithTimeout(context.Background(), time.Second*7) ctx, cancel := context.WithTimeout(context.Background(), time.Second*3)
defer cancel() defer cancel()
conn, err := DialClientGRPCServer(ctx, daemonAddr) conn, err := DialClientGRPCServer(ctx, daemonAddr)
@@ -42,8 +42,6 @@ var downCmd = &cobra.Command{
log.Errorf("call service down method: %v", err) log.Errorf("call service down method: %v", err)
return err return err
} }
cmd.Println("Disconnected")
return nil return nil
}, },
} }

View File

@@ -39,11 +39,6 @@ var loginCmd = &cobra.Command{
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, hostName) ctx = context.WithValue(ctx, system.DeviceNameCtxKey, hostName)
} }
providedSetupKey, err := getSetupKey()
if err != nil {
return err
}
// workaround to run without service // workaround to run without service
if logFile == "console" { if logFile == "console" {
err = handleRebrand(cmd) err = handleRebrand(cmd)
@@ -67,7 +62,7 @@ var loginCmd = &cobra.Command{
config, _ = internal.UpdateOldManagementURL(ctx, config, configPath) config, _ = internal.UpdateOldManagementURL(ctx, config, configPath)
err = foregroundLogin(ctx, cmd, config, providedSetupKey) err = foregroundLogin(ctx, cmd, config, setupKey)
if err != nil { if err != nil {
return fmt.Errorf("foreground login failed: %v", err) return fmt.Errorf("foreground login failed: %v", err)
} }
@@ -86,7 +81,7 @@ var loginCmd = &cobra.Command{
client := proto.NewDaemonServiceClient(conn) client := proto.NewDaemonServiceClient(conn)
loginRequest := proto.LoginRequest{ loginRequest := proto.LoginRequest{
SetupKey: providedSetupKey, SetupKey: setupKey,
ManagementUrl: managementURL, ManagementUrl: managementURL,
IsLinuxDesktopClient: isLinuxRunningDesktop(), IsLinuxDesktopClient: isLinuxRunningDesktop(),
Hostname: hostName, Hostname: hostName,

View File

@@ -5,8 +5,8 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )

View File

@@ -1,173 +0,0 @@
package cmd
import (
"fmt"
"strings"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/proto"
)
var appendFlag bool
var networksCMD = &cobra.Command{
Use: "networks",
Aliases: []string{"routes"},
Short: "Manage networks",
Long: `Commands to list, select, or deselect networks. Replaces the "routes" command.`,
}
var routesListCmd = &cobra.Command{
Use: "list",
Aliases: []string{"ls"},
Short: "List networks",
Example: " netbird networks list",
Long: "List all available network routes.",
RunE: networksList,
}
var routesSelectCmd = &cobra.Command{
Use: "select network...|all",
Short: "Select network",
Long: "Select a list of networks by identifiers or 'all' to clear all selections and to accept all (including new) networks.\nDefault mode is replace, use -a to append to already selected networks.",
Example: " netbird networks select all\n netbird networks select route1 route2\n netbird routes select -a route3",
Args: cobra.MinimumNArgs(1),
RunE: networksSelect,
}
var routesDeselectCmd = &cobra.Command{
Use: "deselect network...|all",
Short: "Deselect networks",
Long: "Deselect previously selected networks by identifiers or 'all' to disable accepting any networks.",
Example: " netbird networks deselect all\n netbird networks deselect route1 route2",
Args: cobra.MinimumNArgs(1),
RunE: networksDeselect,
}
func init() {
routesSelectCmd.PersistentFlags().BoolVarP(&appendFlag, "append", "a", false, "Append to current network selection instead of replacing")
}
func networksList(cmd *cobra.Command, _ []string) error {
conn, err := getClient(cmd)
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
resp, err := client.ListNetworks(cmd.Context(), &proto.ListNetworksRequest{})
if err != nil {
return fmt.Errorf("failed to list network: %v", status.Convert(err).Message())
}
if len(resp.Routes) == 0 {
cmd.Println("No networks available.")
return nil
}
printNetworks(cmd, resp)
return nil
}
func printNetworks(cmd *cobra.Command, resp *proto.ListNetworksResponse) {
cmd.Println("Available Networks:")
for _, route := range resp.Routes {
printNetwork(cmd, route)
}
}
func printNetwork(cmd *cobra.Command, route *proto.Network) {
selectedStatus := getSelectedStatus(route)
domains := route.GetDomains()
if len(domains) > 0 {
printDomainRoute(cmd, route, domains, selectedStatus)
} else {
printNetworkRoute(cmd, route, selectedStatus)
}
}
func getSelectedStatus(route *proto.Network) string {
if route.GetSelected() {
return "Selected"
}
return "Not Selected"
}
func printDomainRoute(cmd *cobra.Command, route *proto.Network, domains []string, selectedStatus string) {
cmd.Printf("\n - ID: %s\n Domains: %s\n Status: %s\n", route.GetID(), strings.Join(domains, ", "), selectedStatus)
resolvedIPs := route.GetResolvedIPs()
if len(resolvedIPs) > 0 {
printResolvedIPs(cmd, domains, resolvedIPs)
} else {
cmd.Printf(" Resolved IPs: -\n")
}
}
func printNetworkRoute(cmd *cobra.Command, route *proto.Network, selectedStatus string) {
cmd.Printf("\n - ID: %s\n Network: %s\n Status: %s\n", route.GetID(), route.GetRange(), selectedStatus)
}
func printResolvedIPs(cmd *cobra.Command, _ []string, resolvedIPs map[string]*proto.IPList) {
cmd.Printf(" Resolved IPs:\n")
for resolvedDomain, ipList := range resolvedIPs {
cmd.Printf(" [%s]: %s\n", resolvedDomain, strings.Join(ipList.GetIps(), ", "))
}
}
func networksSelect(cmd *cobra.Command, args []string) error {
conn, err := getClient(cmd)
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
req := &proto.SelectNetworksRequest{
NetworkIDs: args,
}
if len(args) == 1 && args[0] == "all" {
req.All = true
} else if appendFlag {
req.Append = true
}
if _, err := client.SelectNetworks(cmd.Context(), req); err != nil {
return fmt.Errorf("failed to select networks: %v", status.Convert(err).Message())
}
cmd.Println("Networks selected successfully.")
return nil
}
func networksDeselect(cmd *cobra.Command, args []string) error {
conn, err := getClient(cmd)
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
req := &proto.SelectNetworksRequest{
NetworkIDs: args,
}
if len(args) == 1 && args[0] == "all" {
req.All = true
}
if _, err := client.DeselectNetworks(cmd.Context(), req); err != nil {
return fmt.Errorf("failed to deselect networks: %v", status.Convert(err).Message())
}
cmd.Println("Networks deselected successfully.")
return nil
}

View File

@@ -1,33 +0,0 @@
//go:build pprof
// +build pprof
package cmd
import (
"net/http"
_ "net/http/pprof"
"os"
log "github.com/sirupsen/logrus"
)
func init() {
addr := pprofAddr()
go pprof(addr)
}
func pprofAddr() string {
listenAddr := os.Getenv("NB_PPROF_ADDR")
if listenAddr == "" {
return "localhost:6969"
}
return listenAddr
}
func pprof(listenAddr string) {
log.Infof("listening pprof on: %s\n", listenAddr)
if err := http.ListenAndServe(listenAddr, nil); err != nil {
log.Fatalf("Failed to start pprof: %v", err)
}
}

View File

@@ -37,7 +37,6 @@ const (
serverSSHAllowedFlag = "allow-server-ssh" serverSSHAllowedFlag = "allow-server-ssh"
extraIFaceBlackListFlag = "extra-iface-blacklist" extraIFaceBlackListFlag = "extra-iface-blacklist"
dnsRouteIntervalFlag = "dns-router-interval" dnsRouteIntervalFlag = "dns-router-interval"
systemInfoFlag = "system-info"
) )
var ( var (
@@ -56,7 +55,6 @@ var (
managementURL string managementURL string
adminURL string adminURL string
setupKey string setupKey string
setupKeyPath string
hostName string hostName string
preSharedKey string preSharedKey string
natExternalIPs []string natExternalIPs []string
@@ -71,7 +69,6 @@ var (
autoConnectDisabled bool autoConnectDisabled bool
extraIFaceBlackList []string extraIFaceBlackList []string
anonymizeFlag bool anonymizeFlag bool
debugSystemInfoFlag bool
dnsRouteInterval time.Duration dnsRouteInterval time.Duration
rootCmd = &cobra.Command{ rootCmd = &cobra.Command{
@@ -94,15 +91,12 @@ func init() {
oldDefaultConfigPathDir = "/etc/wiretrustee/" oldDefaultConfigPathDir = "/etc/wiretrustee/"
oldDefaultLogFileDir = "/var/log/wiretrustee/" oldDefaultLogFileDir = "/var/log/wiretrustee/"
switch runtime.GOOS { if runtime.GOOS == "windows" {
case "windows":
defaultConfigPathDir = os.Getenv("PROGRAMDATA") + "\\Netbird\\" defaultConfigPathDir = os.Getenv("PROGRAMDATA") + "\\Netbird\\"
defaultLogFileDir = os.Getenv("PROGRAMDATA") + "\\Netbird\\" defaultLogFileDir = os.Getenv("PROGRAMDATA") + "\\Netbird\\"
oldDefaultConfigPathDir = os.Getenv("PROGRAMDATA") + "\\Wiretrustee\\" oldDefaultConfigPathDir = os.Getenv("PROGRAMDATA") + "\\Wiretrustee\\"
oldDefaultLogFileDir = os.Getenv("PROGRAMDATA") + "\\Wiretrustee\\" oldDefaultLogFileDir = os.Getenv("PROGRAMDATA") + "\\Wiretrustee\\"
case "freebsd":
defaultConfigPathDir = "/var/db/netbird/"
} }
defaultConfigPath = defaultConfigPathDir + "config.json" defaultConfigPath = defaultConfigPathDir + "config.json"
@@ -127,10 +121,8 @@ func init() {
rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name") rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name")
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", defaultConfigPath, "Netbird config file location") rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", defaultConfigPath, "Netbird config file location")
rootCmd.PersistentFlags().StringVarP(&logLevel, "log-level", "l", "info", "sets Netbird log level") rootCmd.PersistentFlags().StringVarP(&logLevel, "log-level", "l", "info", "sets Netbird log level")
rootCmd.PersistentFlags().StringVar(&logFile, "log-file", defaultLogFile, "sets Netbird log path. If console is specified the log will be output to stdout. If syslog is specified the log will be sent to syslog daemon.") rootCmd.PersistentFlags().StringVar(&logFile, "log-file", defaultLogFile, "sets Netbird log path. If console is specified the log will be output to stdout")
rootCmd.PersistentFlags().StringVarP(&setupKey, "setup-key", "k", "", "Setup key obtained from the Management Service Dashboard (used to register peer)") rootCmd.PersistentFlags().StringVarP(&setupKey, "setup-key", "k", "", "Setup key obtained from the Management Service Dashboard (used to register peer)")
rootCmd.PersistentFlags().StringVar(&setupKeyPath, "setup-key-file", "", "The path to a setup key obtained from the Management Service Dashboard (used to register peer) This is ignored if the setup-key flag is provided.")
rootCmd.MarkFlagsMutuallyExclusive("setup-key", "setup-key-file")
rootCmd.PersistentFlags().StringVar(&preSharedKey, preSharedKeyFlag, "", "Sets Wireguard PreSharedKey property. If set, then only peers that have the same key can communicate.") rootCmd.PersistentFlags().StringVar(&preSharedKey, preSharedKeyFlag, "", "Sets Wireguard PreSharedKey property. If set, then only peers that have the same key can communicate.")
rootCmd.PersistentFlags().StringVarP(&hostName, "hostname", "n", "", "Sets a custom hostname for the device") rootCmd.PersistentFlags().StringVarP(&hostName, "hostname", "n", "", "Sets a custom hostname for the device")
rootCmd.PersistentFlags().BoolVarP(&anonymizeFlag, "anonymize", "A", false, "anonymize IP addresses and non-netbird.io domains in logs and status output") rootCmd.PersistentFlags().BoolVarP(&anonymizeFlag, "anonymize", "A", false, "anonymize IP addresses and non-netbird.io domains in logs and status output")
@@ -142,20 +134,19 @@ func init() {
rootCmd.AddCommand(loginCmd) rootCmd.AddCommand(loginCmd)
rootCmd.AddCommand(versionCmd) rootCmd.AddCommand(versionCmd)
rootCmd.AddCommand(sshCmd) rootCmd.AddCommand(sshCmd)
rootCmd.AddCommand(networksCMD) rootCmd.AddCommand(routesCmd)
rootCmd.AddCommand(debugCmd) rootCmd.AddCommand(debugCmd)
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd) // service control commands are subcommands of service serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd) // service control commands are subcommands of service
serviceCmd.AddCommand(installCmd, uninstallCmd) // service installer commands are subcommands of service serviceCmd.AddCommand(installCmd, uninstallCmd) // service installer commands are subcommands of service
networksCMD.AddCommand(routesListCmd) routesCmd.AddCommand(routesListCmd)
networksCMD.AddCommand(routesSelectCmd, routesDeselectCmd) routesCmd.AddCommand(routesSelectCmd, routesDeselectCmd)
debugCmd.AddCommand(debugBundleCmd) debugCmd.AddCommand(debugBundleCmd)
debugCmd.AddCommand(logCmd) debugCmd.AddCommand(logCmd)
logCmd.AddCommand(logLevelCmd) logCmd.AddCommand(logLevelCmd)
debugCmd.AddCommand(forCmd) debugCmd.AddCommand(forCmd)
debugCmd.AddCommand(persistenceCmd)
upCmd.PersistentFlags().StringSliceVar(&natExternalIPs, externalIPMapFlag, nil, upCmd.PersistentFlags().StringSliceVar(&natExternalIPs, externalIPMapFlag, nil,
`Sets external IPs maps between local addresses and interfaces.`+ `Sets external IPs maps between local addresses and interfaces.`+
@@ -174,8 +165,6 @@ func init() {
upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.") upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.")
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted") upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted")
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.") upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", false, "Adds system information to the debug bundle")
} }
// SetupCloseHandler handles SIGTERM signal and exits with success // SetupCloseHandler handles SIGTERM signal and exits with success
@@ -257,21 +246,6 @@ var CLIBackOffSettings = &backoff.ExponentialBackOff{
Clock: backoff.SystemClock, Clock: backoff.SystemClock,
} }
func getSetupKey() (string, error) {
if setupKeyPath != "" && setupKey == "" {
return getSetupKeyFromFile(setupKeyPath)
}
return setupKey, nil
}
func getSetupKeyFromFile(setupKeyPath string) (string, error) {
data, err := os.ReadFile(setupKeyPath)
if err != nil {
return "", fmt.Errorf("failed to read setup key file: %v", err)
}
return strings.TrimSpace(string(data)), nil
}
func handleRebrand(cmd *cobra.Command) error { func handleRebrand(cmd *cobra.Command) error {
var err error var err error
if logFile == defaultLogFile { if logFile == defaultLogFile {

View File

@@ -4,10 +4,6 @@ import (
"fmt" "fmt"
"io" "io"
"testing" "testing"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/iface"
) )
func TestInitCommands(t *testing.T) { func TestInitCommands(t *testing.T) {
@@ -38,44 +34,3 @@ func TestInitCommands(t *testing.T) {
}) })
} }
} }
func TestSetFlagsFromEnvVars(t *testing.T) {
var cmd = &cobra.Command{
Use: "netbird",
Long: "test",
SilenceUsage: true,
Run: func(cmd *cobra.Command, args []string) {
SetFlagsFromEnvVars(cmd)
},
}
cmd.PersistentFlags().StringSliceVar(&natExternalIPs, externalIPMapFlag, nil,
`comma separated list of external IPs to map to the Wireguard interface`)
cmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name")
cmd.PersistentFlags().BoolVar(&rosenpassEnabled, enableRosenpassFlag, false, "Enable Rosenpass feature Rosenpass.")
cmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port")
t.Setenv("NB_EXTERNAL_IP_MAP", "abc,dec")
t.Setenv("NB_INTERFACE_NAME", "test-name")
t.Setenv("NB_ENABLE_ROSENPASS", "true")
t.Setenv("NB_WIREGUARD_PORT", "10000")
err := cmd.Execute()
if err != nil {
t.Fatalf("expected no error while running netbird command, got %v", err)
}
if len(natExternalIPs) != 2 {
t.Errorf("expected 2 external ips, got %d", len(natExternalIPs))
}
if natExternalIPs[0] != "abc" || natExternalIPs[1] != "dec" {
t.Errorf("expected abc,dec, got %s,%s", natExternalIPs[0], natExternalIPs[1])
}
if interfaceName != "test-name" {
t.Errorf("expected test-name, got %s", interfaceName)
}
if !rosenpassEnabled {
t.Errorf("expected rosenpassEnabled to be true, got false")
}
if wireguardPort != 10000 {
t.Errorf("expected wireguardPort to be 10000, got %d", wireguardPort)
}
}

174
client/cmd/route.go Normal file
View File

@@ -0,0 +1,174 @@
package cmd
import (
"fmt"
"strings"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/proto"
)
var appendFlag bool
var routesCmd = &cobra.Command{
Use: "routes",
Short: "Manage network routes",
Long: `Commands to list, select, or deselect network routes.`,
}
var routesListCmd = &cobra.Command{
Use: "list",
Aliases: []string{"ls"},
Short: "List routes",
Example: " netbird routes list",
Long: "List all available network routes.",
RunE: routesList,
}
var routesSelectCmd = &cobra.Command{
Use: "select route...|all",
Short: "Select routes",
Long: "Select a list of routes by identifiers or 'all' to clear all selections and to accept all (including new) routes.\nDefault mode is replace, use -a to append to already selected routes.",
Example: " netbird routes select all\n netbird routes select route1 route2\n netbird routes select -a route3",
Args: cobra.MinimumNArgs(1),
RunE: routesSelect,
}
var routesDeselectCmd = &cobra.Command{
Use: "deselect route...|all",
Short: "Deselect routes",
Long: "Deselect previously selected routes by identifiers or 'all' to disable accepting any routes.",
Example: " netbird routes deselect all\n netbird routes deselect route1 route2",
Args: cobra.MinimumNArgs(1),
RunE: routesDeselect,
}
func init() {
routesSelectCmd.PersistentFlags().BoolVarP(&appendFlag, "append", "a", false, "Append to current route selection instead of replacing")
}
func routesList(cmd *cobra.Command, _ []string) error {
conn, err := getClient(cmd)
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
resp, err := client.ListRoutes(cmd.Context(), &proto.ListRoutesRequest{})
if err != nil {
return fmt.Errorf("failed to list routes: %v", status.Convert(err).Message())
}
if len(resp.Routes) == 0 {
cmd.Println("No routes available.")
return nil
}
printRoutes(cmd, resp)
return nil
}
func printRoutes(cmd *cobra.Command, resp *proto.ListRoutesResponse) {
cmd.Println("Available Routes:")
for _, route := range resp.Routes {
printRoute(cmd, route)
}
}
func printRoute(cmd *cobra.Command, route *proto.Route) {
selectedStatus := getSelectedStatus(route)
domains := route.GetDomains()
if len(domains) > 0 {
printDomainRoute(cmd, route, domains, selectedStatus)
} else {
printNetworkRoute(cmd, route, selectedStatus)
}
}
func getSelectedStatus(route *proto.Route) string {
if route.GetSelected() {
return "Selected"
}
return "Not Selected"
}
func printDomainRoute(cmd *cobra.Command, route *proto.Route, domains []string, selectedStatus string) {
cmd.Printf("\n - ID: %s\n Domains: %s\n Status: %s\n", route.GetID(), strings.Join(domains, ", "), selectedStatus)
resolvedIPs := route.GetResolvedIPs()
if len(resolvedIPs) > 0 {
printResolvedIPs(cmd, domains, resolvedIPs)
} else {
cmd.Printf(" Resolved IPs: -\n")
}
}
func printNetworkRoute(cmd *cobra.Command, route *proto.Route, selectedStatus string) {
cmd.Printf("\n - ID: %s\n Network: %s\n Status: %s\n", route.GetID(), route.GetNetwork(), selectedStatus)
}
func printResolvedIPs(cmd *cobra.Command, domains []string, resolvedIPs map[string]*proto.IPList) {
cmd.Printf(" Resolved IPs:\n")
for _, domain := range domains {
if ipList, exists := resolvedIPs[domain]; exists {
cmd.Printf(" [%s]: %s\n", domain, strings.Join(ipList.GetIps(), ", "))
}
}
}
func routesSelect(cmd *cobra.Command, args []string) error {
conn, err := getClient(cmd)
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
req := &proto.SelectRoutesRequest{
RouteIDs: args,
}
if len(args) == 1 && args[0] == "all" {
req.All = true
} else if appendFlag {
req.Append = true
}
if _, err := client.SelectRoutes(cmd.Context(), req); err != nil {
return fmt.Errorf("failed to select routes: %v", status.Convert(err).Message())
}
cmd.Println("Routes selected successfully.")
return nil
}
func routesDeselect(cmd *cobra.Command, args []string) error {
conn, err := getClient(cmd)
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
req := &proto.SelectRoutesRequest{
RouteIDs: args,
}
if len(args) == 1 && args[0] == "all" {
req.All = true
}
if _, err := client.DeselectRoutes(cmd.Context(), req); err != nil {
return fmt.Errorf("failed to deselect routes: %v", status.Convert(err).Message())
}
cmd.Println("Routes deselected successfully.")
return nil
}

View File

@@ -2,23 +2,18 @@ package cmd
import ( import (
"context" "context"
"sync"
"github.com/kardianos/service" "github.com/kardianos/service"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"google.golang.org/grpc" "google.golang.org/grpc"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/server"
) )
type program struct { type program struct {
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
serv *grpc.Server serv *grpc.Server
serverInstance *server.Server
serverInstanceMu sync.Mutex
} }
func newProgram(ctx context.Context, cancel context.CancelFunc) *program { func newProgram(ctx context.Context, cancel context.CancelFunc) *program {

View File

@@ -61,10 +61,6 @@ func (p *program) Start(svc service.Service) error {
} }
proto.RegisterDaemonServiceServer(p.serv, serverInstance) proto.RegisterDaemonServiceServer(p.serv, serverInstance)
p.serverInstanceMu.Lock()
p.serverInstance = serverInstance
p.serverInstanceMu.Unlock()
log.Printf("started daemon server: %v", split[1]) log.Printf("started daemon server: %v", split[1])
if err := p.serv.Serve(listen); err != nil { if err := p.serv.Serve(listen); err != nil {
log.Errorf("failed to serve daemon requests: %v", err) log.Errorf("failed to serve daemon requests: %v", err)
@@ -74,16 +70,6 @@ func (p *program) Start(svc service.Service) error {
} }
func (p *program) Stop(srv service.Service) error { func (p *program) Stop(srv service.Service) error {
p.serverInstanceMu.Lock()
if p.serverInstance != nil {
in := new(proto.DownRequest)
_, err := p.serverInstance.Down(p.ctx, in)
if err != nil {
log.Errorf("failed to stop daemon: %v", err)
}
}
p.serverInstanceMu.Unlock()
p.cancel() p.cancel()
if p.serv != nil { if p.serv != nil {

View File

@@ -31,8 +31,6 @@ var installCmd = &cobra.Command{
configPath, configPath,
"--log-level", "--log-level",
logLevel, logLevel,
"--daemon-addr",
daemonAddr,
} }
if managementURL != "" { if managementURL != "" {

View File

@@ -1,181 +0,0 @@
package cmd
import (
"fmt"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/proto"
)
var (
allFlag bool
)
var stateCmd = &cobra.Command{
Use: "state",
Short: "Manage daemon state",
Long: "Provides commands for managing and inspecting the Netbird daemon state.",
}
var stateListCmd = &cobra.Command{
Use: "list",
Aliases: []string{"ls"},
Short: "List all stored states",
Long: "Lists all registered states with their status and basic information.",
Example: " netbird state list",
RunE: stateList,
}
var stateCleanCmd = &cobra.Command{
Use: "clean [state-name]",
Short: "Clean stored states",
Long: `Clean specific state or all states. The daemon must not be running.
This will perform cleanup operations and remove the state.`,
Example: ` netbird state clean dns_state
netbird state clean --all`,
RunE: stateClean,
PreRunE: func(cmd *cobra.Command, args []string) error {
// Check mutual exclusivity between --all flag and state-name argument
if allFlag && len(args) > 0 {
return fmt.Errorf("cannot specify both --all flag and state name")
}
if !allFlag && len(args) != 1 {
return fmt.Errorf("requires a state name argument or --all flag")
}
return nil
},
}
var stateDeleteCmd = &cobra.Command{
Use: "delete [state-name]",
Short: "Delete stored states",
Long: `Delete specific state or all states from storage. The daemon must not be running.
This will remove the state without performing any cleanup operations.`,
Example: ` netbird state delete dns_state
netbird state delete --all`,
RunE: stateDelete,
PreRunE: func(cmd *cobra.Command, args []string) error {
// Check mutual exclusivity between --all flag and state-name argument
if allFlag && len(args) > 0 {
return fmt.Errorf("cannot specify both --all flag and state name")
}
if !allFlag && len(args) != 1 {
return fmt.Errorf("requires a state name argument or --all flag")
}
return nil
},
}
func init() {
rootCmd.AddCommand(stateCmd)
stateCmd.AddCommand(stateListCmd, stateCleanCmd, stateDeleteCmd)
stateCleanCmd.Flags().BoolVarP(&allFlag, "all", "a", false, "Clean all states")
stateDeleteCmd.Flags().BoolVarP(&allFlag, "all", "a", false, "Delete all states")
}
func stateList(cmd *cobra.Command, _ []string) error {
conn, err := getClient(cmd)
if err != nil {
return err
}
defer func() {
if err := conn.Close(); err != nil {
log.Errorf(errCloseConnection, err)
}
}()
client := proto.NewDaemonServiceClient(conn)
resp, err := client.ListStates(cmd.Context(), &proto.ListStatesRequest{})
if err != nil {
return fmt.Errorf("failed to list states: %v", status.Convert(err).Message())
}
cmd.Printf("\nStored states:\n\n")
for _, state := range resp.States {
cmd.Printf("- %s\n", state.Name)
}
return nil
}
func stateClean(cmd *cobra.Command, args []string) error {
var stateName string
if !allFlag {
stateName = args[0]
}
conn, err := getClient(cmd)
if err != nil {
return err
}
defer func() {
if err := conn.Close(); err != nil {
log.Errorf(errCloseConnection, err)
}
}()
client := proto.NewDaemonServiceClient(conn)
resp, err := client.CleanState(cmd.Context(), &proto.CleanStateRequest{
StateName: stateName,
All: allFlag,
})
if err != nil {
return fmt.Errorf("failed to clean state: %v", status.Convert(err).Message())
}
if resp.CleanedStates == 0 {
cmd.Println("No states were cleaned")
return nil
}
if allFlag {
cmd.Printf("Successfully cleaned %d states\n", resp.CleanedStates)
} else {
cmd.Printf("Successfully cleaned state %q\n", stateName)
}
return nil
}
func stateDelete(cmd *cobra.Command, args []string) error {
var stateName string
if !allFlag {
stateName = args[0]
}
conn, err := getClient(cmd)
if err != nil {
return err
}
defer func() {
if err := conn.Close(); err != nil {
log.Errorf(errCloseConnection, err)
}
}()
client := proto.NewDaemonServiceClient(conn)
resp, err := client.DeleteState(cmd.Context(), &proto.DeleteStateRequest{
StateName: stateName,
All: allFlag,
})
if err != nil {
return fmt.Errorf("failed to delete state: %v", status.Convert(err).Message())
}
if resp.DeletedStates == 0 {
cmd.Println("No states were deleted")
return nil
}
if allFlag {
cmd.Printf("Successfully deleted %d states\n", resp.DeletedStates)
} else {
cmd.Printf("Successfully deleted state %q\n", stateName)
}
return nil
}

View File

@@ -31,16 +31,15 @@ type peerStateDetailOutput struct {
Status string `json:"status" yaml:"status"` Status string `json:"status" yaml:"status"`
LastStatusUpdate time.Time `json:"lastStatusUpdate" yaml:"lastStatusUpdate"` LastStatusUpdate time.Time `json:"lastStatusUpdate" yaml:"lastStatusUpdate"`
ConnType string `json:"connectionType" yaml:"connectionType"` ConnType string `json:"connectionType" yaml:"connectionType"`
Direct bool `json:"direct" yaml:"direct"`
IceCandidateType iceCandidateType `json:"iceCandidateType" yaml:"iceCandidateType"` IceCandidateType iceCandidateType `json:"iceCandidateType" yaml:"iceCandidateType"`
IceCandidateEndpoint iceCandidateType `json:"iceCandidateEndpoint" yaml:"iceCandidateEndpoint"` IceCandidateEndpoint iceCandidateType `json:"iceCandidateEndpoint" yaml:"iceCandidateEndpoint"`
RelayAddress string `json:"relayAddress" yaml:"relayAddress"`
LastWireguardHandshake time.Time `json:"lastWireguardHandshake" yaml:"lastWireguardHandshake"` LastWireguardHandshake time.Time `json:"lastWireguardHandshake" yaml:"lastWireguardHandshake"`
TransferReceived int64 `json:"transferReceived" yaml:"transferReceived"` TransferReceived int64 `json:"transferReceived" yaml:"transferReceived"`
TransferSent int64 `json:"transferSent" yaml:"transferSent"` TransferSent int64 `json:"transferSent" yaml:"transferSent"`
Latency time.Duration `json:"latency" yaml:"latency"` Latency time.Duration `json:"latency" yaml:"latency"`
RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"` RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"`
Routes []string `json:"routes" yaml:"routes"` Routes []string `json:"routes" yaml:"routes"`
Networks []string `json:"networks" yaml:"networks"`
} }
type peersStateOutput struct { type peersStateOutput struct {
@@ -99,7 +98,6 @@ type statusOutputOverview struct {
RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"` RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"`
RosenpassPermissive bool `json:"quantumResistancePermissive" yaml:"quantumResistancePermissive"` RosenpassPermissive bool `json:"quantumResistancePermissive" yaml:"quantumResistancePermissive"`
Routes []string `json:"routes" yaml:"routes"` Routes []string `json:"routes" yaml:"routes"`
Networks []string `json:"networks" yaml:"networks"`
NSServerGroups []nsServerGroupStateOutput `json:"dnsServers" yaml:"dnsServers"` NSServerGroups []nsServerGroupStateOutput `json:"dnsServers" yaml:"dnsServers"`
} }
@@ -284,8 +282,7 @@ func convertToStatusOutputOverview(resp *proto.StatusResponse) statusOutputOverv
FQDN: pbFullStatus.GetLocalPeerState().GetFqdn(), FQDN: pbFullStatus.GetLocalPeerState().GetFqdn(),
RosenpassEnabled: pbFullStatus.GetLocalPeerState().GetRosenpassEnabled(), RosenpassEnabled: pbFullStatus.GetLocalPeerState().GetRosenpassEnabled(),
RosenpassPermissive: pbFullStatus.GetLocalPeerState().GetRosenpassPermissive(), RosenpassPermissive: pbFullStatus.GetLocalPeerState().GetRosenpassPermissive(),
Routes: pbFullStatus.GetLocalPeerState().GetNetworks(), Routes: pbFullStatus.GetLocalPeerState().GetRoutes(),
Networks: pbFullStatus.GetLocalPeerState().GetNetworks(),
NSServerGroups: mapNSGroups(pbFullStatus.GetDnsServers()), NSServerGroups: mapNSGroups(pbFullStatus.GetDnsServers()),
} }
@@ -338,18 +335,16 @@ func mapNSGroups(servers []*proto.NSGroupState) []nsServerGroupStateOutput {
func mapPeers(peers []*proto.PeerState) peersStateOutput { func mapPeers(peers []*proto.PeerState) peersStateOutput {
var peersStateDetail []peerStateDetailOutput var peersStateDetail []peerStateDetailOutput
peersConnected := 0
for _, pbPeerState := range peers {
localICE := "" localICE := ""
remoteICE := "" remoteICE := ""
localICEEndpoint := "" localICEEndpoint := ""
remoteICEEndpoint := "" remoteICEEndpoint := ""
relayServerAddress := ""
connType := "" connType := ""
peersConnected := 0
lastHandshake := time.Time{} lastHandshake := time.Time{}
transferReceived := int64(0) transferReceived := int64(0)
transferSent := int64(0) transferSent := int64(0)
for _, pbPeerState := range peers {
isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String() isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String()
if skipDetailByFilters(pbPeerState, isPeerConnected) { if skipDetailByFilters(pbPeerState, isPeerConnected) {
continue continue
@@ -365,7 +360,6 @@ func mapPeers(peers []*proto.PeerState) peersStateOutput {
if pbPeerState.Relayed { if pbPeerState.Relayed {
connType = "Relayed" connType = "Relayed"
} }
relayServerAddress = pbPeerState.GetRelayAddress()
lastHandshake = pbPeerState.GetLastWireguardHandshake().AsTime().Local() lastHandshake = pbPeerState.GetLastWireguardHandshake().AsTime().Local()
transferReceived = pbPeerState.GetBytesRx() transferReceived = pbPeerState.GetBytesRx()
transferSent = pbPeerState.GetBytesTx() transferSent = pbPeerState.GetBytesTx()
@@ -378,6 +372,7 @@ func mapPeers(peers []*proto.PeerState) peersStateOutput {
Status: pbPeerState.GetConnStatus(), Status: pbPeerState.GetConnStatus(),
LastStatusUpdate: timeLocal, LastStatusUpdate: timeLocal,
ConnType: connType, ConnType: connType,
Direct: pbPeerState.GetDirect(),
IceCandidateType: iceCandidateType{ IceCandidateType: iceCandidateType{
Local: localICE, Local: localICE,
Remote: remoteICE, Remote: remoteICE,
@@ -386,15 +381,13 @@ func mapPeers(peers []*proto.PeerState) peersStateOutput {
Local: localICEEndpoint, Local: localICEEndpoint,
Remote: remoteICEEndpoint, Remote: remoteICEEndpoint,
}, },
RelayAddress: relayServerAddress,
FQDN: pbPeerState.GetFqdn(), FQDN: pbPeerState.GetFqdn(),
LastWireguardHandshake: lastHandshake, LastWireguardHandshake: lastHandshake,
TransferReceived: transferReceived, TransferReceived: transferReceived,
TransferSent: transferSent, TransferSent: transferSent,
Latency: pbPeerState.GetLatency().AsDuration(), Latency: pbPeerState.GetLatency().AsDuration(),
RosenpassEnabled: pbPeerState.GetRosenpassEnabled(), RosenpassEnabled: pbPeerState.GetRosenpassEnabled(),
Routes: pbPeerState.GetNetworks(), Routes: pbPeerState.GetRoutes(),
Networks: pbPeerState.GetNetworks(),
} }
peersStateDetail = append(peersStateDetail, peerState) peersStateDetail = append(peersStateDetail, peerState)
@@ -495,10 +488,10 @@ func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays
relaysString = fmt.Sprintf("%d/%d Available", overview.Relays.Available, overview.Relays.Total) relaysString = fmt.Sprintf("%d/%d Available", overview.Relays.Available, overview.Relays.Total)
} }
networks := "-" routes := "-"
if len(overview.Networks) > 0 { if len(overview.Routes) > 0 {
sort.Strings(overview.Networks) sort.Strings(overview.Routes)
networks = strings.Join(overview.Networks, ", ") routes = strings.Join(overview.Routes, ", ")
} }
var dnsServersString string var dnsServersString string
@@ -560,7 +553,6 @@ func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays
"Interface type: %s\n"+ "Interface type: %s\n"+
"Quantum resistance: %s\n"+ "Quantum resistance: %s\n"+
"Routes: %s\n"+ "Routes: %s\n"+
"Networks: %s\n"+
"Peers count: %s\n", "Peers count: %s\n",
fmt.Sprintf("%s/%s%s", goos, goarch, goarm), fmt.Sprintf("%s/%s%s", goos, goarch, goarm),
overview.DaemonVersion, overview.DaemonVersion,
@@ -573,8 +565,7 @@ func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays
interfaceIP, interfaceIP,
interfaceTypeString, interfaceTypeString,
rosenpassEnabledStatus, rosenpassEnabledStatus,
networks, routes,
networks,
peersCountString, peersCountString,
) )
return summary return summary
@@ -637,10 +628,10 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo
} }
} }
networks := "-" routes := "-"
if len(peerState.Networks) > 0 { if len(peerState.Routes) > 0 {
sort.Strings(peerState.Networks) sort.Strings(peerState.Routes)
networks = strings.Join(peerState.Networks, ", ") routes = strings.Join(peerState.Routes, ", ")
} }
peerString := fmt.Sprintf( peerString := fmt.Sprintf(
@@ -650,33 +641,31 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo
" Status: %s\n"+ " Status: %s\n"+
" -- detail --\n"+ " -- detail --\n"+
" Connection type: %s\n"+ " Connection type: %s\n"+
" Direct: %t\n"+
" ICE candidate (Local/Remote): %s/%s\n"+ " ICE candidate (Local/Remote): %s/%s\n"+
" ICE candidate endpoints (Local/Remote): %s/%s\n"+ " ICE candidate endpoints (Local/Remote): %s/%s\n"+
" Relay server address: %s\n"+
" Last connection update: %s\n"+ " Last connection update: %s\n"+
" Last WireGuard handshake: %s\n"+ " Last WireGuard handshake: %s\n"+
" Transfer status (received/sent) %s/%s\n"+ " Transfer status (received/sent) %s/%s\n"+
" Quantum resistance: %s\n"+ " Quantum resistance: %s\n"+
" Routes: %s\n"+ " Routes: %s\n"+
" Networks: %s\n"+
" Latency: %s\n", " Latency: %s\n",
peerState.FQDN, peerState.FQDN,
peerState.IP, peerState.IP,
peerState.PubKey, peerState.PubKey,
peerState.Status, peerState.Status,
peerState.ConnType, peerState.ConnType,
peerState.Direct,
localICE, localICE,
remoteICE, remoteICE,
localICEEndpoint, localICEEndpoint,
remoteICEEndpoint, remoteICEEndpoint,
peerState.RelayAddress,
timeAgo(peerState.LastStatusUpdate), timeAgo(peerState.LastStatusUpdate),
timeAgo(peerState.LastWireguardHandshake), timeAgo(peerState.LastWireguardHandshake),
toIEC(peerState.TransferReceived), toIEC(peerState.TransferReceived),
toIEC(peerState.TransferSent), toIEC(peerState.TransferSent),
rosenpassEnabledStatus, rosenpassEnabledStatus,
networks, routes,
networks,
peerState.Latency.String(), peerState.Latency.String(),
) )
@@ -688,7 +677,7 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo
func skipDetailByFilters(peerState *proto.PeerState, isConnected bool) bool { func skipDetailByFilters(peerState *proto.PeerState, isConnected bool) bool {
statusEval := false statusEval := false
ipEval := false ipEval := false
nameEval := true nameEval := false
if statusFilter != "" { if statusFilter != "" {
lowerStatusFilter := strings.ToLower(statusFilter) lowerStatusFilter := strings.ToLower(statusFilter)
@@ -708,13 +697,11 @@ func skipDetailByFilters(peerState *proto.PeerState, isConnected bool) bool {
if len(prefixNamesFilter) > 0 { if len(prefixNamesFilter) > 0 {
for prefixNameFilter := range prefixNamesFilterMap { for prefixNameFilter := range prefixNamesFilterMap {
if strings.HasPrefix(peerState.Fqdn, prefixNameFilter) { if !strings.HasPrefix(peerState.Fqdn, prefixNameFilter) {
nameEval = false nameEval = true
break break
} }
} }
} else {
nameEval = false
} }
return statusEval || ipEval || nameEval return statusEval || ipEval || nameEval
@@ -815,23 +802,12 @@ func anonymizePeerDetail(a *anonymize.Anonymizer, peer *peerStateDetailOutput) {
if remoteIP, port, err := net.SplitHostPort(peer.IceCandidateEndpoint.Remote); err == nil { if remoteIP, port, err := net.SplitHostPort(peer.IceCandidateEndpoint.Remote); err == nil {
peer.IceCandidateEndpoint.Remote = fmt.Sprintf("%s:%s", a.AnonymizeIPString(remoteIP), port) peer.IceCandidateEndpoint.Remote = fmt.Sprintf("%s:%s", a.AnonymizeIPString(remoteIP), port)
} }
peer.RelayAddress = a.AnonymizeURI(peer.RelayAddress)
for i, route := range peer.Networks {
peer.Networks[i] = a.AnonymizeIPString(route)
}
for i, route := range peer.Networks {
peer.Networks[i] = a.AnonymizeRoute(route)
}
for i, route := range peer.Routes { for i, route := range peer.Routes {
peer.Routes[i] = a.AnonymizeIPString(route) peer.Routes[i] = a.AnonymizeIPString(route)
} }
for i, route := range peer.Routes { for i, route := range peer.Routes {
peer.Routes[i] = a.AnonymizeRoute(route) peer.Routes[i] = anonymizeRoute(a, route)
} }
} }
@@ -866,13 +842,22 @@ func anonymizeOverview(a *anonymize.Anonymizer, overview *statusOutputOverview)
} }
} }
for i, route := range overview.Networks {
overview.Networks[i] = a.AnonymizeRoute(route)
}
for i, route := range overview.Routes { for i, route := range overview.Routes {
overview.Routes[i] = a.AnonymizeRoute(route) overview.Routes[i] = anonymizeRoute(a, route)
} }
overview.FQDN = a.AnonymizeDomain(overview.FQDN) overview.FQDN = a.AnonymizeDomain(overview.FQDN)
} }
func anonymizeRoute(a *anonymize.Anonymizer, route string) string {
prefix, err := netip.ParsePrefix(route)
if err == nil {
ip := a.AnonymizeIPString(prefix.Addr().String())
return fmt.Sprintf("%s/%d", ip, prefix.Bits())
}
domains := strings.Split(route, ", ")
for i, domain := range domains {
domains[i] = a.AnonymizeDomain(domain)
}
return strings.Join(domains, ", ")
}

View File

@@ -37,6 +37,7 @@ var resp = &proto.StatusResponse{
ConnStatus: "Connected", ConnStatus: "Connected",
ConnStatusUpdate: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 1, 0, time.UTC)), ConnStatusUpdate: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 1, 0, time.UTC)),
Relayed: false, Relayed: false,
Direct: true,
LocalIceCandidateType: "", LocalIceCandidateType: "",
RemoteIceCandidateType: "", RemoteIceCandidateType: "",
LocalIceCandidateEndpoint: "", LocalIceCandidateEndpoint: "",
@@ -44,7 +45,7 @@ var resp = &proto.StatusResponse{
LastWireguardHandshake: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 2, 0, time.UTC)), LastWireguardHandshake: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 2, 0, time.UTC)),
BytesRx: 200, BytesRx: 200,
BytesTx: 100, BytesTx: 100,
Networks: []string{ Routes: []string{
"10.1.0.0/24", "10.1.0.0/24",
}, },
Latency: durationpb.New(time.Duration(10000000)), Latency: durationpb.New(time.Duration(10000000)),
@@ -56,6 +57,7 @@ var resp = &proto.StatusResponse{
ConnStatus: "Connected", ConnStatus: "Connected",
ConnStatusUpdate: timestamppb.New(time.Date(2002, time.Month(2), 2, 2, 2, 2, 0, time.UTC)), ConnStatusUpdate: timestamppb.New(time.Date(2002, time.Month(2), 2, 2, 2, 2, 0, time.UTC)),
Relayed: true, Relayed: true,
Direct: false,
LocalIceCandidateType: "relay", LocalIceCandidateType: "relay",
RemoteIceCandidateType: "prflx", RemoteIceCandidateType: "prflx",
LocalIceCandidateEndpoint: "10.0.0.1:10001", LocalIceCandidateEndpoint: "10.0.0.1:10001",
@@ -93,7 +95,7 @@ var resp = &proto.StatusResponse{
PubKey: "Some-Pub-Key", PubKey: "Some-Pub-Key",
KernelInterface: true, KernelInterface: true,
Fqdn: "some-localhost.awesome-domain.com", Fqdn: "some-localhost.awesome-domain.com",
Networks: []string{ Routes: []string{
"10.10.0.0/24", "10.10.0.0/24",
}, },
}, },
@@ -135,6 +137,7 @@ var overview = statusOutputOverview{
Status: "Connected", Status: "Connected",
LastStatusUpdate: time.Date(2001, 1, 1, 1, 1, 1, 0, time.UTC), LastStatusUpdate: time.Date(2001, 1, 1, 1, 1, 1, 0, time.UTC),
ConnType: "P2P", ConnType: "P2P",
Direct: true,
IceCandidateType: iceCandidateType{ IceCandidateType: iceCandidateType{
Local: "", Local: "",
Remote: "", Remote: "",
@@ -149,9 +152,6 @@ var overview = statusOutputOverview{
Routes: []string{ Routes: []string{
"10.1.0.0/24", "10.1.0.0/24",
}, },
Networks: []string{
"10.1.0.0/24",
},
Latency: time.Duration(10000000), Latency: time.Duration(10000000),
}, },
{ {
@@ -161,6 +161,7 @@ var overview = statusOutputOverview{
Status: "Connected", Status: "Connected",
LastStatusUpdate: time.Date(2002, 2, 2, 2, 2, 2, 0, time.UTC), LastStatusUpdate: time.Date(2002, 2, 2, 2, 2, 2, 0, time.UTC),
ConnType: "Relayed", ConnType: "Relayed",
Direct: false,
IceCandidateType: iceCandidateType{ IceCandidateType: iceCandidateType{
Local: "relay", Local: "relay",
Remote: "prflx", Remote: "prflx",
@@ -233,9 +234,6 @@ var overview = statusOutputOverview{
Routes: []string{ Routes: []string{
"10.10.0.0/24", "10.10.0.0/24",
}, },
Networks: []string{
"10.10.0.0/24",
},
} }
func TestConversionFromFullStatusToOutputOverview(t *testing.T) { func TestConversionFromFullStatusToOutputOverview(t *testing.T) {
@@ -285,6 +283,7 @@ func TestParsingToJSON(t *testing.T) {
"status": "Connected", "status": "Connected",
"lastStatusUpdate": "2001-01-01T01:01:01Z", "lastStatusUpdate": "2001-01-01T01:01:01Z",
"connectionType": "P2P", "connectionType": "P2P",
"direct": true,
"iceCandidateType": { "iceCandidateType": {
"local": "", "local": "",
"remote": "" "remote": ""
@@ -293,7 +292,6 @@ func TestParsingToJSON(t *testing.T) {
"local": "", "local": "",
"remote": "" "remote": ""
}, },
"relayAddress": "",
"lastWireguardHandshake": "2001-01-01T01:01:02Z", "lastWireguardHandshake": "2001-01-01T01:01:02Z",
"transferReceived": 200, "transferReceived": 200,
"transferSent": 100, "transferSent": 100,
@@ -301,9 +299,6 @@ func TestParsingToJSON(t *testing.T) {
"quantumResistance": false, "quantumResistance": false,
"routes": [ "routes": [
"10.1.0.0/24" "10.1.0.0/24"
],
"networks": [
"10.1.0.0/24"
] ]
}, },
{ {
@@ -313,6 +308,7 @@ func TestParsingToJSON(t *testing.T) {
"status": "Connected", "status": "Connected",
"lastStatusUpdate": "2002-02-02T02:02:02Z", "lastStatusUpdate": "2002-02-02T02:02:02Z",
"connectionType": "Relayed", "connectionType": "Relayed",
"direct": false,
"iceCandidateType": { "iceCandidateType": {
"local": "relay", "local": "relay",
"remote": "prflx" "remote": "prflx"
@@ -321,14 +317,12 @@ func TestParsingToJSON(t *testing.T) {
"local": "10.0.0.1:10001", "local": "10.0.0.1:10001",
"remote": "10.0.10.1:10002" "remote": "10.0.10.1:10002"
}, },
"relayAddress": "",
"lastWireguardHandshake": "2002-02-02T02:02:03Z", "lastWireguardHandshake": "2002-02-02T02:02:03Z",
"transferReceived": 2000, "transferReceived": 2000,
"transferSent": 1000, "transferSent": 1000,
"latency": 10000000, "latency": 10000000,
"quantumResistance": false, "quantumResistance": false,
"routes": null, "routes": null
"networks": null
} }
] ]
}, },
@@ -369,9 +363,6 @@ func TestParsingToJSON(t *testing.T) {
"routes": [ "routes": [
"10.10.0.0/24" "10.10.0.0/24"
], ],
"networks": [
"10.10.0.0/24"
],
"dnsServers": [ "dnsServers": [
{ {
"servers": [ "servers": [
@@ -417,13 +408,13 @@ func TestParsingToYAML(t *testing.T) {
status: Connected status: Connected
lastStatusUpdate: 2001-01-01T01:01:01Z lastStatusUpdate: 2001-01-01T01:01:01Z
connectionType: P2P connectionType: P2P
direct: true
iceCandidateType: iceCandidateType:
local: "" local: ""
remote: "" remote: ""
iceCandidateEndpoint: iceCandidateEndpoint:
local: "" local: ""
remote: "" remote: ""
relayAddress: ""
lastWireguardHandshake: 2001-01-01T01:01:02Z lastWireguardHandshake: 2001-01-01T01:01:02Z
transferReceived: 200 transferReceived: 200
transferSent: 100 transferSent: 100
@@ -431,28 +422,25 @@ func TestParsingToYAML(t *testing.T) {
quantumResistance: false quantumResistance: false
routes: routes:
- 10.1.0.0/24 - 10.1.0.0/24
networks:
- 10.1.0.0/24
- fqdn: peer-2.awesome-domain.com - fqdn: peer-2.awesome-domain.com
netbirdIp: 192.168.178.102 netbirdIp: 192.168.178.102
publicKey: Pubkey2 publicKey: Pubkey2
status: Connected status: Connected
lastStatusUpdate: 2002-02-02T02:02:02Z lastStatusUpdate: 2002-02-02T02:02:02Z
connectionType: Relayed connectionType: Relayed
direct: false
iceCandidateType: iceCandidateType:
local: relay local: relay
remote: prflx remote: prflx
iceCandidateEndpoint: iceCandidateEndpoint:
local: 10.0.0.1:10001 local: 10.0.0.1:10001
remote: 10.0.10.1:10002 remote: 10.0.10.1:10002
relayAddress: ""
lastWireguardHandshake: 2002-02-02T02:02:03Z lastWireguardHandshake: 2002-02-02T02:02:03Z
transferReceived: 2000 transferReceived: 2000
transferSent: 1000 transferSent: 1000
latency: 10ms latency: 10ms
quantumResistance: false quantumResistance: false
routes: [] routes: []
networks: []
cliVersion: development cliVersion: development
daemonVersion: 0.14.1 daemonVersion: 0.14.1
management: management:
@@ -481,8 +469,6 @@ quantumResistance: false
quantumResistancePermissive: false quantumResistancePermissive: false
routes: routes:
- 10.10.0.0/24 - 10.10.0.0/24
networks:
- 10.10.0.0/24
dnsServers: dnsServers:
- servers: - servers:
- 8.8.8.8:53 - 8.8.8.8:53
@@ -519,15 +505,14 @@ func TestParsingToDetail(t *testing.T) {
Status: Connected Status: Connected
-- detail -- -- detail --
Connection type: P2P Connection type: P2P
Direct: true
ICE candidate (Local/Remote): -/- ICE candidate (Local/Remote): -/-
ICE candidate endpoints (Local/Remote): -/- ICE candidate endpoints (Local/Remote): -/-
Relay server address:
Last connection update: %s Last connection update: %s
Last WireGuard handshake: %s Last WireGuard handshake: %s
Transfer status (received/sent) 200 B/100 B Transfer status (received/sent) 200 B/100 B
Quantum resistance: false Quantum resistance: false
Routes: 10.1.0.0/24 Routes: 10.1.0.0/24
Networks: 10.1.0.0/24
Latency: 10ms Latency: 10ms
peer-2.awesome-domain.com: peer-2.awesome-domain.com:
@@ -536,15 +521,14 @@ func TestParsingToDetail(t *testing.T) {
Status: Connected Status: Connected
-- detail -- -- detail --
Connection type: Relayed Connection type: Relayed
Direct: false
ICE candidate (Local/Remote): relay/prflx ICE candidate (Local/Remote): relay/prflx
ICE candidate endpoints (Local/Remote): 10.0.0.1:10001/10.0.10.1:10002 ICE candidate endpoints (Local/Remote): 10.0.0.1:10001/10.0.10.1:10002
Relay server address:
Last connection update: %s Last connection update: %s
Last WireGuard handshake: %s Last WireGuard handshake: %s
Transfer status (received/sent) 2.0 KiB/1000 B Transfer status (received/sent) 2.0 KiB/1000 B
Quantum resistance: false Quantum resistance: false
Routes: - Routes: -
Networks: -
Latency: 10ms Latency: 10ms
OS: %s/%s OS: %s/%s
@@ -563,7 +547,6 @@ NetBird IP: 192.168.178.100/16
Interface type: Kernel Interface type: Kernel
Quantum resistance: false Quantum resistance: false
Routes: 10.10.0.0/24 Routes: 10.10.0.0/24
Networks: 10.10.0.0/24
Peers count: 2/2 Connected Peers count: 2/2 Connected
`, lastConnectionUpdate1, lastHandshake1, lastConnectionUpdate2, lastHandshake2, runtime.GOOS, runtime.GOARCH, overview.CliVersion) `, lastConnectionUpdate1, lastHandshake1, lastConnectionUpdate2, lastHandshake2, runtime.GOOS, runtime.GOARCH, overview.CliVersion)
@@ -585,7 +568,6 @@ NetBird IP: 192.168.178.100/16
Interface type: Kernel Interface type: Kernel
Quantum resistance: false Quantum resistance: false
Routes: 10.10.0.0/24 Routes: 10.10.0.0/24
Networks: 10.10.0.0/24
Peers count: 2/2 Connected Peers count: 2/2 Connected
` `

View File

@@ -3,6 +3,7 @@ package cmd
import ( import (
"context" "context"
"net" "net"
"path/filepath"
"testing" "testing"
"time" "time"
@@ -10,9 +11,6 @@ import (
"go.opentelemetry.io/otel" "go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
@@ -35,12 +33,18 @@ func startTestingServices(t *testing.T) string {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
testDir := t.TempDir()
config.Datadir = testDir
err = util.CopyFileContents("../testdata/store.json", filepath.Join(testDir, "store.json"))
if err != nil {
t.Fatal(err)
}
_, signalLis := startSignal(t) _, signalLis := startSignal(t)
signalAddr := signalLis.Addr().String() signalAddr := signalLis.Addr().String()
config.Signal.URI = signalAddr config.Signal.URI = signalAddr
_, mgmLis := startManagement(t, config, "../testdata/store.sql") _, mgmLis := startManagement(t, config)
mgmAddr := mgmLis.Addr().String() mgmAddr := mgmLis.Addr().String()
return mgmAddr return mgmAddr
} }
@@ -52,7 +56,7 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
t.Fatal(err) t.Fatal(err)
} }
s := grpc.NewServer() s := grpc.NewServer()
srv, err := sig.NewServer(context.Background(), otel.Meter("")) srv, err := sig.NewServer(otel.Meter(""))
require.NoError(t, err) require.NoError(t, err)
sigProto.RegisterSignalExchangeServer(s, srv) sigProto.RegisterSignalExchangeServer(s, srv)
@@ -65,15 +69,14 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
return s, lis return s, lis
} }
func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.Server, net.Listener) { func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Listener) {
t.Helper() t.Helper()
lis, err := net.Listen("tcp", ":0") lis, err := net.Listen("tcp", ":0")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
s := grpc.NewServer() s := grpc.NewServer()
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), testFile, t.TempDir()) store, cleanUp, err := mgmt.NewTestStoreFromJson(context.Background(), config.Datadir)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -85,17 +88,12 @@ func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.
return nil, nil return nil, nil
} }
iv, _ := integrations.NewIntegratedValidator(context.Background(), eventStore) iv, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay) mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -15,11 +15,11 @@ import (
gstatus "google.golang.org/grpc/status" gstatus "google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/durationpb" "google.golang.org/protobuf/types/known/durationpb"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )
@@ -147,11 +147,6 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
ic.DNSRouteInterval = &dnsRouteInterval ic.DNSRouteInterval = &dnsRouteInterval
} }
providedSetupKey, err := getSetupKey()
if err != nil {
return err
}
config, err := internal.UpdateOrCreateConfig(ic) config, err := internal.UpdateOrCreateConfig(ic)
if err != nil { if err != nil {
return fmt.Errorf("get config file: %v", err) return fmt.Errorf("get config file: %v", err)
@@ -159,7 +154,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
config, _ = internal.UpdateOldManagementURL(ctx, config, configPath) config, _ = internal.UpdateOldManagementURL(ctx, config, configPath)
err = foregroundLogin(ctx, cmd, config, providedSetupKey) err = foregroundLogin(ctx, cmd, config, setupKey)
if err != nil { if err != nil {
return fmt.Errorf("foreground login failed: %v", err) return fmt.Errorf("foreground login failed: %v", err)
} }
@@ -168,10 +163,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
ctx, cancel = context.WithCancel(ctx) ctx, cancel = context.WithCancel(ctx)
SetupCloseHandler(ctx, cancel) SetupCloseHandler(ctx, cancel)
r := peer.NewRecorder(config.ManagementURL.String()) connectClient := internal.NewConnectClient(ctx, config, peer.NewRecorder(config.ManagementURL.String()))
r.GetFullStatus()
connectClient := internal.NewConnectClient(ctx, config, r)
return connectClient.Run() return connectClient.Run()
} }
@@ -207,13 +199,8 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
return nil return nil
} }
providedSetupKey, err := getSetupKey()
if err != nil {
return err
}
loginRequest := proto.LoginRequest{ loginRequest := proto.LoginRequest{
SetupKey: providedSetupKey, SetupKey: setupKey,
ManagementUrl: managementURL, ManagementUrl: managementURL,
AdminURL: adminURL, AdminURL: adminURL,
NatExternalIPs: natExternalIPs, NatExternalIPs: natExternalIPs,

View File

@@ -2,7 +2,6 @@ package cmd
import ( import (
"context" "context"
"os"
"testing" "testing"
"time" "time"
@@ -41,36 +40,6 @@ func TestUpDaemon(t *testing.T) {
return return
} }
// Test the setup-key-file flag.
tempFile, err := os.CreateTemp("", "setup-key")
if err != nil {
t.Errorf("could not create temp file, got error %v", err)
return
}
defer os.Remove(tempFile.Name())
if _, err := tempFile.Write([]byte("A2C8E62B-38F5-4553-B31E-DD66C696CEBB")); err != nil {
t.Errorf("could not write to temp file, got error %v", err)
return
}
if err := tempFile.Close(); err != nil {
t.Errorf("unable to close file, got error %v", err)
}
rootCmd.SetArgs([]string{
"login",
"--daemon-addr", "tcp://" + cliAddr,
"--setup-key-file", tempFile.Name(),
"--log-file", "",
})
if err := rootCmd.Execute(); err != nil {
t.Errorf("expected no error while running up command, got %v", err)
return
}
time.Sleep(time.Second * 3)
if status, err := state.Status(); err != nil && status != internal.StatusIdle {
t.Errorf("wrong status after login: %s, %v", internal.StatusIdle, err)
return
}
rootCmd.SetArgs([]string{ rootCmd.SetArgs([]string{
"up", "up",
"--daemon-addr", "tcp://" + cliAddr, "--daemon-addr", "tcp://" + cliAddr,

View File

@@ -8,8 +8,8 @@ import (
) )
func formatError(es []error) string { func formatError(es []error) string {
if len(es) == 1 { if len(es) == 0 {
return fmt.Sprintf("1 error occurred:\n\t* %s", es[0]) return fmt.Sprintf("0 error occurred:\n\t* %s", es[0])
} }
points := make([]string, len(es)) points := make([]string, len(es))

View File

@@ -3,6 +3,7 @@
package firewall package firewall
import ( import (
"context"
"fmt" "fmt"
"runtime" "runtime"
@@ -10,11 +11,10 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter" "github.com/netbirdio/netbird/client/firewall/uspfilter"
"github.com/netbirdio/netbird/client/internal/statemanager"
) )
// NewFirewall creates a firewall manager instance // NewFirewall creates a firewall manager instance
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager) (firewall.Manager, error) { func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager, error) {
if !iface.IsUserspaceBind() { if !iface.IsUserspaceBind() {
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS) return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
} }

View File

@@ -3,7 +3,7 @@
package firewall package firewall
import ( import (
"errors" "context"
"fmt" "fmt"
"os" "os"
@@ -15,7 +15,6 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbnftables "github.com/netbirdio/netbird/client/firewall/nftables" nbnftables "github.com/netbirdio/netbird/client/firewall/nftables"
"github.com/netbirdio/netbird/client/firewall/uspfilter" "github.com/netbirdio/netbird/client/firewall/uspfilter"
"github.com/netbirdio/netbird/client/internal/statemanager"
) )
const ( const (
@@ -33,60 +32,42 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
// FWType is the type for the firewall type // FWType is the type for the firewall type
type FWType int type FWType int
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) { func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager, error) {
// on the linux system we try to user nftables or iptables // on the linux system we try to user nftables or iptables
// in any case, because we need to allow netbird interface traffic // in any case, because we need to allow netbird interface traffic
// so we use AllowNetbird traffic from these firewall managers // so we use AllowNetbird traffic from these firewall managers
// for the userspace packet filtering firewall // for the userspace packet filtering firewall
fm, err := createNativeFirewall(iface, stateManager) var fm firewall.Manager
var errFw error
if !iface.IsUserspaceBind() {
return fm, err
}
if err != nil {
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
}
return createUserspaceFirewall(iface, fm)
}
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) {
fm, err := createFW(iface)
if err != nil {
return nil, fmt.Errorf("create firewall: %s", err)
}
if err = fm.Init(stateManager); err != nil {
return nil, fmt.Errorf("init firewall: %s", err)
}
return fm, nil
}
func createFW(iface IFaceMapper) (firewall.Manager, error) {
switch check() { switch check() {
case IPTABLES: case IPTABLES:
log.Info("creating an iptables firewall manager") log.Info("creating an iptables firewall manager")
return nbiptables.Create(iface) fm, errFw = nbiptables.Create(context, iface)
if errFw != nil {
log.Errorf("failed to create iptables manager: %s", errFw)
}
case NFTABLES: case NFTABLES:
log.Info("creating an nftables firewall manager") log.Info("creating an nftables firewall manager")
return nbnftables.Create(iface) fm, errFw = nbnftables.Create(context, iface)
default: if errFw != nil {
log.Info("no firewall manager found, trying to use userspace packet filtering firewall") log.Errorf("failed to create nftables manager: %s", errFw)
return nil, errors.New("no firewall manager found")
} }
default:
errFw = fmt.Errorf("no firewall manager found")
log.Info("no firewall manager found, trying to use userspace packet filtering firewall")
} }
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager) (firewall.Manager, error) { if iface.IsUserspaceBind() {
var errUsp error var errUsp error
if fm != nil { if errFw == nil {
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm) fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm)
} else { } else {
fm, errUsp = uspfilter.Create(iface) fm, errUsp = uspfilter.Create(iface)
} }
if errUsp != nil { if errUsp != nil {
return nil, fmt.Errorf("create userspace firewall: %s", errUsp) log.Debugf("failed to create userspace filtering firewall: %s", errUsp)
return nil, errUsp
} }
if err := fm.AllowNetbird(); err != nil { if err := fm.AllowNetbird(); err != nil {
@@ -95,6 +76,13 @@ func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager) (firewall.M
return fm, nil return fm, nil
} }
if errFw != nil {
return nil, errFw
}
return fm, nil
}
// check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found. // check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found.
func check() FWType { func check() FWType {
useIPTABLES := false useIPTABLES := false

View File

@@ -1,13 +1,11 @@
package firewall package firewall
import ( import "github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/client/iface/device"
)
// IFaceMapper defines subset methods of interface required for manager // IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface { type IFaceMapper interface {
Name() string Name() string
Address() device.WGAddress Address() iface.WGAddress
IsUserspaceBind() bool IsUserspaceBind() bool
SetFilter(device.PacketFilter) error SetFilter(iface.PacketFilter) error
} }

View File

@@ -11,8 +11,6 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net"
) )
const ( const (
@@ -21,65 +19,49 @@ const (
// rules chains contains the effective ACL rules // rules chains contains the effective ACL rules
chainNameInputRules = "NETBIRD-ACL-INPUT" chainNameInputRules = "NETBIRD-ACL-INPUT"
chainNameOutputRules = "NETBIRD-ACL-OUTPUT" chainNameOutputRules = "NETBIRD-ACL-OUTPUT"
postRoutingMark = "0x000007e4"
) )
type aclEntries map[string][][]string
type entry struct {
spec []string
position int
}
type aclManager struct { type aclManager struct {
iptablesClient *iptables.IPTables iptablesClient *iptables.IPTables
wgIface iFaceMapper wgIface iFaceMapper
routingFwChainName string routeingFwChainName string
entries aclEntries entries map[string][][]string
optionalEntries map[string][]entry
ipsetStore *ipsetStore ipsetStore *ipsetStore
stateManager *statemanager.Manager
} }
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routingFwChainName string) (*aclManager, error) { func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routeingFwChainName string) (*aclManager, error) {
m := &aclManager{ m := &aclManager{
iptablesClient: iptablesClient, iptablesClient: iptablesClient,
wgIface: wgIface, wgIface: wgIface,
routingFwChainName: routingFwChainName, routeingFwChainName: routeingFwChainName,
entries: make(map[string][][]string), entries: make(map[string][][]string),
optionalEntries: make(map[string][]entry),
ipsetStore: newIpsetStore(), ipsetStore: newIpsetStore(),
} }
if err := ipset.Init(); err != nil { err := ipset.Init()
return nil, fmt.Errorf("init ipset: %w", err) if err != nil {
return nil, fmt.Errorf("failed to init ipset: %w", err)
} }
m.seedInitialEntries()
err = m.cleanChains()
if err != nil {
return nil, err
}
err = m.createDefaultChains()
if err != nil {
return nil, err
}
return m, nil return m, nil
} }
func (m *aclManager) init(stateManager *statemanager.Manager) error { func (m *aclManager) AddFiltering(
m.stateManager = stateManager
m.seedInitialEntries()
m.seedInitialOptionalEntries()
if err := m.cleanChains(); err != nil {
return fmt.Errorf("clean chains: %w", err)
}
if err := m.createDefaultChains(); err != nil {
return fmt.Errorf("create default chains: %w", err)
}
m.updateState()
return nil
}
func (m *aclManager) AddPeerFiltering(
ip net.IP, ip net.IP,
protocol firewall.Protocol, protocol firewall.Protocol,
sPort *firewall.Port, sPort *firewall.Port,
@@ -145,7 +127,7 @@ func (m *aclManager) AddPeerFiltering(
return nil, fmt.Errorf("rule already exists") return nil, fmt.Errorf("rule already exists")
} }
if err := m.iptablesClient.Append("filter", chain, specs...); err != nil { if err := m.iptablesClient.Insert("filter", chain, 1, specs...); err != nil {
return nil, err return nil, err
} }
@@ -157,18 +139,28 @@ func (m *aclManager) AddPeerFiltering(
chain: chain, chain: chain,
} }
m.updateState() if !shouldAddToPrerouting(protocol, dPort, direction) {
return []firewall.Rule{rule}, nil return []firewall.Rule{rule}, nil
} }
// DeletePeerRule from the firewall by rule definition rulePrerouting, err := m.addPreroutingFilter(ipsetName, string(protocol), dPortVal, ip)
func (m *aclManager) DeletePeerRule(rule firewall.Rule) error { if err != nil {
return []firewall.Rule{rule}, err
}
return []firewall.Rule{rule, rulePrerouting}, nil
}
// DeleteRule from the firewall by rule definition
func (m *aclManager) DeleteRule(rule firewall.Rule) error {
r, ok := rule.(*Rule) r, ok := rule.(*Rule)
if !ok { if !ok {
return fmt.Errorf("invalid rule type") return fmt.Errorf("invalid rule type")
} }
if r.chain == "PREROUTING" {
goto DELETERULE
}
if ipsetList, ok := m.ipsetStore.ipset(r.ipsetName); ok { if ipsetList, ok := m.ipsetStore.ipset(r.ipsetName); ok {
// delete IP from ruleset IPs list and ipset // delete IP from ruleset IPs list and ipset
if _, ok := ipsetList.ips[r.ip]; ok { if _, ok := ipsetList.ips[r.ip]; ok {
@@ -193,23 +185,60 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
} }
} }
if err := m.iptablesClient.Delete(tableName, r.chain, r.specs...); err != nil { DELETERULE:
return fmt.Errorf("failed to delete rule: %s, %v: %w", r.chain, r.specs, err) var table string
if r.chain == "PREROUTING" {
table = "mangle"
} else {
table = "filter"
} }
err := m.iptablesClient.Delete(table, r.chain, r.specs...)
m.updateState() if err != nil {
log.Debugf("failed to delete rule, %s, %v: %s", r.chain, r.specs, err)
return nil }
return err
} }
func (m *aclManager) Reset() error { func (m *aclManager) Reset() error {
if err := m.cleanChains(); err != nil { return m.cleanChains()
return fmt.Errorf("clean chains: %w", err)
} }
m.updateState() func (m *aclManager) addPreroutingFilter(ipsetName string, protocol string, port string, ip net.IP) (*Rule, error) {
var src []string
if ipsetName != "" {
src = []string{"-m", "set", "--set", ipsetName, "src"}
} else {
src = []string{"-s", ip.String()}
}
specs := []string{
"-d", m.wgIface.Address().IP.String(),
"-p", protocol,
"--dport", port,
"-j", "MARK", "--set-mark", postRoutingMark,
}
return nil specs = append(src, specs...)
ok, err := m.iptablesClient.Exists("mangle", "PREROUTING", specs...)
if err != nil {
return nil, fmt.Errorf("failed to check rule: %w", err)
}
if ok {
return nil, fmt.Errorf("rule already exists")
}
if err := m.iptablesClient.Insert("mangle", "PREROUTING", 1, specs...); err != nil {
return nil, err
}
rule := &Rule{
ruleID: uuid.New().String(),
specs: specs,
ipsetName: ipsetName,
ip: ip.String(),
chain: "PREROUTING",
}
return rule, nil
} }
// todo write less destructive cleanup mechanism // todo write less destructive cleanup mechanism
@@ -264,7 +293,8 @@ func (m *aclManager) cleanChains() error {
ok, err = m.iptablesClient.ChainExists("mangle", "PREROUTING") ok, err = m.iptablesClient.ChainExists("mangle", "PREROUTING")
if err != nil { if err != nil {
return fmt.Errorf("list chains: %w", err) log.Debugf("failed to list chains: %s", err)
return err
} }
if ok { if ok {
for _, rule := range m.entries["PREROUTING"] { for _, rule := range m.entries["PREROUTING"] {
@@ -273,6 +303,11 @@ func (m *aclManager) cleanChains() error {
log.Errorf("failed to delete rule: %v, %s", rule, err) log.Errorf("failed to delete rule: %v, %s", rule, err)
} }
} }
err = m.iptablesClient.ClearChain("mangle", "PREROUTING")
if err != nil {
log.Debugf("failed to clear %s chain: %s", "PREROUTING", err)
return err
}
} }
for _, ipsetName := range m.ipsetStore.ipsetNames() { for _, ipsetName := range m.ipsetStore.ipsetNames() {
@@ -303,92 +338,64 @@ func (m *aclManager) createDefaultChains() error {
for chainName, rules := range m.entries { for chainName, rules := range m.entries {
for _, rule := range rules { for _, rule := range rules {
if err := m.iptablesClient.InsertUnique(tableName, chainName, 1, rule...); err != nil { if chainName == "FORWARD" {
// position 2 because we add it after router's, jump rule
if err := m.iptablesClient.InsertUnique(tableName, "FORWARD", 2, rule...); err != nil {
log.Debugf("failed to create input chain jump rule: %s", err)
return err
}
} else {
if err := m.iptablesClient.AppendUnique(tableName, chainName, rule...); err != nil {
log.Debugf("failed to create input chain jump rule: %s", err) log.Debugf("failed to create input chain jump rule: %s", err)
return err return err
} }
} }
} }
for chainName, entries := range m.optionalEntries {
for _, entry := range entries {
if err := m.iptablesClient.InsertUnique(tableName, chainName, entry.position, entry.spec...); err != nil {
log.Errorf("failed to insert optional entry %v: %v", entry.spec, err)
continue
} }
m.entries[chainName] = append(m.entries[chainName], entry.spec)
}
}
clear(m.optionalEntries)
return nil return nil
} }
// seedInitialEntries adds default rules to the entries map, rules are inserted on pos 1, hence the order is reversed.
// We want to make sure our traffic is not dropped by existing rules.
// The existing FORWARD rules/policies decide outbound traffic towards our interface.
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
// The OUTPUT chain gets an extra rule to allow traffic to any set up routes, the return traffic is handled by the INPUT related/established rule.
func (m *aclManager) seedInitialEntries() { func (m *aclManager) seedInitialEntries() {
established := getConntrackEstablished() m.appendToEntries("INPUT",
[]string{"-i", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
m.appendToEntries("INPUT",
[]string{"-i", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
m.appendToEntries("INPUT",
[]string{"-i", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", chainNameInputRules})
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", "DROP"}) m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules})
m.appendToEntries("INPUT", append([]string{"-i", m.wgIface.Name()}, established...)) m.appendToEntries("OUTPUT",
[]string{"-o", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
m.appendToEntries("OUTPUT",
[]string{"-o", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
m.appendToEntries("OUTPUT",
[]string{"-o", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", chainNameOutputRules})
m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "-j", "DROP"})
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"}) m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", m.routingFwChainName}) m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules})
m.appendToEntries("FORWARD", append([]string{"-o", m.wgIface.Name()}, established...)) m.appendToEntries("FORWARD",
} []string{"-o", m.wgIface.Name(), "-m", "mark", "--mark", postRoutingMark, "-j", "ACCEPT"})
m.appendToEntries("FORWARD",
[]string{"-i", m.wgIface.Name(), "-m", "mark", "--mark", postRoutingMark, "-j", "ACCEPT"})
m.appendToEntries("FORWARD", []string{"-o", m.wgIface.Name(), "-j", m.routeingFwChainName})
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", m.routeingFwChainName})
func (m *aclManager) seedInitialOptionalEntries() { m.appendToEntries("PREROUTING",
m.optionalEntries["FORWARD"] = []entry{ []string{"-t", "mangle", "-i", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().IP.String(), "-m", "mark", "--mark", postRoutingMark})
{
spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected), "-j", chainNameInputRules},
position: 2,
},
}
m.optionalEntries["PREROUTING"] = []entry{
{
spec: []string{"-t", "mangle", "-i", m.wgIface.Name(), "-m", "addrtype", "--dst-type", "LOCAL", "-j", "MARK", "--set-mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected)},
position: 1,
},
}
} }
func (m *aclManager) appendToEntries(chainName string, spec []string) { func (m *aclManager) appendToEntries(chainName string, spec []string) {
m.entries[chainName] = append(m.entries[chainName], spec) m.entries[chainName] = append(m.entries[chainName], spec)
} }
func (m *aclManager) updateState() {
if m.stateManager == nil {
return
}
var currentState *ShutdownState
if existing := m.stateManager.GetState(currentState); existing != nil {
if existingState, ok := existing.(*ShutdownState); ok {
currentState = existingState
}
}
if currentState == nil {
currentState = &ShutdownState{}
}
currentState.Lock()
defer currentState.Unlock()
currentState.ACLEntries = m.entries
currentState.ACLIPsetStore = m.ipsetStore
if err := m.stateManager.UpdateState(currentState); err != nil {
log.Errorf("failed to update state: %v", err)
}
}
// filterRuleSpecs returns the specs of a filtering rule // filterRuleSpecs returns the specs of a filtering rule
func filterRuleSpecs( func filterRuleSpecs(
ip net.IP, protocol string, sPort, dPort string, direction firewall.RuleDirection, action firewall.Action, ipsetName string, ip net.IP, protocol string, sPort, dPort string, direction firewall.RuleDirection, action firewall.Action, ipsetName string,
@@ -449,3 +456,18 @@ func transformIPsetName(ipsetName string, sPort, dPort string) string {
return ipsetName return ipsetName
} }
} }
func shouldAddToPrerouting(proto firewall.Protocol, dPort *firewall.Port, direction firewall.RuleDirection) bool {
if proto == "all" {
return false
}
if direction != firewall.RuleDirectionIN {
return false
}
if dPort == nil {
return false
}
return true
}

View File

@@ -4,17 +4,13 @@ import (
"context" "context"
"fmt" "fmt"
"net" "net"
"net/netip"
"sync" "sync"
"github.com/coreos/go-iptables/iptables" "github.com/coreos/go-iptables/iptables"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/client/internal/statemanager"
) )
// Manager of iptables firewall // Manager of iptables firewall
@@ -25,7 +21,7 @@ type Manager struct {
ipv4Client *iptables.IPTables ipv4Client *iptables.IPTables
aclMgr *aclManager aclMgr *aclManager
router *router router *routerManager
} }
// iFaceMapper defines subset methods of interface required for manager // iFaceMapper defines subset methods of interface required for manager
@@ -36,10 +32,10 @@ type iFaceMapper interface {
} }
// Create iptables firewall manager // Create iptables firewall manager
func Create(wgIface iFaceMapper) (*Manager, error) { func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
if err != nil { if err != nil {
return nil, fmt.Errorf("init iptables: %w", err) return nil, fmt.Errorf("iptables is not installed in the system or not supported")
} }
m := &Manager{ m := &Manager{
@@ -47,55 +43,24 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
ipv4Client: iptablesClient, ipv4Client: iptablesClient,
} }
m.router, err = newRouter(iptablesClient, wgIface) m.router, err = newRouterManager(context, iptablesClient)
if err != nil { if err != nil {
return nil, fmt.Errorf("create router: %w", err) log.Debugf("failed to initialize route related chains: %s", err)
return nil, err
} }
m.aclMgr, err = newAclManager(iptablesClient, wgIface, m.router.RouteingFwChainName())
m.aclMgr, err = newAclManager(iptablesClient, wgIface, chainRTFWD)
if err != nil { if err != nil {
return nil, fmt.Errorf("create acl manager: %w", err) log.Debugf("failed to initialize ACL manager: %s", err)
return nil, err
} }
return m, nil return m, nil
} }
func (m *Manager) Init(stateManager *statemanager.Manager) error { // AddFiltering rule to the firewall
state := &ShutdownState{
InterfaceState: &InterfaceState{
NameStr: m.wgIface.Name(),
WGAddress: m.wgIface.Address(),
UserspaceBind: m.wgIface.IsUserspaceBind(),
},
}
stateManager.RegisterState(state)
if err := stateManager.UpdateState(state); err != nil {
log.Errorf("failed to update state: %v", err)
}
if err := m.router.init(stateManager); err != nil {
return fmt.Errorf("router init: %w", err)
}
if err := m.aclMgr.init(stateManager); err != nil {
// TODO: cleanup router
return fmt.Errorf("acl manager init: %w", err)
}
// persist early to ensure cleanup of chains
go func() {
if err := stateManager.PersistState(context.Background()); err != nil {
log.Errorf("failed to persist state: %v", err)
}
}()
return nil
}
// AddPeerFiltering adds a rule to the firewall
// //
// Comment will be ignored because some system this feature is not supported // Comment will be ignored because some system this feature is not supported
func (m *Manager) AddPeerFiltering( func (m *Manager) AddFiltering(
ip net.IP, ip net.IP,
protocol firewall.Protocol, protocol firewall.Protocol,
sPort *firewall.Port, sPort *firewall.Port,
@@ -108,86 +73,50 @@ func (m *Manager) AddPeerFiltering(
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
return m.aclMgr.AddPeerFiltering(ip, protocol, sPort, dPort, direction, action, ipsetName) return m.aclMgr.AddFiltering(ip, protocol, sPort, dPort, direction, action, ipsetName)
} }
func (m *Manager) AddRouteFiltering( // DeleteRule from the firewall by rule definition
sources []netip.Prefix, func (m *Manager) DeleteRule(rule firewall.Rule) error {
destination netip.Prefix,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if !destination.Addr().Is4() { return m.aclMgr.DeleteRule(rule)
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
}
return m.router.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
}
// DeletePeerRule from the firewall by rule definition
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.aclMgr.DeletePeerRule(rule)
}
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.DeleteRouteRule(rule)
} }
func (m *Manager) IsServerRouteSupported() bool { func (m *Manager) IsServerRouteSupported() bool {
return true return true
} }
func (m *Manager) AddNatRule(pair firewall.RouterPair) error { func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
return m.router.AddNatRule(pair) return m.router.InsertRoutingRules(pair)
} }
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error { func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
return m.router.RemoveNatRule(pair) return m.router.RemoveRoutingRules(pair)
}
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
return firewall.SetLegacyManagement(m.router, isLegacy)
} }
// Reset firewall to the default state // Reset firewall to the default state
func (m *Manager) Reset(stateManager *statemanager.Manager) error { func (m *Manager) Reset() error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
var merr *multierror.Error errAcl := m.aclMgr.Reset()
if errAcl != nil {
if err := m.aclMgr.Reset(); err != nil { log.Errorf("failed to clean up ACL rules from firewall: %s", errAcl)
merr = multierror.Append(merr, fmt.Errorf("reset acl manager: %w", err))
} }
if err := m.router.Reset(); err != nil { errMgr := m.router.Reset()
merr = multierror.Append(merr, fmt.Errorf("reset router: %w", err)) if errMgr != nil {
log.Errorf("failed to clean up router rules from firewall: %s", errMgr)
return errMgr
} }
return errAcl
// attempt to delete state only if all other operations succeeded
if merr == nil {
if err := stateManager.DeleteState(&ShutdownState{}); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete state: %w", err))
}
}
return nberrors.FormatErrorOrNil(merr)
} }
// AllowNetbird allows netbird interface traffic // AllowNetbird allows netbird interface traffic
@@ -196,7 +125,7 @@ func (m *Manager) AllowNetbird() error {
return nil return nil
} }
_, err := m.AddPeerFiltering( _, err := m.AddFiltering(
net.ParseIP("0.0.0.0"), net.ParseIP("0.0.0.0"),
"all", "all",
nil, nil,
@@ -207,19 +136,20 @@ func (m *Manager) AllowNetbird() error {
"", "",
) )
if err != nil { if err != nil {
return fmt.Errorf("allow netbird interface traffic: %w", err) return fmt.Errorf("failed to allow netbird interface traffic: %w", err)
} }
return nil _, err = m.AddFiltering(
net.ParseIP("0.0.0.0"),
"all",
nil,
nil,
firewall.RuleDirectionOUT,
firewall.ActionAccept,
"",
"",
)
return err
} }
// Flush doesn't need to be implemented for this manager // Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil } func (m *Manager) Flush() error { return nil }
// CollectStats returns connection tracking statistics
func (m *Manager) CollectStats() []*firewall.FlowStats {
return nil
}
func getConntrackEstablished() []string {
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
}

View File

@@ -1,6 +1,7 @@
package iptables package iptables
import ( import (
"context"
"fmt" "fmt"
"net" "net"
"testing" "testing"
@@ -10,24 +11,9 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall/manager" fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/iface"
) )
var ifaceMock = &iFaceMock{
NameFunc: func() string {
return "lo"
},
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
}
},
}
// iFaceMapper defines subset methods of interface required for manager // iFaceMapper defines subset methods of interface required for manager
type iFaceMock struct { type iFaceMock struct {
NameFunc func() string NameFunc func() string
@@ -54,15 +40,29 @@ func TestIptablesManager(t *testing.T) {
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err) require.NoError(t, err)
mock := &iFaceMock{
NameFunc: func() string {
return "lo"
},
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
}
},
}
// just check on the local interface // just check on the local interface
manager, err := Create(ifaceMock) manager, err := Create(context.Background(), mock)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, manager.Init(nil))
time.Sleep(time.Second) time.Sleep(time.Second)
defer func() { defer func() {
err := manager.Reset(nil) err := manager.Reset()
require.NoError(t, err, "clear the manager state") require.NoError(t, err, "clear the manager state")
time.Sleep(time.Second) time.Sleep(time.Second)
@@ -72,7 +72,7 @@ func TestIptablesManager(t *testing.T) {
t.Run("add first rule", func(t *testing.T) { t.Run("add first rule", func(t *testing.T) {
ip := net.ParseIP("10.20.0.2") ip := net.ParseIP("10.20.0.2")
port := &fw.Port{Values: []int{8080}} port := &fw.Port{Values: []int{8080}}
rule1, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic") rule1, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
for _, r := range rule1 { for _, r := range rule1 {
@@ -87,7 +87,7 @@ func TestIptablesManager(t *testing.T) {
port := &fw.Port{ port := &fw.Port{
Values: []int{8043: 8046}, Values: []int{8043: 8046},
} }
rule2, err = manager.AddPeerFiltering( rule2, err = manager.AddFiltering(
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTPS traffic from ports range") ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
@@ -99,7 +99,7 @@ func TestIptablesManager(t *testing.T) {
t.Run("delete first rule", func(t *testing.T) { t.Run("delete first rule", func(t *testing.T) {
for _, r := range rule1 { for _, r := range rule1 {
err := manager.DeletePeerRule(r) err := manager.DeleteRule(r)
require.NoError(t, err, "failed to delete rule") require.NoError(t, err, "failed to delete rule")
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, false, r.(*Rule).specs...) checkRuleSpecs(t, ipv4Client, chainNameOutputRules, false, r.(*Rule).specs...)
@@ -108,7 +108,7 @@ func TestIptablesManager(t *testing.T) {
t.Run("delete second rule", func(t *testing.T) { t.Run("delete second rule", func(t *testing.T) {
for _, r := range rule2 { for _, r := range rule2 {
err := manager.DeletePeerRule(r) err := manager.DeleteRule(r)
require.NoError(t, err, "failed to delete rule") require.NoError(t, err, "failed to delete rule")
} }
@@ -119,10 +119,10 @@ func TestIptablesManager(t *testing.T) {
// add second rule // add second rule
ip := net.ParseIP("10.20.0.3") ip := net.ParseIP("10.20.0.3")
port := &fw.Port{Values: []int{5353}} port := &fw.Port{Values: []int{5353}}
_, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic") _, err = manager.AddFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
err = manager.Reset(nil) err = manager.Reset()
require.NoError(t, err, "failed to reset") require.NoError(t, err, "failed to reset")
ok, err := ipv4Client.ChainExists("filter", chainNameInputRules) ok, err := ipv4Client.ChainExists("filter", chainNameInputRules)
@@ -154,14 +154,13 @@ func TestIptablesManagerIPSet(t *testing.T) {
} }
// just check on the local interface // just check on the local interface
manager, err := Create(mock) manager, err := Create(context.Background(), mock)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, manager.Init(nil))
time.Sleep(time.Second) time.Sleep(time.Second)
defer func() { defer func() {
err := manager.Reset(nil) err := manager.Reset()
require.NoError(t, err, "clear the manager state") require.NoError(t, err, "clear the manager state")
time.Sleep(time.Second) time.Sleep(time.Second)
@@ -171,7 +170,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
t.Run("add first rule with set", func(t *testing.T) { t.Run("add first rule with set", func(t *testing.T) {
ip := net.ParseIP("10.20.0.2") ip := net.ParseIP("10.20.0.2")
port := &fw.Port{Values: []int{8080}} port := &fw.Port{Values: []int{8080}}
rule1, err = manager.AddPeerFiltering( rule1, err = manager.AddFiltering(
ip, "tcp", nil, port, fw.RuleDirectionOUT, ip, "tcp", nil, port, fw.RuleDirectionOUT,
fw.ActionAccept, "default", "accept HTTP traffic", fw.ActionAccept, "default", "accept HTTP traffic",
) )
@@ -190,7 +189,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
port := &fw.Port{ port := &fw.Port{
Values: []int{443}, Values: []int{443},
} }
rule2, err = manager.AddPeerFiltering( rule2, err = manager.AddFiltering(
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept,
"default", "accept HTTPS traffic from ports range", "default", "accept HTTPS traffic from ports range",
) )
@@ -203,7 +202,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
t.Run("delete first rule", func(t *testing.T) { t.Run("delete first rule", func(t *testing.T) {
for _, r := range rule1 { for _, r := range rule1 {
err := manager.DeletePeerRule(r) err := manager.DeleteRule(r)
require.NoError(t, err, "failed to delete rule") require.NoError(t, err, "failed to delete rule")
require.NotContains(t, manager.aclMgr.ipsetStore.ipsets, r.(*Rule).ruleID, "rule must be removed form the ruleset index") require.NotContains(t, manager.aclMgr.ipsetStore.ipsets, r.(*Rule).ruleID, "rule must be removed form the ruleset index")
@@ -212,7 +211,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
t.Run("delete second rule", func(t *testing.T) { t.Run("delete second rule", func(t *testing.T) {
for _, r := range rule2 { for _, r := range rule2 {
err := manager.DeletePeerRule(r) err := manager.DeleteRule(r)
require.NoError(t, err, "failed to delete rule") require.NoError(t, err, "failed to delete rule")
require.Empty(t, manager.aclMgr.ipsetStore.ipsets, "rulesets index after removed second rule must be empty") require.Empty(t, manager.aclMgr.ipsetStore.ipsets, "rulesets index after removed second rule must be empty")
@@ -220,7 +219,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
}) })
t.Run("reset check", func(t *testing.T) { t.Run("reset check", func(t *testing.T) {
err = manager.Reset(nil) err = manager.Reset()
require.NoError(t, err, "failed to reset") require.NoError(t, err, "failed to reset")
}) })
} }
@@ -252,13 +251,12 @@ func TestIptablesCreatePerformance(t *testing.T) {
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} { for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) { t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
// just check on the local interface // just check on the local interface
manager, err := Create(mock) manager, err := Create(context.Background(), mock)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, manager.Init(nil))
time.Sleep(time.Second) time.Sleep(time.Second)
defer func() { defer func() {
err := manager.Reset(nil) err := manager.Reset()
require.NoError(t, err, "clear the manager state") require.NoError(t, err, "clear the manager state")
time.Sleep(time.Second) time.Sleep(time.Second)
@@ -271,9 +269,9 @@ func TestIptablesCreatePerformance(t *testing.T) {
for i := 0; i < testMax; i++ { for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []int{1000 + i}} port := &fw.Port{Values: []int{1000 + i}}
if i%2 == 0 { if i%2 == 0 {
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic") _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
} else { } else {
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic") _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
} }
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")

View File

@@ -3,597 +3,370 @@
package iptables package iptables
import ( import (
"context"
"fmt" "fmt"
"net/netip"
"strconv"
"strings" "strings"
"github.com/coreos/go-iptables/iptables" "github.com/coreos/go-iptables/iptables"
"github.com/hashicorp/go-multierror"
"github.com/nadoo/ipset"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/acl/id" )
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/statemanager" const (
nbnet "github.com/netbirdio/netbird/util/net" Ipv4Forwarding = "netbird-rt-forwarding"
ipv4Nat = "netbird-rt-nat"
) )
// constants needed to manage and create iptable rules // constants needed to manage and create iptable rules
const ( const (
tableFilter = "filter" tableFilter = "filter"
tableNat = "nat" tableNat = "nat"
tableMangle = "mangle" chainFORWARD = "FORWARD"
chainPOSTROUTING = "POSTROUTING" chainPOSTROUTING = "POSTROUTING"
chainPREROUTING = "PREROUTING"
chainRTNAT = "NETBIRD-RT-NAT" chainRTNAT = "NETBIRD-RT-NAT"
chainRTFWD = "NETBIRD-RT-FWD" chainRTFWD = "NETBIRD-RT-FWD"
chainRTPRE = "NETBIRD-RT-PRE"
routingFinalForwardJump = "ACCEPT" routingFinalForwardJump = "ACCEPT"
routingFinalNatJump = "MASQUERADE" routingFinalNatJump = "MASQUERADE"
jumpPre = "jump-pre"
jumpNat = "jump-nat"
matchSet = "--match-set"
) )
type routeFilteringRuleParams struct { type routerManager struct {
Sources []netip.Prefix ctx context.Context
Destination netip.Prefix stop context.CancelFunc
Proto firewall.Protocol
SPort *firewall.Port
DPort *firewall.Port
Direction firewall.RuleDirection
Action firewall.Action
SetName string
}
type routeRules map[string][]string
type ipsetCounter = refcounter.Counter[string, []netip.Prefix, struct{}]
type router struct {
iptablesClient *iptables.IPTables iptablesClient *iptables.IPTables
rules routeRules rules map[string][]string
ipsetCounter *ipsetCounter
wgIface iFaceMapper
legacyManagement bool
stateManager *statemanager.Manager
} }
func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) { func newRouterManager(parentCtx context.Context, iptablesClient *iptables.IPTables) (*routerManager, error) {
r := &router{ ctx, cancel := context.WithCancel(parentCtx)
m := &routerManager{
ctx: ctx,
stop: cancel,
iptablesClient: iptablesClient, iptablesClient: iptablesClient,
rules: make(map[string][]string), rules: make(map[string][]string),
wgIface: wgIface,
} }
r.ipsetCounter = refcounter.New( err := m.cleanUpDefaultForwardRules()
func(name string, sources []netip.Prefix) (struct{}, error) { if err != nil {
return struct{}{}, r.createIpSet(name, sources) log.Errorf("failed to cleanup routing rules: %s", err)
}, return nil, err
func(name string, _ struct{}) error { }
return r.deleteIpSet(name) err = m.createContainers()
}, if err != nil {
) log.Errorf("failed to create containers for route: %s", err)
}
if err := ipset.Init(); err != nil { return m, err
return nil, fmt.Errorf("init ipset: %w", err)
} }
return r, nil // InsertRoutingRules inserts an iptables rule pair to the forwarding chain and if enabled, to the nat chain
func (i *routerManager) InsertRoutingRules(pair firewall.RouterPair) error {
err := i.insertRoutingRule(firewall.ForwardingFormat, tableFilter, chainRTFWD, routingFinalForwardJump, pair)
if err != nil {
return err
} }
func (r *router) init(stateManager *statemanager.Manager) error { err = i.insertRoutingRule(firewall.InForwardingFormat, tableFilter, chainRTFWD, routingFinalForwardJump, firewall.GetInPair(pair))
r.stateManager = stateManager if err != nil {
return err
if err := r.cleanUpDefaultForwardRules(); err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
}
if err := r.createContainers(); err != nil {
return fmt.Errorf("create containers: %w", err)
}
r.updateState()
return nil
}
func (r *router) AddRouteFiltering(
sources []netip.Prefix,
destination netip.Prefix,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
if _, ok := r.rules[string(ruleKey)]; ok {
return ruleKey, nil
}
var setName string
if len(sources) > 1 {
setName = firewall.GenerateSetName(sources)
if _, err := r.ipsetCounter.Increment(setName, sources); err != nil {
return nil, fmt.Errorf("create or get ipset: %w", err)
}
}
params := routeFilteringRuleParams{
Sources: sources,
Destination: destination,
Proto: proto,
SPort: sPort,
DPort: dPort,
Action: action,
SetName: setName,
}
rule := genRouteFilteringRuleSpec(params)
if err := r.iptablesClient.Append(tableFilter, chainRTFWD, rule...); err != nil {
return nil, fmt.Errorf("add route rule: %v", err)
}
r.rules[string(ruleKey)] = rule
r.updateState()
return ruleKey, nil
}
func (r *router) DeleteRouteRule(rule firewall.Rule) error {
ruleKey := rule.GetRuleID()
if rule, exists := r.rules[ruleKey]; exists {
setName := r.findSetNameInRule(rule)
if err := r.iptablesClient.Delete(tableFilter, chainRTFWD, rule...); err != nil {
return fmt.Errorf("delete route rule: %v", err)
}
delete(r.rules, ruleKey)
if setName != "" {
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
return fmt.Errorf("failed to remove ipset: %w", err)
}
}
} else {
log.Debugf("route rule %s not found", ruleKey)
}
r.updateState()
return nil
}
func (r *router) findSetNameInRule(rule []string) string {
for i, arg := range rule {
if arg == "-m" && i+3 < len(rule) && rule[i+1] == "set" && rule[i+2] == matchSet {
return rule[i+3]
}
}
return ""
}
func (r *router) createIpSet(setName string, sources []netip.Prefix) error {
if err := ipset.Create(setName, ipset.OptTimeout(0)); err != nil {
return fmt.Errorf("create set %s: %w", setName, err)
}
for _, prefix := range sources {
if err := ipset.AddPrefix(setName, prefix); err != nil {
return fmt.Errorf("add element to set %s: %w", setName, err)
}
}
return nil
}
func (r *router) deleteIpSet(setName string) error {
if err := ipset.Destroy(setName); err != nil {
return fmt.Errorf("destroy set %s: %w", setName, err)
}
return nil
}
// AddNatRule inserts an iptables rule pair into the nat chain
func (r *router) AddNatRule(pair firewall.RouterPair) error {
if r.legacyManagement {
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
if err := r.addLegacyRouteRule(pair); err != nil {
return fmt.Errorf("add legacy routing rule: %w", err)
}
} }
if !pair.Masquerade { if !pair.Masquerade {
return nil return nil
} }
if err := r.addNatRule(pair); err != nil { err = i.addNATRule(firewall.NatFormat, tableNat, chainRTNAT, routingFinalNatJump, pair)
return fmt.Errorf("add nat rule: %w", err) if err != nil {
}
if err := r.addNatRule(firewall.GetInversePair(pair)); err != nil {
return fmt.Errorf("add inverse nat rule: %w", err)
}
r.updateState()
return nil
}
// RemoveNatRule removes an iptables rule pair from forwarding and nat chains
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove nat rule: %w", err)
}
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
return fmt.Errorf("remove inverse nat rule: %w", err)
}
if err := r.removeLegacyRouteRule(pair); err != nil {
return fmt.Errorf("remove legacy routing rule: %w", err)
}
r.updateState()
return nil
}
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
if err := r.removeLegacyRouteRule(pair); err != nil {
return err return err
} }
rule := []string{"-s", pair.Source.String(), "-d", pair.Destination.String(), "-j", routingFinalForwardJump} err = i.addNATRule(firewall.InNatFormat, tableNat, chainRTNAT, routingFinalNatJump, firewall.GetInPair(pair))
if err := r.iptablesClient.Append(tableFilter, chainRTFWD, rule...); err != nil { if err != nil {
return fmt.Errorf("add legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err) return err
}
r.rules[ruleKey] = rule
return nil
}
func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
if rule, exists := r.rules[ruleKey]; exists {
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil {
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
}
delete(r.rules, ruleKey)
} else {
log.Debugf("legacy forwarding rule %s not found", ruleKey)
} }
return nil return nil
} }
// GetLegacyManagement returns the current legacy management mode // insertRoutingRule inserts an iptables rule
func (r *router) GetLegacyManagement() bool { func (i *routerManager) insertRoutingRule(keyFormat, table, chain, jump string, pair firewall.RouterPair) error {
return r.legacyManagement var err error
ruleKey := firewall.GenKey(keyFormat, pair.ID)
rule := genRuleSpec(jump, pair.Source, pair.Destination)
existingRule, found := i.rules[ruleKey]
if found {
err = i.iptablesClient.DeleteIfExists(table, chain, existingRule...)
if err != nil {
return fmt.Errorf("error while removing existing %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err)
}
delete(i.rules, ruleKey)
} }
// SetLegacyManagement sets the route manager to use legacy management mode err = i.iptablesClient.Insert(table, chain, 1, rule...)
func (r *router) SetLegacyManagement(isLegacy bool) { if err != nil {
r.legacyManagement = isLegacy return fmt.Errorf("error while adding new %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err)
} }
// RemoveAllLegacyRouteRules removes all legacy routing rules for mgmt servers pre route acls i.rules[ruleKey] = rule
func (r *router) RemoveAllLegacyRouteRules() error {
var merr *multierror.Error return nil
for k, rule := range r.rules {
if !strings.HasPrefix(k, firewall.ForwardingFormatPrefix) {
continue
}
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err))
} else {
delete(r.rules, k)
}
} }
r.updateState() // RemoveRoutingRules removes an iptables rule pair from forwarding and nat chains
func (i *routerManager) RemoveRoutingRules(pair firewall.RouterPair) error {
return nberrors.FormatErrorOrNil(merr) err := i.removeRoutingRule(firewall.ForwardingFormat, tableFilter, chainRTFWD, pair)
if err != nil {
return err
} }
func (r *router) Reset() error { err = i.removeRoutingRule(firewall.InForwardingFormat, tableFilter, chainRTFWD, firewall.GetInPair(pair))
var merr *multierror.Error if err != nil {
if err := r.cleanUpDefaultForwardRules(); err != nil { return err
merr = multierror.Append(merr, err)
}
r.rules = make(map[string][]string)
if err := r.ipsetCounter.Flush(); err != nil {
merr = multierror.Append(merr, err)
} }
r.updateState() if !pair.Masquerade {
return nil
return nberrors.FormatErrorOrNil(merr)
} }
func (r *router) cleanUpDefaultForwardRules() error { err = i.removeRoutingRule(firewall.NatFormat, tableNat, chainRTNAT, pair)
if err := r.cleanJumpRules(); err != nil { if err != nil {
return fmt.Errorf("clean jump rules: %w", err) return err
}
err = i.removeRoutingRule(firewall.InNatFormat, tableNat, chainRTNAT, firewall.GetInPair(pair))
if err != nil {
return err
}
return nil
}
func (i *routerManager) removeRoutingRule(keyFormat, table, chain string, pair firewall.RouterPair) error {
var err error
ruleKey := firewall.GenKey(keyFormat, pair.ID)
existingRule, found := i.rules[ruleKey]
if found {
err = i.iptablesClient.DeleteIfExists(table, chain, existingRule...)
if err != nil {
return fmt.Errorf("error while removing existing %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err)
}
}
delete(i.rules, ruleKey)
return nil
}
func (i *routerManager) RouteingFwChainName() string {
return chainRTFWD
}
func (i *routerManager) Reset() error {
err := i.cleanUpDefaultForwardRules()
if err != nil {
return err
}
i.rules = make(map[string][]string)
return nil
}
func (i *routerManager) cleanUpDefaultForwardRules() error {
err := i.cleanJumpRules()
if err != nil {
return err
} }
log.Debug("flushing routing related tables") log.Debug("flushing routing related tables")
for _, chainInfo := range []struct { ok, err := i.iptablesClient.ChainExists(tableFilter, chainRTFWD)
chain string
table string
}{
{chainRTFWD, tableFilter},
{chainRTNAT, tableNat},
{chainRTPRE, tableMangle},
} {
ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain)
if err != nil { if err != nil {
return fmt.Errorf("check chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err) log.Errorf("failed check chain %s,error: %v", chainRTFWD, err)
return err
} else if ok { } else if ok {
if err = r.iptablesClient.ClearAndDeleteChain(chainInfo.table, chainInfo.chain); err != nil { err = i.iptablesClient.ClearAndDeleteChain(tableFilter, chainRTFWD)
return fmt.Errorf("clear and delete chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
}
}
}
return nil
}
func (r *router) createContainers() error {
for _, chainInfo := range []struct {
chain string
table string
}{
{chainRTFWD, tableFilter},
{chainRTPRE, tableMangle},
{chainRTNAT, tableNat},
} {
if err := r.createAndSetupChain(chainInfo.chain); err != nil {
return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
}
}
if err := r.insertEstablishedRule(chainRTFWD); err != nil {
return fmt.Errorf("insert established rule: %w", err)
}
if err := r.addPostroutingRules(); err != nil {
return fmt.Errorf("add static nat rules: %w", err)
}
if err := r.addJumpRules(); err != nil {
return fmt.Errorf("add jump rules: %w", err)
}
return nil
}
func (r *router) addPostroutingRules() error {
// First rule for outbound masquerade
rule1 := []string{
"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
"!", "-o", "lo",
"-j", routingFinalNatJump,
}
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule1...); err != nil {
return fmt.Errorf("add outbound masquerade rule: %v", err)
}
r.rules["static-nat-outbound"] = rule1
// Second rule for return traffic masquerade
rule2 := []string{
"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
"-o", r.wgIface.Name(),
"-j", routingFinalNatJump,
}
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule2...); err != nil {
return fmt.Errorf("add return masquerade rule: %v", err)
}
r.rules["static-nat-return"] = rule2
return nil
}
func (r *router) createAndSetupChain(chain string) error {
table := r.getTableForChain(chain)
if err := r.iptablesClient.NewChain(table, chain); err != nil {
return fmt.Errorf("failed creating chain %s, error: %v", chain, err)
}
return nil
}
func (r *router) getTableForChain(chain string) string {
switch chain {
case chainRTNAT:
return tableNat
case chainRTPRE:
return tableMangle
default:
return tableFilter
}
}
func (r *router) insertEstablishedRule(chain string) error {
establishedRule := getConntrackEstablished()
err := r.iptablesClient.Insert(tableFilter, chain, 1, establishedRule...)
if err != nil { if err != nil {
return fmt.Errorf("failed to insert established rule: %v", err) log.Errorf("failed cleaning chain %s,error: %v", chainRTFWD, err)
return err
}
} }
ruleKey := "established-" + chain ok, err = i.iptablesClient.ChainExists(tableNat, chainRTNAT)
r.rules[ruleKey] = establishedRule if err != nil {
log.Errorf("failed check chain %s,error: %v", chainRTNAT, err)
return nil return err
} } else if ok {
err = i.iptablesClient.ClearAndDeleteChain(tableNat, chainRTNAT)
func (r *router) addJumpRules() error { if err != nil {
// Jump to NAT chain log.Errorf("failed cleaning chain %s,error: %v", chainRTNAT, err)
natRule := []string{"-j", chainRTNAT} return err
if err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, natRule...); err != nil {
return fmt.Errorf("add nat jump rule: %v", err)
}
r.rules[jumpNat] = natRule
// Jump to prerouting chain
preRule := []string{"-j", chainRTPRE}
if err := r.iptablesClient.Insert(tableMangle, chainPREROUTING, 1, preRule...); err != nil {
return fmt.Errorf("add prerouting jump rule: %v", err)
}
r.rules[jumpPre] = preRule
return nil
}
func (r *router) cleanJumpRules() error {
for _, ruleKey := range []string{jumpNat, jumpPre} {
if rule, exists := r.rules[ruleKey]; exists {
table := tableNat
chain := chainPOSTROUTING
if ruleKey == jumpPre {
table = tableMangle
chain = chainPREROUTING
}
if err := r.iptablesClient.DeleteIfExists(table, chain, rule...); err != nil {
return fmt.Errorf("delete rule from chain %s in table %s, err: %v", chain, table, err)
}
delete(r.rules, ruleKey)
} }
} }
return nil return nil
} }
func (r *router) addNatRule(pair firewall.RouterPair) error { func (i *routerManager) createContainers() error {
ruleKey := firewall.GenKey(firewall.NatFormat, pair) if i.rules[Ipv4Forwarding] != nil {
if rule, exists := r.rules[ruleKey]; exists {
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainRTPRE, rule...); err != nil {
return fmt.Errorf("error while removing existing marking rule for %s: %v", pair.Destination, err)
}
delete(r.rules, ruleKey)
}
markValue := nbnet.PreroutingFwmarkMasquerade
if pair.Inverse {
markValue = nbnet.PreroutingFwmarkMasqueradeReturn
}
rule := []string{"-i", r.wgIface.Name()}
if pair.Inverse {
rule = []string{"!", "-i", r.wgIface.Name()}
}
rule = append(rule,
"-m", "conntrack",
"--ctstate", "NEW",
"-s", pair.Source.String(),
"-d", pair.Destination.String(),
"-j", "MARK", "--set-mark", fmt.Sprintf("%#x", markValue),
)
if err := r.iptablesClient.Append(tableMangle, chainRTPRE, rule...); err != nil {
return fmt.Errorf("error while adding marking rule for %s: %v", pair.Destination, err)
}
r.rules[ruleKey] = rule
return nil return nil
} }
func (r *router) removeNatRule(pair firewall.RouterPair) error { errMSGFormat := "failed creating chain %s,error: %v"
ruleKey := firewall.GenKey(firewall.NatFormat, pair) err := i.createChain(tableFilter, chainRTFWD)
if err != nil {
if rule, exists := r.rules[ruleKey]; exists { return fmt.Errorf(errMSGFormat, chainRTFWD, err)
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainRTPRE, rule...); err != nil {
return fmt.Errorf("error while removing marking rule for %s: %v", pair.Destination, err)
} }
delete(r.rules, ruleKey)
} else { err = i.createChain(tableNat, chainRTNAT)
log.Debugf("marking rule %s not found", ruleKey) if err != nil {
return fmt.Errorf(errMSGFormat, chainRTNAT, err)
}
err = i.addJumpRules()
if err != nil {
return fmt.Errorf("error while creating jump rules: %v", err)
} }
return nil return nil
} }
func (r *router) updateState() { // addJumpRules create jump rules to send packets to NetBird chains
if r.stateManager == nil { func (i *routerManager) addJumpRules() error {
return rule := []string{"-j", chainRTFWD}
err := i.iptablesClient.Insert(tableFilter, chainFORWARD, 1, rule...)
if err != nil {
return err
} }
i.rules[Ipv4Forwarding] = rule
var currentState *ShutdownState rule = []string{"-j", chainRTNAT}
if existing := r.stateManager.GetState(currentState); existing != nil { err = i.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, rule...)
if existingState, ok := existing.(*ShutdownState); ok { if err != nil {
currentState = existingState return err
}
}
if currentState == nil {
currentState = &ShutdownState{}
} }
i.rules[ipv4Nat] = rule
currentState.Lock()
defer currentState.Unlock()
currentState.RouteRules = r.rules
currentState.RouteIPsetCounter = r.ipsetCounter
if err := r.stateManager.UpdateState(currentState); err != nil {
log.Errorf("failed to update state: %v", err)
}
}
func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {
var rule []string
if params.SetName != "" {
rule = append(rule, "-m", "set", matchSet, params.SetName, "src")
} else if len(params.Sources) > 0 {
source := params.Sources[0]
rule = append(rule, "-s", source.String())
}
rule = append(rule, "-d", params.Destination.String())
if params.Proto != firewall.ProtocolALL {
rule = append(rule, "-p", strings.ToLower(string(params.Proto)))
rule = append(rule, applyPort("--sport", params.SPort)...)
rule = append(rule, applyPort("--dport", params.DPort)...)
}
rule = append(rule, "-j", actionToStr(params.Action))
return rule
}
func applyPort(flag string, port *firewall.Port) []string {
if port == nil {
return nil return nil
} }
if port.IsRange && len(port.Values) == 2 { // cleanJumpRules cleans jump rules that was sending packets to NetBird chains
return []string{flag, fmt.Sprintf("%d:%d", port.Values[0], port.Values[1])} func (i *routerManager) cleanJumpRules() error {
var err error
errMSGFormat := "failed cleaning rule from chain %s,err: %v"
rule, found := i.rules[Ipv4Forwarding]
if found {
err = i.iptablesClient.DeleteIfExists(tableFilter, chainFORWARD, rule...)
if err != nil {
return fmt.Errorf(errMSGFormat, chainFORWARD, err)
}
}
rule, found = i.rules[ipv4Nat]
if found {
err = i.iptablesClient.DeleteIfExists(tableNat, chainPOSTROUTING, rule...)
if err != nil {
return fmt.Errorf(errMSGFormat, chainPOSTROUTING, err)
}
} }
if len(port.Values) > 1 { rules, err := i.iptablesClient.List("nat", "POSTROUTING")
portList := make([]string, len(port.Values)) if err != nil {
for i, p := range port.Values { return fmt.Errorf("failed to list rules: %s", err)
portList[i] = strconv.Itoa(p)
}
return []string{"-m", "multiport", flag, strings.Join(portList, ",")}
} }
return []string{flag, strconv.Itoa(port.Values[0])} for _, ruleString := range rules {
if !strings.Contains(ruleString, "NETBIRD") {
continue
}
rule := strings.Fields(ruleString)
err := i.iptablesClient.DeleteIfExists("nat", "POSTROUTING", rule[2:]...)
if err != nil {
return fmt.Errorf("failed to delete postrouting jump rule: %s", err)
}
}
rules, err = i.iptablesClient.List(tableFilter, "FORWARD")
if err != nil {
return fmt.Errorf("failed to list rules in FORWARD chain: %s", err)
}
for _, ruleString := range rules {
if !strings.Contains(ruleString, "NETBIRD") {
continue
}
rule := strings.Fields(ruleString)
err := i.iptablesClient.DeleteIfExists(tableFilter, "FORWARD", rule[2:]...)
if err != nil {
return fmt.Errorf("failed to delete FORWARD jump rule: %s", err)
}
}
return nil
}
func (i *routerManager) createChain(table, newChain string) error {
chains, err := i.iptablesClient.ListChains(table)
if err != nil {
return fmt.Errorf("couldn't get %s table chains, error: %v", table, err)
}
shouldCreateChain := true
for _, chain := range chains {
if chain == newChain {
shouldCreateChain = false
}
}
if shouldCreateChain {
err = i.iptablesClient.NewChain(table, newChain)
if err != nil {
return fmt.Errorf("couldn't create chain %s in %s table, error: %v", newChain, table, err)
}
// Add the loopback return rule to the NAT chain
loopbackRule := []string{"-o", "lo", "-j", "RETURN"}
err = i.iptablesClient.Insert(table, newChain, 1, loopbackRule...)
if err != nil {
return fmt.Errorf("failed to add loopback return rule to %s: %v", chainRTNAT, err)
}
err = i.iptablesClient.Append(table, newChain, "-j", "RETURN")
if err != nil {
return fmt.Errorf("couldn't create chain %s default rule, error: %v", newChain, err)
}
}
return nil
}
// addNATRule appends an iptables rule pair to the nat chain
func (i *routerManager) addNATRule(keyFormat, table, chain, jump string, pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(keyFormat, pair.ID)
rule := genRuleSpec(jump, pair.Source, pair.Destination)
existingRule, found := i.rules[ruleKey]
if found {
err := i.iptablesClient.DeleteIfExists(table, chain, existingRule...)
if err != nil {
return fmt.Errorf("error while removing existing NAT rule for %s: %v", pair.Destination, err)
}
delete(i.rules, ruleKey)
}
// inserting after loopback ignore rule
err := i.iptablesClient.Insert(table, chain, 2, rule...)
if err != nil {
return fmt.Errorf("error while appending new NAT rule for %s: %v", pair.Destination, err)
}
i.rules[ruleKey] = rule
return nil
}
// genRuleSpec generates rule specification
func genRuleSpec(jump, source, destination string) []string {
return []string{"-s", source, "-d", destination, "-j", jump}
}
func getIptablesRuleType(table string) string {
ruleType := "forwarding"
if table == tableNat {
ruleType = "nat"
}
return ruleType
} }

View File

@@ -3,18 +3,16 @@
package iptables package iptables
import ( import (
"fmt" "context"
"net/netip"
"os/exec" "os/exec"
"testing" "testing"
"github.com/coreos/go-iptables/iptables" "github.com/coreos/go-iptables/iptables"
"github.com/stretchr/testify/assert" log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/test" "github.com/netbirdio/netbird/client/firewall/test"
nbnet "github.com/netbirdio/netbird/util/net"
) )
func isIptablesSupported() bool { func isIptablesSupported() bool {
@@ -30,45 +28,45 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "failed to init iptables client") require.NoError(t, err, "failed to init iptables client")
manager, err := newRouter(iptablesClient, ifaceMock) manager, err := newRouterManager(context.TODO(), iptablesClient)
require.NoError(t, err, "should return a valid iptables manager") require.NoError(t, err, "should return a valid iptables manager")
require.NoError(t, manager.init(nil))
defer func() { defer func() {
assert.NoError(t, manager.Reset(), "shouldn't return error") _ = manager.Reset()
}() }()
// Now 5 rules: require.Len(t, manager.rules, 2, "should have created rules map")
// 1. established rule in forward chain
// 2. jump rule to NAT chain
// 3. jump rule to PRE chain
// 4. static outbound masquerade rule
// 5. static return masquerade rule
require.Len(t, manager.rules, 5, "should have created rules map")
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT) exists, err := manager.iptablesClient.Exists(tableFilter, chainFORWARD, manager.rules[Ipv4Forwarding]...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainFORWARD)
require.True(t, exists, "forwarding rule should exist")
exists, err = manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, manager.rules[ipv4Nat]...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING) require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
require.True(t, exists, "postrouting jump rule should exist") require.True(t, exists, "postrouting rule should exist")
exists, err = manager.iptablesClient.Exists(tableMangle, chainPREROUTING, "-j", chainRTPRE)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainPREROUTING)
require.True(t, exists, "prerouting jump rule should exist")
pair := firewall.RouterPair{ pair := firewall.RouterPair{
ID: "abc", ID: "abc",
Source: netip.MustParsePrefix("100.100.100.1/32"), Source: "100.100.100.1/32",
Destination: netip.MustParsePrefix("100.100.100.0/24"), Destination: "100.100.100.0/24",
Masquerade: true, Masquerade: true,
} }
forward4Rule := genRuleSpec(routingFinalForwardJump, pair.Source, pair.Destination)
err = manager.AddNatRule(pair) err = manager.iptablesClient.Insert(tableFilter, chainRTFWD, 1, forward4Rule...)
require.NoError(t, err, "adding NAT rule should not return error") require.NoError(t, err, "inserting rule should not return error")
nat4Rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination)
err = manager.iptablesClient.Insert(tableNat, chainRTNAT, 1, nat4Rule...)
require.NoError(t, err, "inserting rule should not return error")
err = manager.Reset() err = manager.Reset()
require.NoError(t, err, "shouldn't return error") require.NoError(t, err, "shouldn't return error")
} }
func TestIptablesManager_AddNatRule(t *testing.T) { func TestIptablesManager_InsertRoutingRules(t *testing.T) {
if !isIptablesSupported() { if !isIptablesSupported() {
t.SkipNow() t.SkipNow()
} }
@@ -78,71 +76,78 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "failed to init iptables client") require.NoError(t, err, "failed to init iptables client")
manager, err := newRouter(iptablesClient, ifaceMock) manager, err := newRouterManager(context.TODO(), iptablesClient)
require.NoError(t, err, "shouldn't return error") require.NoError(t, err, "shouldn't return error")
require.NoError(t, manager.init(nil))
defer func() { defer func() {
assert.NoError(t, manager.Reset(), "shouldn't return error") err := manager.Reset()
if err != nil {
log.Errorf("failed to reset iptables manager: %s", err)
}
}() }()
err = manager.AddNatRule(testCase.InputPair) err = manager.InsertRoutingRules(testCase.InputPair)
require.NoError(t, err, "marking rule should be inserted") require.NoError(t, err, "forwarding pair should be inserted")
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair) forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
markingRule := []string{ forwardRule := genRuleSpec(routingFinalForwardJump, testCase.InputPair.Source, testCase.InputPair.Destination)
"-i", ifaceMock.Name(),
"-m", "conntrack",
"--ctstate", "NEW",
"-s", testCase.InputPair.Source.String(),
"-d", testCase.InputPair.Destination.String(),
"-j", "MARK", "--set-mark",
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
}
exists, err := iptablesClient.Exists(tableMangle, chainRTPRE, markingRule...) exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, forwardRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE) require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
require.True(t, exists, "forwarding rule should exist")
foundRule, found := manager.rules[forwardRuleKey]
require.True(t, found, "forwarding rule should exist in the manager map")
require.Equal(t, forwardRule[:4], foundRule[:4], "stored forwarding rule should match")
inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
inForwardRule := genRuleSpec(routingFinalForwardJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
exists, err = iptablesClient.Exists(tableFilter, chainRTFWD, inForwardRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
require.True(t, exists, "income forwarding rule should exist")
foundRule, found = manager.rules[inForwardRuleKey]
require.True(t, found, "income forwarding rule should exist in the manager map")
require.Equal(t, inForwardRule[:4], foundRule[:4], "stored income forwarding rule should match")
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination)
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
if testCase.InputPair.Masquerade { if testCase.InputPair.Masquerade {
require.True(t, exists, "marking rule should be created") require.True(t, exists, "nat rule should be created")
foundRule, found := manager.rules[natRuleKey] foundNatRule, foundNat := manager.rules[natRuleKey]
require.True(t, found, "marking rule should exist in the map") require.True(t, foundNat, "nat rule should exist in the map")
require.Equal(t, markingRule, foundRule, "stored marking rule should match") require.Equal(t, natRule[:4], foundNatRule[:4], "stored nat rule should match")
} else { } else {
require.False(t, exists, "marking rule should not be created") require.False(t, exists, "nat rule should not be created")
_, found := manager.rules[natRuleKey] _, foundNat := manager.rules[natRuleKey]
require.False(t, found, "marking rule should not exist in the map") require.False(t, foundNat, "nat rule should not exist in the map")
} }
// Check inverse rule inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
inversePair := firewall.GetInversePair(testCase.InputPair) inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
inverseRuleKey := firewall.GenKey(firewall.NatFormat, inversePair)
inverseMarkingRule := []string{
"!", "-i", ifaceMock.Name(),
"-m", "conntrack",
"--ctstate", "NEW",
"-s", inversePair.Source.String(),
"-d", inversePair.Destination.String(),
"-j", "MARK", "--set-mark",
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
}
exists, err = iptablesClient.Exists(tableMangle, chainRTPRE, inverseMarkingRule...) exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE) require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
if testCase.InputPair.Masquerade { if testCase.InputPair.Masquerade {
require.True(t, exists, "inverse marking rule should be created") require.True(t, exists, "income nat rule should be created")
foundRule, found := manager.rules[inverseRuleKey] foundNatRule, foundNat := manager.rules[inNatRuleKey]
require.True(t, found, "inverse marking rule should exist in the map") require.True(t, foundNat, "income nat rule should exist in the map")
require.Equal(t, inverseMarkingRule, foundRule, "stored inverse marking rule should match") require.Equal(t, inNatRule[:4], foundNatRule[:4], "stored income nat rule should match")
} else { } else {
require.False(t, exists, "inverse marking rule should not be created") require.False(t, exists, "nat rule should not be created")
_, found := manager.rules[inverseRuleKey] _, foundNat := manager.rules[inNatRuleKey]
require.False(t, found, "inverse marking rule should not exist in the map") require.False(t, foundNat, "income nat rule should not exist in the map")
} }
}) })
} }
} }
func TestIptablesManager_RemoveNatRule(t *testing.T) { func TestIptablesManager_RemoveRoutingRules(t *testing.T) {
if !isIptablesSupported() { if !isIptablesSupported() {
t.SkipNow() t.SkipNow()
} }
@@ -151,226 +156,72 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
t.Run(testCase.Name, func(t *testing.T) { t.Run(testCase.Name, func(t *testing.T) {
iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4) iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
manager, err := newRouter(iptablesClient, ifaceMock) manager, err := newRouterManager(context.TODO(), iptablesClient)
require.NoError(t, err, "shouldn't return error") require.NoError(t, err, "shouldn't return error")
require.NoError(t, manager.init(nil))
defer func() { defer func() {
assert.NoError(t, manager.Reset(), "shouldn't return error") _ = manager.Reset()
}() }()
err = manager.AddNatRule(testCase.InputPair)
require.NoError(t, err, "should add NAT rule without error")
err = manager.RemoveNatRule(testCase.InputPair)
require.NoError(t, err, "shouldn't return error") require.NoError(t, err, "shouldn't return error")
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair) forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
markingRule := []string{ forwardRule := genRuleSpec(routingFinalForwardJump, testCase.InputPair.Source, testCase.InputPair.Destination)
"-i", ifaceMock.Name(),
"-m", "conntrack",
"--ctstate", "NEW",
"-s", testCase.InputPair.Source.String(),
"-d", testCase.InputPair.Destination.String(),
"-j", "MARK", "--set-mark",
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
}
exists, err := iptablesClient.Exists(tableMangle, chainRTPRE, markingRule...) err = iptablesClient.Insert(tableFilter, chainRTFWD, 1, forwardRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE) require.NoError(t, err, "inserting rule should not return error")
require.False(t, exists, "marking rule should not exist")
_, found := manager.rules[natRuleKey] inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
require.False(t, found, "marking rule should not exist in the manager map") inForwardRule := genRuleSpec(routingFinalForwardJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
// Check inverse rule removal err = iptablesClient.Insert(tableFilter, chainRTFWD, 1, inForwardRule...)
inversePair := firewall.GetInversePair(testCase.InputPair) require.NoError(t, err, "inserting rule should not return error")
inverseRuleKey := firewall.GenKey(firewall.NatFormat, inversePair)
inverseMarkingRule := []string{
"!", "-i", ifaceMock.Name(),
"-m", "conntrack",
"--ctstate", "NEW",
"-s", inversePair.Source.String(),
"-d", inversePair.Destination.String(),
"-j", "MARK", "--set-mark",
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
}
exists, err = iptablesClient.Exists(tableMangle, chainRTPRE, inverseMarkingRule...) natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE) natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination)
require.False(t, exists, "inverse marking rule should not exist")
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, natRule...)
require.NoError(t, err, "inserting rule should not return error")
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, inNatRule...)
require.NoError(t, err, "inserting rule should not return error")
err = manager.Reset()
require.NoError(t, err, "shouldn't return error")
err = manager.RemoveRoutingRules(testCase.InputPair)
require.NoError(t, err, "shouldn't return error")
exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, forwardRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
require.False(t, exists, "forwarding rule should not exist")
_, found := manager.rules[forwardRuleKey]
require.False(t, found, "forwarding rule should exist in the manager map")
exists, err = iptablesClient.Exists(tableFilter, chainRTFWD, inForwardRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
require.False(t, exists, "income forwarding rule should not exist")
_, found = manager.rules[inForwardRuleKey]
require.False(t, found, "income forwarding rule should exist in the manager map")
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
require.False(t, exists, "nat rule should not exist")
_, found = manager.rules[natRuleKey]
require.False(t, found, "nat rule should exist in the manager map")
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
require.False(t, exists, "income nat rule should not exist")
_, found = manager.rules[inNatRuleKey]
require.False(t, found, "income nat rule should exist in the manager map")
_, found = manager.rules[inverseRuleKey]
require.False(t, found, "inverse marking rule should not exist in the map")
})
}
}
func TestRouter_AddRouteFiltering(t *testing.T) {
if !isIptablesSupported() {
t.Skip("iptables not supported on this system")
}
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "Failed to create iptables client")
r, err := newRouter(iptablesClient, ifaceMock)
require.NoError(t, err, "Failed to create router manager")
require.NoError(t, r.init(nil))
defer func() {
err := r.Reset()
require.NoError(t, err, "Failed to reset router")
}()
tests := []struct {
name string
sources []netip.Prefix
destination netip.Prefix
proto firewall.Protocol
sPort *firewall.Port
dPort *firewall.Port
direction firewall.RuleDirection
action firewall.Action
expectSet bool
}{
{
name: "Basic TCP rule with single source",
sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
destination: netip.MustParsePrefix("10.0.0.0/24"),
proto: firewall.ProtocolTCP,
sPort: nil,
dPort: &firewall.Port{Values: []int{80}},
direction: firewall.RuleDirectionIN,
action: firewall.ActionAccept,
expectSet: false,
},
{
name: "UDP rule with multiple sources",
sources: []netip.Prefix{
netip.MustParsePrefix("172.16.0.0/16"),
netip.MustParsePrefix("192.168.0.0/16"),
},
destination: netip.MustParsePrefix("10.0.0.0/8"),
proto: firewall.ProtocolUDP,
sPort: &firewall.Port{Values: []int{1024, 2048}, IsRange: true},
dPort: nil,
direction: firewall.RuleDirectionOUT,
action: firewall.ActionDrop,
expectSet: true,
},
{
name: "All protocols rule",
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
destination: netip.MustParsePrefix("0.0.0.0/0"),
proto: firewall.ProtocolALL,
sPort: nil,
dPort: nil,
direction: firewall.RuleDirectionIN,
action: firewall.ActionAccept,
expectSet: false,
},
{
name: "ICMP rule",
sources: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")},
destination: netip.MustParsePrefix("10.0.0.0/8"),
proto: firewall.ProtocolICMP,
sPort: nil,
dPort: nil,
direction: firewall.RuleDirectionIN,
action: firewall.ActionAccept,
expectSet: false,
},
{
name: "TCP rule with multiple source ports",
sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")},
destination: netip.MustParsePrefix("192.168.0.0/16"),
proto: firewall.ProtocolTCP,
sPort: &firewall.Port{Values: []int{80, 443, 8080}},
dPort: nil,
direction: firewall.RuleDirectionOUT,
action: firewall.ActionAccept,
expectSet: false,
},
{
name: "UDP rule with single IP and port range",
sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.1/32")},
destination: netip.MustParsePrefix("10.0.0.0/24"),
proto: firewall.ProtocolUDP,
sPort: nil,
dPort: &firewall.Port{Values: []int{5000, 5100}, IsRange: true},
direction: firewall.RuleDirectionIN,
action: firewall.ActionDrop,
expectSet: false,
},
{
name: "TCP rule with source and destination ports",
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")},
destination: netip.MustParsePrefix("172.16.0.0/16"),
proto: firewall.ProtocolTCP,
sPort: &firewall.Port{Values: []int{1024, 65535}, IsRange: true},
dPort: &firewall.Port{Values: []int{22}},
direction: firewall.RuleDirectionOUT,
action: firewall.ActionAccept,
expectSet: false,
},
{
name: "Drop all incoming traffic",
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
destination: netip.MustParsePrefix("192.168.0.0/24"),
proto: firewall.ProtocolALL,
sPort: nil,
dPort: nil,
direction: firewall.RuleDirectionIN,
action: firewall.ActionDrop,
expectSet: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
require.NoError(t, err, "AddRouteFiltering failed")
// Check if the rule is in the internal map
rule, ok := r.rules[ruleKey.GetRuleID()]
assert.True(t, ok, "Rule not found in internal map")
// Log the internal rule
t.Logf("Internal rule: %v", rule)
// Check if the rule exists in iptables
exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, rule...)
assert.NoError(t, err, "Failed to check rule existence")
assert.True(t, exists, "Rule not found in iptables")
// Verify rule content
params := routeFilteringRuleParams{
Sources: tt.sources,
Destination: tt.destination,
Proto: tt.proto,
SPort: tt.sPort,
DPort: tt.dPort,
Action: tt.action,
SetName: "",
}
expectedRule := genRouteFilteringRuleSpec(params)
if tt.expectSet {
setName := firewall.GenerateSetName(tt.sources)
params.SetName = setName
expectedRule = genRouteFilteringRuleSpec(params)
// Check if the set was created
_, exists := r.ipsetCounter.Get(setName)
assert.True(t, exists, "IPSet not created")
}
assert.Equal(t, expectedRule, rule, "Rule content mismatch")
// Clean up
err = r.DeleteRouteRule(ruleKey)
require.NoError(t, err, "Failed to delete rule")
}) })
} }
} }

View File

@@ -1,16 +1,14 @@
package iptables package iptables
import "encoding/json"
type ipList struct { type ipList struct {
ips map[string]struct{} ips map[string]struct{}
} }
func newIpList(ip string) *ipList { func newIpList(ip string) ipList {
ips := make(map[string]struct{}) ips := make(map[string]struct{})
ips[ip] = struct{}{} ips[ip] = struct{}{}
return &ipList{ return ipList{
ips: ips, ips: ips,
} }
} }
@@ -19,52 +17,27 @@ func (s *ipList) addIP(ip string) {
s.ips[ip] = struct{}{} s.ips[ip] = struct{}{}
} }
// MarshalJSON implements json.Marshaler
func (s *ipList) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
IPs map[string]struct{} `json:"ips"`
}{
IPs: s.ips,
})
}
// UnmarshalJSON implements json.Unmarshaler
func (s *ipList) UnmarshalJSON(data []byte) error {
temp := struct {
IPs map[string]struct{} `json:"ips"`
}{}
if err := json.Unmarshal(data, &temp); err != nil {
return err
}
s.ips = temp.IPs
if temp.IPs == nil {
temp.IPs = make(map[string]struct{})
}
return nil
}
type ipsetStore struct { type ipsetStore struct {
ipsets map[string]*ipList ipsets map[string]ipList // ipsetName -> ruleset
} }
func newIpsetStore() *ipsetStore { func newIpsetStore() *ipsetStore {
return &ipsetStore{ return &ipsetStore{
ipsets: make(map[string]*ipList), ipsets: make(map[string]ipList),
} }
} }
func (s *ipsetStore) ipset(ipsetName string) (*ipList, bool) { func (s *ipsetStore) ipset(ipsetName string) (ipList, bool) {
r, ok := s.ipsets[ipsetName] r, ok := s.ipsets[ipsetName]
return r, ok return r, ok
} }
func (s *ipsetStore) addIpList(ipsetName string, list *ipList) { func (s *ipsetStore) addIpList(ipsetName string, list ipList) {
s.ipsets[ipsetName] = list s.ipsets[ipsetName] = list
} }
func (s *ipsetStore) deleteIpset(ipsetName string) { func (s *ipsetStore) deleteIpset(ipsetName string) {
s.ipsets[ipsetName] = ipList{}
delete(s.ipsets, ipsetName) delete(s.ipsets, ipsetName)
} }
@@ -75,29 +48,3 @@ func (s *ipsetStore) ipsetNames() []string {
} }
return names return names
} }
// MarshalJSON implements json.Marshaler
func (s *ipsetStore) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
IPSets map[string]*ipList `json:"ipsets"`
}{
IPSets: s.ipsets,
})
}
// UnmarshalJSON implements json.Unmarshaler
func (s *ipsetStore) UnmarshalJSON(data []byte) error {
temp := struct {
IPSets map[string]*ipList `json:"ipsets"`
}{}
if err := json.Unmarshal(data, &temp); err != nil {
return err
}
s.ipsets = temp.IPSets
if temp.IPSets == nil {
temp.IPSets = make(map[string]*ipList)
}
return nil
}

View File

@@ -1,70 +0,0 @@
package iptables
import (
"fmt"
"sync"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
)
type InterfaceState struct {
NameStr string `json:"name"`
WGAddress iface.WGAddress `json:"wg_address"`
UserspaceBind bool `json:"userspace_bind"`
}
func (i *InterfaceState) Name() string {
return i.NameStr
}
func (i *InterfaceState) Address() device.WGAddress {
return i.WGAddress
}
func (i *InterfaceState) IsUserspaceBind() bool {
return i.UserspaceBind
}
type ShutdownState struct {
sync.Mutex
InterfaceState *InterfaceState `json:"interface_state,omitempty"`
RouteRules routeRules `json:"route_rules,omitempty"`
RouteIPsetCounter *ipsetCounter `json:"route_ipset_counter,omitempty"`
ACLEntries aclEntries `json:"acl_entries,omitempty"`
ACLIPsetStore *ipsetStore `json:"acl_ipset_store,omitempty"`
}
func (s *ShutdownState) Name() string {
return "iptables_state"
}
func (s *ShutdownState) Cleanup() error {
ipt, err := Create(s.InterfaceState)
if err != nil {
return fmt.Errorf("create iptables manager: %w", err)
}
if s.RouteRules != nil {
ipt.router.rules = s.RouteRules
}
if s.RouteIPsetCounter != nil {
ipt.router.ipsetCounter.LoadData(s.RouteIPsetCounter)
}
if s.ACLEntries != nil {
ipt.aclMgr.entries = s.ACLEntries
}
if s.ACLIPsetStore != nil {
ipt.aclMgr.ipsetStore = s.ACLIPsetStore
}
if err := ipt.Reset(nil); err != nil {
return fmt.Errorf("reset iptables manager: %w", err)
}
return nil
}

View File

@@ -1,24 +1,15 @@
package manager package manager
import ( import (
"crypto/sha256"
"encoding/hex"
"fmt" "fmt"
"net" "net"
"net/netip"
"sort"
"strings"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager"
) )
const ( const (
ForwardingFormatPrefix = "netbird-fwd-" NatFormat = "netbird-nat-%s"
ForwardingFormat = "netbird-fwd-%s-%t" ForwardingFormat = "netbird-fwd-%s"
PreroutingFormat = "netbird-prerouting-%s-%t" InNatFormat = "netbird-nat-in-%s"
NatFormat = "netbird-nat-%s-%t" InForwardingFormat = "netbird-fwd-in-%s"
) )
// Rule abstraction should be implemented by each firewall manager // Rule abstraction should be implemented by each firewall manager
@@ -55,16 +46,14 @@ const (
// It declares methods which handle actions required by the // It declares methods which handle actions required by the
// Netbird client for ACL and routing functionality // Netbird client for ACL and routing functionality
type Manager interface { type Manager interface {
Init(stateManager *statemanager.Manager) error
// AllowNetbird allows netbird interface traffic // AllowNetbird allows netbird interface traffic
AllowNetbird() error AllowNetbird() error
// AddPeerFiltering adds a rule to the firewall // AddFiltering rule to the firewall
// //
// If comment argument is empty firewall manager should set // If comment argument is empty firewall manager should set
// rule ID as comment for the rule // rule ID as comment for the rule
AddPeerFiltering( AddFiltering(
ip net.IP, ip net.IP,
proto Protocol, proto Protocol,
sPort *Port, sPort *Port,
@@ -75,119 +64,25 @@ type Manager interface {
comment string, comment string,
) ([]Rule, error) ) ([]Rule, error)
// DeletePeerRule from the firewall by rule definition // DeleteRule from the firewall by rule definition
DeletePeerRule(rule Rule) error DeleteRule(rule Rule) error
// IsServerRouteSupported returns true if the firewall supports server side routing operations // IsServerRouteSupported returns true if the firewall supports server side routing operations
IsServerRouteSupported() bool IsServerRouteSupported() bool
AddRouteFiltering(source []netip.Prefix, destination netip.Prefix, proto Protocol, sPort *Port, dPort *Port, action Action) (Rule, error) // InsertRoutingRules inserts a routing firewall rule
InsertRoutingRules(pair RouterPair) error
// DeleteRouteRule deletes a routing rule // RemoveRoutingRules removes a routing firewall rule
DeleteRouteRule(rule Rule) error RemoveRoutingRules(pair RouterPair) error
// AddNatRule inserts a routing NAT rule
AddNatRule(pair RouterPair) error
// RemoveNatRule removes a routing NAT rule
RemoveNatRule(pair RouterPair) error
// SetLegacyManagement sets the legacy management mode
SetLegacyManagement(legacy bool) error
// Reset firewall to the default state // Reset firewall to the default state
Reset(stateManager *statemanager.Manager) error Reset() error
// Flush the changes to firewall controller // Flush the changes to firewall controller
Flush() error Flush() error
// CollectStats returns the statistics of the firewall manager
CollectStats() []*FlowStats
} }
func GenKey(format string, pair RouterPair) string { func GenKey(format string, input string) string {
return fmt.Sprintf(format, pair.ID, pair.Inverse) return fmt.Sprintf(format, input)
}
// LegacyManager defines the interface for legacy management operations
type LegacyManager interface {
RemoveAllLegacyRouteRules() error
GetLegacyManagement() bool
SetLegacyManagement(bool)
}
// SetLegacyManagement sets the route manager to use legacy management
func SetLegacyManagement(router LegacyManager, isLegacy bool) error {
oldLegacy := router.GetLegacyManagement()
if oldLegacy != isLegacy {
router.SetLegacyManagement(isLegacy)
log.Debugf("Set legacy management to %v", isLegacy)
}
// client reconnected to a newer mgmt, we need to clean up the legacy rules
if !isLegacy && oldLegacy {
if err := router.RemoveAllLegacyRouteRules(); err != nil {
return fmt.Errorf("remove legacy routing rules: %v", err)
}
log.Debugf("Legacy routing rules removed")
}
return nil
}
// GenerateSetName generates a unique name for an ipset based on the given sources.
func GenerateSetName(sources []netip.Prefix) string {
// sort for consistent naming
SortPrefixes(sources)
var sourcesStr strings.Builder
for _, src := range sources {
sourcesStr.WriteString(src.String())
}
hash := sha256.Sum256([]byte(sourcesStr.String()))
shortHash := hex.EncodeToString(hash[:])[:8]
return fmt.Sprintf("nb-%s", shortHash)
}
// MergeIPRanges merges overlapping IP ranges and returns a slice of non-overlapping netip.Prefix
func MergeIPRanges(prefixes []netip.Prefix) []netip.Prefix {
if len(prefixes) == 0 {
return prefixes
}
merged := []netip.Prefix{prefixes[0]}
for _, prefix := range prefixes[1:] {
last := merged[len(merged)-1]
if last.Contains(prefix.Addr()) {
// If the current prefix is contained within the last merged prefix, skip it
continue
}
if prefix.Contains(last.Addr()) {
// If the current prefix contains the last merged prefix, replace it
merged[len(merged)-1] = prefix
} else {
// Otherwise, add the current prefix to the merged list
merged = append(merged, prefix)
}
}
return merged
}
// SortPrefixes sorts the given slice of netip.Prefix in place.
// It sorts first by IP address, then by prefix length (most specific to least specific).
func SortPrefixes(prefixes []netip.Prefix) {
sort.Slice(prefixes, func(i, j int) bool {
addrCmp := prefixes[i].Addr().Compare(prefixes[j].Addr())
if addrCmp != 0 {
return addrCmp < 0
}
// If IP addresses are the same, compare prefix lengths (longer prefixes first)
return prefixes[i].Bits() > prefixes[j].Bits()
})
} }

View File

@@ -1,192 +0,0 @@
package manager_test
import (
"net/netip"
"reflect"
"regexp"
"testing"
"github.com/netbirdio/netbird/client/firewall/manager"
)
func TestGenerateSetName(t *testing.T) {
t.Run("Different orders result in same hash", func(t *testing.T) {
prefixes1 := []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("10.0.0.0/8"),
}
prefixes2 := []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/8"),
netip.MustParsePrefix("192.168.1.0/24"),
}
result1 := manager.GenerateSetName(prefixes1)
result2 := manager.GenerateSetName(prefixes2)
if result1 != result2 {
t.Errorf("Different orders produced different hashes: %s != %s", result1, result2)
}
})
t.Run("Result format is correct", func(t *testing.T) {
prefixes := []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("10.0.0.0/8"),
}
result := manager.GenerateSetName(prefixes)
matched, err := regexp.MatchString(`^nb-[0-9a-f]{8}$`, result)
if err != nil {
t.Fatalf("Error matching regex: %v", err)
}
if !matched {
t.Errorf("Result format is incorrect: %s", result)
}
})
t.Run("Empty input produces consistent result", func(t *testing.T) {
result1 := manager.GenerateSetName([]netip.Prefix{})
result2 := manager.GenerateSetName([]netip.Prefix{})
if result1 != result2 {
t.Errorf("Empty input produced inconsistent results: %s != %s", result1, result2)
}
})
t.Run("IPv4 and IPv6 mixing", func(t *testing.T) {
prefixes1 := []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("2001:db8::/32"),
}
prefixes2 := []netip.Prefix{
netip.MustParsePrefix("2001:db8::/32"),
netip.MustParsePrefix("192.168.1.0/24"),
}
result1 := manager.GenerateSetName(prefixes1)
result2 := manager.GenerateSetName(prefixes2)
if result1 != result2 {
t.Errorf("Different orders of IPv4 and IPv6 produced different hashes: %s != %s", result1, result2)
}
})
}
func TestMergeIPRanges(t *testing.T) {
tests := []struct {
name string
input []netip.Prefix
expected []netip.Prefix
}{
{
name: "Empty input",
input: []netip.Prefix{},
expected: []netip.Prefix{},
},
{
name: "Single range",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
},
},
{
name: "Two non-overlapping ranges",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("10.0.0.0/8"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("10.0.0.0/8"),
},
},
{
name: "One range containing another",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/16"),
netip.MustParsePrefix("192.168.1.0/24"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/16"),
},
},
{
name: "One range containing another (different order)",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("192.168.0.0/16"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/16"),
},
},
{
name: "Overlapping ranges",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("192.168.1.128/25"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
},
},
{
name: "Overlapping ranges (different order)",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.1.128/25"),
netip.MustParsePrefix("192.168.1.0/24"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
},
},
{
name: "Multiple overlapping ranges",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/16"),
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("192.168.2.0/24"),
netip.MustParsePrefix("192.168.1.128/25"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/16"),
},
},
{
name: "Partially overlapping ranges",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/23"),
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("192.168.2.0/25"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/23"),
netip.MustParsePrefix("192.168.2.0/25"),
},
},
{
name: "IPv6 ranges",
input: []netip.Prefix{
netip.MustParsePrefix("2001:db8::/32"),
netip.MustParsePrefix("2001:db8:1::/48"),
netip.MustParsePrefix("2001:db8:2::/48"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("2001:db8::/32"),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := manager.MergeIPRanges(tt.input)
if !reflect.DeepEqual(result, tt.expected) {
t.Errorf("MergeIPRanges() = %v, want %v", result, tt.expected)
}
})
}
}

View File

@@ -1,26 +1,18 @@
package manager package manager
import (
"net/netip"
"github.com/netbirdio/netbird/route"
)
type RouterPair struct { type RouterPair struct {
ID route.ID ID string
Source netip.Prefix Source string
Destination netip.Prefix Destination string
Masquerade bool Masquerade bool
Inverse bool
} }
func GetInversePair(pair RouterPair) RouterPair { func GetInPair(pair RouterPair) RouterPair {
return RouterPair{ return RouterPair{
ID: pair.ID, ID: pair.ID,
// invert Source/Destination // invert Source/Destination
Source: pair.Destination, Source: pair.Destination,
Destination: pair.Source, Destination: pair.Source,
Masquerade: pair.Masquerade, Masquerade: pair.Masquerade,
Inverse: true,
} }
} }

View File

@@ -1,107 +0,0 @@
package manager
import (
"encoding/json"
"net"
"slices"
"strconv"
"sync/atomic"
"time"
)
const (
DirectionInbound Direction = 0
DirectionOutbound Direction = 1
)
type Direction uint8
func (d Direction) String() string {
switch d {
case DirectionInbound:
return "inbound"
case DirectionOutbound:
return "outbound"
default:
return "unknown"
}
}
// FlowStats tracks statistics for an individual connection
type FlowStats struct {
StartTime time.Time
LastSeen time.Time
BytesIn atomic.Uint64
BytesOut atomic.Uint64
PacketsIn atomic.Uint64
PacketsOut atomic.Uint64
Protocol uint8
Direction Direction
SourceIP net.IP
DestIP net.IP
SourcePort uint16
DestPort uint16
}
func (f *FlowStats) Clone() *FlowStats {
flowCopy := FlowStats{
StartTime: f.StartTime,
LastSeen: f.LastSeen,
Protocol: f.Protocol,
Direction: f.Direction,
SourceIP: slices.Clone(f.SourceIP),
DestIP: slices.Clone(f.DestIP),
SourcePort: f.SourcePort,
DestPort: f.DestPort,
}
flowCopy.BytesIn.Store(f.BytesIn.Load())
flowCopy.BytesOut.Store(f.BytesOut.Load())
flowCopy.PacketsIn.Store(f.PacketsIn.Load())
flowCopy.PacketsOut.Store(f.PacketsOut.Load())
return &flowCopy
}
// MarshalJSON implements json.Marshaler interface
func (f *FlowStats) MarshalJSON() ([]byte, error) {
return json.Marshal(&struct {
StartTime time.Time
LastSeen time.Time
BytesIn uint64
BytesOut uint64
PacketsIn uint64
PacketsOut uint64
Protocol Protocol
Direction string
SourceIP net.IP
DestIP net.IP
SourcePort uint16
DestPort uint16
}{
StartTime: f.StartTime,
LastSeen: f.LastSeen,
BytesIn: f.BytesIn.Load(),
BytesOut: f.BytesOut.Load(),
PacketsIn: f.PacketsIn.Load(),
PacketsOut: f.PacketsOut.Load(),
Protocol: protoFromInt(f.Protocol),
Direction: f.Direction.String(),
SourceIP: f.SourceIP,
DestIP: f.DestIP,
SourcePort: f.SourcePort,
DestPort: f.DestPort,
})
}
func protoFromInt(p uint8) Protocol {
switch p {
case 6:
return ProtocolTCP
case 17:
return ProtocolUDP
case 1:
return ProtocolICMP
default:
return Protocol(strconv.Itoa(int(p)))
}
}

View File

@@ -5,18 +5,18 @@ import (
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"net" "net"
"net/netip"
"strconv" "strconv"
"strings" "strings"
"time" "time"
"github.com/google/nftables" "github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr" "github.com/google/nftables/expr"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbnet "github.com/netbirdio/netbird/util/net" "github.com/netbirdio/netbird/iface"
) )
const ( const (
@@ -27,64 +27,74 @@ const (
// filter chains contains the rules that jump to the rules chains // filter chains contains the rules that jump to the rules chains
chainNameInputFilter = "netbird-acl-input-filter" chainNameInputFilter = "netbird-acl-input-filter"
chainNameOutputFilter = "netbird-acl-output-filter"
chainNameForwardFilter = "netbird-acl-forward-filter" chainNameForwardFilter = "netbird-acl-forward-filter"
chainNamePrerouting = "netbird-rt-prerouting"
allowNetbirdInputRuleID = "allow Netbird incoming traffic" allowNetbirdInputRuleID = "allow Netbird incoming traffic"
) )
const flushError = "flush: %w"
var ( var (
anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
postroutingMark = []byte{0xe4, 0x7, 0x0, 0x00}
) )
type AclManager struct { type AclManager struct {
rConn *nftables.Conn rConn *nftables.Conn
sConn *nftables.Conn sConn *nftables.Conn
wgIface iFaceMapper wgIface iFaceMapper
routingFwChainName string routeingFwChainName string
workTable *nftables.Table workTable *nftables.Table
chainInputRules *nftables.Chain chainInputRules *nftables.Chain
chainOutputRules *nftables.Chain chainOutputRules *nftables.Chain
chainFwFilter *nftables.Chain
chainPrerouting *nftables.Chain
ipsetStore *ipsetStore ipsetStore *ipsetStore
rules map[string]*Rule rules map[string]*Rule
} }
func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainName string) (*AclManager, error) { // iFaceMapper defines subset methods of interface required for manager
type iFaceMapper interface {
Name() string
Address() iface.WGAddress
IsUserspaceBind() bool
}
func newAclManager(table *nftables.Table, wgIface iFaceMapper, routeingFwChainName string) (*AclManager, error) {
// sConn is used for creating sets and adding/removing elements from them // sConn is used for creating sets and adding/removing elements from them
// it's differ then rConn (which does create new conn for each flush operation) // it's differ then rConn (which does create new conn for each flush operation)
// and is permanent. Using same connection for both type of operations // and is permanent. Using same connection for booth type of operations
// overloads netlink with high amount of rules ( > 10000) // overloads netlink with high amount of rules ( > 10000)
sConn, err := nftables.New(nftables.AsLasting()) sConn, err := nftables.New(nftables.AsLasting())
if err != nil { if err != nil {
return nil, fmt.Errorf("create nf conn: %w", err) return nil, err
} }
return &AclManager{ m := &AclManager{
rConn: &nftables.Conn{}, rConn: &nftables.Conn{},
sConn: sConn, sConn: sConn,
wgIface: wgIface, wgIface: wgIface,
workTable: table, workTable: table,
routingFwChainName: routingFwChainName, routeingFwChainName: routeingFwChainName,
ipsetStore: newIpsetStore(), ipsetStore: newIpsetStore(),
rules: make(map[string]*Rule), rules: make(map[string]*Rule),
}, nil
} }
func (m *AclManager) init(workTable *nftables.Table) error { err = m.createDefaultChains()
m.workTable = workTable if err != nil {
return m.createDefaultChains() return nil, err
} }
// AddPeerFiltering rule to the firewall return m, nil
}
// AddFiltering rule to the firewall
// //
// If comment argument is empty firewall manager should set // If comment argument is empty firewall manager should set
// rule ID as comment for the rule // rule ID as comment for the rule
func (m *AclManager) AddPeerFiltering( func (m *AclManager) AddFiltering(
ip net.IP, ip net.IP,
proto firewall.Protocol, proto firewall.Protocol,
sPort *firewall.Port, sPort *firewall.Port,
@@ -110,11 +120,20 @@ func (m *AclManager) AddPeerFiltering(
} }
newRules = append(newRules, ioRule) newRules = append(newRules, ioRule)
if !shouldAddToPrerouting(proto, dPort, direction) {
return newRules, nil return newRules, nil
} }
// DeletePeerRule from the firewall by rule definition preroutingRule, err := m.addPreroutingFiltering(ipset, proto, dPort, ip)
func (m *AclManager) DeletePeerRule(rule firewall.Rule) error { if err != nil {
return newRules, err
}
newRules = append(newRules, preroutingRule)
return newRules, nil
}
// DeleteRule from the firewall by rule definition
func (m *AclManager) DeleteRule(rule firewall.Rule) error {
r, ok := rule.(*Rule) r, ok := rule.(*Rule)
if !ok { if !ok {
return fmt.Errorf("invalid rule type") return fmt.Errorf("invalid rule type")
@@ -180,7 +199,8 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
return nil return nil
} }
// createDefaultAllowRules creates default allow rules for the input and output chains // createDefaultAllowRules In case if the USP firewall manager can use the native firewall manager we must to create allow rules for
// input and output chains
func (m *AclManager) createDefaultAllowRules() error { func (m *AclManager) createDefaultAllowRules() error {
expIn := []expr.Any{ expIn := []expr.Any{
&expr.Payload{ &expr.Payload{
@@ -194,13 +214,13 @@ func (m *AclManager) createDefaultAllowRules() error {
SourceRegister: 1, SourceRegister: 1,
DestRegister: 1, DestRegister: 1,
Len: 4, Len: 4,
Mask: []byte{0, 0, 0, 0}, Mask: []byte{0x00, 0x00, 0x00, 0x00},
Xor: []byte{0, 0, 0, 0}, Xor: zeroXor,
}, },
// net address // net address
&expr.Cmp{ &expr.Cmp{
Register: 1, Register: 1,
Data: []byte{0, 0, 0, 0}, Data: []byte{0x00, 0x00, 0x00, 0x00},
}, },
&expr.Verdict{ &expr.Verdict{
Kind: expr.VerdictAccept, Kind: expr.VerdictAccept,
@@ -226,13 +246,13 @@ func (m *AclManager) createDefaultAllowRules() error {
SourceRegister: 1, SourceRegister: 1,
DestRegister: 1, DestRegister: 1,
Len: 4, Len: 4,
Mask: []byte{0, 0, 0, 0}, Mask: []byte{0x00, 0x00, 0x00, 0x00},
Xor: []byte{0, 0, 0, 0}, Xor: zeroXor,
}, },
// net address // net address
&expr.Cmp{ &expr.Cmp{
Register: 1, Register: 1,
Data: []byte{0, 0, 0, 0}, Data: []byte{0x00, 0x00, 0x00, 0x00},
}, },
&expr.Verdict{ &expr.Verdict{
Kind: expr.VerdictAccept, Kind: expr.VerdictAccept,
@@ -246,8 +266,10 @@ func (m *AclManager) createDefaultAllowRules() error {
Exprs: expOut, Exprs: expOut,
}) })
if err := m.rConn.Flush(); err != nil { err := m.rConn.Flush()
return fmt.Errorf(flushError, err) if err != nil {
log.Debugf("failed to create default allow rules: %s", err)
return err
} }
return nil return nil
} }
@@ -268,11 +290,15 @@ func (m *AclManager) Flush() error {
log.Errorf("failed to refresh rule handles IPv4 output chain: %v", err) log.Errorf("failed to refresh rule handles IPv4 output chain: %v", err)
} }
if err := m.refreshRuleHandles(m.chainPrerouting); err != nil {
log.Errorf("failed to refresh rule handles IPv4 prerouting chain: %v", err)
}
return nil return nil
} }
func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, ipset *nftables.Set, comment string) (*Rule, error) { func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, ipset *nftables.Set, comment string) (*Rule, error) {
ruleId := generatePeerRuleId(ip, sPort, dPort, direction, action, ipset) ruleId := generateRuleId(ip, sPort, dPort, direction, action, ipset)
if r, ok := m.rules[ruleId]; ok { if r, ok := m.rules[ruleId]; ok {
return &Rule{ return &Rule{
r.nftRule, r.nftRule,
@@ -282,7 +308,18 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f
}, nil }, nil
} }
var expressions []expr.Any ifaceKey := expr.MetaKeyIIFNAME
if direction == firewall.RuleDirectionOUT {
ifaceKey = expr.MetaKeyOIFNAME
}
expressions := []expr.Any{
&expr.Meta{Key: ifaceKey, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
}
if proto != firewall.ProtocolALL { if proto != firewall.ProtocolALL {
expressions = append(expressions, &expr.Payload{ expressions = append(expressions, &expr.Payload{
@@ -292,15 +329,21 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f
Len: uint32(1), Len: uint32(1),
}) })
protoData, err := protoToInt(proto) var protoData []byte
if err != nil { switch proto {
return nil, fmt.Errorf("convert protocol to number: %v", err) case firewall.ProtocolTCP:
protoData = []byte{unix.IPPROTO_TCP}
case firewall.ProtocolUDP:
protoData = []byte{unix.IPPROTO_UDP}
case firewall.ProtocolICMP:
protoData = []byte{unix.IPPROTO_ICMP}
default:
return nil, fmt.Errorf("unsupported protocol: %s", proto)
} }
expressions = append(expressions, &expr.Cmp{ expressions = append(expressions, &expr.Cmp{
Register: 1, Register: 1,
Op: expr.CmpOpEq, Op: expr.CmpOpEq,
Data: []byte{protoData}, Data: protoData,
}) })
} }
@@ -389,9 +432,10 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f
} else { } else {
chain = m.chainOutputRules chain = m.chainOutputRules
} }
nftRule := m.rConn.AddRule(&nftables.Rule{ nftRule := m.rConn.InsertRule(&nftables.Rule{
Table: m.workTable, Table: m.workTable,
Chain: chain, Chain: chain,
Position: 0,
Exprs: expressions, Exprs: expressions,
UserData: userData, UserData: userData,
}) })
@@ -409,13 +453,139 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f
return rule, nil return rule, nil
} }
func (m *AclManager) addPreroutingFiltering(ipset *nftables.Set, proto firewall.Protocol, port *firewall.Port, ip net.IP) (*Rule, error) {
var protoData []byte
switch proto {
case firewall.ProtocolTCP:
protoData = []byte{unix.IPPROTO_TCP}
case firewall.ProtocolUDP:
protoData = []byte{unix.IPPROTO_UDP}
case firewall.ProtocolICMP:
protoData = []byte{unix.IPPROTO_ICMP}
default:
return nil, fmt.Errorf("unsupported protocol: %s", proto)
}
ruleId := generateRuleIdForMangle(ipset, ip, proto, port)
if r, ok := m.rules[ruleId]; ok {
return &Rule{
r.nftRule,
r.nftSet,
r.ruleID,
ip,
}, nil
}
var ipExpression expr.Any
// add individual IP for match if no ipset defined
rawIP := ip.To4()
if ipset == nil {
ipExpression = &expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: rawIP,
}
} else {
ipExpression = &expr.Lookup{
SourceRegister: 1,
SetName: ipset.Name,
SetID: ipset.ID,
}
}
expressions := []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 12,
Len: 4,
},
ipExpression,
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 16,
Len: 4,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: m.wgIface.Address().IP.To4(),
},
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: uint32(9),
Len: uint32(1),
},
&expr.Cmp{
Register: 1,
Op: expr.CmpOpEq,
Data: protoData,
},
}
if port != nil {
expressions = append(expressions,
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: encodePort(*port),
},
)
}
expressions = append(expressions,
&expr.Immediate{
Register: 1,
Data: postroutingMark,
},
&expr.Meta{
Key: expr.MetaKeyMARK,
SourceRegister: true,
Register: 1,
},
)
nftRule := m.rConn.InsertRule(&nftables.Rule{
Table: m.workTable,
Chain: m.chainPrerouting,
Position: 0,
Exprs: expressions,
UserData: []byte(ruleId),
})
if err := m.rConn.Flush(); err != nil {
return nil, fmt.Errorf("flush insert rule: %v", err)
}
rule := &Rule{
nftRule: nftRule,
nftSet: ipset,
ruleID: ruleId,
ip: ip,
}
m.rules[ruleId] = rule
if ipset != nil {
m.ipsetStore.AddReferenceToIpset(ipset.Name)
}
return rule, nil
}
func (m *AclManager) createDefaultChains() (err error) { func (m *AclManager) createDefaultChains() (err error) {
// chainNameInputRules // chainNameInputRules
chain := m.createChain(chainNameInputRules) chain := m.createChain(chainNameInputRules)
err = m.rConn.Flush() err = m.rConn.Flush()
if err != nil { if err != nil {
log.Debugf("failed to create chain (%s): %s", chain.Name, err) log.Debugf("failed to create chain (%s): %s", chain.Name, err)
return fmt.Errorf(flushError, err) return err
} }
m.chainInputRules = chain m.chainInputRules = chain
@@ -431,6 +601,9 @@ func (m *AclManager) createDefaultChains() (err error) {
// netbird-acl-input-filter // netbird-acl-input-filter
// type filter hook input priority filter; policy accept; // type filter hook input priority filter; policy accept;
chain = m.createFilterChainWithHook(chainNameInputFilter, nftables.ChainHookInput) chain = m.createFilterChainWithHook(chainNameInputFilter, nftables.ChainHookInput)
//netbird-acl-input-filter iifname "wt0" ip saddr 100.72.0.0/16 ip daddr != 100.72.0.0/16 accept
m.addRouteAllowRule(chain, expr.MetaKeyIIFNAME)
m.addFwdAllow(chain, expr.MetaKeyIIFNAME)
m.addJumpRule(chain, m.chainInputRules.Name, expr.MetaKeyIIFNAME) // to netbird-acl-input-rules m.addJumpRule(chain, m.chainInputRules.Name, expr.MetaKeyIIFNAME) // to netbird-acl-input-rules
m.addDropExpressions(chain, expr.MetaKeyIIFNAME) m.addDropExpressions(chain, expr.MetaKeyIIFNAME)
err = m.rConn.Flush() err = m.rConn.Flush()
@@ -439,107 +612,43 @@ func (m *AclManager) createDefaultChains() (err error) {
return err return err
} }
// netbird-acl-forward-filter // netbird-acl-output-filter
chainFwFilter := m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward) // type filter hook output priority filter; policy accept;
m.addJumpRulesToRtForward(chainFwFilter) // to netbird-rt-fwd chain = m.createFilterChainWithHook(chainNameOutputFilter, nftables.ChainHookOutput)
m.addDropExpressions(chainFwFilter, expr.MetaKeyIIFNAME) m.addRouteAllowRule(chain, expr.MetaKeyOIFNAME)
m.addFwdAllow(chain, expr.MetaKeyOIFNAME)
m.addJumpRule(chain, m.chainOutputRules.Name, expr.MetaKeyOIFNAME) // to netbird-acl-output-rules
m.addDropExpressions(chain, expr.MetaKeyOIFNAME)
err = m.rConn.Flush()
if err != nil {
log.Debugf("failed to create chain (%s): %s", chainNameOutputFilter, err)
return err
}
// netbird-acl-forward-filter
m.chainFwFilter = m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward)
m.addJumpRulesToRtForward() // to
m.addMarkAccept()
m.addJumpRuleToInputChain() // to netbird-acl-input-rules
m.addDropExpressions(m.chainFwFilter, expr.MetaKeyIIFNAME)
err = m.rConn.Flush() err = m.rConn.Flush()
if err != nil { if err != nil {
log.Debugf("failed to create chain (%s): %s", chainNameForwardFilter, err) log.Debugf("failed to create chain (%s): %s", chainNameForwardFilter, err)
return fmt.Errorf(flushError, err) return err
} }
if err := m.allowRedirectedTraffic(chainFwFilter); err != nil { // netbird-acl-output-filter
log.Errorf("failed to allow redirected traffic: %s", err) // type filter hook output priority filter; policy accept;
m.chainPrerouting = m.createPreroutingMangle()
err = m.rConn.Flush()
if err != nil {
log.Debugf("failed to create chain (%s): %s", m.chainPrerouting.Name, err)
return err
} }
return nil return nil
} }
// Makes redirected traffic originally destined for the host itself (now subject to the forward filter) func (m *AclManager) addJumpRulesToRtForward() {
// go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the
// netbird peer IP.
func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error {
preroutingChain := m.rConn.AddChain(&nftables.Chain{
Name: chainNamePrerouting,
Table: m.workTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityMangle,
})
m.addPreroutingRule(preroutingChain)
m.addFwmarkToForward(chainFwFilter)
if err := m.rConn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
return nil
}
func (m *AclManager) addPreroutingRule(preroutingChain *nftables.Chain) {
m.rConn.AddRule(&nftables.Rule{
Table: m.workTable,
Chain: preroutingChain,
Exprs: []expr.Any{
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Fib{
Register: 1,
ResultADDRTYPE: true,
FlagDADDR: true,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(unix.RTN_LOCAL),
},
&expr.Immediate{
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
},
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
SourceRegister: true,
},
},
})
}
func (m *AclManager) addFwmarkToForward(chainFwFilter *nftables.Chain) {
m.rConn.InsertRule(&nftables.Rule{
Table: m.workTable,
Chain: chainFwFilter,
Exprs: []expr.Any{
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
},
&expr.Verdict{
Kind: expr.VerdictJump,
Chain: m.chainInputRules.Name,
},
},
})
}
func (m *AclManager) addJumpRulesToRtForward(chainFwFilter *nftables.Chain) {
expressions := []expr.Any{ expressions := []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{ &expr.Cmp{
@@ -549,15 +658,68 @@ func (m *AclManager) addJumpRulesToRtForward(chainFwFilter *nftables.Chain) {
}, },
&expr.Verdict{ &expr.Verdict{
Kind: expr.VerdictJump, Kind: expr.VerdictJump,
Chain: m.routingFwChainName, Chain: m.routeingFwChainName,
}, },
} }
_ = m.rConn.AddRule(&nftables.Rule{ _ = m.rConn.AddRule(&nftables.Rule{
Table: m.workTable, Table: m.workTable,
Chain: chainFwFilter, Chain: m.chainFwFilter,
Exprs: expressions, Exprs: expressions,
}) })
expressions = []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Verdict{
Kind: expr.VerdictJump,
Chain: m.routeingFwChainName,
},
}
_ = m.rConn.AddRule(&nftables.Rule{
Table: m.workTable,
Chain: m.chainFwFilter,
Exprs: expressions,
})
}
func (m *AclManager) addMarkAccept() {
// oifname "wt0" meta mark 0x000007e4 accept
// iifname "wt0" meta mark 0x000007e4 accept
ifaces := []expr.MetaKey{expr.MetaKeyIIFNAME, expr.MetaKeyOIFNAME}
for _, iface := range ifaces {
expressions := []expr.Any{
&expr.Meta{Key: iface, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: postroutingMark,
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
}
_ = m.rConn.AddRule(&nftables.Rule{
Table: m.workTable,
Chain: m.chainFwFilter,
Exprs: expressions,
})
}
} }
func (m *AclManager) createChain(name string) *nftables.Chain { func (m *AclManager) createChain(name string) *nftables.Chain {
@@ -567,13 +729,10 @@ func (m *AclManager) createChain(name string) *nftables.Chain {
} }
chain = m.rConn.AddChain(chain) chain = m.rConn.AddChain(chain)
insertReturnTrafficRule(m.rConn, m.workTable, chain)
return chain return chain
} }
func (m *AclManager) createFilterChainWithHook(name string, hookNum *nftables.ChainHook) *nftables.Chain { func (m *AclManager) createFilterChainWithHook(name string, hookNum nftables.ChainHook) *nftables.Chain {
polAccept := nftables.ChainPolicyAccept polAccept := nftables.ChainPolicyAccept
chain := &nftables.Chain{ chain := &nftables.Chain{
Name: name, Name: name,
@@ -587,6 +746,74 @@ func (m *AclManager) createFilterChainWithHook(name string, hookNum *nftables.Ch
return m.rConn.AddChain(chain) return m.rConn.AddChain(chain)
} }
func (m *AclManager) createPreroutingMangle() *nftables.Chain {
polAccept := nftables.ChainPolicyAccept
chain := &nftables.Chain{
Name: "netbird-acl-prerouting-filter",
Table: m.workTable,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityMangle,
Type: nftables.ChainTypeFilter,
Policy: &polAccept,
}
chain = m.rConn.AddChain(chain)
ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4())
expressions := []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Payload{
DestRegister: 2,
Base: expr.PayloadBaseNetworkHeader,
Offset: 12,
Len: 4,
},
&expr.Bitwise{
SourceRegister: 2,
DestRegister: 2,
Len: 4,
Xor: []byte{0x0, 0x0, 0x0, 0x0},
Mask: m.wgIface.Address().Network.Mask,
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 2,
Data: ip.Unmap().AsSlice(),
},
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 16,
Len: 4,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: m.wgIface.Address().IP.To4(),
},
&expr.Immediate{
Register: 1,
Data: postroutingMark,
},
&expr.Meta{
Key: expr.MetaKeyMARK,
SourceRegister: true,
Register: 1,
},
}
_ = m.rConn.AddRule(&nftables.Rule{
Table: m.workTable,
Chain: chain,
Exprs: expressions,
})
return chain
}
func (m *AclManager) addDropExpressions(chain *nftables.Chain, ifaceKey expr.MetaKey) []expr.Any { func (m *AclManager) addDropExpressions(chain *nftables.Chain, ifaceKey expr.MetaKey) []expr.Any {
expressions := []expr.Any{ expressions := []expr.Any{
&expr.Meta{Key: ifaceKey, Register: 1}, &expr.Meta{Key: ifaceKey, Register: 1},
@@ -605,9 +832,9 @@ func (m *AclManager) addDropExpressions(chain *nftables.Chain, ifaceKey expr.Met
return nil return nil
} }
func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr.MetaKey) { func (m *AclManager) addJumpRuleToInputChain() {
expressions := []expr.Any{ expressions := []expr.Any{
&expr.Meta{Key: ifaceKey, Register: 1}, &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{ &expr.Cmp{
Op: expr.CmpOpEq, Op: expr.CmpOpEq,
Register: 1, Register: 1,
@@ -615,10 +842,195 @@ func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr
}, },
&expr.Verdict{ &expr.Verdict{
Kind: expr.VerdictJump, Kind: expr.VerdictJump,
Chain: to, Chain: m.chainInputRules.Name,
}, },
} }
_ = m.rConn.AddRule(&nftables.Rule{
Table: m.workTable,
Chain: m.chainFwFilter,
Exprs: expressions,
})
}
func (m *AclManager) addRouteAllowRule(chain *nftables.Chain, netIfName expr.MetaKey) {
ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4())
var srcOp, dstOp expr.CmpOp
if netIfName == expr.MetaKeyIIFNAME {
srcOp = expr.CmpOpEq
dstOp = expr.CmpOpNeq
} else {
srcOp = expr.CmpOpNeq
dstOp = expr.CmpOpEq
}
expressions := []expr.Any{
&expr.Meta{Key: netIfName, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Payload{
DestRegister: 2,
Base: expr.PayloadBaseNetworkHeader,
Offset: 12,
Len: 4,
},
&expr.Bitwise{
SourceRegister: 2,
DestRegister: 2,
Len: 4,
Xor: []byte{0x0, 0x0, 0x0, 0x0},
Mask: m.wgIface.Address().Network.Mask,
},
&expr.Cmp{
Op: srcOp,
Register: 2,
Data: ip.Unmap().AsSlice(),
},
&expr.Payload{
DestRegister: 2,
Base: expr.PayloadBaseNetworkHeader,
Offset: 16,
Len: 4,
},
&expr.Bitwise{
SourceRegister: 2,
DestRegister: 2,
Len: 4,
Xor: []byte{0x0, 0x0, 0x0, 0x0},
Mask: m.wgIface.Address().Network.Mask,
},
&expr.Cmp{
Op: dstOp,
Register: 2,
Data: ip.Unmap().AsSlice(),
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
}
_ = m.rConn.AddRule(&nftables.Rule{
Table: chain.Table,
Chain: chain,
Exprs: expressions,
})
}
func (m *AclManager) addFwdAllow(chain *nftables.Chain, iifname expr.MetaKey) {
ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4())
var srcOp, dstOp expr.CmpOp
if iifname == expr.MetaKeyIIFNAME {
srcOp = expr.CmpOpNeq
dstOp = expr.CmpOpEq
} else {
srcOp = expr.CmpOpEq
dstOp = expr.CmpOpNeq
}
expressions := []expr.Any{
&expr.Meta{Key: iifname, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Payload{
DestRegister: 2,
Base: expr.PayloadBaseNetworkHeader,
Offset: 12,
Len: 4,
},
&expr.Bitwise{
SourceRegister: 2,
DestRegister: 2,
Len: 4,
Xor: []byte{0x0, 0x0, 0x0, 0x0},
Mask: m.wgIface.Address().Network.Mask,
},
&expr.Cmp{
Op: srcOp,
Register: 2,
Data: ip.Unmap().AsSlice(),
},
&expr.Payload{
DestRegister: 2,
Base: expr.PayloadBaseNetworkHeader,
Offset: 16,
Len: 4,
},
&expr.Bitwise{
SourceRegister: 2,
DestRegister: 2,
Len: 4,
Xor: []byte{0x0, 0x0, 0x0, 0x0},
Mask: m.wgIface.Address().Network.Mask,
},
&expr.Cmp{
Op: dstOp,
Register: 2,
Data: ip.Unmap().AsSlice(),
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
}
_ = m.rConn.AddRule(&nftables.Rule{
Table: chain.Table,
Chain: chain,
Exprs: expressions,
})
}
func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr.MetaKey) {
ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4())
expressions := []expr.Any{
&expr.Meta{Key: ifaceKey, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Payload{
DestRegister: 2,
Base: expr.PayloadBaseNetworkHeader,
Offset: 12,
Len: 4,
},
&expr.Bitwise{
SourceRegister: 2,
DestRegister: 2,
Len: 4,
Xor: []byte{0x0, 0x0, 0x0, 0x0},
Mask: m.wgIface.Address().Network.Mask,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 2,
Data: ip.Unmap().AsSlice(),
},
&expr.Payload{
DestRegister: 2,
Base: expr.PayloadBaseNetworkHeader,
Offset: 16,
Len: 4,
},
&expr.Bitwise{
SourceRegister: 2,
DestRegister: 2,
Len: 4,
Xor: []byte{0x0, 0x0, 0x0, 0x0},
Mask: m.wgIface.Address().Network.Mask,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 2,
Data: ip.Unmap().AsSlice(),
},
&expr.Verdict{
Kind: expr.VerdictJump,
Chain: to,
},
}
_ = m.rConn.AddRule(&nftables.Rule{ _ = m.rConn.AddRule(&nftables.Rule{
Table: chain.Table, Table: chain.Table,
Chain: chain, Chain: chain,
@@ -720,7 +1132,7 @@ func (m *AclManager) refreshRuleHandles(chain *nftables.Chain) error {
return nil return nil
} }
func generatePeerRuleId( func generateRuleId(
ip net.IP, ip net.IP,
sPort *firewall.Port, sPort *firewall.Port,
dPort *firewall.Port, dPort *firewall.Port,
@@ -743,6 +1155,33 @@ func generatePeerRuleId(
} }
return "set:" + ipset.Name + rulesetID return "set:" + ipset.Name + rulesetID
} }
func generateRuleIdForMangle(ipset *nftables.Set, ip net.IP, proto firewall.Protocol, port *firewall.Port) string {
// case of icmp port is empty
var p string
if port != nil {
p = port.String()
}
if ipset != nil {
return fmt.Sprintf("p:set:%s:%s:%v", ipset.Name, proto, p)
} else {
return fmt.Sprintf("p:ip:%s:%s:%v", ip.String(), proto, p)
}
}
func shouldAddToPrerouting(proto firewall.Protocol, dPort *firewall.Port, direction firewall.RuleDirection) bool {
if proto == "all" {
return false
}
if direction != firewall.RuleDirectionIN {
return false
}
if dPort == nil && proto != firewall.ProtocolICMP {
return false
}
return true
}
func encodePort(port firewall.Port) []byte { func encodePort(port firewall.Port) []byte {
bs := make([]byte, 2) bs := make([]byte, 2)
@@ -752,19 +1191,6 @@ func encodePort(port firewall.Port) []byte {
func ifname(n string) []byte { func ifname(n string) []byte {
b := make([]byte, 16) b := make([]byte, 16)
copy(b, n+"\x00") copy(b, []byte(n+"\x00"))
return b return b
} }
func protoToInt(protocol firewall.Protocol) (uint8, error) {
switch protocol {
case firewall.ProtocolTCP:
return unix.IPPROTO_TCP, nil
case firewall.ProtocolUDP:
return unix.IPPROTO_UDP, nil
case firewall.ProtocolICMP:
return unix.IPPROTO_ICMP, nil
}
return 0, fmt.Errorf("unsupported protocol: %s", protocol)
}

View File

@@ -5,34 +5,20 @@ import (
"context" "context"
"fmt" "fmt"
"net" "net"
"net/netip"
"sync" "sync"
"github.com/google/nftables" "github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr" "github.com/google/nftables/expr"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/statemanager"
) )
const ( const (
// tableNameNetbird is the name of the table that is used for filtering by the Netbird client // tableName is the name of the table that is used for filtering by the Netbird client
tableNameNetbird = "netbird" tableName = "netbird"
tableNameFilter = "filter"
chainNameInput = "INPUT"
) )
// iFaceMapper defines subset methods of interface required for manager
type iFaceMapper interface {
Name() string
Address() iface.WGAddress
IsUserspaceBind() bool
}
// Manager of iptables firewall // Manager of iptables firewall
type Manager struct { type Manager struct {
mutex sync.Mutex mutex sync.Mutex
@@ -44,75 +30,35 @@ type Manager struct {
} }
// Create nftables firewall manager // Create nftables firewall manager
func Create(wgIface iFaceMapper) (*Manager, error) { func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
m := &Manager{ m := &Manager{
rConn: &nftables.Conn{}, rConn: &nftables.Conn{},
wgIface: wgIface, wgIface: wgIface,
} }
workTable := &nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4} workTable, err := m.createWorkTable()
var err error
m.router, err = newRouter(workTable, wgIface)
if err != nil { if err != nil {
return nil, fmt.Errorf("create router: %w", err) return nil, err
} }
m.aclManager, err = newAclManager(workTable, wgIface, chainNameRoutingFw) m.router, err = newRouter(context, workTable)
if err != nil { if err != nil {
return nil, fmt.Errorf("create acl manager: %w", err) return nil, err
}
m.aclManager, err = newAclManager(workTable, wgIface, m.router.RouteingFwChainName())
if err != nil {
return nil, err
} }
return m, nil return m, nil
} }
// Init nftables firewall manager // AddFiltering rule to the firewall
func (m *Manager) Init(stateManager *statemanager.Manager) error {
workTable, err := m.createWorkTable()
if err != nil {
return fmt.Errorf("create work table: %w", err)
}
if err := m.router.init(workTable); err != nil {
return fmt.Errorf("router init: %w", err)
}
if err := m.aclManager.init(workTable); err != nil {
// TODO: cleanup router
return fmt.Errorf("acl manager init: %w", err)
}
stateManager.RegisterState(&ShutdownState{})
// We only need to record minimal interface state for potential recreation.
// Unlike iptables, which requires tracking individual rules, nftables maintains
// a known state (our netbird table plus a few static rules). This allows for easy
// cleanup using Reset() without needing to store specific rules.
if err := stateManager.UpdateState(&ShutdownState{
InterfaceState: &InterfaceState{
NameStr: m.wgIface.Name(),
WGAddress: m.wgIface.Address(),
UserspaceBind: m.wgIface.IsUserspaceBind(),
},
}); err != nil {
log.Errorf("failed to update state: %v", err)
}
// persist early
go func() {
if err := stateManager.PersistState(context.Background()); err != nil {
log.Errorf("failed to persist state: %v", err)
}
}()
return nil
}
// AddPeerFiltering rule to the firewall
// //
// If comment argument is empty firewall manager should set // If comment argument is empty firewall manager should set
// rule ID as comment for the rule // rule ID as comment for the rule
func (m *Manager) AddPeerFiltering( func (m *Manager) AddFiltering(
ip net.IP, ip net.IP,
proto firewall.Protocol, proto firewall.Protocol,
sPort *firewall.Port, sPort *firewall.Port,
@@ -130,52 +76,33 @@ func (m *Manager) AddPeerFiltering(
return nil, fmt.Errorf("unsupported IP version: %s", ip.String()) return nil, fmt.Errorf("unsupported IP version: %s", ip.String())
} }
return m.aclManager.AddPeerFiltering(ip, proto, sPort, dPort, direction, action, ipsetName, comment) return m.aclManager.AddFiltering(ip, proto, sPort, dPort, direction, action, ipsetName, comment)
} }
func (m *Manager) AddRouteFiltering(sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action) (firewall.Rule, error) { // DeleteRule from the firewall by rule definition
func (m *Manager) DeleteRule(rule firewall.Rule) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if !destination.Addr().Is4() { return m.aclManager.DeleteRule(rule)
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
}
return m.router.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
}
// DeletePeerRule from the firewall by rule definition
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.aclManager.DeletePeerRule(rule)
}
// DeleteRouteRule deletes a routing rule
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.DeleteRouteRule(rule)
} }
func (m *Manager) IsServerRouteSupported() bool { func (m *Manager) IsServerRouteSupported() bool {
return true return true
} }
func (m *Manager) AddNatRule(pair firewall.RouterPair) error { func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
return m.router.AddNatRule(pair) return m.router.AddRoutingRules(pair)
} }
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error { func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
return m.router.RemoveNatRule(pair) return m.router.RemoveRoutingRules(pair)
} }
// AllowNetbird allows netbird interface traffic // AllowNetbird allows netbird interface traffic
@@ -199,7 +126,7 @@ func (m *Manager) AllowNetbird() error {
var chain *nftables.Chain var chain *nftables.Chain
for _, c := range chains { for _, c := range chains {
if c.Table.Name == tableNameFilter && c.Name == chainNameInput { if c.Table.Name == "filter" && c.Name == "INPUT" {
chain = c chain = c
break break
} }
@@ -230,65 +157,24 @@ func (m *Manager) AllowNetbird() error {
return nil return nil
} }
// SetLegacyManagement sets the route manager to use legacy management
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
return firewall.SetLegacyManagement(m.router, isLegacy)
}
// Reset firewall to the default state // Reset firewall to the default state
func (m *Manager) Reset(stateManager *statemanager.Manager) error { func (m *Manager) Reset() error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if err := m.resetNetbirdInputRules(); err != nil {
return fmt.Errorf("reset netbird input rules: %v", err)
}
if err := m.router.Reset(); err != nil {
return fmt.Errorf("reset router: %v", err)
}
if err := m.cleanupNetbirdTables(); err != nil {
return fmt.Errorf("cleanup netbird tables: %v", err)
}
if err := m.rConn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
if err := stateManager.DeleteState(&ShutdownState{}); err != nil {
return fmt.Errorf("delete state: %v", err)
}
return nil
}
func (m *Manager) resetNetbirdInputRules() error {
chains, err := m.rConn.ListChains() chains, err := m.rConn.ListChains()
if err != nil { if err != nil {
return fmt.Errorf("list chains: %w", err) return fmt.Errorf("list of chains: %w", err)
} }
m.deleteNetbirdInputRules(chains)
return nil
}
func (m *Manager) deleteNetbirdInputRules(chains []*nftables.Chain) {
for _, c := range chains { for _, c := range chains {
if c.Table.Name == tableNameFilter && c.Name == chainNameInput { // delete Netbird allow input traffic rule if it exists
if c.Table.Name == "filter" && c.Name == "INPUT" {
rules, err := m.rConn.GetRules(c.Table, c) rules, err := m.rConn.GetRules(c.Table, c)
if err != nil { if err != nil {
log.Errorf("get rules for chain %q: %v", c.Name, err) log.Errorf("get rules for chain %q: %v", c.Name, err)
continue continue
} }
m.deleteMatchingRules(rules)
}
}
}
func (m *Manager) deleteMatchingRules(rules []*nftables.Rule) {
for _, r := range rules { for _, r := range rules {
if bytes.Equal(r.UserData, []byte(allowNetbirdInputRuleID)) { if bytes.Equal(r.UserData, []byte(allowNetbirdInputRuleID)) {
if err := m.rConn.DelRule(r); err != nil { if err := m.rConn.DelRule(r); err != nil {
@@ -297,19 +183,21 @@ func (m *Manager) deleteMatchingRules(rules []*nftables.Rule) {
} }
} }
} }
func (m *Manager) cleanupNetbirdTables() error {
tables, err := m.rConn.ListTables()
if err != nil {
return fmt.Errorf("list tables: %w", err)
} }
m.router.ResetForwardRules()
tables, err := m.rConn.ListTables()
if err != nil {
return fmt.Errorf("list of tables: %w", err)
}
for _, t := range tables { for _, t := range tables {
if t.Name == tableNameNetbird { if t.Name == tableName {
m.rConn.DelTable(t) m.rConn.DelTable(t)
} }
} }
return nil
return m.rConn.Flush()
} }
// Flush rule/chain/set operations from the buffer // Flush rule/chain/set operations from the buffer
@@ -323,11 +211,6 @@ func (m *Manager) Flush() error {
return m.aclManager.Flush() return m.aclManager.Flush()
} }
// CollectStats returns connection tracking statistics
func (m *Manager) CollectStats() []*firewall.FlowStats {
return nil
}
func (m *Manager) createWorkTable() (*nftables.Table, error) { func (m *Manager) createWorkTable() (*nftables.Table, error) {
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4) tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil { if err != nil {
@@ -335,12 +218,12 @@ func (m *Manager) createWorkTable() (*nftables.Table, error) {
} }
for _, t := range tables { for _, t := range tables {
if t.Name == tableNameNetbird { if t.Name == tableName {
m.rConn.DelTable(t) m.rConn.DelTable(t)
} }
} }
table := m.rConn.AddTable(&nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4}) table := m.rConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4})
err = m.rConn.Flush() err = m.rConn.Flush()
return table, err return table, err
} }
@@ -368,7 +251,7 @@ func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) {
func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftables.Rule { func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftables.Rule {
ifName := ifname(m.wgIface.Name()) ifName := ifname(m.wgIface.Name())
for _, rule := range existedRules { for _, rule := range existedRules {
if rule.Table.Name == tableNameFilter && rule.Chain.Name == chainNameInput { if rule.Table.Name == "filter" && rule.Chain.Name == "INPUT" {
if len(rule.Exprs) < 4 { if len(rule.Exprs) < 4 {
if e, ok := rule.Exprs[0].(*expr.Meta); !ok || e.Key != expr.MetaKeyIIFNAME { if e, ok := rule.Exprs[0].(*expr.Meta); !ok || e.Key != expr.MetaKeyIIFNAME {
continue continue
@@ -382,38 +265,3 @@ func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftable
} }
return nil return nil
} }
func insertReturnTrafficRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain) {
rule := &nftables.Rule{
Table: table,
Chain: chain,
Exprs: getEstablishedExprs(1),
}
conn.InsertRule(rule)
}
func getEstablishedExprs(register uint32) []expr.Any {
return []expr.Any{
&expr.Ct{
Key: expr.CtKeySTATE,
Register: register,
},
&expr.Bitwise{
SourceRegister: register,
DestRegister: register,
Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED),
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: register,
Data: []byte{0, 0, 0, 0},
},
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
}
}

View File

@@ -1,39 +1,22 @@
package nftables package nftables
import ( import (
"bytes" "context"
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"os/exec"
"testing" "testing"
"time" "time"
"github.com/google/nftables" "github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr" "github.com/google/nftables/expr"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
fw "github.com/netbirdio/netbird/client/firewall/manager" fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/iface"
) )
var ifaceMock = &iFaceMock{
NameFunc: func() string {
return "lo"
},
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
IP: net.ParseIP("100.96.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("100.96.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
}
},
}
// iFaceMapper defines subset methods of interface required for manager // iFaceMapper defines subset methods of interface required for manager
type iFaceMock struct { type iFaceMock struct {
NameFunc func() string NameFunc func() string
@@ -57,15 +40,28 @@ func (i *iFaceMock) Address() iface.WGAddress {
func (i *iFaceMock) IsUserspaceBind() bool { return false } func (i *iFaceMock) IsUserspaceBind() bool { return false }
func TestNftablesManager(t *testing.T) { func TestNftablesManager(t *testing.T) {
mock := &iFaceMock{
NameFunc: func() string {
return "lo"
},
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
IP: net.ParseIP("100.96.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("100.96.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
}
},
}
// just check on the local interface // just check on the local interface
manager, err := Create(ifaceMock) manager, err := Create(context.Background(), mock)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, manager.Init(nil))
time.Sleep(time.Second * 3) time.Sleep(time.Second * 3)
defer func() { defer func() {
err = manager.Reset(nil) err = manager.Reset()
require.NoError(t, err, "failed to reset") require.NoError(t, err, "failed to reset")
time.Sleep(time.Second) time.Sleep(time.Second)
}() }()
@@ -74,7 +70,7 @@ func TestNftablesManager(t *testing.T) {
testClient := &nftables.Conn{} testClient := &nftables.Conn{}
rule, err := manager.AddPeerFiltering( rule, err := manager.AddFiltering(
ip, ip,
fw.ProtocolTCP, fw.ProtocolTCP,
nil, nil,
@@ -92,35 +88,17 @@ func TestNftablesManager(t *testing.T) {
rules, err := testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules) rules, err := testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
require.NoError(t, err, "failed to get rules") require.NoError(t, err, "failed to get rules")
require.Len(t, rules, 2, "expected 2 rules") require.Len(t, rules, 1, "expected 1 rules")
expectedExprs1 := []expr.Any{
&expr.Ct{
Key: expr.CtKeySTATE,
Register: 1,
},
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED),
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: []byte{0, 0, 0, 0},
},
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
}
require.ElementsMatch(t, rules[0].Exprs, expectedExprs1, "expected the same expressions")
ipToAdd, _ := netip.AddrFromSlice(ip) ipToAdd, _ := netip.AddrFromSlice(ip)
add := ipToAdd.Unmap() add := ipToAdd.Unmap()
expectedExprs2 := []expr.Any{ expectedExprs := []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname("lo"),
},
&expr.Payload{ &expr.Payload{
DestRegister: 1, DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader, Base: expr.PayloadBaseNetworkHeader,
@@ -156,10 +134,10 @@ func TestNftablesManager(t *testing.T) {
}, },
&expr.Verdict{Kind: expr.VerdictDrop}, &expr.Verdict{Kind: expr.VerdictDrop},
} }
require.ElementsMatch(t, rules[1].Exprs, expectedExprs2, "expected the same expressions") require.ElementsMatch(t, rules[0].Exprs, expectedExprs, "expected the same expressions")
for _, r := range rule { for _, r := range rule {
err = manager.DeletePeerRule(r) err = manager.DeleteRule(r)
require.NoError(t, err, "failed to delete rule") require.NoError(t, err, "failed to delete rule")
} }
@@ -168,10 +146,9 @@ func TestNftablesManager(t *testing.T) {
rules, err = testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules) rules, err = testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
require.NoError(t, err, "failed to get rules") require.NoError(t, err, "failed to get rules")
// established rule remains require.Len(t, rules, 0, "expected 0 rules after deletion")
require.Len(t, rules, 1, "expected 1 rules after deletion")
err = manager.Reset(nil) err = manager.Reset()
require.NoError(t, err, "failed to reset") require.NoError(t, err, "failed to reset")
} }
@@ -194,13 +171,12 @@ func TestNFtablesCreatePerformance(t *testing.T) {
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} { for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) { t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
// just check on the local interface // just check on the local interface
manager, err := Create(mock) manager, err := Create(context.Background(), mock)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, manager.Init(nil))
time.Sleep(time.Second * 3) time.Sleep(time.Second * 3)
defer func() { defer func() {
if err := manager.Reset(nil); err != nil { if err := manager.Reset(); err != nil {
t.Errorf("clear the manager state: %v", err) t.Errorf("clear the manager state: %v", err)
} }
time.Sleep(time.Second) time.Sleep(time.Second)
@@ -211,9 +187,9 @@ func TestNFtablesCreatePerformance(t *testing.T) {
for i := 0; i < testMax; i++ { for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []int{1000 + i}} port := &fw.Port{Values: []int{1000 + i}}
if i%2 == 0 { if i%2 == 0 {
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic") _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
} else { } else {
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic") _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
} }
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
@@ -227,105 +203,3 @@ func TestNFtablesCreatePerformance(t *testing.T) {
}) })
} }
} }
func runIptablesSave(t *testing.T) (string, string) {
t.Helper()
var stdout, stderr bytes.Buffer
cmd := exec.Command("iptables-save")
cmd.Stdout = &stdout
cmd.Stderr = &stderr
err := cmd.Run()
require.NoError(t, err, "iptables-save failed to run")
return stdout.String(), stderr.String()
}
func verifyIptablesOutput(t *testing.T, stdout, stderr string) {
t.Helper()
// Check for any incompatibility warnings
require.NotContains(t,
stderr,
"incompatible",
"iptables-save produced compatibility warning. Full stderr: %s",
stderr,
)
// Verify standard tables are present
expectedTables := []string{
"*filter",
"*nat",
"*mangle",
}
for _, table := range expectedTables {
require.Contains(t,
stdout,
table,
"iptables-save output missing expected table: %s\nFull stdout: %s",
table,
stdout,
)
}
}
func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
if _, err := exec.LookPath("iptables-save"); err != nil {
t.Skipf("iptables-save not available on this system: %v", err)
}
// First ensure iptables-nft tables exist by running iptables-save
stdout, stderr := runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr)
manager, err := Create(ifaceMock)
require.NoError(t, err, "failed to create manager")
require.NoError(t, manager.Init(nil))
t.Cleanup(func() {
err := manager.Reset(nil)
require.NoError(t, err, "failed to reset manager state")
// Verify iptables output after reset
stdout, stderr := runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr)
})
ip := net.ParseIP("100.96.0.1")
_, err = manager.AddPeerFiltering(
ip,
fw.ProtocolTCP,
nil,
&fw.Port{Values: []int{80}},
fw.RuleDirectionIN,
fw.ActionAccept,
"",
"test rule",
)
require.NoError(t, err, "failed to add peer filtering rule")
_, err = manager.AddRouteFiltering(
[]netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")},
netip.MustParsePrefix("10.1.0.0/24"),
fw.ProtocolTCP,
nil,
&fw.Port{Values: []int{443}},
fw.ActionAccept,
)
require.NoError(t, err, "failed to add route filtering rule")
pair := fw.RouterPair{
Source: netip.MustParsePrefix("192.168.1.0/24"),
Destination: netip.MustParsePrefix("10.0.0.0/24"),
Masquerade: true,
}
err = manager.AddNatRule(pair)
require.NoError(t, err, "failed to add NAT rule")
stdout, stderr = runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr)
}

View File

@@ -0,0 +1,431 @@
package nftables
import (
"bytes"
"context"
"errors"
"fmt"
"net"
"net/netip"
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/manager"
)
const (
chainNameRouteingFw = "netbird-rt-fwd"
chainNameRoutingNat = "netbird-rt-nat"
userDataAcceptForwardRuleSrc = "frwacceptsrc"
userDataAcceptForwardRuleDst = "frwacceptdst"
loopbackInterface = "lo\x00"
)
// some presets for building nftable rules
var (
zeroXor = binaryutil.NativeEndian.PutUint32(0)
exprCounterAccept = []expr.Any{
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
}
errFilterTableNotFound = fmt.Errorf("nftables: 'filter' table not found")
)
type router struct {
ctx context.Context
stop context.CancelFunc
conn *nftables.Conn
workTable *nftables.Table
filterTable *nftables.Table
chains map[string]*nftables.Chain
// rules is useful to avoid duplicates and to get missing attributes that we don't have when adding new rules
rules map[string]*nftables.Rule
isDefaultFwdRulesEnabled bool
}
func newRouter(parentCtx context.Context, workTable *nftables.Table) (*router, error) {
ctx, cancel := context.WithCancel(parentCtx)
r := &router{
ctx: ctx,
stop: cancel,
conn: &nftables.Conn{},
workTable: workTable,
chains: make(map[string]*nftables.Chain),
rules: make(map[string]*nftables.Rule),
}
var err error
r.filterTable, err = r.loadFilterTable()
if err != nil {
if errors.Is(err, errFilterTableNotFound) {
log.Warnf("table 'filter' not found for forward rules")
} else {
return nil, err
}
}
err = r.cleanUpDefaultForwardRules()
if err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
}
err = r.createContainers()
if err != nil {
log.Errorf("failed to create containers for route: %s", err)
}
return r, err
}
func (r *router) RouteingFwChainName() string {
return chainNameRouteingFw
}
// ResetForwardRules cleans existing nftables default forward rules from the system
func (r *router) ResetForwardRules() {
err := r.cleanUpDefaultForwardRules()
if err != nil {
log.Errorf("failed to reset forward rules: %s", err)
}
}
func (r *router) loadFilterTable() (*nftables.Table, error) {
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil {
return nil, fmt.Errorf("nftables: unable to list tables: %v", err)
}
for _, table := range tables {
if table.Name == "filter" {
return table, nil
}
}
return nil, errFilterTableNotFound
}
func (r *router) createContainers() error {
r.chains[chainNameRouteingFw] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRouteingFw,
Table: r.workTable,
})
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRoutingNat,
Table: r.workTable,
Hooknum: nftables.ChainHookPostrouting,
Priority: nftables.ChainPriorityNATSource - 1,
Type: nftables.ChainTypeNAT,
})
// Add RETURN rule for loopback interface
loRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingNat],
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte(loopbackInterface),
},
&expr.Verdict{Kind: expr.VerdictReturn},
},
}
r.conn.InsertRule(loRule)
err := r.refreshRulesMap()
if err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
}
err = r.conn.Flush()
if err != nil {
return fmt.Errorf("nftables: unable to initialize table: %v", err)
}
return nil
}
// AddRoutingRules appends a nftable rule pair to the forwarding chain and if enabled, to the nat chain
func (r *router) AddRoutingRules(pair manager.RouterPair) error {
err := r.refreshRulesMap()
if err != nil {
return err
}
err = r.addRoutingRule(manager.ForwardingFormat, chainNameRouteingFw, pair, false)
if err != nil {
return err
}
err = r.addRoutingRule(manager.InForwardingFormat, chainNameRouteingFw, manager.GetInPair(pair), false)
if err != nil {
return err
}
if pair.Masquerade {
err = r.addRoutingRule(manager.NatFormat, chainNameRoutingNat, pair, true)
if err != nil {
return err
}
err = r.addRoutingRule(manager.InNatFormat, chainNameRoutingNat, manager.GetInPair(pair), true)
if err != nil {
return err
}
}
if r.filterTable != nil && !r.isDefaultFwdRulesEnabled {
log.Debugf("add default accept forward rule")
r.acceptForwardRule(pair.Source)
}
err = r.conn.Flush()
if err != nil {
return fmt.Errorf("nftables: unable to insert rules for %s: %v", pair.Destination, err)
}
return nil
}
// addRoutingRule inserts a nftable rule to the conn client flush queue
func (r *router) addRoutingRule(format, chainName string, pair manager.RouterPair, isNat bool) error {
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
var expression []expr.Any
if isNat {
expression = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) // nolint:gocritic
} else {
expression = append(sourceExp, append(destExp, exprCounterAccept...)...) // nolint:gocritic
}
ruleKey := manager.GenKey(format, pair.ID)
_, exists := r.rules[ruleKey]
if exists {
err := r.removeRoutingRule(format, pair)
if err != nil {
return err
}
}
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainName],
Exprs: expression,
UserData: []byte(ruleKey),
})
return nil
}
func (r *router) acceptForwardRule(sourceNetwork string) {
src := generateCIDRMatcherExpressions(true, sourceNetwork)
dst := generateCIDRMatcherExpressions(false, "0.0.0.0/0")
var exprs []expr.Any
exprs = append(src, append(dst, &expr.Verdict{ // nolint:gocritic
Kind: expr.VerdictAccept,
})...)
rule := &nftables.Rule{
Table: r.filterTable,
Chain: &nftables.Chain{
Name: "FORWARD",
Table: r.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
},
Exprs: exprs,
UserData: []byte(userDataAcceptForwardRuleSrc),
}
r.conn.AddRule(rule)
src = generateCIDRMatcherExpressions(true, "0.0.0.0/0")
dst = generateCIDRMatcherExpressions(false, sourceNetwork)
exprs = append(src, append(dst, &expr.Verdict{ //nolint:gocritic
Kind: expr.VerdictAccept,
})...)
rule = &nftables.Rule{
Table: r.filterTable,
Chain: &nftables.Chain{
Name: "FORWARD",
Table: r.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
},
Exprs: exprs,
UserData: []byte(userDataAcceptForwardRuleDst),
}
r.conn.AddRule(rule)
r.isDefaultFwdRulesEnabled = true
}
// RemoveRoutingRules removes a nftable rule pair from forwarding and nat chains
func (r *router) RemoveRoutingRules(pair manager.RouterPair) error {
err := r.refreshRulesMap()
if err != nil {
return err
}
err = r.removeRoutingRule(manager.ForwardingFormat, pair)
if err != nil {
return err
}
err = r.removeRoutingRule(manager.InForwardingFormat, manager.GetInPair(pair))
if err != nil {
return err
}
err = r.removeRoutingRule(manager.NatFormat, pair)
if err != nil {
return err
}
err = r.removeRoutingRule(manager.InNatFormat, manager.GetInPair(pair))
if err != nil {
return err
}
if len(r.rules) == 0 {
err := r.cleanUpDefaultForwardRules()
if err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
}
}
err = r.conn.Flush()
if err != nil {
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err)
}
log.Debugf("nftables: removed rules for %s", pair.Destination)
return nil
}
// removeRoutingRule add a nftable rule to the removal queue and delete from rules map
func (r *router) removeRoutingRule(format string, pair manager.RouterPair) error {
ruleKey := manager.GenKey(format, pair.ID)
rule, found := r.rules[ruleKey]
if found {
ruleType := "forwarding"
if rule.Chain.Type == nftables.ChainTypeNAT {
ruleType = "nat"
}
err := r.conn.DelRule(rule)
if err != nil {
return fmt.Errorf("nftables: unable to remove %s rule for %s: %v", ruleType, pair.Destination, err)
}
log.Debugf("nftables: removing %s rule for %s", ruleType, pair.Destination)
delete(r.rules, ruleKey)
}
return nil
}
// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid
// duplicates and to get missing attributes that we don't have when adding new rules
func (r *router) refreshRulesMap() error {
for _, chain := range r.chains {
rules, err := r.conn.GetRules(chain.Table, chain)
if err != nil {
return fmt.Errorf("nftables: unable to list rules: %v", err)
}
for _, rule := range rules {
if len(rule.UserData) > 0 {
r.rules[string(rule.UserData)] = rule
}
}
}
return nil
}
func (r *router) cleanUpDefaultForwardRules() error {
if r.filterTable == nil {
r.isDefaultFwdRulesEnabled = false
return nil
}
chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
if err != nil {
return err
}
var rules []*nftables.Rule
for _, chain := range chains {
if chain.Table.Name != r.filterTable.Name {
continue
}
if chain.Name != "FORWARD" {
continue
}
rules, err = r.conn.GetRules(r.filterTable, chain)
if err != nil {
return err
}
}
for _, rule := range rules {
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleSrc)) || bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleDst)) {
err := r.conn.DelRule(rule)
if err != nil {
return err
}
}
}
r.isDefaultFwdRulesEnabled = false
return r.conn.Flush()
}
// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR
func generateCIDRMatcherExpressions(source bool, cidr string) []expr.Any {
ip, network, _ := net.ParseCIDR(cidr)
ipToAdd, _ := netip.AddrFromSlice(ip)
add := ipToAdd.Unmap()
var offSet uint32
if source {
offSet = 12 // src offset
} else {
offSet = 16 // dst offset
}
return []expr.Any{
// fetch src add
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: offSet,
Len: 4,
},
// net mask
&expr.Bitwise{
DestRegister: 1,
SourceRegister: 1,
Len: 4,
Mask: network.Mask,
Xor: zeroXor,
},
// net address
&expr.Cmp{
Register: 1,
Data: add.AsSlice(),
},
}
}

View File

@@ -1,989 +0,0 @@
package nftables
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"net"
"net/netip"
"strings"
"github.com/coreos/go-iptables/iptables"
"github.com/davecgh/go-spew/spew"
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
nbnet "github.com/netbirdio/netbird/util/net"
)
const (
chainNameRoutingFw = "netbird-rt-fwd"
chainNameRoutingNat = "netbird-rt-postrouting"
chainNameForward = "FORWARD"
userDataAcceptForwardRuleIif = "frwacceptiif"
userDataAcceptForwardRuleOif = "frwacceptoif"
)
const refreshRulesMapError = "refresh rules map: %w"
var (
errFilterTableNotFound = fmt.Errorf("nftables: 'filter' table not found")
)
type router struct {
conn *nftables.Conn
workTable *nftables.Table
filterTable *nftables.Table
chains map[string]*nftables.Chain
// rules is useful to avoid duplicates and to get missing attributes that we don't have when adding new rules
rules map[string]*nftables.Rule
ipsetCounter *refcounter.Counter[string, []netip.Prefix, *nftables.Set]
wgIface iFaceMapper
legacyManagement bool
}
func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error) {
r := &router{
conn: &nftables.Conn{},
workTable: workTable,
chains: make(map[string]*nftables.Chain),
rules: make(map[string]*nftables.Rule),
wgIface: wgIface,
}
r.ipsetCounter = refcounter.New(
r.createIpSet,
r.deleteIpSet,
)
var err error
r.filterTable, err = r.loadFilterTable()
if err != nil {
if errors.Is(err, errFilterTableNotFound) {
log.Warnf("table 'filter' not found for forward rules")
} else {
return nil, fmt.Errorf("load filter table: %w", err)
}
}
return r, nil
}
func (r *router) init(workTable *nftables.Table) error {
r.workTable = workTable
if err := r.removeAcceptForwardRules(); err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
}
if err := r.createContainers(); err != nil {
return fmt.Errorf("create containers: %w", err)
}
return nil
}
// Reset cleans existing nftables default forward rules from the system
func (r *router) Reset() error {
// clear without deleting the ipsets, the nf table will be deleted by the caller
r.ipsetCounter.Clear()
return r.removeAcceptForwardRules()
}
func (r *router) loadFilterTable() (*nftables.Table, error) {
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil {
return nil, fmt.Errorf("nftables: unable to list tables: %v", err)
}
for _, table := range tables {
if table.Name == "filter" {
return table, nil
}
}
return nil, errFilterTableNotFound
}
func (r *router) createContainers() error {
r.chains[chainNameRoutingFw] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRoutingFw,
Table: r.workTable,
})
insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw])
prio := *nftables.ChainPriorityNATSource - 1
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRoutingNat,
Table: r.workTable,
Hooknum: nftables.ChainHookPostrouting,
Priority: &prio,
Type: nftables.ChainTypeNAT,
})
// Chain is created by acl manager
// TODO: move creation to a common place
r.chains[chainNamePrerouting] = &nftables.Chain{
Name: chainNamePrerouting,
Table: r.workTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityMangle,
}
// Add the single NAT rule that matches on mark
if err := r.addPostroutingRules(); err != nil {
return fmt.Errorf("add single nat rule: %v", err)
}
if err := r.acceptForwardRules(); err != nil {
log.Errorf("failed to add accept rules for the forward chain: %s", err)
}
if err := r.refreshRulesMap(); err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("nftables: unable to initialize table: %v", err)
}
return nil
}
// AddRouteFiltering appends a nftables rule to the routing chain
func (r *router) AddRouteFiltering(
sources []netip.Prefix,
destination netip.Prefix,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
if _, ok := r.rules[string(ruleKey)]; ok {
return ruleKey, nil
}
chain := r.chains[chainNameRoutingFw]
var exprs []expr.Any
switch {
case len(sources) == 1 && sources[0].Bits() == 0:
// If it's 0.0.0.0/0, we don't need to add any source matching
case len(sources) == 1:
// If there's only one source, we can use it directly
exprs = append(exprs, generateCIDRMatcherExpressions(true, sources[0])...)
default:
// If there are multiple sources, create or get an ipset
var err error
exprs, err = r.getIpSetExprs(sources, exprs)
if err != nil {
return nil, fmt.Errorf("get ipset expressions: %w", err)
}
}
// Handle destination
exprs = append(exprs, generateCIDRMatcherExpressions(false, destination)...)
// Handle protocol
if proto != firewall.ProtocolALL {
protoNum, err := protoToInt(proto)
if err != nil {
return nil, fmt.Errorf("convert protocol to number: %w", err)
}
exprs = append(exprs, &expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1})
exprs = append(exprs, &expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{protoNum},
})
exprs = append(exprs, applyPort(sPort, true)...)
exprs = append(exprs, applyPort(dPort, false)...)
}
exprs = append(exprs, &expr.Counter{})
var verdict expr.VerdictKind
if action == firewall.ActionAccept {
verdict = expr.VerdictAccept
} else {
verdict = expr.VerdictDrop
}
exprs = append(exprs, &expr.Verdict{Kind: verdict})
rule := &nftables.Rule{
Table: r.workTable,
Chain: chain,
Exprs: exprs,
UserData: []byte(ruleKey),
}
rule = r.conn.AddRule(rule)
log.Tracef("Adding route rule %s", spew.Sdump(rule))
if err := r.conn.Flush(); err != nil {
return nil, fmt.Errorf(flushError, err)
}
r.rules[string(ruleKey)] = rule
log.Debugf("nftables: added route rule: sources=%v, destination=%v, proto=%v, sPort=%v, dPort=%v, action=%v", sources, destination, proto, sPort, dPort, action)
return ruleKey, nil
}
func (r *router) getIpSetExprs(sources []netip.Prefix, exprs []expr.Any) ([]expr.Any, error) {
setName := firewall.GenerateSetName(sources)
ref, err := r.ipsetCounter.Increment(setName, sources)
if err != nil {
return nil, fmt.Errorf("create or get ipset for sources: %w", err)
}
exprs = append(exprs,
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 12,
Len: 4,
},
&expr.Lookup{
SourceRegister: 1,
SetName: ref.Out.Name,
SetID: ref.Out.ID,
},
)
return exprs, nil
}
func (r *router) DeleteRouteRule(rule firewall.Rule) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
ruleKey := rule.GetRuleID()
nftRule, exists := r.rules[ruleKey]
if !exists {
log.Debugf("route rule %s not found", ruleKey)
return nil
}
if nftRule.Handle == 0 {
return fmt.Errorf("route rule %s has no handle", ruleKey)
}
setName := r.findSetNameInRule(nftRule)
if err := r.deleteNftRule(nftRule, ruleKey); err != nil {
return fmt.Errorf("delete: %w", err)
}
if setName != "" {
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
return fmt.Errorf("decrement ipset reference: %w", err)
}
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
return nil
}
func (r *router) createIpSet(setName string, sources []netip.Prefix) (*nftables.Set, error) {
// overlapping prefixes will result in an error, so we need to merge them
sources = firewall.MergeIPRanges(sources)
set := &nftables.Set{
Name: setName,
Table: r.workTable,
// required for prefixes
Interval: true,
KeyType: nftables.TypeIPAddr,
}
var elements []nftables.SetElement
for _, prefix := range sources {
// TODO: Implement IPv6 support
if prefix.Addr().Is6() {
log.Printf("Skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
continue
}
// nftables needs half-open intervals [firstIP, lastIP) for prefixes
// e.g. 10.0.0.0/24 becomes [10.0.0.0, 10.0.1.0), 10.1.1.1/32 becomes [10.1.1.1, 10.1.1.2) etc
firstIP := prefix.Addr()
lastIP := calculateLastIP(prefix).Next()
elements = append(elements,
// the nft tool also adds a line like this, see https://github.com/google/nftables/issues/247
// nftables.SetElement{Key: []byte{0, 0, 0, 0}, IntervalEnd: true},
nftables.SetElement{Key: firstIP.AsSlice()},
nftables.SetElement{Key: lastIP.AsSlice(), IntervalEnd: true},
)
}
if err := r.conn.AddSet(set, elements); err != nil {
return nil, fmt.Errorf("error adding elements to set %s: %w", setName, err)
}
if err := r.conn.Flush(); err != nil {
return nil, fmt.Errorf("flush error: %w", err)
}
log.Printf("Created new ipset: %s with %d elements", setName, len(elements)/2)
return set, nil
}
// calculateLastIP determines the last IP in a given prefix.
func calculateLastIP(prefix netip.Prefix) netip.Addr {
hostMask := ^uint32(0) >> prefix.Masked().Bits()
lastIP := uint32FromNetipAddr(prefix.Addr()) | hostMask
return netip.AddrFrom4(uint32ToBytes(lastIP))
}
// Utility function to convert netip.Addr to uint32.
func uint32FromNetipAddr(addr netip.Addr) uint32 {
b := addr.As4()
return binary.BigEndian.Uint32(b[:])
}
// Utility function to convert uint32 to a netip-compatible byte slice.
func uint32ToBytes(ip uint32) [4]byte {
var b [4]byte
binary.BigEndian.PutUint32(b[:], ip)
return b
}
func (r *router) deleteIpSet(setName string, set *nftables.Set) error {
r.conn.DelSet(set)
if err := r.conn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
log.Debugf("Deleted unused ipset %s", setName)
return nil
}
func (r *router) findSetNameInRule(rule *nftables.Rule) string {
for _, e := range rule.Exprs {
if lookup, ok := e.(*expr.Lookup); ok {
return lookup.SetName
}
}
return ""
}
func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete rule %s: %w", ruleKey, err)
}
delete(r.rules, ruleKey)
log.Debugf("removed route rule %s", ruleKey)
return nil
}
// AddNatRule appends a nftables rule pair to the nat chain
func (r *router) AddNatRule(pair firewall.RouterPair) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
if r.legacyManagement {
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
if err := r.addLegacyRouteRule(pair); err != nil {
return fmt.Errorf("add legacy routing rule: %w", err)
}
}
if pair.Masquerade {
if err := r.addNatRule(pair); err != nil {
return fmt.Errorf("add nat rule: %w", err)
}
if err := r.addNatRule(firewall.GetInversePair(pair)); err != nil {
return fmt.Errorf("add inverse nat rule: %w", err)
}
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("nftables: insert rules for %s: %v", pair.Destination, err)
}
return nil
}
// addNatRule inserts a nftables rule to the conn client flush queue
func (r *router) addNatRule(pair firewall.RouterPair) error {
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
op := expr.CmpOpEq
if pair.Inverse {
op = expr.CmpOpNeq
}
exprs := []expr.Any{
// We only care about NEW connections to mark them and later identify them in the postrouting chain for masquerading.
// Masquerading will take care of the conntrack state, which means we won't need to mark established connections.
&expr.Ct{
Key: expr.CtKeySTATE,
Register: 1,
},
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW),
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: []byte{0, 0, 0, 0},
},
// interface matching
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
&expr.Cmp{
Op: op,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
}
exprs = append(exprs, sourceExp...)
exprs = append(exprs, destExp...)
var markValue uint32 = nbnet.PreroutingFwmarkMasquerade
if pair.Inverse {
markValue = nbnet.PreroutingFwmarkMasqueradeReturn
}
exprs = append(exprs,
&expr.Immediate{
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(markValue),
},
&expr.Meta{
Key: expr.MetaKeyMARK,
SourceRegister: true,
Register: 1,
},
)
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
if _, exists := r.rules[ruleKey]; exists {
if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove prerouting rule: %w", err)
}
}
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNamePrerouting],
Exprs: exprs,
UserData: []byte(ruleKey),
})
return nil
}
// addPostroutingRules adds the masquerade rules
func (r *router) addPostroutingRules() error {
// First masquerade rule for traffic coming in from WireGuard interface
exprs := []expr.Any{
// Match on the first fwmark
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkMasquerade),
},
// We need to exclude the loopback interface as this changes the ebpf proxy port
&expr.Meta{
Key: expr.MetaKeyOIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: ifname("lo"),
},
&expr.Counter{},
&expr.Masq{},
}
r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingNat],
Exprs: exprs,
})
// Second masquerade rule for traffic going out through WireGuard interface
exprs2 := []expr.Any{
// Match on the second fwmark
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkMasqueradeReturn),
},
// Match WireGuard interface
&expr.Meta{
Key: expr.MetaKeyOIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
&expr.Counter{},
&expr.Masq{},
}
r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingNat],
Exprs: exprs2,
})
return nil
}
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
exprs := []expr.Any{
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
}
expression := append(sourceExp, append(destExp, exprs...)...) // nolint:gocritic
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
if _, exists := r.rules[ruleKey]; exists {
if err := r.removeLegacyRouteRule(pair); err != nil {
return fmt.Errorf("remove legacy routing rule: %w", err)
}
}
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingFw],
Exprs: expression,
UserData: []byte(ruleKey),
})
return nil
}
// removeLegacyRouteRule removes a legacy routing rule for mgmt servers pre route acls
func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
if rule, exists := r.rules[ruleKey]; exists {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
}
log.Debugf("nftables: removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleKey)
} else {
log.Debugf("nftables: legacy forwarding rule %s not found", ruleKey)
}
return nil
}
// GetLegacyManagement returns the route manager's legacy management mode
func (r *router) GetLegacyManagement() bool {
return r.legacyManagement
}
// SetLegacyManagement sets the route manager to use legacy management mode
func (r *router) SetLegacyManagement(isLegacy bool) {
r.legacyManagement = isLegacy
}
// RemoveAllLegacyRouteRules removes all legacy routing rules for mgmt servers pre route acls
func (r *router) RemoveAllLegacyRouteRules() error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
var merr *multierror.Error
for k, rule := range r.rules {
if !strings.HasPrefix(k, firewall.ForwardingFormatPrefix) {
continue
}
if err := r.conn.DelRule(rule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err))
} else {
delete(r.rules, k)
}
}
return nberrors.FormatErrorOrNil(merr)
}
// acceptForwardRules adds iif/oif rules in the filter table/forward chain to make sure
// that our traffic is not dropped by existing rules there.
// The existing FORWARD rules/policies decide outbound traffic towards our interface.
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
func (r *router) acceptForwardRules() error {
if r.filterTable == nil {
log.Debugf("table 'filter' not found for forward rules, skipping accept rules")
return nil
}
fw := "iptables"
defer func() {
log.Debugf("Used %s to add accept forward rules", fw)
}()
// Try iptables first and fallback to nftables if iptables is not available
ipt, err := iptables.New()
if err != nil {
// filter table exists but iptables is not
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
fw = "nftables"
return r.acceptForwardRulesNftables()
}
return r.acceptForwardRulesIptables(ipt)
}
func (r *router) acceptForwardRulesIptables(ipt *iptables.IPTables) error {
var merr *multierror.Error
for _, rule := range r.getAcceptForwardRules() {
if err := ipt.Insert("filter", chainNameForward, 1, rule...); err != nil {
merr = multierror.Append(err, fmt.Errorf("add iptables rule: %v", err))
} else {
log.Debugf("added iptables rule: %v", rule)
}
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) getAcceptForwardRules() [][]string {
intf := r.wgIface.Name()
return [][]string{
{"-i", intf, "-j", "ACCEPT"},
{"-o", intf, "-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"},
}
}
func (r *router) acceptForwardRulesNftables() error {
intf := ifname(r.wgIface.Name())
// Rule for incoming interface (iif) with counter
iifRule := &nftables.Rule{
Table: r.filterTable,
Chain: &nftables.Chain{
Name: chainNameForward,
Table: r.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
},
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: intf,
},
&expr.Counter{},
&expr.Verdict{Kind: expr.VerdictAccept},
},
UserData: []byte(userDataAcceptForwardRuleIif),
}
r.conn.InsertRule(iifRule)
oifExprs := []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: intf,
},
}
// Rule for outgoing interface (oif) with counter
oifRule := &nftables.Rule{
Table: r.filterTable,
Chain: &nftables.Chain{
Name: "FORWARD",
Table: r.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
},
Exprs: append(oifExprs, getEstablishedExprs(2)...),
UserData: []byte(userDataAcceptForwardRuleOif),
}
r.conn.InsertRule(oifRule)
return nil
}
func (r *router) removeAcceptForwardRules() error {
if r.filterTable == nil {
return nil
}
// Try iptables first and fallback to nftables if iptables is not available
ipt, err := iptables.New()
if err != nil {
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
return r.removeAcceptForwardRulesNftables()
}
return r.removeAcceptForwardRulesIptables(ipt)
}
func (r *router) removeAcceptForwardRulesNftables() error {
chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
if err != nil {
return fmt.Errorf("list chains: %v", err)
}
for _, chain := range chains {
if chain.Table.Name != r.filterTable.Name || chain.Name != chainNameForward {
continue
}
rules, err := r.conn.GetRules(r.filterTable, chain)
if err != nil {
return fmt.Errorf("get rules: %v", err)
}
for _, rule := range rules {
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete rule: %v", err)
}
}
}
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
return nil
}
func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error {
var merr *multierror.Error
for _, rule := range r.getAcceptForwardRules() {
if err := ipt.DeleteIfExists("filter", chainNameForward, rule...); err != nil {
merr = multierror.Append(err, fmt.Errorf("remove iptables rule: %v", err))
}
}
return nberrors.FormatErrorOrNil(merr)
}
// RemoveNatRule removes the prerouting mark rule
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove prerouting rule: %w", err)
}
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
return fmt.Errorf("remove inverse prerouting rule: %w", err)
}
if err := r.removeLegacyRouteRule(pair); err != nil {
return fmt.Errorf("remove legacy routing rule: %w", err)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err)
}
log.Debugf("nftables: removed nat rules for %s", pair.Destination)
return nil
}
func (r *router) removeNatRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
if rule, exists := r.rules[ruleKey]; exists {
err := r.conn.DelRule(rule)
if err != nil {
return fmt.Errorf("remove prerouting rule %s -> %s: %v", pair.Source, pair.Destination, err)
}
log.Debugf("nftables: removed prerouting rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleKey)
} else {
log.Debugf("nftables: prerouting rule %s not found", ruleKey)
}
return nil
}
// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid
// duplicates and to get missing attributes that we don't have when adding new rules
func (r *router) refreshRulesMap() error {
for _, chain := range r.chains {
rules, err := r.conn.GetRules(chain.Table, chain)
if err != nil {
return fmt.Errorf("nftables: unable to list rules: %v", err)
}
for _, rule := range rules {
if len(rule.UserData) > 0 {
r.rules[string(rule.UserData)] = rule
}
}
}
return nil
}
// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR
func generateCIDRMatcherExpressions(source bool, prefix netip.Prefix) []expr.Any {
var offset uint32
if source {
offset = 12 // src offset
} else {
offset = 16 // dst offset
}
ones := prefix.Bits()
// 0.0.0.0/0 doesn't need extra expressions
if ones == 0 {
return nil
}
mask := net.CIDRMask(ones, 32)
return []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: offset,
Len: 4,
},
// netmask
&expr.Bitwise{
DestRegister: 1,
SourceRegister: 1,
Len: 4,
Mask: mask,
Xor: []byte{0, 0, 0, 0},
},
// net address
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: prefix.Masked().Addr().AsSlice(),
},
}
}
func applyPort(port *firewall.Port, isSource bool) []expr.Any {
if port == nil {
return nil
}
var exprs []expr.Any
offset := uint32(2) // Default offset for destination port
if isSource {
offset = 0 // Offset for source port
}
exprs = append(exprs, &expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: offset,
Len: 2,
})
if port.IsRange && len(port.Values) == 2 {
// Handle port range
exprs = append(exprs,
&expr.Cmp{
Op: expr.CmpOpGte,
Register: 1,
Data: binaryutil.BigEndian.PutUint16(uint16(port.Values[0])),
},
&expr.Cmp{
Op: expr.CmpOpLte,
Register: 1,
Data: binaryutil.BigEndian.PutUint16(uint16(port.Values[1])),
},
)
} else {
// Handle single port or multiple ports
for i, p := range port.Values {
if i > 0 {
// Add a bitwise OR operation between port checks
exprs = append(exprs, &expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: []byte{0x00, 0x00, 0xff, 0xff},
Xor: []byte{0x00, 0x00, 0x00, 0x00},
})
}
exprs = append(exprs, &expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.BigEndian.PutUint16(uint16(p)),
})
}
}
return exprs
}

View File

@@ -3,16 +3,12 @@
package nftables package nftables
import ( import (
"encoding/binary" "context"
"net/netip"
"os/exec"
"testing" "testing"
"github.com/coreos/go-iptables/iptables" "github.com/coreos/go-iptables/iptables"
"github.com/google/nftables" "github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr" "github.com/google/nftables/expr"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
@@ -28,629 +24,197 @@ const (
NFTABLES NFTABLES
) )
func TestNftablesManager_AddNatRule(t *testing.T) { func TestNftablesManager_InsertRoutingRules(t *testing.T) {
if check() != NFTABLES { if check() != NFTABLES {
t.Skip("nftables not supported on this OS") t.Skip("nftables not supported on this OS")
} }
table, err := createWorkTable()
if err != nil {
t.Fatal(err)
}
defer deleteWorkTable()
for _, testCase := range test.InsertRuleTestCases { for _, testCase := range test.InsertRuleTestCases {
t.Run(testCase.Name, func(t *testing.T) { t.Run(testCase.Name, func(t *testing.T) {
// need fw manager to init both acl mgr and router for all chains to be present manager, err := newRouter(context.TODO(), table)
manager, err := Create(ifaceMock) require.NoError(t, err, "failed to create router")
t.Cleanup(func() {
require.NoError(t, manager.Reset(nil))
})
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
nftablesTestingClient := &nftables.Conn{} nftablesTestingClient := &nftables.Conn{}
rtr := manager.router defer manager.ResetForwardRules()
err = rtr.AddNatRule(testCase.InputPair)
require.NoError(t, err, "pair should be inserted")
t.Cleanup(func() { require.NoError(t, err, "shouldn't return error")
require.NoError(t, rtr.RemoveNatRule(testCase.InputPair), "failed to remove rule")
})
if testCase.InputPair.Masquerade { err = manager.AddRoutingRules(testCase.InputPair)
// Build expected expressions for connection tracking defer func() {
conntrackExprs := []expr.Any{ _ = manager.RemoveRoutingRules(testCase.InputPair)
&expr.Ct{ }()
Key: expr.CtKeySTATE, require.NoError(t, err, "forwarding pair should be inserted")
Register: 1,
},
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW),
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: []byte{0, 0, 0, 0},
},
}
// Build interface matching expression
ifaceExprs := []expr.Any{
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(ifaceMock.Name()),
},
}
// Build CIDR matching expressions
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source) sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination) destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
testingExpression := append(sourceExp, destExp...) //nolint:gocritic
fwdRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
// Combine all expressions in the correct order
// nolint:gocritic
testingExpression := append(conntrackExprs, ifaceExprs...)
testingExpression = append(testingExpression, sourceExp...)
testingExpression = append(testingExpression, destExp...)
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
found := 0 found := 0
for _, chain := range rtr.chains { for _, chain := range manager.chains {
if chain.Name == chainNamePrerouting {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain) rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name) require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules { for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey { if len(rule.UserData) > 0 && string(rule.UserData) == fwdRuleKey {
// Compare expressions up to the mark setting expressions require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "forwarding rule elements should match")
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "prerouting nat rule elements should match")
found = 1 found = 1
} }
} }
} }
require.Equal(t, 1, found, "should find at least 1 rule to test")
if testCase.InputPair.Masquerade {
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
found := 0
for _, chain := range manager.chains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "nat rule elements should match")
found = 1
} }
require.Equal(t, 1, found, "should find at least 1 rule in prerouting chain") }
}
require.Equal(t, 1, found, "should find at least 1 rule to test")
}
sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInPair(testCase.InputPair).Source)
destExp = generateCIDRMatcherExpressions(false, firewall.GetInPair(testCase.InputPair).Destination)
testingExpression = append(sourceExp, destExp...) //nolint:gocritic
inFwdRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
found = 0
for _, chain := range manager.chains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == inFwdRuleKey {
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income forwarding rule elements should match")
found = 1
}
}
}
require.Equal(t, 1, found, "should find at least 1 rule to test")
if testCase.InputPair.Masquerade {
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
found := 0
for _, chain := range manager.chains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == inNatRuleKey {
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income nat rule elements should match")
found = 1
}
}
}
require.Equal(t, 1, found, "should find at least 1 rule to test")
} }
}) })
} }
} }
func TestNftablesManager_RemoveNatRule(t *testing.T) { func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
if check() != NFTABLES { if check() != NFTABLES {
t.Skip("nftables not supported on this OS") t.Skip("nftables not supported on this OS")
} }
table, err := createWorkTable()
if err != nil {
t.Fatal(err)
}
defer deleteWorkTable()
for _, testCase := range test.RemoveRuleTestCases { for _, testCase := range test.RemoveRuleTestCases {
t.Run(testCase.Name, func(t *testing.T) { t.Run(testCase.Name, func(t *testing.T) {
manager, err := Create(ifaceMock) manager, err := newRouter(context.TODO(), table)
t.Cleanup(func() { require.NoError(t, err, "failed to create router")
require.NoError(t, manager.Reset(nil))
nftablesTestingClient := &nftables.Conn{}
defer manager.ResetForwardRules()
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
forwardExp := append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic
forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
insertedForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: manager.workTable,
Chain: manager.chains[chainNameRouteingFw],
Exprs: forwardExp,
UserData: []byte(forwardRuleKey),
}) })
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
rtr := manager.router natExp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
// First add the NAT rule using the router's method insertedNat := nftablesTestingClient.InsertRule(&nftables.Rule{
err = rtr.AddNatRule(testCase.InputPair) Table: manager.workTable,
require.NoError(t, err, "should add NAT rule") Chain: manager.chains[chainNameRoutingNat],
Exprs: natExp,
UserData: []byte(natRuleKey),
})
// Verify the rule was added sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInPair(testCase.InputPair).Source)
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair) destExp = generateCIDRMatcherExpressions(false, firewall.GetInPair(testCase.InputPair).Destination)
found := false
rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNamePrerouting]) forwardExp = append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic
require.NoError(t, err, "should list rules") inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
insertedInForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: manager.workTable,
Chain: manager.chains[chainNameRouteingFw],
Exprs: forwardExp,
UserData: []byte(inForwardRuleKey),
})
natExp = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
insertedInNat := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: manager.workTable,
Chain: manager.chains[chainNameRoutingNat],
Exprs: natExp,
UserData: []byte(inNatRuleKey),
})
err = nftablesTestingClient.Flush()
require.NoError(t, err, "shouldn't return error")
manager.ResetForwardRules()
err = manager.RemoveRoutingRules(testCase.InputPair)
require.NoError(t, err, "shouldn't return error")
for _, chain := range manager.chains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules { for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey { if len(rule.UserData) > 0 {
found = true require.NotEqual(t, insertedForwarding.UserData, rule.UserData, "forwarding rule should not exist")
break require.NotEqual(t, insertedNat.UserData, rule.UserData, "nat rule should not exist")
require.NotEqual(t, insertedInForwarding.UserData, rule.UserData, "income forwarding rule should not exist")
require.NotEqual(t, insertedInNat.UserData, rule.UserData, "income nat rule should not exist")
} }
} }
require.True(t, found, "NAT rule should exist before removal")
// Now remove the rule
err = rtr.RemoveNatRule(testCase.InputPair)
require.NoError(t, err, "shouldn't return error when removing rule")
// Verify the rule was removed
found = false
rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNamePrerouting])
require.NoError(t, err, "should list rules after removal")
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
found = true
break
} }
}
require.False(t, found, "NAT rule should not exist after removal")
// Verify the static postrouting rules still exist
rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameRoutingNat])
require.NoError(t, err, "should list postrouting rules")
foundCounter := false
for _, rule := range rules {
for _, e := range rule.Exprs {
if _, ok := e.(*expr.Counter); ok {
foundCounter = true
break
}
}
if foundCounter {
break
}
}
require.True(t, foundCounter, "static postrouting rule should remain")
}) })
} }
} }
func TestRouter_AddRouteFiltering(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
workTable, err := createWorkTable()
require.NoError(t, err, "Failed to create work table")
defer deleteWorkTable()
r, err := newRouter(workTable, ifaceMock)
require.NoError(t, err, "Failed to create router")
require.NoError(t, r.init(workTable))
defer func(r *router) {
require.NoError(t, r.Reset(), "Failed to reset rules")
}(r)
tests := []struct {
name string
sources []netip.Prefix
destination netip.Prefix
proto firewall.Protocol
sPort *firewall.Port
dPort *firewall.Port
direction firewall.RuleDirection
action firewall.Action
expectSet bool
}{
{
name: "Basic TCP rule with single source",
sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
destination: netip.MustParsePrefix("10.0.0.0/24"),
proto: firewall.ProtocolTCP,
sPort: nil,
dPort: &firewall.Port{Values: []int{80}},
direction: firewall.RuleDirectionIN,
action: firewall.ActionAccept,
expectSet: false,
},
{
name: "UDP rule with multiple sources",
sources: []netip.Prefix{
netip.MustParsePrefix("172.16.0.0/16"),
netip.MustParsePrefix("192.168.0.0/16"),
},
destination: netip.MustParsePrefix("10.0.0.0/8"),
proto: firewall.ProtocolUDP,
sPort: &firewall.Port{Values: []int{1024, 2048}, IsRange: true},
dPort: nil,
direction: firewall.RuleDirectionOUT,
action: firewall.ActionDrop,
expectSet: true,
},
{
name: "All protocols rule",
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
destination: netip.MustParsePrefix("0.0.0.0/0"),
proto: firewall.ProtocolALL,
sPort: nil,
dPort: nil,
direction: firewall.RuleDirectionIN,
action: firewall.ActionAccept,
expectSet: false,
},
{
name: "ICMP rule",
sources: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")},
destination: netip.MustParsePrefix("10.0.0.0/8"),
proto: firewall.ProtocolICMP,
sPort: nil,
dPort: nil,
direction: firewall.RuleDirectionIN,
action: firewall.ActionAccept,
expectSet: false,
},
{
name: "TCP rule with multiple source ports",
sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")},
destination: netip.MustParsePrefix("192.168.0.0/16"),
proto: firewall.ProtocolTCP,
sPort: &firewall.Port{Values: []int{80, 443, 8080}},
dPort: nil,
direction: firewall.RuleDirectionOUT,
action: firewall.ActionAccept,
expectSet: false,
},
{
name: "UDP rule with single IP and port range",
sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.1/32")},
destination: netip.MustParsePrefix("10.0.0.0/24"),
proto: firewall.ProtocolUDP,
sPort: nil,
dPort: &firewall.Port{Values: []int{5000, 5100}, IsRange: true},
direction: firewall.RuleDirectionIN,
action: firewall.ActionDrop,
expectSet: false,
},
{
name: "TCP rule with source and destination ports",
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")},
destination: netip.MustParsePrefix("172.16.0.0/16"),
proto: firewall.ProtocolTCP,
sPort: &firewall.Port{Values: []int{1024, 65535}, IsRange: true},
dPort: &firewall.Port{Values: []int{22}},
direction: firewall.RuleDirectionOUT,
action: firewall.ActionAccept,
expectSet: false,
},
{
name: "Drop all incoming traffic",
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
destination: netip.MustParsePrefix("192.168.0.0/24"),
proto: firewall.ProtocolALL,
sPort: nil,
dPort: nil,
direction: firewall.RuleDirectionIN,
action: firewall.ActionDrop,
expectSet: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
require.NoError(t, err, "AddRouteFiltering failed")
t.Cleanup(func() {
require.NoError(t, r.DeleteRouteRule(ruleKey), "Failed to delete rule")
})
// Check if the rule is in the internal map
rule, ok := r.rules[ruleKey.GetRuleID()]
assert.True(t, ok, "Rule not found in internal map")
t.Log("Internal rule expressions:")
for i, expr := range rule.Exprs {
t.Logf(" [%d] %T: %+v", i, expr, expr)
}
// Verify internal rule content
verifyRule(t, rule, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.direction, tt.action, tt.expectSet)
// Check if the rule exists in nftables and verify its content
rules, err := r.conn.GetRules(r.workTable, r.chains[chainNameRoutingFw])
require.NoError(t, err, "Failed to get rules from nftables")
var nftRule *nftables.Rule
for _, rule := range rules {
if string(rule.UserData) == ruleKey.GetRuleID() {
nftRule = rule
break
}
}
require.NotNil(t, nftRule, "Rule not found in nftables")
t.Log("Actual nftables rule expressions:")
for i, expr := range nftRule.Exprs {
t.Logf(" [%d] %T: %+v", i, expr, expr)
}
// Verify actual nftables rule content
verifyRule(t, nftRule, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.direction, tt.action, tt.expectSet)
})
}
}
func TestNftablesCreateIpSet(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
workTable, err := createWorkTable()
require.NoError(t, err, "Failed to create work table")
defer deleteWorkTable()
r, err := newRouter(workTable, ifaceMock)
require.NoError(t, err, "Failed to create router")
require.NoError(t, r.init(workTable))
defer func() {
require.NoError(t, r.Reset(), "Failed to reset router")
}()
tests := []struct {
name string
sources []netip.Prefix
expected []netip.Prefix
}{
{
name: "Single IP",
sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.1/32")},
},
{
name: "Multiple IPs",
sources: []netip.Prefix{
netip.MustParsePrefix("192.168.1.1/32"),
netip.MustParsePrefix("10.0.0.1/32"),
netip.MustParsePrefix("172.16.0.1/32"),
},
},
{
name: "Single Subnet",
sources: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")},
},
{
name: "Multiple Subnets with Various Prefix Lengths",
sources: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/8"),
netip.MustParsePrefix("172.16.0.0/16"),
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("203.0.113.0/26"),
},
},
{
name: "Mix of Single IPs and Subnets in Different Positions",
sources: []netip.Prefix{
netip.MustParsePrefix("192.168.1.1/32"),
netip.MustParsePrefix("10.0.0.0/16"),
netip.MustParsePrefix("172.16.0.1/32"),
netip.MustParsePrefix("203.0.113.0/24"),
},
},
{
name: "Overlapping IPs/Subnets",
sources: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/8"),
netip.MustParsePrefix("10.0.0.0/16"),
netip.MustParsePrefix("10.0.0.1/32"),
netip.MustParsePrefix("192.168.0.0/16"),
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("192.168.1.1/32"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/8"),
netip.MustParsePrefix("192.168.0.0/16"),
},
},
}
// Add this helper function inside TestNftablesCreateIpSet
printNftSets := func() {
cmd := exec.Command("nft", "list", "sets")
output, err := cmd.CombinedOutput()
if err != nil {
t.Logf("Failed to run 'nft list sets': %v", err)
} else {
t.Logf("Current nft sets:\n%s", output)
}
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
setName := firewall.GenerateSetName(tt.sources)
set, err := r.createIpSet(setName, tt.sources)
if err != nil {
t.Logf("Failed to create IP set: %v", err)
printNftSets()
require.NoError(t, err, "Failed to create IP set")
}
require.NotNil(t, set, "Created set is nil")
// Verify set properties
assert.Equal(t, setName, set.Name, "Set name mismatch")
assert.Equal(t, r.workTable, set.Table, "Set table mismatch")
assert.True(t, set.Interval, "Set interval property should be true")
assert.Equal(t, nftables.TypeIPAddr, set.KeyType, "Set key type mismatch")
// Fetch the created set from nftables
fetchedSet, err := r.conn.GetSetByName(r.workTable, setName)
require.NoError(t, err, "Failed to fetch created set")
require.NotNil(t, fetchedSet, "Fetched set is nil")
// Verify set elements
elements, err := r.conn.GetSetElements(fetchedSet)
require.NoError(t, err, "Failed to get set elements")
// Count the number of unique prefixes (excluding interval end markers)
uniquePrefixes := make(map[string]bool)
for _, elem := range elements {
if !elem.IntervalEnd {
ip := netip.AddrFrom4(*(*[4]byte)(elem.Key))
uniquePrefixes[ip.String()] = true
}
}
// Check against expected merged prefixes
expectedCount := len(tt.expected)
if expectedCount == 0 {
expectedCount = len(tt.sources)
}
assert.Equal(t, expectedCount, len(uniquePrefixes), "Number of unique prefixes in set doesn't match expected")
// Verify each expected prefix is in the set
for _, expected := range tt.expected {
found := false
for _, elem := range elements {
if !elem.IntervalEnd {
ip := netip.AddrFrom4(*(*[4]byte)(elem.Key))
if expected.Contains(ip) {
found = true
break
}
}
}
assert.True(t, found, "Expected prefix %s not found in set", expected)
}
r.conn.DelSet(set)
if err := r.conn.Flush(); err != nil {
t.Logf("Failed to delete set: %v", err)
printNftSets()
}
require.NoError(t, err, "Failed to delete set")
})
}
}
func verifyRule(t *testing.T, rule *nftables.Rule, sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, expectSet bool) {
t.Helper()
assert.NotNil(t, rule, "Rule should not be nil")
// Verify sources and destination
if expectSet {
assert.True(t, containsSetLookup(rule.Exprs), "Rule should contain set lookup for multiple sources")
} else if len(sources) == 1 && sources[0].Bits() != 0 {
if direction == firewall.RuleDirectionIN {
assert.True(t, containsCIDRMatcher(rule.Exprs, sources[0], true), "Rule should contain source CIDR matcher for %s", sources[0])
} else {
assert.True(t, containsCIDRMatcher(rule.Exprs, sources[0], false), "Rule should contain destination CIDR matcher for %s", sources[0])
}
}
if direction == firewall.RuleDirectionIN {
assert.True(t, containsCIDRMatcher(rule.Exprs, destination, false), "Rule should contain destination CIDR matcher for %s", destination)
} else {
assert.True(t, containsCIDRMatcher(rule.Exprs, destination, true), "Rule should contain source CIDR matcher for %s", destination)
}
// Verify protocol
if proto != firewall.ProtocolALL {
assert.True(t, containsProtocol(rule.Exprs, proto), "Rule should contain protocol matcher for %s", proto)
}
// Verify ports
if sPort != nil {
assert.True(t, containsPort(rule.Exprs, sPort, true), "Rule should contain source port matcher for %v", sPort)
}
if dPort != nil {
assert.True(t, containsPort(rule.Exprs, dPort, false), "Rule should contain destination port matcher for %v", dPort)
}
// Verify action
assert.True(t, containsAction(rule.Exprs, action), "Rule should contain correct action: %s", action)
}
func containsSetLookup(exprs []expr.Any) bool {
for _, e := range exprs {
if _, ok := e.(*expr.Lookup); ok {
return true
}
}
return false
}
func containsCIDRMatcher(exprs []expr.Any, prefix netip.Prefix, isSource bool) bool {
var offset uint32
if isSource {
offset = 12 // src offset
} else {
offset = 16 // dst offset
}
var payloadFound, bitwiseFound, cmpFound bool
for _, e := range exprs {
switch ex := e.(type) {
case *expr.Payload:
if ex.Base == expr.PayloadBaseNetworkHeader && ex.Offset == offset && ex.Len == 4 {
payloadFound = true
}
case *expr.Bitwise:
if ex.Len == 4 && len(ex.Mask) == 4 && len(ex.Xor) == 4 {
bitwiseFound = true
}
case *expr.Cmp:
if ex.Op == expr.CmpOpEq && len(ex.Data) == 4 {
cmpFound = true
}
}
}
return (payloadFound && bitwiseFound && cmpFound) || prefix.Bits() == 0
}
func containsPort(exprs []expr.Any, port *firewall.Port, isSource bool) bool {
var offset uint32 = 2 // Default offset for destination port
if isSource {
offset = 0 // Offset for source port
}
var payloadFound, portMatchFound bool
for _, e := range exprs {
switch ex := e.(type) {
case *expr.Payload:
if ex.Base == expr.PayloadBaseTransportHeader && ex.Offset == offset && ex.Len == 2 {
payloadFound = true
}
case *expr.Cmp:
if port.IsRange {
if ex.Op == expr.CmpOpGte || ex.Op == expr.CmpOpLte {
portMatchFound = true
}
} else {
if ex.Op == expr.CmpOpEq && len(ex.Data) == 2 {
portValue := binary.BigEndian.Uint16(ex.Data)
for _, p := range port.Values {
if uint16(p) == portValue {
portMatchFound = true
break
}
}
}
}
}
if payloadFound && portMatchFound {
return true
}
}
return false
}
func containsProtocol(exprs []expr.Any, proto firewall.Protocol) bool {
var metaFound, cmpFound bool
expectedProto, _ := protoToInt(proto)
for _, e := range exprs {
switch ex := e.(type) {
case *expr.Meta:
if ex.Key == expr.MetaKeyL4PROTO {
metaFound = true
}
case *expr.Cmp:
if ex.Op == expr.CmpOpEq && len(ex.Data) == 1 && ex.Data[0] == expectedProto {
cmpFound = true
}
}
}
return metaFound && cmpFound
}
func containsAction(exprs []expr.Any, action firewall.Action) bool {
for _, e := range exprs {
if verdict, ok := e.(*expr.Verdict); ok {
switch action {
case firewall.ActionAccept:
return verdict.Kind == expr.VerdictAccept
case firewall.ActionDrop:
return verdict.Kind == expr.VerdictDrop
}
}
}
return false
}
// check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found. // check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found.
func check() int { func check() int {
nf := nftables.Conn{} nf := nftables.Conn{}
@@ -686,12 +250,12 @@ func createWorkTable() (*nftables.Table, error) {
} }
for _, t := range tables { for _, t := range tables {
if t.Name == tableNameNetbird { if t.Name == tableName {
sConn.DelTable(t) sConn.DelTable(t)
} }
} }
table := sConn.AddTable(&nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4}) table := sConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4})
err = sConn.Flush() err = sConn.Flush()
return table, err return table, err
@@ -709,7 +273,7 @@ func deleteWorkTable() {
} }
for _, t := range tables { for _, t := range tables {
if t.Name == tableNameNetbird { if t.Name == tableName {
sConn.DelTable(t) sConn.DelTable(t)
} }
} }

View File

@@ -1,47 +0,0 @@
package nftables
import (
"fmt"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
)
type InterfaceState struct {
NameStr string `json:"name"`
WGAddress iface.WGAddress `json:"wg_address"`
UserspaceBind bool `json:"userspace_bind"`
}
func (i *InterfaceState) Name() string {
return i.NameStr
}
func (i *InterfaceState) Address() device.WGAddress {
return i.WGAddress
}
func (i *InterfaceState) IsUserspaceBind() bool {
return i.UserspaceBind
}
type ShutdownState struct {
InterfaceState *InterfaceState `json:"interface_state,omitempty"`
}
func (s *ShutdownState) Name() string {
return "nftables_state"
}
func (s *ShutdownState) Cleanup() error {
nft, err := Create(s.InterfaceState)
if err != nil {
return fmt.Errorf("create nftables manager: %w", err)
}
if err := nft.Reset(nil); err != nil {
return fmt.Errorf("reset nftables manager: %w", err)
}
return nil
}

View File

@@ -1,10 +1,8 @@
//go:build !android
package test package test
import ( import firewall "github.com/netbirdio/netbird/client/firewall/manager"
"net/netip"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
var ( var (
InsertRuleTestCases = []struct { InsertRuleTestCases = []struct {
@@ -15,8 +13,8 @@ var (
Name: "Insert Forwarding IPV4 Rule", Name: "Insert Forwarding IPV4 Rule",
InputPair: firewall.RouterPair{ InputPair: firewall.RouterPair{
ID: "zxa", ID: "zxa",
Source: netip.MustParsePrefix("100.100.100.1/32"), Source: "100.100.100.1/32",
Destination: netip.MustParsePrefix("100.100.200.0/24"), Destination: "100.100.200.0/24",
Masquerade: false, Masquerade: false,
}, },
}, },
@@ -24,8 +22,8 @@ var (
Name: "Insert Forwarding And Nat IPV4 Rules", Name: "Insert Forwarding And Nat IPV4 Rules",
InputPair: firewall.RouterPair{ InputPair: firewall.RouterPair{
ID: "zxa", ID: "zxa",
Source: netip.MustParsePrefix("100.100.100.1/32"), Source: "100.100.100.1/32",
Destination: netip.MustParsePrefix("100.100.200.0/24"), Destination: "100.100.200.0/24",
Masquerade: true, Masquerade: true,
}, },
}, },
@@ -40,8 +38,8 @@ var (
Name: "Remove Forwarding And Nat IPV4 Rules", Name: "Remove Forwarding And Nat IPV4 Rules",
InputPair: firewall.RouterPair{ InputPair: firewall.RouterPair{
ID: "zxa", ID: "zxa",
Source: netip.MustParsePrefix("100.100.100.1/32"), Source: "100.100.100.1/32",
Destination: netip.MustParsePrefix("100.100.200.0/24"), Destination: "100.100.200.0/24",
Masquerade: true, Masquerade: true,
}, },
}, },

View File

@@ -2,36 +2,16 @@
package uspfilter package uspfilter
import (
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
// Reset firewall to the default state // Reset firewall to the default state
func (m *Manager) Reset(stateManager *statemanager.Manager) error { func (m *Manager) Reset() error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
m.outgoingRules = make(map[string]RuleSet) m.outgoingRules = make(map[string]RuleSet)
m.incomingRules = make(map[string]RuleSet) m.incomingRules = make(map[string]RuleSet)
if m.udpTracker != nil {
m.udpTracker.Close()
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, nil)
}
if m.icmpTracker != nil {
m.icmpTracker.Close()
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, nil)
}
if m.tcpTracker != nil {
m.tcpTracker.Close()
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, nil)
}
if m.nativeFirewall != nil { if m.nativeFirewall != nil {
return m.nativeFirewall.Reset(stateManager) return m.nativeFirewall.Reset()
} }
return nil return nil
} }

View File

@@ -6,9 +6,6 @@ import (
"syscall" "syscall"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/internal/statemanager"
) )
type action string type action string
@@ -20,28 +17,13 @@ const (
) )
// Reset firewall to the default state // Reset firewall to the default state
func (m *Manager) Reset(*statemanager.Manager) error { func (m *Manager) Reset() error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
m.outgoingRules = make(map[string]RuleSet) m.outgoingRules = make(map[string]RuleSet)
m.incomingRules = make(map[string]RuleSet) m.incomingRules = make(map[string]RuleSet)
if m.udpTracker != nil {
m.udpTracker.Close()
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, nil)
}
if m.icmpTracker != nil {
m.icmpTracker.Close()
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, nil)
}
if m.tcpTracker != nil {
m.tcpTracker.Close()
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, nil)
}
if !isWindowsFirewallReachable() { if !isWindowsFirewallReachable() {
return nil return nil
} }

View File

@@ -1,138 +0,0 @@
// common.go
package conntrack
import (
"net"
"sync"
"sync/atomic"
"time"
)
// BaseConnTrack provides common fields and locking for all connection types
type BaseConnTrack struct {
sync.RWMutex
SourceIP net.IP
DestIP net.IP
SourcePort uint16
DestPort uint16
lastSeen atomic.Int64 // Unix nano for atomic access
established atomic.Bool
}
// these small methods will be inlined by the compiler
// UpdateLastSeen safely updates the last seen timestamp
func (b *BaseConnTrack) UpdateLastSeen() {
b.lastSeen.Store(time.Now().UnixNano())
}
// IsEstablished safely checks if connection is established
func (b *BaseConnTrack) IsEstablished() bool {
return b.established.Load()
}
// SetEstablished safely sets the established state
func (b *BaseConnTrack) SetEstablished(state bool) {
b.established.Store(state)
}
// GetLastSeen safely gets the last seen timestamp
func (b *BaseConnTrack) GetLastSeen() time.Time {
return time.Unix(0, b.lastSeen.Load())
}
// timeoutExceeded checks if the connection has exceeded the given timeout
func (b *BaseConnTrack) timeoutExceeded(timeout time.Duration) bool {
lastSeen := time.Unix(0, b.lastSeen.Load())
return time.Since(lastSeen) > timeout
}
// IPAddr is a fixed-size IP address to avoid allocations
type IPAddr [16]byte
// MakeIPAddr creates an IPAddr from net.IP
func MakeIPAddr(ip net.IP) (addr IPAddr) {
// Optimization: check for v4 first as it's more common
if ip4 := ip.To4(); ip4 != nil {
copy(addr[12:], ip4)
} else {
copy(addr[:], ip.To16())
}
return addr
}
// ConnKey uniquely identifies a connection
type ConnKey struct {
SrcIP IPAddr
DstIP IPAddr
SrcPort uint16
DstPort uint16
}
// makeConnKey creates a connection key
func makeConnKey(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) ConnKey {
return ConnKey{
SrcIP: MakeIPAddr(srcIP),
DstIP: MakeIPAddr(dstIP),
SrcPort: srcPort,
DstPort: dstPort,
}
}
// ValidateIPs checks if IPs match without allocation
func ValidateIPs(connIP IPAddr, pktIP net.IP) bool {
if ip4 := pktIP.To4(); ip4 != nil {
// Compare IPv4 addresses (last 4 bytes)
for i := 0; i < 4; i++ {
if connIP[12+i] != ip4[i] {
return false
}
}
return true
}
// Compare full IPv6 addresses
ip6 := pktIP.To16()
for i := 0; i < 16; i++ {
if connIP[i] != ip6[i] {
return false
}
}
return true
}
// PreallocatedIPs is a pool of IP byte slices to reduce allocations
type PreallocatedIPs struct {
sync.Pool
}
// NewPreallocatedIPs creates a new IP pool
func NewPreallocatedIPs() *PreallocatedIPs {
return &PreallocatedIPs{
Pool: sync.Pool{
New: func() interface{} {
ip := make(net.IP, 16)
return &ip
},
},
}
}
// Get retrieves an IP from the pool
func (p *PreallocatedIPs) Get() net.IP {
return *p.Pool.Get().(*net.IP)
}
// Put returns an IP to the pool
func (p *PreallocatedIPs) Put(ip net.IP) {
p.Pool.Put(&ip)
}
// copyIP copies an IP address efficiently
func copyIP(dst, src net.IP) {
if len(src) == 16 {
copy(dst, src)
} else {
// Handle IPv4
copy(dst[12:], src.To4())
}
}

View File

@@ -1,115 +0,0 @@
package conntrack
import (
"net"
"testing"
)
func BenchmarkIPOperations(b *testing.B) {
b.Run("MakeIPAddr", func(b *testing.B) {
ip := net.ParseIP("192.168.1.1")
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = MakeIPAddr(ip)
}
})
b.Run("ValidateIPs", func(b *testing.B) {
ip1 := net.ParseIP("192.168.1.1")
ip2 := net.ParseIP("192.168.1.1")
addr := MakeIPAddr(ip1)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = ValidateIPs(addr, ip2)
}
})
b.Run("IPPool", func(b *testing.B) {
pool := NewPreallocatedIPs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
ip := pool.Get()
pool.Put(ip)
}
})
}
func BenchmarkAtomicOperations(b *testing.B) {
conn := &BaseConnTrack{}
b.Run("UpdateLastSeen", func(b *testing.B) {
for i := 0; i < b.N; i++ {
conn.UpdateLastSeen()
}
})
b.Run("IsEstablished", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = conn.IsEstablished()
}
})
b.Run("SetEstablished", func(b *testing.B) {
for i := 0; i < b.N; i++ {
conn.SetEstablished(i%2 == 0)
}
})
b.Run("GetLastSeen", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = conn.GetLastSeen()
}
})
}
// Memory pressure tests
func BenchmarkMemoryPressure(b *testing.B) {
b.Run("TCPHighLoad", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, nil)
defer tracker.Close()
// Generate different IPs
srcIPs := make([]net.IP, 100)
dstIPs := make([]net.IP, 100)
for i := 0; i < 100; i++ {
srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256))
dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256))
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
srcIdx := i % len(srcIPs)
dstIdx := (i + 1) % len(dstIPs)
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, TCPSyn, nil)
// Simulate some valid inbound packets
if i%3 == 0 {
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), TCPAck, nil)
}
}
})
b.Run("UDPHighLoad", func(b *testing.B) {
tracker := NewUDPTracker(DefaultUDPTimeout, nil)
defer tracker.Close()
// Generate different IPs
srcIPs := make([]net.IP, 100)
dstIPs := make([]net.IP, 100)
for i := 0; i < 100; i++ {
srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256))
dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256))
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
srcIdx := i % len(srcIPs)
dstIdx := (i + 1) % len(dstIPs)
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, nil)
// Simulate some valid inbound packets
if i%3 == 0 {
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), nil)
}
}
})
}

View File

@@ -1,187 +0,0 @@
package conntrack
import (
"net"
"sync"
"time"
"github.com/google/gopacket/layers"
fw "github.com/netbirdio/netbird/client/firewall/manager"
)
const (
// DefaultICMPTimeout is the default timeout for ICMP connections
DefaultICMPTimeout = 30 * time.Second
// ICMPCleanupInterval is how often we check for stale ICMP connections
ICMPCleanupInterval = 15 * time.Second
)
// ICMPConnKey uniquely identifies an ICMP connection
type ICMPConnKey struct {
// Supports both IPv4 and IPv6
SrcIP [16]byte
DstIP [16]byte
Sequence uint16 // ICMP sequence number
ID uint16 // ICMP identifier
}
// ICMPConnTrack represents an ICMP connection state
type ICMPConnTrack struct {
BaseConnTrack
Sequence uint16
ID uint16
}
// ICMPTracker manages ICMP connection states
type ICMPTracker struct {
connections map[ICMPConnKey]*ICMPConnTrack
timeout time.Duration
cleanupTicker *time.Ticker
mutex sync.RWMutex
done chan struct{}
ipPool *PreallocatedIPs
stats *Stats
}
// NewICMPTracker creates a new ICMP connection tracker
func NewICMPTracker(timeout time.Duration, stats *Stats) *ICMPTracker {
if timeout == 0 {
timeout = DefaultICMPTimeout
}
tracker := &ICMPTracker{
connections: make(map[ICMPConnKey]*ICMPConnTrack),
timeout: timeout,
cleanupTicker: time.NewTicker(ICMPCleanupInterval),
done: make(chan struct{}),
ipPool: NewPreallocatedIPs(),
stats: stats,
}
go tracker.cleanupRoutine()
return tracker
}
// TrackOutbound records an outbound ICMP Echo Request
func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, packetData []byte) {
key := makeICMPKey(srcIP, dstIP, id, seq)
now := time.Now().UnixNano()
t.mutex.Lock()
conn, exists := t.connections[key]
if !exists {
srcIPCopy := t.ipPool.Get()
dstIPCopy := t.ipPool.Get()
copyIP(srcIPCopy, srcIP)
copyIP(dstIPCopy, dstIP)
conn = &ICMPConnTrack{
BaseConnTrack: BaseConnTrack{
SourceIP: srcIPCopy,
DestIP: dstIPCopy,
},
ID: id,
Sequence: seq,
}
conn.lastSeen.Store(now)
conn.established.Store(true)
t.connections[key] = conn
if t.stats != nil {
t.stats.TrackNewConnection(1, srcIP, dstIP, 0, 0, fw.DirectionOutbound)
}
}
t.mutex.Unlock()
if t.stats != nil {
key := makeConnKey(srcIP, dstIP, 0, 0)
t.stats.TrackPacket(1, false, uint64(len(packetData)), false, key)
}
conn.lastSeen.Store(now)
}
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, icmpType uint8, packetData []byte) bool {
switch icmpType {
case uint8(layers.ICMPv4TypeDestinationUnreachable),
uint8(layers.ICMPv4TypeTimeExceeded):
return true
case uint8(layers.ICMPv4TypeEchoReply):
// continue processing
default:
return false
}
key := makeICMPKey(dstIP, srcIP, id, seq)
t.mutex.RLock()
conn, exists := t.connections[key]
t.mutex.RUnlock()
if !exists {
return false
}
if conn.timeoutExceeded(t.timeout) {
return false
}
if t.stats != nil {
key := makeConnKey(srcIP, dstIP, 0, 0)
t.stats.TrackPacket(1, false, uint64(len(packetData)), true, key)
}
return conn.IsEstablished() &&
ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
conn.ID == id &&
conn.Sequence == seq
}
func (t *ICMPTracker) cleanupRoutine() {
for {
select {
case <-t.cleanupTicker.C:
t.cleanup()
case <-t.done:
return
}
}
}
func (t *ICMPTracker) cleanup() {
t.mutex.Lock()
defer t.mutex.Unlock()
for key, conn := range t.connections {
if conn.timeoutExceeded(t.timeout) {
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
delete(t.connections, key)
}
}
}
// Close stops the cleanup routine and releases resources
func (t *ICMPTracker) Close() {
t.cleanupTicker.Stop()
close(t.done)
t.mutex.Lock()
for _, conn := range t.connections {
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
}
t.connections = nil
t.mutex.Unlock()
}
// makeICMPKey creates an ICMP connection key
func makeICMPKey(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) ICMPConnKey {
return ICMPConnKey{
SrcIP: MakeIPAddr(srcIP),
DstIP: MakeIPAddr(dstIP),
ID: id,
Sequence: seq,
}
}

View File

@@ -1,39 +0,0 @@
package conntrack
import (
"net"
"testing"
)
func BenchmarkICMPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewICMPTracker(DefaultICMPTimeout, nil)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), uint16(i%65535), nil)
}
})
b.Run("IsValidInbound", func(b *testing.B) {
tracker := NewICMPTracker(DefaultICMPTimeout, nil)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
// Pre-populate some connections
for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), uint16(i), nil)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.IsValidInbound(dstIP, srcIP, uint16(i%1000), uint16(i%1000), 0, nil)
}
})
}

View File

@@ -1,172 +0,0 @@
package conntrack
import (
"net"
"slices"
"sync"
"sync/atomic"
"time"
fw "github.com/netbirdio/netbird/client/firewall/manager"
)
// Stats represents connection tracking statistics
type Stats struct {
TotalConnsCreated atomic.Uint64
TotalConnsTimedOut atomic.Uint64
TotalPacketsDropped atomic.Uint64
ActiveConns atomic.Int64
TCPConns atomic.Int64
UDPConns atomic.Int64
ICMPConns atomic.Int64
TCPStateStats struct {
SynReceived atomic.Uint64
Established atomic.Uint64
FinWait atomic.Uint64
TimeWait atomic.Uint64
InvalidStates atomic.Uint64
}
PacketStats struct {
TCPPackets atomic.Uint64
UDPPackets atomic.Uint64
ICMPPackets atomic.Uint64
}
flowMutex sync.RWMutex
flows map[ConnKey]*fw.FlowStats
}
// NewStats creates a new Stats instance
func NewStats() *Stats {
return &Stats{
flows: make(map[ConnKey]*fw.FlowStats),
}
}
// TrackNewConnection records a new connection
func (s *Stats) TrackNewConnection(proto uint8, srcIP net.IP, dstIP net.IP, srcPort, dstPort uint16, direction fw.Direction) {
s.TotalConnsCreated.Add(1)
s.ActiveConns.Add(1)
switch proto {
case 6: // TCP
s.TCPConns.Add(1)
case 17: // UDP
s.UDPConns.Add(1)
case 1: // ICMP
s.ICMPConns.Add(1)
}
flow := &fw.FlowStats{
StartTime: time.Now(),
LastSeen: time.Now(),
Protocol: proto,
Direction: direction,
SourceIP: slices.Clone(srcIP),
DestIP: slices.Clone(dstIP),
SourcePort: srcPort,
DestPort: dstPort,
}
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
s.flowMutex.Lock()
s.flows[key] = flow
s.flowMutex.Unlock()
}
// TrackConnectionClosed records a connection closure
func (s *Stats) TrackConnectionClosed(proto uint8, timedOut bool, key ConnKey) {
s.ActiveConns.Add(-1)
if timedOut {
s.TotalConnsTimedOut.Add(1)
}
switch proto {
case 6: // TCP
s.TCPConns.Add(-1)
case 17: // UDP
s.UDPConns.Add(-1)
case 1: // ICMP
s.ICMPConns.Add(-1)
}
s.flowMutex.Lock()
delete(s.flows, key)
s.flowMutex.Unlock()
}
// TrackPacket records packet statistics
func (s *Stats) TrackPacket(proto uint8, dropped bool, bytes uint64, isInbound bool, key ConnKey) {
if dropped {
s.TotalPacketsDropped.Add(1)
return
}
switch proto {
case 6: // TCP
s.PacketStats.TCPPackets.Add(1)
case 17: // UDP
s.PacketStats.UDPPackets.Add(1)
case 1: // ICMP
s.PacketStats.ICMPPackets.Add(1)
}
s.flowMutex.RLock()
if flow, exists := s.flows[key]; exists {
if isInbound {
flow.BytesIn.Add(bytes)
flow.PacketsIn.Add(1)
} else {
flow.BytesOut.Add(bytes)
flow.PacketsOut.Add(1)
}
flow.LastSeen = time.Now()
}
s.flowMutex.RUnlock()
}
// TrackTCPState updates TCP state statistics
func (s *Stats) TrackTCPState(newState TCPState) {
switch newState {
case TCPStateSynReceived:
s.TCPStateStats.SynReceived.Add(1)
case TCPStateEstablished:
s.TCPStateStats.Established.Add(1)
case TCPStateFinWait1, TCPStateFinWait2:
s.TCPStateStats.FinWait.Add(1)
case TCPStateTimeWait:
s.TCPStateStats.TimeWait.Add(1)
default:
s.TCPStateStats.InvalidStates.Add(1)
}
}
// GetFlowSnapshot returns a copy of current flow statistics if enabled
func (s *Stats) GetFlowSnapshot() []*fw.FlowStats {
s.flowMutex.RLock()
defer s.flowMutex.RUnlock()
snapshot := make([]*fw.FlowStats, 0, len(s.flows))
for _, flow := range s.flows {
snapshot = append(snapshot, flow.Clone())
}
return snapshot
}
// CleanupFlows removes flow entries older than the specified duration if enabled
func (s *Stats) CleanupFlows(maxAge time.Duration) {
threshold := time.Now().Add(-maxAge)
s.flowMutex.Lock()
defer s.flowMutex.Unlock()
for key, flow := range s.flows {
if flow.LastSeen.Before(threshold) {
delete(s.flows, key)
}
}
}

View File

@@ -1,399 +0,0 @@
package conntrack
// TODO: Send RST packets for invalid/timed-out connections
import (
"net"
"sync"
"time"
fw "github.com/netbirdio/netbird/client/firewall/manager"
)
const (
// MSL (Maximum Segment Lifetime) is typically 2 minutes
MSL = 2 * time.Minute
// TimeWaitTimeout (TIME-WAIT) should last 2*MSL
TimeWaitTimeout = 2 * MSL
)
const (
TCPSyn uint8 = 0x02
TCPAck uint8 = 0x10
TCPFin uint8 = 0x01
TCPRst uint8 = 0x04
TCPPush uint8 = 0x08
TCPUrg uint8 = 0x20
)
const (
// DefaultTCPTimeout is the default timeout for established TCP connections
DefaultTCPTimeout = 3 * time.Hour
// TCPHandshakeTimeout is timeout for TCP handshake completion
TCPHandshakeTimeout = 60 * time.Second
// TCPCleanupInterval is how often we check for stale connections
TCPCleanupInterval = 5 * time.Minute
)
// TCPState represents the state of a TCP connection
type TCPState int
const (
TCPStateNew TCPState = iota
TCPStateSynSent
TCPStateSynReceived
TCPStateEstablished
TCPStateFinWait1
TCPStateFinWait2
TCPStateClosing
TCPStateTimeWait
TCPStateCloseWait
TCPStateLastAck
TCPStateClosed
)
// TCPConnKey uniquely identifies a TCP connection
type TCPConnKey struct {
SrcIP [16]byte
DstIP [16]byte
SrcPort uint16
DstPort uint16
}
// TCPConnTrack represents a TCP connection state
type TCPConnTrack struct {
BaseConnTrack
State TCPState
}
// TCPTracker manages TCP connection states
type TCPTracker struct {
connections map[ConnKey]*TCPConnTrack
mutex sync.RWMutex
cleanupTicker *time.Ticker
done chan struct{}
timeout time.Duration
ipPool *PreallocatedIPs
stats *Stats
}
// NewTCPTracker creates a new TCP connection tracker
func NewTCPTracker(timeout time.Duration, stats *Stats) *TCPTracker {
tracker := &TCPTracker{
connections: make(map[ConnKey]*TCPConnTrack),
cleanupTicker: time.NewTicker(TCPCleanupInterval),
done: make(chan struct{}),
timeout: timeout,
ipPool: NewPreallocatedIPs(),
stats: stats,
}
go tracker.cleanupRoutine()
return tracker
}
// TrackOutbound processes an outbound TCP packet and updates connection state
func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8, packetData []byte) {
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
now := time.Now().UnixNano()
t.mutex.Lock()
conn, exists := t.connections[key]
if !exists {
srcIPCopy := t.ipPool.Get()
dstIPCopy := t.ipPool.Get()
copyIP(srcIPCopy, srcIP)
copyIP(dstIPCopy, dstIP)
conn = &TCPConnTrack{
BaseConnTrack: BaseConnTrack{
SourceIP: srcIPCopy,
DestIP: dstIPCopy,
SourcePort: srcPort,
DestPort: dstPort,
},
State: TCPStateNew,
}
conn.lastSeen.Store(now)
conn.established.Store(false)
t.connections[key] = conn
if t.stats != nil {
t.stats.TrackNewConnection(6, srcIP, dstIP, srcPort, dstPort, fw.DirectionOutbound)
}
}
t.mutex.Unlock()
conn.Lock()
oldState := conn.State
t.updateState(conn, flags, true)
if oldState != conn.State && t.stats != nil {
t.stats.TrackTCPState(conn.State)
}
conn.Unlock()
if t.stats != nil {
t.stats.TrackPacket(6, false, uint64(len(packetData)), false, key)
}
conn.lastSeen.Store(now)
}
// IsValidInbound checks if an inbound TCP packet matches a tracked connection
func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8, packetData []byte) bool {
if !isValidFlagCombination(flags) {
return false
}
// Handle new SYN packets
if flags&TCPSyn != 0 && flags&TCPAck == 0 {
key := makeConnKey(dstIP, srcIP, dstPort, srcPort)
t.mutex.Lock()
if _, exists := t.connections[key]; !exists {
// Use preallocated IPs
srcIPCopy := t.ipPool.Get()
dstIPCopy := t.ipPool.Get()
copyIP(srcIPCopy, dstIP)
copyIP(dstIPCopy, srcIP)
conn := &TCPConnTrack{
BaseConnTrack: BaseConnTrack{
SourceIP: srcIPCopy,
DestIP: dstIPCopy,
SourcePort: dstPort,
DestPort: srcPort,
},
State: TCPStateSynReceived,
}
conn.lastSeen.Store(time.Now().UnixNano())
conn.established.Store(false)
t.connections[key] = conn
}
t.mutex.Unlock()
if t.stats != nil {
t.stats.TrackPacket(6, false, uint64(len(packetData)), true, key)
}
return true
}
// Look up existing connection
key := makeConnKey(dstIP, srcIP, dstPort, srcPort)
t.mutex.RLock()
conn, exists := t.connections[key]
t.mutex.RUnlock()
if !exists {
return false
}
if t.stats != nil {
t.stats.TrackPacket(6, false, uint64(len(packetData)), true, key)
}
// Handle RST packets
if flags&TCPRst != 0 {
conn.Lock()
isEstablished := conn.IsEstablished()
if isEstablished || conn.State == TCPStateSynSent || conn.State == TCPStateSynReceived {
conn.State = TCPStateClosed
conn.SetEstablished(false)
conn.Unlock()
return true
}
conn.Unlock()
return false
}
// Update state
conn.Lock()
t.updateState(conn, flags, false)
conn.UpdateLastSeen()
isEstablished := conn.IsEstablished()
isValidState := t.isValidStateForFlags(conn.State, flags)
conn.Unlock()
return isEstablished || isValidState
}
// updateState updates the TCP connection state based on flags
func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound bool) {
// Handle RST flag specially - it always causes transition to closed
if flags&TCPRst != 0 {
conn.State = TCPStateClosed
conn.SetEstablished(false)
return
}
switch conn.State {
case TCPStateNew:
if flags&TCPSyn != 0 && flags&TCPAck == 0 {
conn.State = TCPStateSynSent
}
case TCPStateSynSent:
if flags&TCPSyn != 0 && flags&TCPAck != 0 {
if isOutbound {
conn.State = TCPStateSynReceived
} else {
// Simultaneous open
conn.State = TCPStateEstablished
conn.SetEstablished(true)
}
}
case TCPStateSynReceived:
if flags&TCPAck != 0 && flags&TCPSyn == 0 {
conn.State = TCPStateEstablished
conn.SetEstablished(true)
}
case TCPStateEstablished:
if flags&TCPFin != 0 {
if isOutbound {
conn.State = TCPStateFinWait1
} else {
conn.State = TCPStateCloseWait
}
conn.SetEstablished(false)
}
case TCPStateFinWait1:
switch {
case flags&TCPFin != 0 && flags&TCPAck != 0:
// Simultaneous close - both sides sent FIN
conn.State = TCPStateClosing
case flags&TCPFin != 0:
conn.State = TCPStateFinWait2
case flags&TCPAck != 0:
conn.State = TCPStateFinWait2
}
case TCPStateFinWait2:
if flags&TCPFin != 0 {
conn.State = TCPStateTimeWait
}
case TCPStateClosing:
if flags&TCPAck != 0 {
conn.State = TCPStateTimeWait
// Keep established = false from previous state
}
case TCPStateCloseWait:
if flags&TCPFin != 0 {
conn.State = TCPStateLastAck
}
case TCPStateLastAck:
if flags&TCPAck != 0 {
conn.State = TCPStateClosed
}
case TCPStateTimeWait:
// Stay in TIME-WAIT for 2MSL before transitioning to closed
// This is handled by the cleanup routine
}
}
// isValidStateForFlags checks if the TCP flags are valid for the current connection state
func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool {
if !isValidFlagCombination(flags) {
return false
}
switch state {
case TCPStateNew:
return flags&TCPSyn != 0 && flags&TCPAck == 0
case TCPStateSynSent:
return flags&TCPSyn != 0 && flags&TCPAck != 0
case TCPStateSynReceived:
return flags&TCPAck != 0
case TCPStateEstablished:
if flags&TCPRst != 0 {
return true
}
return flags&TCPAck != 0
case TCPStateFinWait1:
return flags&TCPFin != 0 || flags&TCPAck != 0
case TCPStateFinWait2:
return flags&TCPFin != 0 || flags&TCPAck != 0
case TCPStateClosing:
// In CLOSING state, we should accept the final ACK
return flags&TCPAck != 0
case TCPStateTimeWait:
// In TIME_WAIT, we might see retransmissions
return flags&TCPAck != 0
case TCPStateCloseWait:
return flags&TCPFin != 0 || flags&TCPAck != 0
case TCPStateLastAck:
return flags&TCPAck != 0
}
return false
}
func (t *TCPTracker) cleanupRoutine() {
for {
select {
case <-t.cleanupTicker.C:
t.cleanup()
case <-t.done:
return
}
}
}
func (t *TCPTracker) cleanup() {
t.mutex.Lock()
defer t.mutex.Unlock()
for key, conn := range t.connections {
var timeout time.Duration
switch {
case conn.State == TCPStateTimeWait:
timeout = TimeWaitTimeout
case conn.IsEstablished():
timeout = t.timeout
default:
timeout = TCPHandshakeTimeout
}
lastSeen := conn.GetLastSeen()
if time.Since(lastSeen) > timeout {
// Return IPs to pool
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
delete(t.connections, key)
}
}
}
// Close stops the cleanup routine and releases resources
func (t *TCPTracker) Close() {
t.cleanupTicker.Stop()
close(t.done)
// Clean up all remaining IPs
t.mutex.Lock()
for _, conn := range t.connections {
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
}
t.connections = nil
t.mutex.Unlock()
}
func isValidFlagCombination(flags uint8) bool {
// Invalid: SYN+FIN
if flags&TCPSyn != 0 && flags&TCPFin != 0 {
return false
}
// Invalid: RST with SYN or FIN
if flags&TCPRst != 0 && (flags&TCPSyn != 0 || flags&TCPFin != 0) {
return false
}
return true
}

View File

@@ -1,311 +0,0 @@
package conntrack
import (
"net"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestTCPStateMachine(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, nil)
defer tracker.Close()
srcIP := net.ParseIP("100.64.0.1")
dstIP := net.ParseIP("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
t.Run("Security Tests", func(t *testing.T) {
tests := []struct {
name string
flags uint8
wantDrop bool
desc string
}{
{
name: "Block unsolicited SYN-ACK",
flags: TCPSyn | TCPAck,
wantDrop: true,
desc: "Should block SYN-ACK without prior SYN",
},
{
name: "Block invalid SYN-FIN",
flags: TCPSyn | TCPFin,
wantDrop: true,
desc: "Should block invalid SYN-FIN combination",
},
{
name: "Block unsolicited RST",
flags: TCPRst,
wantDrop: true,
desc: "Should block RST without connection",
},
{
name: "Block unsolicited ACK",
flags: TCPAck,
wantDrop: true,
desc: "Should block ACK without connection",
},
{
name: "Block data without connection",
flags: TCPAck | TCPPush,
wantDrop: true,
desc: "Should block data without established connection",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
isValid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, tt.flags, nil)
require.Equal(t, !tt.wantDrop, isValid, tt.desc)
})
}
})
t.Run("Connection Flow Tests", func(t *testing.T) {
tests := []struct {
name string
test func(*testing.T)
desc string
}{
{
name: "Normal Handshake",
test: func(t *testing.T) {
t.Helper()
// Send initial SYN
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, nil)
// Receive SYN-ACK
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, nil)
require.True(t, valid, "SYN-ACK should be allowed")
// Send ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, nil)
// Test data transfer
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, nil)
require.True(t, valid, "Data should be allowed after handshake")
},
},
{
name: "Normal Close",
test: func(t *testing.T) {
t.Helper()
// First establish connection
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Send FIN
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, nil)
// Receive ACK for FIN
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, nil)
require.True(t, valid, "ACK for FIN should be allowed")
// Receive FIN from other side
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, nil)
require.True(t, valid, "FIN should be allowed")
// Send final ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, nil)
},
},
{
name: "RST During Connection",
test: func(t *testing.T) {
t.Helper()
// First establish connection
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Receive RST
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, nil)
require.True(t, valid, "RST should be allowed for established connection")
// Verify connection is closed
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, nil)
t.Helper()
require.False(t, valid, "Data should be blocked after RST")
},
},
{
name: "Simultaneous Close",
test: func(t *testing.T) {
t.Helper()
// First establish connection
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Both sides send FIN+ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, nil)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, nil)
require.True(t, valid, "Simultaneous FIN should be allowed")
// Both sides send final ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, nil)
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, nil)
require.True(t, valid, "Final ACKs should be allowed")
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Helper()
tracker = NewTCPTracker(DefaultTCPTimeout, nil)
tt.test(t)
})
}
})
}
func TestRSTHandling(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, nil)
defer tracker.Close()
srcIP := net.ParseIP("100.64.0.1")
dstIP := net.ParseIP("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
tests := []struct {
name string
setupState func()
sendRST func()
wantValid bool
desc string
}{
{
name: "RST in established",
setupState: func() {
// Establish connection first
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, nil)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, nil)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, nil)
},
sendRST: func() {
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, nil)
},
wantValid: true,
desc: "Should accept RST for established connection",
},
{
name: "RST without connection",
setupState: func() {},
sendRST: func() {
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, nil)
},
wantValid: false,
desc: "Should reject RST without connection",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.setupState()
tt.sendRST()
// Verify connection state is as expected
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
conn := tracker.connections[key]
if tt.wantValid {
require.NotNil(t, conn)
require.Equal(t, TCPStateClosed, conn.State)
require.False(t, conn.IsEstablished())
}
})
}
}
// Helper to establish a TCP connection
func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP net.IP, srcPort, dstPort uint16) {
t.Helper()
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, nil)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, nil)
require.True(t, valid, "SYN-ACK should be allowed")
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, nil)
}
func BenchmarkTCPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, nil)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, nil)
}
})
b.Run("IsValidInbound", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, nil)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
// Pre-populate some connections
for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, nil)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck, nil)
}
})
b.Run("ConcurrentAccess", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, nil)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
if i%2 == 0 {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, nil)
} else {
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck, nil)
}
i++
}
})
})
}
// Benchmark connection cleanup
func BenchmarkCleanup(b *testing.B) {
b.Run("TCPCleanup", func(b *testing.B) {
tracker := NewTCPTracker(100*time.Millisecond, nil) // Short timeout for testing
defer tracker.Close()
// Pre-populate with expired connections
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
for i := 0; i < 10000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, nil)
}
// Wait for connections to expire
time.Sleep(200 * time.Millisecond)
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.cleanup()
}
})
}

View File

@@ -1,173 +0,0 @@
package conntrack
import (
"net"
"sync"
"time"
fw "github.com/netbirdio/netbird/client/firewall/manager"
)
const (
// DefaultUDPTimeout is the default timeout for UDP connections
DefaultUDPTimeout = 30 * time.Second
// UDPCleanupInterval is how often we check for stale connections
UDPCleanupInterval = 15 * time.Second
)
// UDPConnTrack represents a UDP connection state
type UDPConnTrack struct {
BaseConnTrack
}
// UDPTracker manages UDP connection states
type UDPTracker struct {
connections map[ConnKey]*UDPConnTrack
timeout time.Duration
cleanupTicker *time.Ticker
mutex sync.RWMutex
done chan struct{}
ipPool *PreallocatedIPs
stats *Stats
}
// NewUDPTracker creates a new UDP connection tracker
func NewUDPTracker(timeout time.Duration, stats *Stats) *UDPTracker {
if timeout == 0 {
timeout = DefaultUDPTimeout
}
tracker := &UDPTracker{
connections: make(map[ConnKey]*UDPConnTrack),
timeout: timeout,
cleanupTicker: time.NewTicker(UDPCleanupInterval),
done: make(chan struct{}),
ipPool: NewPreallocatedIPs(),
stats: stats,
}
go tracker.cleanupRoutine()
return tracker
}
// TrackOutbound records an outbound UDP connection
func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, packetData []byte) {
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
now := time.Now().UnixNano()
t.mutex.Lock()
conn, exists := t.connections[key]
if !exists {
srcIPCopy := t.ipPool.Get()
dstIPCopy := t.ipPool.Get()
copyIP(srcIPCopy, srcIP)
copyIP(dstIPCopy, dstIP)
conn = &UDPConnTrack{
BaseConnTrack: BaseConnTrack{
SourceIP: srcIPCopy,
DestIP: dstIPCopy,
SourcePort: srcPort,
DestPort: dstPort,
},
}
conn.lastSeen.Store(now)
conn.established.Store(true)
t.connections[key] = conn
if t.stats != nil {
t.stats.TrackNewConnection(17, srcIP, dstIP, srcPort, dstPort, fw.DirectionOutbound)
}
}
t.mutex.Unlock()
if t.stats != nil {
t.stats.TrackPacket(17, false, uint64(len(packetData)), false, key)
}
conn.lastSeen.Store(now)
}
// IsValidInbound checks if an inbound packet matches a tracked connection
func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, packetData []byte) bool {
key := makeConnKey(dstIP, srcIP, dstPort, srcPort)
t.mutex.RLock()
conn, exists := t.connections[key]
t.mutex.RUnlock()
if !exists {
return false
}
if conn.timeoutExceeded(t.timeout) {
return false
}
if t.stats != nil {
t.stats.TrackPacket(17, false, uint64(len(packetData)), true, key)
}
return conn.IsEstablished() &&
ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
conn.DestPort == srcPort &&
conn.SourcePort == dstPort
}
// cleanupRoutine periodically removes stale connections
func (t *UDPTracker) cleanupRoutine() {
for {
select {
case <-t.cleanupTicker.C:
t.cleanup()
case <-t.done:
return
}
}
}
func (t *UDPTracker) cleanup() {
t.mutex.Lock()
defer t.mutex.Unlock()
for key, conn := range t.connections {
if conn.timeoutExceeded(t.timeout) {
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
delete(t.connections, key)
}
}
}
// Close stops the cleanup routine and releases resources
func (t *UDPTracker) Close() {
t.cleanupTicker.Stop()
close(t.done)
t.mutex.Lock()
for _, conn := range t.connections {
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
}
t.connections = nil
t.mutex.Unlock()
}
// GetConnection safely retrieves a connection state
func (t *UDPTracker) GetConnection(srcIP net.IP, srcPort uint16, dstIP net.IP, dstPort uint16) (*UDPConnTrack, bool) {
t.mutex.RLock()
defer t.mutex.RUnlock()
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
conn, exists := t.connections[key]
if !exists {
return nil, false
}
return conn, true
}
// Timeout returns the configured timeout duration for the tracker
func (t *UDPTracker) Timeout() time.Duration {
return t.timeout
}

View File

@@ -1,243 +0,0 @@
package conntrack
import (
"net"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewUDPTracker(t *testing.T) {
tests := []struct {
name string
timeout time.Duration
wantTimeout time.Duration
}{
{
name: "with custom timeout",
timeout: 1 * time.Minute,
wantTimeout: 1 * time.Minute,
},
{
name: "with zero timeout uses default",
timeout: 0,
wantTimeout: DefaultUDPTimeout,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tracker := NewUDPTracker(tt.timeout, nil)
assert.NotNil(t, tracker)
assert.Equal(t, tt.wantTimeout, tracker.timeout)
assert.NotNil(t, tracker.connections)
assert.NotNil(t, tracker.cleanupTicker)
assert.NotNil(t, tracker.done)
})
}
}
func TestUDPTracker_TrackOutbound(t *testing.T) {
tracker := NewUDPTracker(DefaultUDPTimeout, nil)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.2")
dstIP := net.ParseIP("192.168.1.3")
srcPort := uint16(12345)
dstPort := uint16(53)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, nil)
// Verify connection was tracked
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
conn, exists := tracker.connections[key]
require.True(t, exists)
assert.True(t, conn.SourceIP.Equal(srcIP))
assert.True(t, conn.DestIP.Equal(dstIP))
assert.Equal(t, srcPort, conn.SourcePort)
assert.Equal(t, dstPort, conn.DestPort)
assert.True(t, conn.IsEstablished())
assert.WithinDuration(t, time.Now(), conn.GetLastSeen(), 1*time.Second)
}
func TestUDPTracker_IsValidInbound(t *testing.T) {
tracker := NewUDPTracker(1*time.Second, nil)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.2")
dstIP := net.ParseIP("192.168.1.3")
srcPort := uint16(12345)
dstPort := uint16(53)
// Track outbound connection
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, nil)
tests := []struct {
name string
srcIP net.IP
dstIP net.IP
srcPort uint16
dstPort uint16
sleep time.Duration
want bool
}{
{
name: "valid inbound response",
srcIP: dstIP, // Original destination is now source
dstIP: srcIP, // Original source is now destination
srcPort: dstPort, // Original destination port is now source
dstPort: srcPort, // Original source port is now destination
sleep: 0,
want: true,
},
{
name: "invalid source IP",
srcIP: net.ParseIP("192.168.1.4"),
dstIP: srcIP,
srcPort: dstPort,
dstPort: srcPort,
sleep: 0,
want: false,
},
{
name: "invalid destination IP",
srcIP: dstIP,
dstIP: net.ParseIP("192.168.1.4"),
srcPort: dstPort,
dstPort: srcPort,
sleep: 0,
want: false,
},
{
name: "invalid source port",
srcIP: dstIP,
dstIP: srcIP,
srcPort: 54321,
dstPort: srcPort,
sleep: 0,
want: false,
},
{
name: "invalid destination port",
srcIP: dstIP,
dstIP: srcIP,
srcPort: dstPort,
dstPort: 54321,
sleep: 0,
want: false,
},
{
name: "expired connection",
srcIP: dstIP,
dstIP: srcIP,
srcPort: dstPort,
dstPort: srcPort,
sleep: 2 * time.Second, // Longer than tracker timeout
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.sleep > 0 {
time.Sleep(tt.sleep)
}
got := tracker.IsValidInbound(tt.srcIP, tt.dstIP, tt.srcPort, tt.dstPort, nil)
assert.Equal(t, tt.want, got)
})
}
}
func TestUDPTracker_Cleanup(t *testing.T) {
// Use shorter intervals for testing
timeout := 50 * time.Millisecond
cleanupInterval := 25 * time.Millisecond
// Create tracker with custom cleanup interval
tracker := &UDPTracker{
connections: make(map[ConnKey]*UDPConnTrack),
timeout: timeout,
cleanupTicker: time.NewTicker(cleanupInterval),
done: make(chan struct{}),
ipPool: NewPreallocatedIPs(),
}
// Start cleanup routine
go tracker.cleanupRoutine()
// Add some connections
connections := []struct {
srcIP net.IP
dstIP net.IP
srcPort uint16
dstPort uint16
}{
{
srcIP: net.ParseIP("192.168.1.2"),
dstIP: net.ParseIP("192.168.1.3"),
srcPort: 12345,
dstPort: 53,
},
{
srcIP: net.ParseIP("192.168.1.4"),
dstIP: net.ParseIP("192.168.1.5"),
srcPort: 12346,
dstPort: 53,
},
}
for _, conn := range connections {
tracker.TrackOutbound(conn.srcIP, conn.dstIP, conn.srcPort, conn.dstPort, nil)
}
// Verify initial connections
assert.Len(t, tracker.connections, 2)
// Wait for connection timeout and cleanup interval
time.Sleep(timeout + 2*cleanupInterval)
tracker.mutex.RLock()
connCount := len(tracker.connections)
tracker.mutex.RUnlock()
// Verify connections were cleaned up
assert.Equal(t, 0, connCount, "Expected all connections to be cleaned up")
// Properly close the tracker
tracker.Close()
}
func BenchmarkUDPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewUDPTracker(DefaultUDPTimeout, nil)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, nil)
}
})
b.Run("IsValidInbound", func(b *testing.B) {
tracker := NewUDPTracker(DefaultUDPTimeout, nil)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
// Pre-populate some connections
for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, nil)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), nil)
}
})
}

View File

@@ -3,9 +3,6 @@ package uspfilter
import ( import (
"fmt" "fmt"
"net" "net"
"net/netip"
"os"
"strconv"
"sync" "sync"
"github.com/google/gopacket" "github.com/google/gopacket"
@@ -14,26 +11,18 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal/statemanager"
) )
const layerTypeAll = 0 const layerTypeAll = 0
const (
EnvDisableConntrack = "NB_DISABLE_CONNTRACK"
EnvEnableStats = "NB_ENABLE_CONNTRACK_STATS"
)
var ( var (
errRouteNotSupported = fmt.Errorf("route not supported with userspace firewall") errRouteNotSupported = fmt.Errorf("route not supported with userspace firewall")
) )
// IFaceMapper defines subset methods of interface required for manager // IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface { type IFaceMapper interface {
SetFilter(device.PacketFilter) error SetFilter(iface.PacketFilter) error
Address() iface.WGAddress Address() iface.WGAddress
} }
@@ -50,14 +39,6 @@ type Manager struct {
nativeFirewall firewall.Manager nativeFirewall firewall.Manager
mutex sync.RWMutex mutex sync.RWMutex
stateful bool
udpTracker *conntrack.UDPTracker
icmpTracker *conntrack.ICMPTracker
tcpTracker *conntrack.TCPTracker
statsEnabled bool
stats *conntrack.Stats
} }
// decoder for packages // decoder for packages
@@ -89,9 +70,6 @@ func CreateWithNativeFirewall(iface IFaceMapper, nativeFirewall firewall.Manager
} }
func create(iface IFaceMapper) (*Manager, error) { func create(iface IFaceMapper) (*Manager, error) {
disableConntrack, _ := strconv.ParseBool(os.Getenv(EnvDisableConntrack))
enableStats, _ := strconv.ParseBool(os.Getenv(EnvEnableStats))
m := &Manager{ m := &Manager{
decoders: sync.Pool{ decoders: sync.Pool{
New: func() any { New: func() any {
@@ -109,22 +87,6 @@ func create(iface IFaceMapper) (*Manager, error) {
outgoingRules: make(map[string]RuleSet), outgoingRules: make(map[string]RuleSet),
incomingRules: make(map[string]RuleSet), incomingRules: make(map[string]RuleSet),
wgIface: iface, wgIface: iface,
stateful: !disableConntrack,
statsEnabled: enableStats,
}
if enableStats {
m.stats = conntrack.NewStats()
log.Info("connection tracking statistics enabled")
}
// Only initialize trackers if stateful mode is enabled
if disableConntrack {
log.Info("conntrack is disabled")
} else {
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.stats)
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.stats)
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.stats)
} }
if err := iface.SetFilter(m); err != nil { if err := iface.SetFilter(m); err != nil {
@@ -133,10 +95,6 @@ func create(iface IFaceMapper) (*Manager, error) {
return m, nil return m, nil
} }
func (m *Manager) Init(*statemanager.Manager) error {
return nil
}
func (m *Manager) IsServerRouteSupported() bool { func (m *Manager) IsServerRouteSupported() bool {
if m.nativeFirewall == nil { if m.nativeFirewall == nil {
return false return false
@@ -145,26 +103,26 @@ func (m *Manager) IsServerRouteSupported() bool {
} }
} }
func (m *Manager) AddNatRule(pair firewall.RouterPair) error { func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error {
if m.nativeFirewall == nil { if m.nativeFirewall == nil {
return errRouteNotSupported return errRouteNotSupported
} }
return m.nativeFirewall.AddNatRule(pair) return m.nativeFirewall.InsertRoutingRules(pair)
} }
// RemoveNatRule removes a routing firewall rule // RemoveRoutingRules removes a routing firewall rule
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error { func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error {
if m.nativeFirewall == nil { if m.nativeFirewall == nil {
return errRouteNotSupported return errRouteNotSupported
} }
return m.nativeFirewall.RemoveNatRule(pair) return m.nativeFirewall.RemoveRoutingRules(pair)
} }
// AddPeerFiltering rule to the firewall // AddFiltering rule to the firewall
// //
// If comment argument is empty firewall manager should set // If comment argument is empty firewall manager should set
// rule ID as comment for the rule // rule ID as comment for the rule
func (m *Manager) AddPeerFiltering( func (m *Manager) AddFiltering(
ip net.IP, ip net.IP,
proto firewall.Protocol, proto firewall.Protocol,
sPort *firewall.Port, sPort *firewall.Port,
@@ -230,22 +188,8 @@ func (m *Manager) AddPeerFiltering(
return []firewall.Rule{&r}, nil return []firewall.Rule{&r}, nil
} }
func (m *Manager) AddRouteFiltering(sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action) (firewall.Rule, error) { // DeleteRule from the firewall by rule definition
if m.nativeFirewall == nil { func (m *Manager) DeleteRule(rule firewall.Rule) error {
return nil, errRouteNotSupported
}
return m.nativeFirewall.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
}
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
if m.nativeFirewall == nil {
return errRouteNotSupported
}
return m.nativeFirewall.DeleteRouteRule(rule)
}
// DeletePeerRule from the firewall by rule definition
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
@@ -271,225 +215,83 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
return nil return nil
} }
// SetLegacyManagement doesn't need to be implemented for this manager
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
if m.nativeFirewall == nil {
return nil
}
return m.nativeFirewall.SetLegacyManagement(isLegacy)
}
// Flush doesn't need to be implemented for this manager // Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil } func (m *Manager) Flush() error { return nil }
// DropOutgoing filter outgoing packets // DropOutgoing filter outgoing packets
func (m *Manager) DropOutgoing(packetData []byte) bool { func (m *Manager) DropOutgoing(packetData []byte) bool {
return m.processOutgoingHooks(packetData) return m.dropFilter(packetData, m.outgoingRules, false)
} }
// DropIncoming filter incoming packets // DropIncoming filter incoming packets
func (m *Manager) DropIncoming(packetData []byte) bool { func (m *Manager) DropIncoming(packetData []byte) bool {
return m.dropFilter(packetData, m.incomingRules) return m.dropFilter(packetData, m.incomingRules, true)
} }
// processOutgoingHooks processes UDP hooks for outgoing packets and tracks TCP/UDP/ICMP // dropFilter implements same logic for booth direction of the traffic
func (m *Manager) processOutgoingHooks(packetData []byte) bool { func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet, isIncomingPacket bool) bool {
m.mutex.RLock() m.mutex.RLock()
defer m.mutex.RUnlock() defer m.mutex.RUnlock()
d := m.decoders.Get().(*decoder) d := m.decoders.Get().(*decoder)
defer m.decoders.Put(d) defer m.decoders.Put(d)
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
return false
}
if len(d.decoded) < 2 {
return false
}
srcIP, dstIP := m.extractIPs(d)
if srcIP == nil {
return false
}
// Always process UDP hooks
if d.decoded[1] == layers.LayerTypeUDP {
// Track UDP state only if enabled
if m.stateful {
m.udpTracker.TrackOutbound(srcIP, dstIP,
uint16(d.udp.SrcPort),
uint16(d.udp.DstPort),
packetData)
}
return m.checkUDPHooks(d, dstIP, packetData)
}
// Track other protocols only if stateful mode is enabled
if m.stateful {
switch d.decoded[1] {
case layers.LayerTypeTCP:
m.tcpTracker.TrackOutbound(srcIP, dstIP,
uint16(d.tcp.SrcPort),
uint16(d.tcp.DstPort),
getTCPFlags(&d.tcp),
packetData)
case layers.LayerTypeICMPv4:
m.icmpTracker.TrackOutbound(srcIP, dstIP,
d.icmp4.Id,
d.icmp4.Seq,
packetData)
}
}
return false
}
func (m *Manager) extractIPs(d *decoder) (srcIP, dstIP net.IP) {
switch d.decoded[0] {
case layers.LayerTypeIPv4:
return d.ip4.SrcIP, d.ip4.DstIP
case layers.LayerTypeIPv6:
return d.ip6.SrcIP, d.ip6.DstIP
default:
return nil, nil
}
}
func getTCPFlags(tcp *layers.TCP) uint8 {
var flags uint8
if tcp.SYN {
flags |= conntrack.TCPSyn
}
if tcp.ACK {
flags |= conntrack.TCPAck
}
if tcp.FIN {
flags |= conntrack.TCPFin
}
if tcp.RST {
flags |= conntrack.TCPRst
}
if tcp.PSH {
flags |= conntrack.TCPPush
}
if tcp.URG {
flags |= conntrack.TCPUrg
}
return flags
}
func (m *Manager) checkUDPHooks(d *decoder, dstIP net.IP, packetData []byte) bool {
for _, ipKey := range []string{dstIP.String(), "0.0.0.0", "::"} {
if rules, exists := m.outgoingRules[ipKey]; exists {
for _, rule := range rules {
if rule.udpHook != nil && (rule.dPort == 0 || rule.dPort == uint16(d.udp.DstPort)) {
return rule.udpHook(packetData)
}
}
}
}
return false
}
// dropFilter implements filtering logic for incoming packets
func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool {
m.mutex.RLock()
defer m.mutex.RUnlock()
d := m.decoders.Get().(*decoder)
defer m.decoders.Put(d)
if !m.isValidPacket(d, packetData) {
return true
}
srcIP, dstIP := m.extractIPs(d)
if srcIP == nil {
log.Errorf("unknown layer: %v", d.decoded[0])
return true
}
if !m.isWireguardTraffic(srcIP, dstIP) {
return false
}
// Check connection state only if enabled
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP, packetData) {
return false
}
return m.applyRules(srcIP, packetData, rules, d)
}
func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool {
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
log.Tracef("couldn't decode layer, err: %s", err) log.Tracef("couldn't decode layer, err: %s", err)
return false return true
} }
if len(d.decoded) < 2 { if len(d.decoded) < 2 {
log.Tracef("not enough levels in network packet") log.Tracef("not enough levels in network packet")
return false
}
return true return true
} }
func (m *Manager) isWireguardTraffic(srcIP, dstIP net.IP) bool { ipLayer := d.decoded[0]
return m.wgNetwork.Contains(srcIP) && m.wgNetwork.Contains(dstIP)
}
func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP, packetData []byte) bool {
switch d.decoded[1] {
case layers.LayerTypeTCP:
return m.tcpTracker.IsValidInbound(
srcIP,
dstIP,
uint16(d.tcp.SrcPort),
uint16(d.tcp.DstPort),
getTCPFlags(&d.tcp),
packetData,
)
case layers.LayerTypeUDP:
return m.udpTracker.IsValidInbound(
srcIP,
dstIP,
uint16(d.udp.SrcPort),
uint16(d.udp.DstPort),
packetData,
)
case layers.LayerTypeICMPv4:
return m.icmpTracker.IsValidInbound(
srcIP,
dstIP,
d.icmp4.Id,
d.icmp4.Seq,
d.icmp4.TypeCode.Type(),
packetData,
)
// TODO: ICMPv6
}
switch ipLayer {
case layers.LayerTypeIPv4:
if !m.wgNetwork.Contains(d.ip4.SrcIP) || !m.wgNetwork.Contains(d.ip4.DstIP) {
return false return false
} }
case layers.LayerTypeIPv6:
if !m.wgNetwork.Contains(d.ip6.SrcIP) || !m.wgNetwork.Contains(d.ip6.DstIP) {
return false
}
default:
log.Errorf("unknown layer: %v", d.decoded[0])
return true
}
func (m *Manager) applyRules(srcIP net.IP, packetData []byte, rules map[string]RuleSet, d *decoder) bool { var ip net.IP
if filter, ok := validateRule(srcIP, packetData, rules[srcIP.String()], d); ok { switch ipLayer {
case layers.LayerTypeIPv4:
if isIncomingPacket {
ip = d.ip4.SrcIP
} else {
ip = d.ip4.DstIP
}
case layers.LayerTypeIPv6:
if isIncomingPacket {
ip = d.ip6.SrcIP
} else {
ip = d.ip6.DstIP
}
}
filter, ok := validateRule(ip, packetData, rules[ip.String()], d)
if ok {
return filter
}
filter, ok = validateRule(ip, packetData, rules["0.0.0.0"], d)
if ok {
return filter
}
filter, ok = validateRule(ip, packetData, rules["::"], d)
if ok {
return filter return filter
} }
if filter, ok := validateRule(srcIP, packetData, rules["0.0.0.0"], d); ok { // default policy is DROP ALL
return filter
}
if filter, ok := validateRule(srcIP, packetData, rules["::"], d); ok {
return filter
}
// Default policy: DROP ALL
return true return true
} }
@@ -535,6 +337,7 @@ func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decode
if rule.dPort != 0 && rule.dPort == uint16(d.udp.DstPort) { if rule.dPort != 0 && rule.dPort == uint16(d.udp.DstPort) {
return rule.drop, true return rule.drop, true
} }
return rule.drop, true
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6: case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
return rule.drop, true return rule.drop, true
} }
@@ -593,7 +396,7 @@ func (m *Manager) RemovePacketHook(hookID string) error {
for _, r := range arr { for _, r := range arr {
if r.id == hookID { if r.id == hookID {
rule := r rule := r
return m.DeletePeerRule(&rule) return m.DeleteRule(&rule)
} }
} }
} }
@@ -601,17 +404,9 @@ func (m *Manager) RemovePacketHook(hookID string) error {
for _, r := range arr { for _, r := range arr {
if r.id == hookID { if r.id == hookID {
rule := r rule := r
return m.DeletePeerRule(&rule) return m.DeleteRule(&rule)
} }
} }
} }
return fmt.Errorf("hook with given id not found") return fmt.Errorf("hook with given id not found")
} }
// CollectStats returns connection tracking statistics
func (m *Manager) CollectStats() []*firewall.FlowStats {
if m.stats == nil {
return nil
}
return m.stats.GetFlowSnapshot()
}

File diff suppressed because it is too large Load Diff

View File

@@ -3,7 +3,6 @@ package uspfilter
import ( import (
"fmt" "fmt"
"net" "net"
"sync"
"testing" "testing"
"time" "time"
@@ -12,17 +11,15 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall/manager" fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
) )
type IFaceMock struct { type IFaceMock struct {
SetFilterFunc func(device.PacketFilter) error SetFilterFunc func(iface.PacketFilter) error
AddressFunc func() iface.WGAddress AddressFunc func() iface.WGAddress
} }
func (i *IFaceMock) SetFilter(iface device.PacketFilter) error { func (i *IFaceMock) SetFilter(iface iface.PacketFilter) error {
if i.SetFilterFunc == nil { if i.SetFilterFunc == nil {
return fmt.Errorf("not implemented") return fmt.Errorf("not implemented")
} }
@@ -38,7 +35,7 @@ func (i *IFaceMock) Address() iface.WGAddress {
func TestManagerCreate(t *testing.T) { func TestManagerCreate(t *testing.T) {
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(iface.PacketFilter) error { return nil },
} }
m, err := Create(ifaceMock) m, err := Create(ifaceMock)
@@ -52,10 +49,10 @@ func TestManagerCreate(t *testing.T) {
} }
} }
func TestManagerAddPeerFiltering(t *testing.T) { func TestManagerAddFiltering(t *testing.T) {
isSetFilterCalled := false isSetFilterCalled := false
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { SetFilterFunc: func(iface.PacketFilter) error {
isSetFilterCalled = true isSetFilterCalled = true
return nil return nil
}, },
@@ -74,7 +71,7 @@ func TestManagerAddPeerFiltering(t *testing.T) {
action := fw.ActionDrop action := fw.ActionDrop
comment := "Test rule" comment := "Test rule"
rule, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment) rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
if err != nil { if err != nil {
t.Errorf("failed to add filtering: %v", err) t.Errorf("failed to add filtering: %v", err)
return return
@@ -93,7 +90,7 @@ func TestManagerAddPeerFiltering(t *testing.T) {
func TestManagerDeleteRule(t *testing.T) { func TestManagerDeleteRule(t *testing.T) {
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(iface.PacketFilter) error { return nil },
} }
m, err := Create(ifaceMock) m, err := Create(ifaceMock)
@@ -109,7 +106,7 @@ func TestManagerDeleteRule(t *testing.T) {
action := fw.ActionDrop action := fw.ActionDrop
comment := "Test rule" comment := "Test rule"
rule, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment) rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
if err != nil { if err != nil {
t.Errorf("failed to add filtering: %v", err) t.Errorf("failed to add filtering: %v", err)
return return
@@ -122,14 +119,14 @@ func TestManagerDeleteRule(t *testing.T) {
action = fw.ActionDrop action = fw.ActionDrop
comment = "Test rule 2" comment = "Test rule 2"
rule2, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment) rule2, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
if err != nil { if err != nil {
t.Errorf("failed to add filtering: %v", err) t.Errorf("failed to add filtering: %v", err)
return return
} }
for _, r := range rule { for _, r := range rule {
err = m.DeletePeerRule(r) err = m.DeleteRule(r)
if err != nil { if err != nil {
t.Errorf("failed to delete rule: %v", err) t.Errorf("failed to delete rule: %v", err)
return return
@@ -143,7 +140,7 @@ func TestManagerDeleteRule(t *testing.T) {
} }
for _, r := range rule2 { for _, r := range rule2 {
err = m.DeletePeerRule(r) err = m.DeleteRule(r)
if err != nil { if err != nil {
t.Errorf("failed to delete rule: %v", err) t.Errorf("failed to delete rule: %v", err)
return return
@@ -187,10 +184,10 @@ func TestAddUDPPacketHook(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
manager, err := Create(&IFaceMock{ manager := &Manager{
SetFilterFunc: func(device.PacketFilter) error { return nil }, incomingRules: map[string]RuleSet{},
}) outgoingRules: map[string]RuleSet{},
require.NoError(t, err) }
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook) manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
@@ -239,7 +236,7 @@ func TestAddUDPPacketHook(t *testing.T) {
func TestManagerReset(t *testing.T) { func TestManagerReset(t *testing.T) {
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(iface.PacketFilter) error { return nil },
} }
m, err := Create(ifaceMock) m, err := Create(ifaceMock)
@@ -255,13 +252,13 @@ func TestManagerReset(t *testing.T) {
action := fw.ActionDrop action := fw.ActionDrop
comment := "Test rule" comment := "Test rule"
_, err = m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment) _, err = m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
if err != nil { if err != nil {
t.Errorf("failed to add filtering: %v", err) t.Errorf("failed to add filtering: %v", err)
return return
} }
err = m.Reset(nil) err = m.Reset()
if err != nil { if err != nil {
t.Errorf("failed to reset Manager: %v", err) t.Errorf("failed to reset Manager: %v", err)
return return
@@ -274,7 +271,7 @@ func TestManagerReset(t *testing.T) {
func TestNotMatchByIP(t *testing.T) { func TestNotMatchByIP(t *testing.T) {
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(iface.PacketFilter) error { return nil },
} }
m, err := Create(ifaceMock) m, err := Create(ifaceMock)
@@ -293,7 +290,7 @@ func TestNotMatchByIP(t *testing.T) {
action := fw.ActionAccept action := fw.ActionAccept
comment := "Test rule" comment := "Test rule"
_, err = m.AddPeerFiltering(ip, proto, nil, nil, direction, action, "", comment) _, err = m.AddFiltering(ip, proto, nil, nil, direction, action, "", comment)
if err != nil { if err != nil {
t.Errorf("failed to add filtering: %v", err) t.Errorf("failed to add filtering: %v", err)
return return
@@ -315,7 +312,7 @@ func TestNotMatchByIP(t *testing.T) {
t.Errorf("failed to set network layer for checksum: %v", err) t.Errorf("failed to set network layer for checksum: %v", err)
return return
} }
payload := gopacket.Payload("test") payload := gopacket.Payload([]byte("test"))
buf := gopacket.NewSerializeBuffer() buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{ opts := gopacket.SerializeOptions{
@@ -327,12 +324,12 @@ func TestNotMatchByIP(t *testing.T) {
return return
} }
if m.dropFilter(buf.Bytes(), m.outgoingRules) { if m.dropFilter(buf.Bytes(), m.outgoingRules, false) {
t.Errorf("expected packet to be accepted") t.Errorf("expected packet to be accepted")
return return
} }
if err = m.Reset(nil); err != nil { if err = m.Reset(); err != nil {
t.Errorf("failed to reset Manager: %v", err) t.Errorf("failed to reset Manager: %v", err)
return return
} }
@@ -342,7 +339,7 @@ func TestNotMatchByIP(t *testing.T) {
func TestRemovePacketHook(t *testing.T) { func TestRemovePacketHook(t *testing.T) {
// creating mock iface // creating mock iface
iface := &IFaceMock{ iface := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(iface.PacketFilter) error { return nil },
} }
// creating manager instance // creating manager instance
@@ -350,9 +347,6 @@ func TestRemovePacketHook(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("Failed to create Manager: %s", err) t.Fatalf("Failed to create Manager: %s", err)
} }
defer func() {
require.NoError(t, manager.Reset(nil))
}()
// Add a UDP packet hook // Add a UDP packet hook
hookFunc := func(data []byte) bool { return true } hookFunc := func(data []byte) bool { return true }
@@ -389,101 +383,19 @@ func TestRemovePacketHook(t *testing.T) {
} }
} }
func TestProcessOutgoingHooks(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
})
require.NoError(t, err)
manager.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
}
manager.udpTracker.Close()
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, nil)
defer func() {
require.NoError(t, manager.Reset(nil))
}()
manager.decoders = sync.Pool{
New: func() any {
d := &decoder{
decoded: []gopacket.LayerType{},
}
d.parser = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv4,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser.IgnoreUnsupported = true
return d
},
}
hookCalled := false
hookID := manager.AddUDPPacketHook(
false,
net.ParseIP("100.10.0.100"),
53,
func([]byte) bool {
hookCalled = true
return true
},
)
require.NotEmpty(t, hookID)
// Create test UDP packet
ipv4 := &layers.IPv4{
TTL: 64,
Version: 4,
SrcIP: net.ParseIP("100.10.0.1"),
DstIP: net.ParseIP("100.10.0.100"),
Protocol: layers.IPProtocolUDP,
}
udp := &layers.UDP{
SrcPort: 51334,
DstPort: 53,
}
err = udp.SetNetworkLayerForChecksum(ipv4)
require.NoError(t, err)
payload := gopacket.Payload("test")
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{
ComputeChecksums: true,
FixLengths: true,
}
err = gopacket.SerializeLayers(buf, opts, ipv4, udp, payload)
require.NoError(t, err)
// Test hook gets called
result := manager.processOutgoingHooks(buf.Bytes())
require.True(t, result)
require.True(t, hookCalled)
// Test non-UDP packet is ignored
ipv4.Protocol = layers.IPProtocolTCP
buf = gopacket.NewSerializeBuffer()
err = gopacket.SerializeLayers(buf, opts, ipv4)
require.NoError(t, err)
result = manager.processOutgoingHooks(buf.Bytes())
require.False(t, result)
}
func TestUSPFilterCreatePerformance(t *testing.T) { func TestUSPFilterCreatePerformance(t *testing.T) {
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} { for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) { t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
// just check on the local interface // just check on the local interface
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(iface.PacketFilter) error { return nil },
} }
manager, err := Create(ifaceMock) manager, err := Create(ifaceMock)
require.NoError(t, err) require.NoError(t, err)
time.Sleep(time.Second) time.Sleep(time.Second)
defer func() { defer func() {
if err := manager.Reset(nil); err != nil { if err := manager.Reset(); err != nil {
t.Errorf("clear the manager state: %v", err) t.Errorf("clear the manager state: %v", err)
} }
time.Sleep(time.Second) time.Sleep(time.Second)
@@ -494,9 +406,9 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
for i := 0; i < testMax; i++ { for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []int{1000 + i}} port := &fw.Port{Values: []int{1000 + i}}
if i%2 == 0 { if i%2 == 0 {
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic") _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
} else { } else {
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic") _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
} }
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
@@ -505,213 +417,3 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
}) })
} }
} }
func TestStatefulFirewall_UDPTracking(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
})
require.NoError(t, err)
manager.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
}
manager.udpTracker.Close() // Close the existing tracker
manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, nil)
manager.decoders = sync.Pool{
New: func() any {
d := &decoder{
decoded: []gopacket.LayerType{},
}
d.parser = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv4,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser.IgnoreUnsupported = true
return d
},
}
defer func() {
require.NoError(t, manager.Reset(nil))
}()
// Set up packet parameters
srcIP := net.ParseIP("100.10.0.1")
dstIP := net.ParseIP("100.10.0.100")
srcPort := uint16(51334)
dstPort := uint16(53)
// Create outbound packet
outboundIPv4 := &layers.IPv4{
TTL: 64,
Version: 4,
SrcIP: srcIP,
DstIP: dstIP,
Protocol: layers.IPProtocolUDP,
}
outboundUDP := &layers.UDP{
SrcPort: layers.UDPPort(srcPort),
DstPort: layers.UDPPort(dstPort),
}
err = outboundUDP.SetNetworkLayerForChecksum(outboundIPv4)
require.NoError(t, err)
outboundBuf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{
ComputeChecksums: true,
FixLengths: true,
}
err = gopacket.SerializeLayers(outboundBuf, opts,
outboundIPv4,
outboundUDP,
gopacket.Payload("test"),
)
require.NoError(t, err)
// Process outbound packet and verify connection tracking
drop := manager.DropOutgoing(outboundBuf.Bytes())
require.False(t, drop, "Initial outbound packet should not be dropped")
// Verify connection was tracked
conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort)
require.True(t, exists, "Connection should be tracked after outbound packet")
require.True(t, conntrack.ValidateIPs(conntrack.MakeIPAddr(srcIP), conn.SourceIP), "Source IP should match")
require.True(t, conntrack.ValidateIPs(conntrack.MakeIPAddr(dstIP), conn.DestIP), "Destination IP should match")
require.Equal(t, srcPort, conn.SourcePort, "Source port should match")
require.Equal(t, dstPort, conn.DestPort, "Destination port should match")
// Create valid inbound response packet
inboundIPv4 := &layers.IPv4{
TTL: 64,
Version: 4,
SrcIP: dstIP, // Original destination is now source
DstIP: srcIP, // Original source is now destination
Protocol: layers.IPProtocolUDP,
}
inboundUDP := &layers.UDP{
SrcPort: layers.UDPPort(dstPort), // Original destination port is now source
DstPort: layers.UDPPort(srcPort), // Original source port is now destination
}
err = inboundUDP.SetNetworkLayerForChecksum(inboundIPv4)
require.NoError(t, err)
inboundBuf := gopacket.NewSerializeBuffer()
err = gopacket.SerializeLayers(inboundBuf, opts,
inboundIPv4,
inboundUDP,
gopacket.Payload("response"),
)
require.NoError(t, err)
// Test roundtrip response handling over time
checkPoints := []struct {
sleep time.Duration
shouldAllow bool
description string
}{
{
sleep: 0,
shouldAllow: true,
description: "Immediate response should be allowed",
},
{
sleep: 50 * time.Millisecond,
shouldAllow: true,
description: "Response within timeout should be allowed",
},
{
sleep: 100 * time.Millisecond,
shouldAllow: true,
description: "Response at half timeout should be allowed",
},
{
// tracker hasn't updated conn for 250ms -> greater than 200ms timeout
sleep: 250 * time.Millisecond,
shouldAllow: false,
description: "Response after timeout should be dropped",
},
}
for _, cp := range checkPoints {
time.Sleep(cp.sleep)
drop = manager.dropFilter(inboundBuf.Bytes(), manager.incomingRules)
require.Equal(t, cp.shouldAllow, !drop, cp.description)
// If the connection should still be valid, verify it exists
if cp.shouldAllow {
conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort)
require.True(t, exists, "Connection should still exist during valid window")
require.True(t, time.Since(conn.GetLastSeen()) < manager.udpTracker.Timeout(),
"LastSeen should be updated for valid responses")
}
}
// Test invalid response packets (while connection is expired)
invalidCases := []struct {
name string
modifyFunc func(*layers.IPv4, *layers.UDP)
description string
}{
{
name: "wrong source IP",
modifyFunc: func(ip *layers.IPv4, udp *layers.UDP) {
ip.SrcIP = net.ParseIP("100.10.0.101")
},
description: "Response from wrong IP should be dropped",
},
{
name: "wrong destination IP",
modifyFunc: func(ip *layers.IPv4, udp *layers.UDP) {
ip.DstIP = net.ParseIP("100.10.0.2")
},
description: "Response to wrong IP should be dropped",
},
{
name: "wrong source port",
modifyFunc: func(ip *layers.IPv4, udp *layers.UDP) {
udp.SrcPort = 54
},
description: "Response from wrong port should be dropped",
},
{
name: "wrong destination port",
modifyFunc: func(ip *layers.IPv4, udp *layers.UDP) {
udp.DstPort = 51335
},
description: "Response to wrong port should be dropped",
},
}
// Create a new outbound connection for invalid tests
drop = manager.processOutgoingHooks(outboundBuf.Bytes())
require.False(t, drop, "Second outbound packet should not be dropped")
for _, tc := range invalidCases {
t.Run(tc.name, func(t *testing.T) {
testIPv4 := *inboundIPv4
testUDP := *inboundUDP
tc.modifyFunc(&testIPv4, &testUDP)
err = testUDP.SetNetworkLayerForChecksum(&testIPv4)
require.NoError(t, err)
testBuf := gopacket.NewSerializeBuffer()
err = gopacket.SerializeLayers(testBuf, opts,
&testIPv4,
&testUDP,
gopacket.Payload("response"),
)
require.NoError(t, err)
// Verify the invalid packet is dropped
drop = manager.dropFilter(testBuf.Bytes(), manager.incomingRules)
require.True(t, drop, tc.description)
})
}
}

View File

@@ -1,12 +0,0 @@
package bind
import (
wireguard "golang.zx2c4.com/wireguard/conn"
nbnet "github.com/netbirdio/netbird/util/net"
)
func init() {
// ControlFns is not thread safe and should only be modified during init.
*wireguard.ControlFns = append(*wireguard.ControlFns, nbnet.ControlProtectSocket)
}

View File

@@ -1,5 +0,0 @@
package bind
import wgConn "golang.zx2c4.com/wireguard/conn"
type Endpoint = wgConn.StdNetEndpoint

View File

@@ -1,303 +0,0 @@
package bind
import (
"fmt"
"net"
"net/netip"
"runtime"
"strings"
"sync"
"github.com/pion/stun/v2"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
wgConn "golang.zx2c4.com/wireguard/conn"
)
type RecvMessage struct {
Endpoint *Endpoint
Buffer []byte
}
type receiverCreator struct {
iceBind *ICEBind
}
func (rc receiverCreator) CreateIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgPool *sync.Pool) wgConn.ReceiveFunc {
return rc.iceBind.createIPv4ReceiverFn(pc, conn, rxOffload, msgPool)
}
// ICEBind is a bind implementation with two main features:
// 1. filter out STUN messages and handle them
// 2. forward the received packets to the WireGuard interface from the relayed connection
//
// ICEBind.endpoints var is a map that stores the connection for each relayed peer. Fake address is just an IP address
// without port, in the format of 127.1.x.x where x.x is the last two octets of the peer address. We try to avoid to
// use the port because in the Send function the wgConn.Endpoint the port info is not exported.
type ICEBind struct {
*wgConn.StdNetBind
RecvChan chan RecvMessage
transportNet transport.Net
filterFn FilterFn
endpoints map[netip.Addr]net.Conn
endpointsMu sync.Mutex
// every time when Close() is called (i.e. BindUpdate()) we need to close exit from the receiveRelayed and create a
// new closed channel. With the closedChanMu we can safely close the channel and create a new one
closedChan chan struct{}
closedChanMu sync.RWMutex // protect the closeChan recreation from reading from it.
closed bool
muUDPMux sync.Mutex
udpMux *UniversalUDPMuxDefault
}
func NewICEBind(transportNet transport.Net, filterFn FilterFn) *ICEBind {
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
ib := &ICEBind{
StdNetBind: b,
RecvChan: make(chan RecvMessage, 1),
transportNet: transportNet,
filterFn: filterFn,
endpoints: make(map[netip.Addr]net.Conn),
closedChan: make(chan struct{}),
closed: true,
}
rc := receiverCreator{
ib,
}
ib.StdNetBind = wgConn.NewStdNetBindWithReceiverCreator(rc)
return ib
}
func (s *ICEBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) {
s.closed = false
s.closedChanMu.Lock()
s.closedChan = make(chan struct{})
s.closedChanMu.Unlock()
fns, port, err := s.StdNetBind.Open(uport)
if err != nil {
return nil, 0, err
}
fns = append(fns, s.receiveRelayed)
return fns, port, nil
}
func (s *ICEBind) Close() error {
if s.closed {
return nil
}
s.closed = true
close(s.closedChan)
return s.StdNetBind.Close()
}
// GetICEMux returns the ICE UDPMux that was created and used by ICEBind
func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
s.muUDPMux.Lock()
defer s.muUDPMux.Unlock()
if s.udpMux == nil {
return nil, fmt.Errorf("ICEBind has not been initialized yet")
}
return s.udpMux, nil
}
func (b *ICEBind) SetEndpoint(peerAddress *net.UDPAddr, conn net.Conn) (*net.UDPAddr, error) {
fakeUDPAddr, err := fakeAddress(peerAddress)
if err != nil {
return nil, err
}
// force IPv4
fakeAddr, ok := netip.AddrFromSlice(fakeUDPAddr.IP.To4())
if !ok {
return nil, fmt.Errorf("failed to convert IP to netip.Addr")
}
b.endpointsMu.Lock()
b.endpoints[fakeAddr] = conn
b.endpointsMu.Unlock()
return fakeUDPAddr, nil
}
func (b *ICEBind) RemoveEndpoint(fakeUDPAddr *net.UDPAddr) {
fakeAddr, ok := netip.AddrFromSlice(fakeUDPAddr.IP.To4())
if !ok {
log.Warnf("failed to convert IP to netip.Addr")
return
}
b.endpointsMu.Lock()
defer b.endpointsMu.Unlock()
delete(b.endpoints, fakeAddr)
}
func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
b.endpointsMu.Lock()
conn, ok := b.endpoints[ep.DstIP()]
b.endpointsMu.Unlock()
if !ok {
return b.StdNetBind.Send(bufs, ep)
}
for _, buf := range bufs {
if _, err := conn.Write(buf); err != nil {
return err
}
}
return nil
}
func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgsPool *sync.Pool) wgConn.ReceiveFunc {
s.muUDPMux.Lock()
defer s.muUDPMux.Unlock()
s.udpMux = NewUniversalUDPMuxDefault(
UniversalUDPMuxParams{
UDPConn: conn,
Net: s.transportNet,
FilterFn: s.filterFn,
},
)
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
msgs := getMessages(msgsPool)
for i := range bufs {
(*msgs)[i].Buffers[0] = bufs[i]
(*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
}
defer putMessages(msgs, msgsPool)
var numMsgs int
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
if rxOffload {
readAt := len(*msgs) - (wgConn.IdealBatchSize / wgConn.UdpSegmentMaxDatagrams)
//nolint
numMsgs, err = pc.ReadBatch((*msgs)[readAt:], 0)
if err != nil {
return 0, err
}
numMsgs, err = wgConn.SplitCoalescedMessages(*msgs, readAt, wgConn.GetGSOSize)
if err != nil {
return 0, err
}
} else {
numMsgs, err = pc.ReadBatch(*msgs, 0)
if err != nil {
return 0, err
}
}
} else {
msg := &(*msgs)[0]
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
if err != nil {
return 0, err
}
numMsgs = 1
}
for i := 0; i < numMsgs; i++ {
msg := &(*msgs)[i]
// todo: handle err
ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr)
if ok {
continue
}
sizes[i] = msg.N
if sizes[i] == 0 {
continue
}
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
ep := &wgConn.StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
wgConn.GetSrcFromControl(msg.OOB[:msg.NN], ep)
eps[i] = ep
}
return numMsgs, nil
}
}
func (s *ICEBind) filterOutStunMessages(buffers [][]byte, n int, addr net.Addr) (bool, error) {
for i := range buffers {
if !stun.IsMessage(buffers[i]) {
continue
}
msg, err := s.parseSTUNMessage(buffers[i][:n])
if err != nil {
buffers[i] = []byte{}
return true, err
}
muxErr := s.udpMux.HandleSTUNMessage(msg, addr)
if muxErr != nil {
log.Warnf("failed to handle STUN packet")
}
buffers[i] = []byte{}
return true, nil
}
return false, nil
}
func (s *ICEBind) parseSTUNMessage(raw []byte) (*stun.Message, error) {
msg := &stun.Message{
Raw: raw,
}
if err := msg.Decode(); err != nil {
return nil, err
}
return msg, nil
}
// receiveRelayed is a receive function that is used to receive packets from the relayed connection and forward to the
// WireGuard. Critical part is do not block if the Closed() has been called.
func (c *ICEBind) receiveRelayed(buffs [][]byte, sizes []int, eps []wgConn.Endpoint) (int, error) {
c.closedChanMu.RLock()
defer c.closedChanMu.RUnlock()
select {
case <-c.closedChan:
return 0, net.ErrClosed
case msg, ok := <-c.RecvChan:
if !ok {
return 0, net.ErrClosed
}
copy(buffs[0], msg.Buffer)
sizes[0] = len(msg.Buffer)
eps[0] = wgConn.Endpoint(msg.Endpoint)
return 1, nil
}
}
// fakeAddress returns a fake address that is used to as an identifier for the peer.
// The fake address is in the format of 127.1.x.x where x.x is the last two octets of the peer address.
func fakeAddress(peerAddress *net.UDPAddr) (*net.UDPAddr, error) {
octets := strings.Split(peerAddress.IP.String(), ".")
if len(octets) != 4 {
return nil, fmt.Errorf("invalid IP format")
}
newAddr := &net.UDPAddr{
IP: net.ParseIP(fmt.Sprintf("127.1.%s.%s", octets[2], octets[3])),
Port: peerAddress.Port,
}
return newAddr, nil
}
func getMessages(msgsPool *sync.Pool) *[]ipv6.Message {
return msgsPool.Get().(*[]ipv6.Message)
}
func putMessages(msgs *[]ipv6.Message, msgsPool *sync.Pool) {
for i := range *msgs {
(*msgs)[i].OOB = (*msgs)[i].OOB[:0]
(*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB}
}
msgsPool.Put(msgs)
}

View File

@@ -1,5 +0,0 @@
package configurer
import "errors"
var ErrPeerNotFound = errors.New("peer not found")

View File

@@ -1,9 +0,0 @@
package configurer
import "time"
type WGStats struct {
LastHandshake time.Time
TxBytes int64
RxBytes int64
}

View File

@@ -1,18 +0,0 @@
//go:build !android
package iface
import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
)
type WGTunDevice interface {
Create() (device.WGConfigurer, error)
Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(address WGAddress) error
WgAddress() WGAddress
DeviceName() string
Close() error
FilteredDevice() *device.FilteredDevice
}

View File

@@ -1,119 +0,0 @@
//go:build !android
// +build !android
package device
import (
"fmt"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/netstack"
)
type TunNetstackDevice struct {
name string
address WGAddress
port int
key string
mtu int
listenAddress string
iceBind *bind.ICEBind
device *device.Device
filteredDevice *FilteredDevice
nsTun *netstack.NetStackTun
udpMux *bind.UniversalUDPMuxDefault
configurer WGConfigurer
}
func NewNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice {
return &TunNetstackDevice{
name: name,
address: address,
port: wgPort,
key: key,
mtu: mtu,
listenAddress: listenAddress,
iceBind: iceBind,
}
}
func (t *TunNetstackDevice) Create() (WGConfigurer, error) {
log.Info("create netstack tun interface")
t.nsTun = netstack.NewNetStackTun(t.listenAddress, t.address.IP.String(), t.mtu)
tunIface, err := t.nsTun.Create()
if err != nil {
return nil, fmt.Errorf("error creating tun device: %s", err)
}
t.filteredDevice = newDeviceFilter(tunIface)
t.device = device.NewDevice(
t.filteredDevice,
t.iceBind,
device.NewLogger(wgLogLevel(), "[netbird] "),
)
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
_ = tunIface.Close()
return nil, fmt.Errorf("error configuring interface: %s", err)
}
log.Debugf("device has been created: %s", t.name)
return t.configurer, nil
}
func (t *TunNetstackDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
if t.device == nil {
return nil, fmt.Errorf("device is not ready yet")
}
err := t.device.Up()
if err != nil {
return nil, err
}
udpMux, err := t.iceBind.GetICEMux()
if err != nil {
return nil, err
}
t.udpMux = udpMux
log.Debugf("netstack device is ready to use")
return udpMux, nil
}
func (t *TunNetstackDevice) UpdateAddr(WGAddress) error {
return nil
}
func (t *TunNetstackDevice) Close() error {
if t.configurer != nil {
t.configurer.Close()
}
if t.device != nil {
t.device.Close()
}
if t.udpMux != nil {
return t.udpMux.Close()
}
return nil
}
func (t *TunNetstackDevice) WgAddress() WGAddress {
return t.address
}
func (t *TunNetstackDevice) DeviceName() string {
return t.name
}
func (t *TunNetstackDevice) FilteredDevice() *FilteredDevice {
return t.filteredDevice
}

View File

@@ -1,20 +0,0 @@
package device
import (
"net"
"time"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/configurer"
)
type WGConfigurer interface {
ConfigureInterface(privateKey string, port int) error
UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemovePeer(peerKey string) error
AddAllowedIP(peerKey string, allowedIP string) error
RemoveAllowedIP(peerKey string, allowedIP string) error
Close()
GetStats(peerKey string) (configurer.WGStats, error)
}

View File

@@ -1,15 +0,0 @@
package device
import (
"os"
"golang.zx2c4.com/wireguard/device"
)
func wgLogLevel() int {
if os.Getenv("NB_WG_DEBUG") == "true" {
return device.LogLevelVerbose
} else {
return device.LogLevelSilent
}
}

View File

@@ -1,4 +0,0 @@
package device
// CustomWindowsGUIDString is a custom GUID string for the interface
var CustomWindowsGUIDString string

View File

@@ -1,16 +0,0 @@
package iface
import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
)
type WGTunDevice interface {
Create(routes []string, dns string, searchDomains []string) (device.WGConfigurer, error)
Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(address WGAddress) error
WgAddress() WGAddress
DeviceName() string
Close() error
FilteredDevice() *device.FilteredDevice
}

View File

@@ -1,24 +0,0 @@
package iface
import (
"fmt"
)
// CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up.
// Will reuse an existing one.
func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains []string) error {
w.mu.Lock()
defer w.mu.Unlock()
cfgr, err := w.tun.Create(routes, dns, searchDomains)
if err != nil {
return err
}
w.configurer = cfgr
return nil
}
// Create this function make sense on mobile only
func (w *WGIface) Create() error {
return fmt.Errorf("this function has not implemented on this platform")
}

View File

@@ -1,41 +0,0 @@
//go:build !ios
package iface
import (
"fmt"
"time"
"github.com/cenkalti/backoff/v4"
)
// Create creates a new Wireguard interface, sets a given IP and brings it up.
// Will reuse an existing one.
// this function is different on Android
func (w *WGIface) Create() error {
w.mu.Lock()
defer w.mu.Unlock()
backOff := &backoff.ExponentialBackOff{
InitialInterval: 20 * time.Millisecond,
MaxElapsedTime: 500 * time.Millisecond,
Stop: backoff.Stop,
Clock: backoff.SystemClock,
}
operation := func() error {
cfgr, err := w.tun.Create()
if err != nil {
return err
}
w.configurer = cfgr
return nil
}
return backoff.Retry(operation, backOff)
}
// CreateOnAndroid this function make sense on mobile only
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
return fmt.Errorf("this function has not implemented on this platform")
}

View File

@@ -1,17 +0,0 @@
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
package iface
import (
"fmt"
"os/exec"
)
func (w *WGIface) Destroy() error {
out, err := exec.Command("ifconfig", w.Name(), "destroy").CombinedOutput()
if err != nil {
return fmt.Errorf("failed to remove interface %s: %w - %s", w.Name(), err, out)
}
return nil
}

View File

@@ -1,22 +0,0 @@
//go:build linux && !android
package iface
import (
"fmt"
"github.com/vishvananda/netlink"
)
func (w *WGIface) Destroy() error {
link, err := netlink.LinkByName(w.Name())
if err != nil {
return fmt.Errorf("failed to get link by name %s: %w", w.Name(), err)
}
if err := netlink.LinkDel(link); err != nil {
return fmt.Errorf("failed to delete link %s: %w", w.Name(), err)
}
return nil
}

View File

@@ -1,9 +0,0 @@
//go:build android || (ios && !darwin)
package iface
import "errors"
func (w *WGIface) Destroy() error {
return errors.New("not supported on mobile")
}

View File

@@ -1,32 +0,0 @@
//go:build windows
package iface
import (
"fmt"
"os/exec"
log "github.com/sirupsen/logrus"
)
func (w *WGIface) Destroy() error {
netshCmd := GetSystem32Command("netsh")
out, err := exec.Command(netshCmd, "interface", "set", "interface", w.Name(), "admin=disable").CombinedOutput()
if err != nil {
return fmt.Errorf("failed to remove interface %s: %w - %s", w.Name(), err, out)
}
return nil
}
// GetSystem32Command checks if a command can be found in the system path and returns it. In case it can't find it
// in the path it will return the full path of a command assuming C:\windows\system32 as the base path.
func GetSystem32Command(command string) string {
_, err := exec.LookPath(command)
if err == nil {
return command
}
log.Tracef("Command %s not found in PATH, using C:\\windows\\system32\\%s.exe path", command, command)
return "C:\\windows\\system32\\" + command + ".exe"
}

View File

@@ -1,10 +0,0 @@
package iface
import (
"github.com/netbirdio/netbird/client/iface/device"
)
// GetInterfaceGUIDString returns an interface GUID. This is useful on Windows only
func (w *WGIface) GetInterfaceGUIDString() (string, error) {
return w.tun.(*device.TunDevice).GetInterfaceGUIDString()
}

View File

@@ -1,112 +0,0 @@
package iface
import (
"net"
"time"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
type MockWGIface struct {
CreateFunc func() error
CreateOnAndroidFunc func(routeRange []string, ip string, domains []string) error
IsUserspaceBindFunc func() bool
NameFunc func() string
AddressFunc func() device.WGAddress
ToInterfaceFunc func() *net.Interface
UpFunc func() (*bind.UniversalUDPMuxDefault, error)
UpdateAddrFunc func(newAddr string) error
UpdatePeerFunc func(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemovePeerFunc func(peerKey string) error
AddAllowedIPFunc func(peerKey string, allowedIP string) error
RemoveAllowedIPFunc func(peerKey string, allowedIP string) error
CloseFunc func() error
SetFilterFunc func(filter device.PacketFilter) error
GetFilterFunc func() device.PacketFilter
GetDeviceFunc func() *device.FilteredDevice
GetStatsFunc func(peerKey string) (configurer.WGStats, error)
GetInterfaceGUIDStringFunc func() (string, error)
GetProxyFunc func() wgproxy.Proxy
}
func (m *MockWGIface) GetInterfaceGUIDString() (string, error) {
return m.GetInterfaceGUIDStringFunc()
}
func (m *MockWGIface) Create() error {
return m.CreateFunc()
}
func (m *MockWGIface) CreateOnAndroid(routeRange []string, ip string, domains []string) error {
return m.CreateOnAndroidFunc(routeRange, ip, domains)
}
func (m *MockWGIface) IsUserspaceBind() bool {
return m.IsUserspaceBindFunc()
}
func (m *MockWGIface) Name() string {
return m.NameFunc()
}
func (m *MockWGIface) Address() device.WGAddress {
return m.AddressFunc()
}
func (m *MockWGIface) ToInterface() *net.Interface {
return m.ToInterfaceFunc()
}
func (m *MockWGIface) Up() (*bind.UniversalUDPMuxDefault, error) {
return m.UpFunc()
}
func (m *MockWGIface) UpdateAddr(newAddr string) error {
return m.UpdateAddrFunc(newAddr)
}
func (m *MockWGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
return m.UpdatePeerFunc(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
}
func (m *MockWGIface) RemovePeer(peerKey string) error {
return m.RemovePeerFunc(peerKey)
}
func (m *MockWGIface) AddAllowedIP(peerKey string, allowedIP string) error {
return m.AddAllowedIPFunc(peerKey, allowedIP)
}
func (m *MockWGIface) RemoveAllowedIP(peerKey string, allowedIP string) error {
return m.RemoveAllowedIPFunc(peerKey, allowedIP)
}
func (m *MockWGIface) Close() error {
return m.CloseFunc()
}
func (m *MockWGIface) SetFilter(filter device.PacketFilter) error {
return m.SetFilterFunc(filter)
}
func (m *MockWGIface) GetFilter() device.PacketFilter {
return m.GetFilterFunc()
}
func (m *MockWGIface) GetDevice() *device.FilteredDevice {
return m.GetDeviceFunc()
}
func (m *MockWGIface) GetStats(peerKey string) (configurer.WGStats, error) {
return m.GetStatsFunc(peerKey)
}
func (m *MockWGIface) GetProxy() wgproxy.Proxy {
//TODO implement me
panic("implement me")
}

View File

@@ -1,24 +0,0 @@
package iface
import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
wgIFace := &WGIface{
userspaceBind: true,
tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter),
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
}
return wgIFace, nil
}

View File

@@ -1,34 +0,0 @@
//go:build !ios
package iface
import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
var tun WGTunDevice
if netstack.IsEnabled() {
tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
} else {
tun = device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
}
wgIFace := &WGIface{
userspaceBind: true,
tun: tun,
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
}
return wgIFace, nil
}

View File

@@ -1,26 +0,0 @@
//go:build ios
package iface
import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
wgIFace := &WGIface{
tun: device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, iceBind, opts.MobileArgs.TunFd),
userspaceBind: true,
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
}
return wgIFace, nil
}

Some files were not shown because too many files have changed in this diff Show More