Compare commits

..

55 Commits

Author SHA1 Message Date
pascal-fischer
0371f529ca Add sonar badge (#2381) 2024-08-02 18:48:12 +02:00
pascal-fischer
501fd93e47 Fix DNS resolution for routes on iOS (#2378) 2024-08-02 18:43:00 +02:00
Misha Bragin
727a4f0753 Remove Codacy badge as it is broken (#2379) 2024-08-02 18:20:13 +02:00
Maycon Santos
e6f7222034 Fix Windows file version (#2380)
Systems that validates the binary version didn't like the build number as we set

This fixes the versioning and will use a static build number
2024-08-02 18:07:57 +02:00
Maycon Santos
bfc33a3f6f Move Bundle to before netbird down (#2377)
This allows to get interface and route information added by the agent
2024-08-02 14:54:37 +02:00
Viktor Liu
5ad4ae769a Extend client debug bundle (#2341)
Adds readme (with --anonymize)
Fixes archive file timestamps
Adds routes info
Adds interfaces
Adds client config
2024-08-02 11:47:12 +02:00
David Fry
f84b606506 add extra auth audience (#2350) 2024-08-01 18:52:50 +02:00
keacwu
216d9f2ee8 Adding geolocation download log message. (#2085)
* Adding geolocation download prompt message.

* import log file and remove unnecessary else

---------

Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>
2024-08-01 18:52:38 +02:00
Viktor Liu
57624203c9 Allow route updates even if some domains failed resolution (#2368) 2024-08-01 18:38:19 +02:00
Jakub Kołodziejczak
24e031ab74 Fix syslog output containing duplicated timestamps (#2292)
```console
journalctl
```
```diff
- Jul 19 14:41:01 rpi /usr/bin/netbird[614]: 2024-07-19T14:41:01+02:00 ERRO %!s(<nil>): error while handling message of Peer [key: REDACTED] error: [wrongly addressed message REDACTED]
- Jul 19 21:53:03 rpi /usr/bin/netbird[614]: 2024-07-19T21:53:03+02:00 WARN %!s(<nil>): disconnected from the Signal service but will retry silently. Reason: rpc error: code = Internal desc = server closed the stream without sending trailers
- Jul 19 21:53:04 rpi /usr/bin/netbird[614]: 2024-07-19T21:53:04+02:00 INFO %!s(<nil>): connected to the Signal Service stream
- Jul 19 22:24:10 rpi /usr/bin/netbird[614]: 2024-07-19T22:24:10+02:00 WARN [error: read udp 192.168.1.11:48398->9.9.9.9:53: i/o timeout, upstream: 9.9.9.9:53] %!s(<nil>): got an error while connecting to upstream
+ Jul 19 14:41:01 rpi /usr/bin/netbird[614]: error while handling message of Peer [key: REDACTED] error: [wrongly addressed message REDACTED]
+ Jul 19 21:53:03 rpi /usr/bin/netbird[614]: disconnected from the Signal service but will retry silently. Reason: rpc error: code = Internal desc = server closed the stream without sending trailers
+ Jul 19 21:53:04 rpi /usr/bin/netbird[614]: connected to the Signal Service stream
+ Jul 19 22:24:10 rpi /usr/bin/netbird[614]: [error: read udp 192.168.1.11:48398->9.9.9.9:53: i/o timeout, upstream: 9.9.9.9:53] got an error while connecting to upstream
```

please notice that although log level is no longer present in the syslog
message it is still respected by syslog logger, so the log levels are
not lost:
```console
journalctl -p 3
```
```diff
- Jul 19 14:41:01 rpi /usr/bin/netbird[614]: 2024-07-19T14:41:01+02:00 ERRO %!s(<nil>): error while handling message of Peer [key: REDACTED] error: [wrongly addressed message REDACTED]
+ Jul 19 14:41:01 rpi /usr/bin/netbird[614]: error while handling message of Peer [key: REDACTED] error: [wrongly addressed message REDACTED]
```
2024-08-01 18:22:02 +02:00
dependabot[bot]
df8b8db068 Bump github.com/docker/docker (#2356)
Bumps [github.com/docker/docker](https://github.com/docker/docker) from 26.1.3+incompatible to 26.1.4+incompatible.
- [Release notes](https://github.com/docker/docker/releases)
- [Commits](https://github.com/docker/docker/compare/v26.1.3...v26.1.4)

---
updated-dependencies:
- dependency-name: github.com/docker/docker
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-08-01 17:20:15 +02:00
Zoltan Papp
3506ac4234 When creating new setup key, "revoked" field doesn't do anything (#2357)
Remove unused field from API
2024-08-01 17:13:58 +02:00
Zoltan Papp
0c8f8a62c7 Handling invalid UTF-8 character in sys info (#2360)
In some operation systems, the sys info contains invalid characters.
In this patch try to keep the original fallback logic but filter out the cases when the character is invalid.
2024-08-01 16:46:55 +02:00
Maycon Santos
cbf9f2058e Use accountID retrieved from the sync call to acquire read lock sooner (#2369)
Use accountID retrieved from the sync call to acquire read lock sooner and avoiding extra DB calls.
- Use the account ID across sync calls
- Moved account read lock
- Renamed CancelPeerRoutines to OnPeerDisconnected
- Added race tests
2024-08-01 16:21:43 +02:00
Evgenii
02f3105e48 Freebsd test all root component (#2361)
* chore(tests): add all root component into FreeBSD check

* change timeout for each component

* add client tests execution measure

* revert -p1 for client tests and explain why

* measure duration of all test run
2024-08-01 11:56:18 +02:00
Maycon Santos
5ee9c77e90 Move write peer lock (#2364)
Moved the write peer lock to avoid latency caused by disk access

Updated the method CancelPeerRoutines to use the peer public key
2024-07-31 21:51:45 +02:00
Bethuel Mmbaga
c832cef44c Update SaveUsers and SaveGroups to SaveAccount (#2362)
Changed SaveUsers and SaveGroups method calls to SaveAccount for consistency in data persistence operations.
2024-07-31 19:48:12 +03:00
Maycon Santos
165988429c Add write lock for peer when saving its connection status (#2359) 2024-07-31 14:53:32 +02:00
Evgenii
9d2047a08a Fix freebsd tests (#2346) 2024-07-31 09:58:04 +02:00
Maycon Santos
da39c8bbca Refactor login with store.SavePeer (#2334)
This pull request refactors the login functionality by integrating store.SavePeer. The changes aim to improve the handling of peer login processes, particularly focusing on synchronization and error handling.

Changes:
- Refactored login logic to use store.SavePeer.
- Added checks for login without lock for login necessary checks from the client and utilized write lock for full login flow.
- Updated error handling with status.NewPeerLoginExpiredError().
- Moved geoIP check logic to a more appropriate place.
- Removed redundant calls and improved documentation.
- Moved the code to smaller methods to improve readability.
2024-07-29 13:30:27 +02:00
Bethuel Mmbaga
7321046cd6 Remove redundant check for empty JWT groups (#2323)
* Remove redundant check for empty group names in SetJWTGroups

* add test
2024-07-26 16:33:54 +02:00
Maycon Santos
ea3205643a Save daemon address on service install (#2328) 2024-07-26 16:33:20 +02:00
Zoltan Papp
1a15b0f900 Fix race issue in set listener (#2332) 2024-07-26 16:27:51 +02:00
Maycon Santos
1f48fdf6ca Add SavePeer method to prevent a possible account inconsistency (#2296)
SyncPeer was storing the account with a simple read lock

This change introduces the SavePeer method to the store to be used in these cases
2024-07-26 07:49:05 +02:00
Maycon Santos
45fd1e9c21 add save peer status test for connected peers (#2321) 2024-07-25 16:22:04 +02:00
Zoltan Papp
63aeeb834d Fix error handling (#2316) 2024-07-24 13:27:01 +02:00
Maycon Santos
268e801ec5 Ignore network monitor checks for software interfaces (#2302)
ignore checks for Teredo and ISATAP interfaces
2024-07-22 19:44:15 +02:00
Maycon Santos
788f130941 Retry management connection only on context canceled (#2301) 2024-07-22 15:49:25 +02:00
Maycon Santos
926e11b086 Remove default allow for UDP on unmatched packet (#2300)
This fixes an issue where UDP rules were ineffective for userspace clients (Windows/macOS)
2024-07-22 15:35:17 +02:00
Carlos Hernandez
0a8c78deb1 Minor fix local dns search domain (#2287) 2024-07-19 16:44:12 +02:00
Maycon Santos
c815ad86fd Fix macOS DNS unclean shutdown restore call on startup (#2286)
previously, we called the restore method from the startup when there was an unclean shutdown. But it never had the state keys to clean since they are stored in memory

this change addresses the issue by falling back to default values when restoring the host's DNS
2024-07-18 18:06:09 +02:00
Carlos Hernandez
ef1a39cb01 Refactor macOS system DNS configuration (#2284)
On macOS use the recommended settings for providing split DNS. As per
the docs an empty string will force the configuration to be the default.
In order to to support split DNS an additional service config is added
for the local server and search domain settings.

see: https://developer.apple.com/documentation/devicemanagement/vpn/dns
2024-07-18 16:39:41 +02:00
Maycon Santos
c900fa81bb Remove copy functions from signal (#2285)
remove migration function for wiretrustee directories to netbird
2024-07-18 12:15:14 +02:00
Maycon Santos
9a6de52dd0 Check if route interface is a Microsoft ISATAP device (#2282)
check if the nexthop interfaces are Microsoft ISATAP devices and ignore their suffixes when comparing them
2024-07-17 23:49:09 +02:00
Maycon Santos
19147f518e Add faster availability DNS probe and update test domain to .com (#2280)
* Add faster availability DNS probe and update test domain to .com

- Count success queries and compare it before doing after network map probes.

- Reduce the first dns probe to 500ms

- Updated test domain with com instead of . due to Palo alto DNS proxy server issues

* use fqdn

* Update client/internal/dns/upstream.go

Co-authored-by: Viktor Liu <17948409+lixmal@users.noreply.github.com>

---------

Co-authored-by: Viktor Liu <17948409+lixmal@users.noreply.github.com>
2024-07-17 23:48:37 +02:00
Viktor Liu
e78ec2e985 Don't add exclusion routes for IPs that are part of connected networks (#2258)
This prevents arp/ndp issues on macOS leading to unreachability of that IP.
2024-07-17 19:50:06 +02:00
pascal-fischer
95d725f2c1 Wait on daemon down (#2279) 2024-07-17 16:26:06 +02:00
benniekiss
4fad0e521f Support custom SSL certificates for the signal service (#2257) 2024-07-16 20:44:21 +02:00
ctrl-zzz
a711e116a3 fix: save peer status correctly in sqlstore (#2262)
* fix: save peer status correctly in sqlstore

https://github.com/netbirdio/netbird/issues/2110#issuecomment-2162768273

* feat: update test function

* refactor: simplify status update
2024-07-16 18:38:12 +03:00
Maycon Santos
668d229b67 Fix metric label typo (#2278) 2024-07-16 16:55:57 +02:00
Maycon Santos
7c595e8493 Add get_registration_delay_milliseconds metric (#2275) 2024-07-16 15:36:51 +02:00
Jakub Kołodziejczak
f9c59a7131 Refactor log util (#2276) 2024-07-16 11:50:35 +02:00
Jakub Kołodziejczak
1d6f5482dd feat(client): send logs to syslog (#2259) 2024-07-16 10:19:58 +02:00
Carlos Hernandez
12ff93ba72 Ignore no unique route updates (#2266) 2024-07-16 10:19:01 +02:00
Maycon Santos
88d1c5a0fd fix forwarded metrics (#2273) 2024-07-16 10:14:30 +02:00
Bethuel Mmbaga
1537b0f5e7 Add batch save/update for groups and users (#2245)
* Add functionality to update multiple users

* Remove SaveUsers from DefaultAccountManager

* Add SaveGroups method to AccountManager interface

* Refactoring

* Add SaveUsers and SaveGroups methods to store interface

* Refactor method SaveAccount to SaveUsers and SaveGroups

The method SaveAccount in user.go and group.go files was split into two separate methods. Now, user-specific data is handled by SaveUsers and group-specific data is handled by SaveGroups method. This provides a cleaner and more efficient way to save user and group data.

* Add account ID to user and group in SqlStore

* Refactor SaveUsers and SaveGroups in store

* Remove unnecessary ID assignment in SaveUsers and SaveGroups
2024-07-15 17:04:06 +03:00
Maycon Santos
2577100096 Limit GUI process execution to one per UID (#2267)
replaces PID with checking process name and path and UID checks
2024-07-15 14:53:52 +02:00
Zoltan Papp
bc09348f5a Add logging option for wg device (#2271) 2024-07-15 14:45:18 +02:00
Edouard Vanbelle
d5ba2ef6ec fix 2260: fallback serial to Board (#2263) 2024-07-15 14:43:50 +02:00
pascal-fischer
47752e1573 Support DNS routes on iOS (#2254) 2024-07-15 10:40:57 +02:00
Maycon Santos
58fbc1249c Fix parameter limit issue for Postgres store (#2261)
Added CreateBatchSize for both SQL stores and updated tests to test large accounts with Postgres, too. Increased the account peer size to 6K.
2024-07-12 09:28:53 +02:00
dependabot[bot]
1cc341a268 Bump google.golang.org/grpc from 1.64.0 to 1.64.1 (#2248)
Bumps [google.golang.org/grpc](https://github.com/grpc/grpc-go) from 1.64.0 to 1.64.1.
- [Release notes](https://github.com/grpc/grpc-go/releases)
- [Commits](https://github.com/grpc/grpc-go/compare/v1.64.0...v1.64.1)

---
updated-dependencies:
- dependency-name: google.golang.org/grpc
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-07-12 08:52:27 +02:00
Viktor Liu
89df6e7242 Get client ui locale on windows natively (#2251) 2024-07-12 08:25:33 +02:00
Maycon Santos
f74646a3ac Add release version to windows binaries and update sign pipeline version (#2256) 2024-07-11 19:06:55 +02:00
pascal-fischer
e8c2fafccd Avoid empty domain overwrite (#2252) 2024-07-10 14:08:35 +02:00
106 changed files with 2975 additions and 1147 deletions

8
.editorconfig Normal file
View File

@@ -0,0 +1,8 @@
root = true
[*]
end_of_line = lf
insert_final_newline = true
[*.go]
indent_style = tab

View File

@@ -13,7 +13,7 @@ concurrency:
jobs:
test:
runs-on: ubuntu-latest
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v4
- name: Test in FreeBSD
@@ -21,19 +21,26 @@ jobs:
uses: vmactions/freebsd-vm@v1
with:
usesh: true
copyback: false
release: "14.1"
prepare: |
pkg install -y curl
pkg install -y git
pkg install -y go
# -x - to print all executed commands
# -e - to faile on first error
run: |
set -x
curl -o go.tar.gz https://go.dev/dl/go1.21.11.freebsd-amd64.tar.gz -L
tar zxf go.tar.gz
mv go /usr/local/go
ln -s /usr/local/go/bin/go /usr/local/bin/go
go mod tidy
go test -timeout 5m -p 1 ./iface/...
go test -timeout 5m -p 1 ./client/...
cd client
go build .
cd ..
set -e -x
time go build -o netbird client/main.go
# check all component except management, since we do not support management server on freebsd
time go test -timeout 1m -failfast ./base62/...
# NOTE: without -p1 `client/internal/dns` will fail becasue of `listen udp4 :33100: bind: address already in use`
time go test -timeout 8m -failfast -p 1 ./client/...
time go test -timeout 1m -failfast ./dns/...
time go test -timeout 1m -failfast ./encryption/...
time go test -timeout 1m -failfast ./formatter/...
time go test -timeout 1m -failfast ./iface/...
time go test -timeout 1m -failfast ./route/...
time go test -timeout 1m -failfast ./sharedsock/...
time go test -timeout 1m -failfast ./signal/...
time go test -timeout 1m -failfast ./util/...
time go test -timeout 1m -failfast ./version/...

View File

@@ -10,8 +10,10 @@ on:
env:
SIGN_PIPE_VER: "v0.0.11"
SIGN_PIPE_VER: "v0.0.12"
GORELEASER_VER: "v1.14.1"
PRODUCT_NAME: "NetBird"
COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)"
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
@@ -23,6 +25,13 @@ jobs:
env:
flags: ""
steps:
- name: Parse semver string
id: semver_parser
uses: booxmedialtd/ws-action-parse-semver@v1
with:
input_string: ${{ (startsWith(github.ref, 'refs/tags/v') && github.ref) || 'refs/tags/v0.0.0' }}
version_extractor_regex: '\/v(.*)$'
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV
-
@@ -68,18 +77,18 @@ jobs:
- name: Install OS build dependencies
run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu
- name: Install rsrc
run: go install github.com/akavel/rsrc@v0.10.2
- name: Generate windows rsrc amd64
run: rsrc -arch amd64 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_amd64.syso
- name: Generate windows rsrc arm64
run: rsrc -arch arm64 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_arm64.syso
- name: Generate windows rsrc arm
run: rsrc -arch arm -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_arm.syso
- name: Generate windows rsrc 386
run: rsrc -arch 386 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_386.syso
-
name: Run GoReleaser
- name: Install goversioninfo
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
- name: Generate windows syso 386
run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_386.syso
- name: Generate windows syso arm
run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm.syso
- name: Generate windows syso arm64
run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso
- name: Generate windows syso amd64
run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_amd64.syso
- name: Run GoReleaser
uses: goreleaser/goreleaser-action@v4
with:
version: ${{ env.GORELEASER_VER }}
@@ -121,6 +130,13 @@ jobs:
release_ui:
runs-on: ubuntu-latest
steps:
- name: Parse semver string
id: semver_parser
uses: booxmedialtd/ws-action-parse-semver@v1
with:
input_string: ${{ (startsWith(github.ref, 'refs/tags/v') && github.ref) || 'refs/tags/v0.0.0' }}
version_extractor_regex: '\/v(.*)$'
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV
- name: Checkout
@@ -151,10 +167,11 @@ jobs:
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libappindicator3-dev gir1.2-appindicator3-0.1 libxxf86vm-dev gcc-mingw-w64-x86-64
- name: Install rsrc
run: go install github.com/akavel/rsrc@v0.10.2
- name: Generate windows rsrc
run: rsrc -arch amd64 -ico client/ui/netbird.ico -manifest client/ui/manifest.xml -o client/ui/resources_windows_amd64.syso
- name: Install goversioninfo
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
- name: Generate windows syso amd64
run: goversioninfo -64 -icon client/ui/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_amd64.syso
- name: Run GoReleaser
uses: goreleaser/goreleaser-action@v4
with:

View File

@@ -10,10 +10,12 @@
<img width="234" src="docs/media/logo-full.png"/>
</p>
<p>
<a href="https://img.shields.io/badge/license-BSD--3-blue)">
<img src="https://sonarcloud.io/api/project_badges/measure?project=netbirdio_netbird&metric=alert_status" />
</a>
<a href="https://github.com/netbirdio/netbird/blob/main/LICENSE">
<img src="https://img.shields.io/badge/license-BSD--3-blue" />
</a>
<a href="https://www.codacy.com/gh/netbirdio/netbird/dashboard?utm_source=github.com&amp;utm_medium=referral&amp;utm_content=netbirdio/netbird&amp;utm_campaign=Badge_Grade"><img src="https://app.codacy.com/project/badge/Grade/e3013d046aec44cdb7462c8673b00976"/></a>
<br>
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-vrahf41g-ik1v7fV8du6t0RwxSrJ96A">
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>

View File

@@ -178,6 +178,21 @@ func (a *Anonymizer) AnonymizeDNSLogLine(logEntry string) string {
})
}
// AnonymizeRoute anonymizes a route string by replacing IP addresses with anonymized versions and
// domain names with random strings.
func (a *Anonymizer) AnonymizeRoute(route string) string {
prefix, err := netip.ParsePrefix(route)
if err == nil {
ip := a.AnonymizeIPString(prefix.Addr().String())
return fmt.Sprintf("%s/%d", ip, prefix.Bits())
}
domains := strings.Split(route, ", ")
for i, domain := range domains {
domains[i] = a.AnonymizeDomain(domain)
}
return strings.Join(domains, ", ")
}
func isWellKnown(addr netip.Addr) bool {
wellKnown := []string{
"8.8.8.8", "8.8.4.4", // Google DNS IPv4

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"time"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
@@ -13,6 +14,8 @@ import (
"github.com/netbirdio/netbird/client/server"
)
const errCloseConnection = "Failed to close connection: %v"
var debugCmd = &cobra.Command{
Use: "debug",
Short: "Debugging commands",
@@ -63,12 +66,17 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
if err != nil {
return err
}
defer conn.Close()
defer func() {
if err := conn.Close(); err != nil {
log.Errorf(errCloseConnection, err)
}
}()
client := proto.NewDaemonServiceClient(conn)
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
Anonymize: anonymizeFlag,
Status: getStatusOutput(cmd),
Anonymize: anonymizeFlag,
Status: getStatusOutput(cmd),
SystemInfo: debugSystemInfoFlag,
})
if err != nil {
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
@@ -84,7 +92,11 @@ func setLogLevel(cmd *cobra.Command, args []string) error {
if err != nil {
return err
}
defer conn.Close()
defer func() {
if err := conn.Close(); err != nil {
log.Errorf(errCloseConnection, err)
}
}()
client := proto.NewDaemonServiceClient(conn)
level := server.ParseLogLevel(args[0])
@@ -113,7 +125,11 @@ func runForDuration(cmd *cobra.Command, args []string) error {
if err != nil {
return err
}
defer conn.Close()
defer func() {
if err := conn.Close(); err != nil {
log.Errorf(errCloseConnection, err)
}
}()
client := proto.NewDaemonServiceClient(conn)
@@ -122,17 +138,20 @@ func runForDuration(cmd *cobra.Command, args []string) error {
return fmt.Errorf("failed to get status: %v", status.Convert(err).Message())
}
restoreUp := stat.Status == string(internal.StatusConnected) || stat.Status == string(internal.StatusConnecting)
stateWasDown := stat.Status != string(internal.StatusConnected) && stat.Status != string(internal.StatusConnecting)
initialLogLevel, err := client.GetLogLevel(cmd.Context(), &proto.GetLogLevelRequest{})
if err != nil {
return fmt.Errorf("failed to get log level: %v", status.Convert(err).Message())
}
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
if stateWasDown {
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
}
cmd.Println("Netbird up")
time.Sleep(time.Second * 10)
}
cmd.Println("Netbird down")
initialLevelTrace := initialLogLevel.GetLevel() >= proto.LogLevel_TRACE
if !initialLevelTrace {
@@ -145,6 +164,11 @@ func runForDuration(cmd *cobra.Command, args []string) error {
cmd.Println("Log level set to trace.")
}
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
}
cmd.Println("Netbird down")
time.Sleep(1 * time.Second)
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
@@ -162,21 +186,25 @@ func runForDuration(cmd *cobra.Command, args []string) error {
}
cmd.Println("\nDuration completed")
cmd.Println("Creating debug bundle...")
headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd))
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
Anonymize: anonymizeFlag,
Status: statusOutput,
SystemInfo: debugSystemInfoFlag,
})
if err != nil {
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
}
cmd.Println("Netbird down")
time.Sleep(1 * time.Second)
if restoreUp {
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
if stateWasDown {
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
}
cmd.Println("Netbird up")
cmd.Println("Netbird down")
}
if !initialLevelTrace {
@@ -186,16 +214,6 @@ func runForDuration(cmd *cobra.Command, args []string) error {
cmd.Println("Log level restored to", initialLogLevel.GetLevel())
}
cmd.Println("Creating debug bundle...")
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
Anonymize: anonymizeFlag,
Status: statusOutput,
})
if err != nil {
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
}
cmd.Println(resp.GetPath())
return nil

View File

@@ -26,7 +26,7 @@ var downCmd = &cobra.Command{
return err
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second*3)
ctx, cancel := context.WithTimeout(context.Background(), time.Second*7)
defer cancel()
conn, err := DialClientGRPCServer(ctx, daemonAddr)

View File

@@ -37,6 +37,7 @@ const (
serverSSHAllowedFlag = "allow-server-ssh"
extraIFaceBlackListFlag = "extra-iface-blacklist"
dnsRouteIntervalFlag = "dns-router-interval"
systemInfoFlag = "system-info"
)
var (
@@ -69,6 +70,7 @@ var (
autoConnectDisabled bool
extraIFaceBlackList []string
anonymizeFlag bool
debugSystemInfoFlag bool
dnsRouteInterval time.Duration
rootCmd = &cobra.Command{
@@ -121,7 +123,7 @@ func init() {
rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name")
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", defaultConfigPath, "Netbird config file location")
rootCmd.PersistentFlags().StringVarP(&logLevel, "log-level", "l", "info", "sets Netbird log level")
rootCmd.PersistentFlags().StringVar(&logFile, "log-file", defaultLogFile, "sets Netbird log path. If console is specified the log will be output to stdout")
rootCmd.PersistentFlags().StringVar(&logFile, "log-file", defaultLogFile, "sets Netbird log path. If console is specified the log will be output to stdout. If syslog is specified the log will be sent to syslog daemon.")
rootCmd.PersistentFlags().StringVarP(&setupKey, "setup-key", "k", "", "Setup key obtained from the Management Service Dashboard (used to register peer)")
rootCmd.PersistentFlags().StringVar(&preSharedKey, preSharedKeyFlag, "", "Sets Wireguard PreSharedKey property. If set, then only peers that have the same key can communicate.")
rootCmd.PersistentFlags().StringVarP(&hostName, "hostname", "n", "", "Sets a custom hostname for the device")
@@ -165,6 +167,8 @@ func init() {
upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.")
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted")
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", false, "Adds system information to the debug bundle")
}
// SetupCloseHandler handles SIGTERM signal and exits with success

View File

@@ -31,6 +31,8 @@ var installCmd = &cobra.Command{
configPath,
"--log-level",
logLevel,
"--daemon-addr",
daemonAddr,
}
if managementURL != "" {

View File

@@ -807,7 +807,7 @@ func anonymizePeerDetail(a *anonymize.Anonymizer, peer *peerStateDetailOutput) {
}
for i, route := range peer.Routes {
peer.Routes[i] = anonymizeRoute(a, route)
peer.Routes[i] = a.AnonymizeRoute(route)
}
}
@@ -843,21 +843,8 @@ func anonymizeOverview(a *anonymize.Anonymizer, overview *statusOutputOverview)
}
for i, route := range overview.Routes {
overview.Routes[i] = anonymizeRoute(a, route)
overview.Routes[i] = a.AnonymizeRoute(route)
}
overview.FQDN = a.AnonymizeDomain(overview.FQDN)
}
func anonymizeRoute(a *anonymize.Anonymizer, route string) string {
prefix, err := netip.ParsePrefix(route)
if err == nil {
ip := a.AnonymizeIPString(prefix.Addr().String())
return fmt.Sprintf("%s/%d", ip, prefix.Bits())
}
domains := strings.Split(route, ", ")
for i, domain := range domains {
domains[i] = a.AnonymizeDomain(domain)
}
return strings.Join(domains, ", ")
}

View File

@@ -337,7 +337,6 @@ func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decode
if rule.dPort != 0 && rule.dPort == uint16(d.udp.DstPort) {
return rule.drop, true
}
return rule.drop, true
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
return rule.drop, true
}

View File

@@ -15,6 +15,12 @@ type hostManager interface {
restoreUncleanShutdownDNS(storedDNSAddress *netip.Addr) error
}
type SystemDNSSettings struct {
Domains []string
ServerIP string
ServerPort int
}
type HostDNSConfig struct {
Domains []DomainConfig `json:"domains"`
RouteAll bool `json:"routeAll"`

View File

@@ -7,6 +7,7 @@ import (
"bytes"
"fmt"
"io"
"net"
"net/netip"
"os/exec"
"strconv"
@@ -18,7 +19,7 @@ import (
const (
netbirdDNSStateKeyFormat = "State:/Network/Service/NetBird-%s/DNS"
globalIPv4State = "State:/Network/Global/IPv4"
primaryServiceSetupKeyFormat = "Setup:/Network/Service/%s/DNS"
primaryServiceStateKeyFormat = "State:/Network/Service/%s/DNS"
keySupplementalMatchDomains = "SupplementalMatchDomains"
keySupplementalMatchDomainsNoSearch = "SupplementalMatchDomainsNoSearch"
keyServerAddresses = "ServerAddresses"
@@ -28,12 +29,12 @@ const (
scutilPath = "/usr/sbin/scutil"
searchSuffix = "Search"
matchSuffix = "Match"
localSuffix = "Local"
)
type systemConfigurator struct {
// primaryServiceID primary interface in the system. AKA the interface with the default route
primaryServiceID string
createdKeys map[string]struct{}
createdKeys map[string]struct{}
systemDNSSettings SystemDNSSettings
}
func newHostManager() (hostManager, error) {
@@ -49,20 +50,6 @@ func (s *systemConfigurator) supportCustomPort() bool {
func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig) error {
var err error
if config.RouteAll {
err = s.addDNSSetupForAll(config.ServerIP, config.ServerPort)
if err != nil {
return fmt.Errorf("add dns setup for all: %w", err)
}
} else if s.primaryServiceID != "" {
err = s.removeKeyFromSystemConfig(getKeyWithInput(primaryServiceSetupKeyFormat, s.primaryServiceID))
if err != nil {
return fmt.Errorf("remote key from system config: %w", err)
}
s.primaryServiceID = ""
log.Infof("removed %s:%d as main DNS resolver for this peer", config.ServerIP, config.ServerPort)
}
// create a file for unclean shutdown detection
if err := createUncleanShutdownIndicator(); err != nil {
log.Errorf("failed to create unclean shutdown file: %s", err)
@@ -73,6 +60,19 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig) error {
matchDomains []string
)
err = s.recordSystemDNSSettings(true)
if err != nil {
log.Errorf("unable to update record of System's DNS config: %s", err.Error())
}
if config.RouteAll {
searchDomains = append(searchDomains, "\"\"")
err = s.addLocalDNS()
if err != nil {
log.Infof("failed to enable split DNS")
}
}
for _, dConf := range config.Domains {
if dConf.Disabled {
continue
@@ -110,23 +110,17 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig) error {
}
func (s *systemConfigurator) restoreHostDNS() error {
lines := ""
for key := range s.createdKeys {
lines += buildRemoveKeyOperation(key)
keys := s.getRemovableKeysWithDefaults()
for _, key := range keys {
keyType := "search"
if strings.Contains(key, matchSuffix) {
keyType = "match"
}
log.Infof("removing %s domains from system", keyType)
}
if s.primaryServiceID != "" {
lines += buildRemoveKeyOperation(getKeyWithInput(primaryServiceSetupKeyFormat, s.primaryServiceID))
log.Infof("restoring DNS resolver configuration for system")
}
_, err := runSystemConfigCommand(wrapCommand(lines))
if err != nil {
log.Errorf("got an error while cleaning the system configuration: %s", err)
return fmt.Errorf("clean system: %w", err)
err := s.removeKeyFromSystemConfig(key)
if err != nil {
log.Errorf("failed to remove %s domains from system: %s", keyType, err)
}
}
if err := removeUncleanShutdownIndicator(); err != nil {
@@ -136,6 +130,19 @@ func (s *systemConfigurator) restoreHostDNS() error {
return nil
}
func (s *systemConfigurator) getRemovableKeysWithDefaults() []string {
if len(s.createdKeys) == 0 {
// return defaults for startup calls
return []string{getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix), getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)}
}
keys := make([]string, 0, len(s.createdKeys))
for key := range s.createdKeys {
keys = append(keys, key)
}
return keys
}
func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error {
line := buildRemoveKeyOperation(key)
_, err := runSystemConfigCommand(wrapCommand(line))
@@ -148,6 +155,97 @@ func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error {
return nil
}
func (s *systemConfigurator) addLocalDNS() error {
if s.systemDNSSettings.ServerIP == "" || len(s.systemDNSSettings.Domains) == 0 {
err := s.recordSystemDNSSettings(true)
log.Errorf("Unable to get system DNS configuration")
return err
}
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
if s.systemDNSSettings.ServerIP != "" && len(s.systemDNSSettings.Domains) != 0 {
err := s.addSearchDomains(localKey, strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort)
if err != nil {
return fmt.Errorf("couldn't add local network DNS conf: %w", err)
}
} else {
log.Info("Not enabling local DNS server")
}
return nil
}
func (s *systemConfigurator) recordSystemDNSSettings(force bool) error {
if s.systemDNSSettings.ServerIP != "" && len(s.systemDNSSettings.Domains) != 0 && !force {
return nil
}
systemDNSSettings, err := s.getSystemDNSSettings()
if err != nil {
return fmt.Errorf("couldn't get current DNS config: %w", err)
}
s.systemDNSSettings = systemDNSSettings
return nil
}
func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
primaryServiceKey, _, err := s.getPrimaryService()
if err != nil || primaryServiceKey == "" {
return SystemDNSSettings{}, fmt.Errorf("couldn't find the primary service key: %w", err)
}
dnsServiceKey := getKeyWithInput(primaryServiceStateKeyFormat, primaryServiceKey)
line := buildCommandLine("show", dnsServiceKey, "")
stdinCommands := wrapCommand(line)
b, err := runSystemConfigCommand(stdinCommands)
if err != nil {
return SystemDNSSettings{}, fmt.Errorf("sending the command: %w", err)
}
var dnsSettings SystemDNSSettings
inSearchDomainsArray := false
inServerAddressesArray := false
scanner := bufio.NewScanner(bytes.NewReader(b))
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
switch {
case strings.HasPrefix(line, "DomainName :"):
domainName := strings.TrimSpace(strings.Split(line, ":")[1])
dnsSettings.Domains = append(dnsSettings.Domains, domainName)
case line == "SearchDomains : <array> {":
inSearchDomainsArray = true
continue
case line == "ServerAddresses : <array> {":
inServerAddressesArray = true
continue
case line == "}":
inSearchDomainsArray = false
inServerAddressesArray = false
}
if inSearchDomainsArray {
searchDomain := strings.Split(line, " : ")[1]
dnsSettings.Domains = append(dnsSettings.Domains, searchDomain)
} else if inServerAddressesArray {
address := strings.Split(line, " : ")[1]
if ip := net.ParseIP(address); ip != nil && ip.To4() != nil {
dnsSettings.ServerIP = address
inServerAddressesArray = false // Stop reading after finding the first IPv4 address
}
}
}
if err := scanner.Err(); err != nil {
return dnsSettings, err
}
// default to 53 port
dnsSettings.ServerPort = 53
return dnsSettings, nil
}
func (s *systemConfigurator) addSearchDomains(key, domains string, ip string, port int) error {
err := s.addDNSState(key, domains, ip, port, true)
if err != nil {
@@ -194,23 +292,6 @@ func (s *systemConfigurator) addDNSState(state, domains, dnsServer string, port
return nil
}
func (s *systemConfigurator) addDNSSetupForAll(dnsServer string, port int) error {
primaryServiceKey, existingNameserver, err := s.getPrimaryService()
if err != nil || primaryServiceKey == "" {
return fmt.Errorf("couldn't find the primary service key: %w", err)
}
err = s.addDNSSetup(getKeyWithInput(primaryServiceSetupKeyFormat, primaryServiceKey), dnsServer, port, existingNameserver)
if err != nil {
return fmt.Errorf("add dns setup: %w", err)
}
log.Infof("configured %s:%d as main DNS resolver for this peer", dnsServer, port)
s.primaryServiceID = primaryServiceKey
return nil
}
func (s *systemConfigurator) getPrimaryService() (string, string, error) {
line := buildCommandLine("show", globalIPv4State, "")
stdinCommands := wrapCommand(line)
@@ -239,19 +320,6 @@ func (s *systemConfigurator) getPrimaryService() (string, string, error) {
return primaryService, router, nil
}
func (s *systemConfigurator) addDNSSetup(setupKey, dnsServer string, port int, existingDNSServer string) error {
lines := buildAddCommandLine(keySupplementalMatchDomainsNoSearch, digitSymbol+strconv.Itoa(0))
lines += buildAddCommandLine(keyServerAddresses, arraySymbol+dnsServer+" "+existingDNSServer)
lines += buildAddCommandLine(keyServerPort, digitSymbol+strconv.Itoa(port))
addDomainCommand := buildCreateStateWithOperation(setupKey, lines)
stdinCommands := wrapCommand(addDomainCommand)
_, err := runSystemConfigCommand(stdinCommands)
if err != nil {
return fmt.Errorf("applying dns setup, error: %w", err)
}
return nil
}
func (s *systemConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error {
if err := s.restoreHostDNS(); err != nil {
return fmt.Errorf("restoring dns via scutil: %w", err)

View File

@@ -94,7 +94,7 @@ func NewDefaultServer(
var dnsService service
if wgInterface.IsUserspaceBind() {
dnsService = newServiceViaMemory(wgInterface)
dnsService = NewServiceViaMemory(wgInterface)
} else {
dnsService = newServiceViaListener(wgInterface, addrPort)
}
@@ -112,7 +112,7 @@ func NewDefaultServerPermanentUpstream(
statusRecorder *peer.Status,
) *DefaultServer {
log.Debugf("host dns address list is: %v", hostsDnsList)
ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface), statusRecorder)
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder)
ds.hostsDNSHolder.set(hostsDnsList)
ds.permanent = true
ds.addHostRootZone()
@@ -130,7 +130,7 @@ func NewDefaultServerIos(
iosDnsManager IosDnsManager,
statusRecorder *peer.Status,
) *DefaultServer {
ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface), statusRecorder)
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder)
ds.iosDnsManager = iosDnsManager
return ds
}

View File

@@ -534,7 +534,7 @@ func TestDNSServerStartStop(t *testing.T) {
func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
hostManager := &mockHostConfigurator{}
server := DefaultServer{
service: newServiceViaMemory(&mocWGIface{}),
service: NewServiceViaMemory(&mocWGIface{}),
localResolver: &localResolver{
registeredMap: make(registrationMap),
},

View File

@@ -128,6 +128,9 @@ func (s *serviceViaListener) RuntimeIP() string {
}
func (s *serviceViaListener) setListenerStatus(running bool) {
s.listenerFlagLock.Lock()
defer s.listenerFlagLock.Unlock()
s.listenerIsRunning = running
}

View File

@@ -12,7 +12,7 @@ import (
log "github.com/sirupsen/logrus"
)
type serviceViaMemory struct {
type ServiceViaMemory struct {
wgInterface WGIface
dnsMux *dns.ServeMux
runtimeIP string
@@ -22,8 +22,8 @@ type serviceViaMemory struct {
listenerFlagLock sync.Mutex
}
func newServiceViaMemory(wgIface WGIface) *serviceViaMemory {
s := &serviceViaMemory{
func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory {
s := &ServiceViaMemory{
wgInterface: wgIface,
dnsMux: dns.NewServeMux(),
@@ -33,7 +33,7 @@ func newServiceViaMemory(wgIface WGIface) *serviceViaMemory {
return s
}
func (s *serviceViaMemory) Listen() error {
func (s *ServiceViaMemory) Listen() error {
s.listenerFlagLock.Lock()
defer s.listenerFlagLock.Unlock()
@@ -52,7 +52,7 @@ func (s *serviceViaMemory) Listen() error {
return nil
}
func (s *serviceViaMemory) Stop() {
func (s *ServiceViaMemory) Stop() {
s.listenerFlagLock.Lock()
defer s.listenerFlagLock.Unlock()
@@ -67,23 +67,23 @@ func (s *serviceViaMemory) Stop() {
s.listenerIsRunning = false
}
func (s *serviceViaMemory) RegisterMux(pattern string, handler dns.Handler) {
func (s *ServiceViaMemory) RegisterMux(pattern string, handler dns.Handler) {
s.dnsMux.Handle(pattern, handler)
}
func (s *serviceViaMemory) DeregisterMux(pattern string) {
func (s *ServiceViaMemory) DeregisterMux(pattern string) {
s.dnsMux.HandleRemove(pattern)
}
func (s *serviceViaMemory) RuntimePort() int {
func (s *ServiceViaMemory) RuntimePort() int {
return s.runtimePort
}
func (s *serviceViaMemory) RuntimeIP() string {
func (s *ServiceViaMemory) RuntimeIP() string {
return s.runtimeIP
}
func (s *serviceViaMemory) filterDNSTraffic() (string, error) {
func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
filter := s.wgInterface.GetFilter()
if filter == nil {
return "", fmt.Errorf("can't set DNS filter, filter not initialized")

View File

@@ -24,7 +24,7 @@ const (
probeTimeout = 2 * time.Second
)
const testRecord = "."
const testRecord = "com."
type upstreamClient interface {
exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
@@ -42,6 +42,7 @@ type upstreamResolverBase struct {
upstreamServers []string
disabled bool
failsCount atomic.Int32
successCount atomic.Int32
failsTillDeact int32
mutex sync.Mutex
reactivatePeriod time.Duration
@@ -124,6 +125,7 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return
}
u.successCount.Add(1)
log.Tracef("took %s to query the upstream %s", t, upstream)
err = w.WriteMsg(rm)
@@ -172,6 +174,11 @@ func (u *upstreamResolverBase) probeAvailability() {
default:
}
// avoid probe if upstreams could resolve at least one query and fails count is less than failsTillDeact
if u.successCount.Load() > 0 && u.failsCount.Load() < u.failsTillDeact {
return
}
var success bool
var mu sync.Mutex
var wg sync.WaitGroup
@@ -183,7 +190,7 @@ func (u *upstreamResolverBase) probeAvailability() {
wg.Add(1)
go func() {
defer wg.Done()
err := u.testNameserver(upstream)
err := u.testNameserver(upstream, 500*time.Millisecond)
if err != nil {
errors = multierror.Append(errors, err)
log.Warnf("probing upstream nameserver %s: %s", upstream, err)
@@ -224,7 +231,7 @@ func (u *upstreamResolverBase) waitUntilResponse() {
}
for _, upstream := range u.upstreamServers {
if err := u.testNameserver(upstream); err != nil {
if err := u.testNameserver(upstream, probeTimeout); err != nil {
log.Tracef("upstream check for %s: %s", upstream, err)
} else {
// at least one upstream server is available, stop probing
@@ -244,6 +251,7 @@ func (u *upstreamResolverBase) waitUntilResponse() {
log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServers)
u.failsCount.Store(0)
u.successCount.Add(1)
u.reactivate()
u.disabled = false
}
@@ -265,13 +273,14 @@ func (u *upstreamResolverBase) disable(err error) {
}
log.Warnf("Upstream resolving is Disabled for %v", reactivatePeriod)
u.successCount.Store(0)
u.deactivate(err)
u.disabled = true
go u.waitUntilResponse()
}
func (u *upstreamResolverBase) testNameserver(server string) error {
ctx, cancel := context.WithTimeout(u.ctx, probeTimeout)
func (u *upstreamResolverBase) testNameserver(server string, timeout time.Duration) error {
ctx, cancel := context.WithTimeout(u.ctx, timeout)
defer cancel()
r := new(dns.Msg).SetQuestion(testRecord, dns.TypeSOA)

View File

@@ -4,6 +4,7 @@ package dns
import (
"context"
"fmt"
"net"
"syscall"
"time"
@@ -17,9 +18,9 @@ import (
type upstreamResolverIOS struct {
*upstreamResolverBase
lIP net.IP
lNet *net.IPNet
iIndex int
lIP net.IP
lNet *net.IPNet
interfaceName string
}
func newUpstreamResolver(
@@ -32,17 +33,11 @@ func newUpstreamResolver(
) (*upstreamResolverIOS, error) {
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder)
index, err := getInterfaceIndex(interfaceName)
if err != nil {
log.Debugf("unable to get interface index for %s: %s", interfaceName, err)
return nil, err
}
ios := &upstreamResolverIOS{
upstreamResolverBase: upstreamResolverBase,
lIP: ip,
lNet: net,
iIndex: index,
interfaceName: interfaceName,
}
ios.upstreamClient = ios
@@ -53,7 +48,7 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
client := &dns.Client{}
upstreamHost, _, err := net.SplitHostPort(upstream)
if err != nil {
log.Errorf("error while parsing upstream host: %s", err)
return nil, 0, fmt.Errorf("error while parsing upstream host: %s", err)
}
timeout := upstreamTimeout
@@ -65,26 +60,35 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
upstreamIP := net.ParseIP(upstreamHost)
if u.lNet.Contains(upstreamIP) || net.IP.IsPrivate(upstreamIP) {
log.Debugf("using private client to query upstream: %s", upstream)
client = u.getClientPrivate(timeout)
client, err = GetClientPrivate(u.lIP, u.interfaceName, timeout)
if err != nil {
return nil, 0, fmt.Errorf("error while creating private client: %s", err)
}
}
// Cannot use client.ExchangeContext because it overwrites our Dialer
return client.Exchange(r, upstream)
}
// getClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface
// GetClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface
// This method is needed for iOS
func (u *upstreamResolverIOS) getClientPrivate(dialTimeout time.Duration) *dns.Client {
func GetClientPrivate(ip net.IP, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) {
index, err := getInterfaceIndex(interfaceName)
if err != nil {
log.Debugf("unable to get interface index for %s: %s", interfaceName, err)
return nil, err
}
dialer := &net.Dialer{
LocalAddr: &net.UDPAddr{
IP: u.lIP,
IP: ip,
Port: 0, // Let the OS pick a free port
},
Timeout: dialTimeout,
Control: func(network, address string, c syscall.RawConn) error {
var operr error
fn := func(s uintptr) {
operr = unix.SetsockoptInt(int(s), unix.IPPROTO_IP, unix.IP_BOUND_IF, u.iIndex)
operr = unix.SetsockoptInt(int(s), unix.IPPROTO_IP, unix.IP_BOUND_IF, index)
}
if err := c.Control(fn); err != nil {
@@ -101,7 +105,7 @@ func (u *upstreamResolverIOS) getClientPrivate(dialTimeout time.Duration) *dns.C
client := &dns.Client{
Dialer: dialer,
}
return client
return client, nil
}
func getInterfaceIndex(interfaceName string) (int, error) {

View File

@@ -266,8 +266,23 @@ func (e *Engine) Stop() error {
e.close()
e.wgConnWorker.Wait()
log.Infof("stopped Netbird Engine")
return nil
maxWaitTime := 5 * time.Second
timeout := time.After(maxWaitTime)
for {
if !e.IsWGIfaceUp() {
log.Infof("stopped Netbird Engine")
return nil
}
select {
case <-timeout:
return fmt.Errorf("timeout when waiting for interface shutdown")
default:
time.Sleep(100 * time.Millisecond)
}
}
}
// Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services
@@ -1533,3 +1548,20 @@ func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
return slices.Equal(checks.Files, oChecks.Files)
})
}
func (e *Engine) IsWGIfaceUp() bool {
if e == nil || e.wgInterface == nil {
return false
}
iface, err := net.InterfaceByName(e.wgInterface.Name())
if err != nil {
log.Debugf("failed to get interface by name %s: %v", e.wgInterface.Name(), err)
return false
}
if iface.Flags&net.FlagUp != 0 {
return true
}
return false
}

View File

@@ -4,6 +4,7 @@ package networkmonitor
import (
"context"
"errors"
"fmt"
"syscall"
"unsafe"
@@ -21,11 +22,20 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
return fmt.Errorf("failed to open routing socket: %v", err)
}
defer func() {
if err := unix.Close(fd); err != nil {
err := unix.Close(fd)
if err != nil && !errors.Is(err, unix.EBADF) {
log.Errorf("Network monitor: failed to close routing socket: %v", err)
}
}()
go func() {
<-ctx.Done()
err := unix.Close(fd)
if err != nil && !errors.Is(err, unix.EBADF) {
log.Debugf("Network monitor: closed routing socket")
}
}()
for {
select {
case <-ctx.Done():
@@ -34,7 +44,9 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
buf := make([]byte, 2048)
n, err := unix.Read(fd, buf)
if err != nil {
log.Errorf("Network monitor: failed to read from routing socket: %v", err)
if !errors.Is(err, unix.EBADF) && !errors.Is(err, unix.EINVAL) {
log.Errorf("Network monitor: failed to read from routing socket: %v", err)
}
continue
}
if n < unix.SizeofRtMsghdr {

View File

@@ -99,6 +99,11 @@ func routeChanged(nexthop systemops.Nexthop, intf *net.Interface, routes []syste
return false
}
if isSoftInterface(nexthop.Intf.Name) {
log.Tracef("network monitor: ignoring default route change for soft interface %s", nexthop.Intf.Name)
return false
}
unspec := getUnspecifiedPrefix(nexthop.IP)
defaultRoutes, foundMatchingRoute := processRoutes(nexthop, intf, routes, unspec)
@@ -119,7 +124,7 @@ func getUnspecifiedPrefix(ip netip.Addr) netip.Prefix {
return netip.PrefixFrom(netip.IPv4Unspecified(), 0)
}
func processRoutes(nexthop systemops.Nexthop, intf *net.Interface, routes []systemops.Route, unspec netip.Prefix) ([]string, bool) {
func processRoutes(nexthop systemops.Nexthop, nexthopIntf *net.Interface, routes []systemops.Route, unspec netip.Prefix) ([]string, bool) {
var defaultRoutes []string
foundMatchingRoute := false
@@ -128,7 +133,7 @@ func processRoutes(nexthop systemops.Nexthop, intf *net.Interface, routes []syst
routeInfo := formatRouteInfo(r)
defaultRoutes = append(defaultRoutes, routeInfo)
if r.Nexthop == nexthop.IP && compareIntf(r.Interface, intf) == 0 {
if r.Nexthop == nexthop.IP && compareIntf(r.Interface, nexthopIntf) == 0 {
foundMatchingRoute = true
log.Debugf("network monitor: found matching default route: %s", routeInfo)
}
@@ -232,14 +237,18 @@ func stateFromInt(state uint8) string {
}
func compareIntf(a, b *net.Interface) int {
if a == nil && b == nil {
switch {
case a == nil && b == nil:
return 0
}
if a == nil {
case a == nil:
return -1
}
if b == nil {
case b == nil:
return 1
default:
return a.Index - b.Index
}
return a.Index - b.Index
}
func isSoftInterface(name string) bool {
return strings.Contains(strings.ToLower(name), "isatap") || strings.Contains(strings.ToLower(name), "teredo")
}

View File

@@ -3,12 +3,14 @@ package routemanager
import (
"context"
"fmt"
"reflect"
"time"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
@@ -64,7 +66,7 @@ func newClientNetworkWatcher(ctx context.Context, dnsRouteInterval time.Duration
routePeersNotifiers: make(map[string]chan struct{}),
routeUpdate: make(chan routesUpdate),
peerStateUpdate: make(chan struct{}),
handler: handlerFromRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouteInterval, statusRecorder),
handler: handlerFromRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouteInterval, statusRecorder, wgInterface),
}
return client
}
@@ -309,22 +311,33 @@ func (c *clientNetwork) sendUpdateToClientNetworkWatcher(update routesUpdate) {
}()
}
func (c *clientNetwork) handleUpdate(update routesUpdate) {
func (c *clientNetwork) handleUpdate(update routesUpdate) bool {
isUpdateMapDifferent := false
updateMap := make(map[route.ID]*route.Route)
for _, r := range update.routes {
updateMap[r.ID] = r
}
if len(c.routes) != len(updateMap) {
isUpdateMapDifferent = true
}
for id, r := range c.routes {
_, found := updateMap[id]
if !found {
close(c.routePeersNotifiers[r.Peer])
delete(c.routePeersNotifiers, r.Peer)
isUpdateMapDifferent = true
continue
}
if !reflect.DeepEqual(c.routes[id], updateMap[id]) {
isUpdateMapDifferent = true
}
}
c.routes = updateMap
return isUpdateMapDifferent
}
// peersStateAndUpdateWatcher is the main point of reacting on client network routing events.
@@ -351,13 +364,19 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
log.Debugf("Received a new client network route update for [%v]", c.handler)
c.handleUpdate(update)
// hash update somehow
isTrueRouteUpdate := c.handleUpdate(update)
c.updateSerial = update.updateSerial
err := c.recalculateRouteAndUpdatePeerAndSystem()
if err != nil {
log.Errorf("Failed to recalculate routes for network [%v]: %v", c.handler, err)
if isTrueRouteUpdate {
log.Debug("Client network update contains different routes, recalculating routes")
err := c.recalculateRouteAndUpdatePeerAndSystem()
if err != nil {
log.Errorf("Failed to recalculate routes for network [%v]: %v", c.handler, err)
}
} else {
log.Debug("Route update is not different, skipping route recalculation")
}
c.startPeersStatusChangeWatcher()
@@ -365,9 +384,10 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
}
}
func handlerFromRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, dnsRouterInteval time.Duration, statusRecorder *peer.Status) RouteHandler {
func handlerFromRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, dnsRouterInteval time.Duration, statusRecorder *peer.Status, wgInterface *iface.WGIface) RouteHandler {
if rt.IsDynamic() {
return dynamic.NewRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouterInteval, statusRecorder)
dns := nbdns.NewServiceViaMemory(wgInterface)
return dynamic.NewRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouterInteval, statusRecorder, wgInterface, fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort()))
}
return static.NewRoute(rt, routeRefCounter, allowedIPsRefCounter)
}

View File

@@ -16,6 +16,7 @@ import (
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/util"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/route"
)
@@ -47,6 +48,8 @@ type Route struct {
currentPeerKey string
cancel context.CancelFunc
statusRecorder *peer.Status
wgInterface *iface.WGIface
resolverAddr string
}
func NewRoute(
@@ -55,6 +58,8 @@ func NewRoute(
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
interval time.Duration,
statusRecorder *peer.Status,
wgInterface *iface.WGIface,
resolverAddr string,
) *Route {
return &Route{
route: rt,
@@ -63,6 +68,8 @@ func NewRoute(
interval: interval,
dynamicDomains: domainMap{},
statusRecorder: statusRecorder,
wgInterface: wgInterface,
resolverAddr: resolverAddr,
}
}
@@ -189,9 +196,14 @@ func (r *Route) startResolver(ctx context.Context) {
}
func (r *Route) update(ctx context.Context) error {
if resolved, err := r.resolveDomains(); err != nil {
return fmt.Errorf("resolve domains: %w", err)
} else if err := r.updateDynamicRoutes(ctx, resolved); err != nil {
resolved, err := r.resolveDomains()
if err != nil {
if len(resolved) == 0 {
return fmt.Errorf("resolve domains: %w", err)
}
log.Warnf("Failed to resolve domains: %v", err)
}
if err := r.updateDynamicRoutes(ctx, resolved); err != nil {
return fmt.Errorf("update dynamic routes: %w", err)
}
@@ -223,11 +235,17 @@ func (r *Route) resolve(results chan resolveResult) {
wg.Add(1)
go func(domain domain.Domain) {
defer wg.Done()
ips, err := net.LookupIP(string(domain))
ips, err := r.getIPsFromResolver(domain)
if err != nil {
results <- resolveResult{domain: domain, err: fmt.Errorf("resolve d %s: %w", domain.SafeString(), err)}
return
log.Tracef("Failed to resolve domain %s with private resolver: %v", domain.SafeString(), err)
ips, err = net.LookupIP(string(domain))
if err != nil {
results <- resolveResult{domain: domain, err: fmt.Errorf("resolve d %s: %w", domain.SafeString(), err)}
return
}
}
for _, ip := range ips {
prefix, err := util.GetPrefixFromIP(ip)
if err != nil {

View File

@@ -0,0 +1,13 @@
//go:build !ios
package dynamic
import (
"net"
"github.com/netbirdio/netbird/management/domain"
)
func (r *Route) getIPsFromResolver(domain domain.Domain) ([]net.IP, error) {
return net.LookupIP(string(domain))
}

View File

@@ -0,0 +1,55 @@
//go:build ios
package dynamic
import (
"fmt"
"net"
"time"
"github.com/miekg/dns"
nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/management/domain"
)
const dialTimeout = 10 * time.Second
func (r *Route) getIPsFromResolver(domain domain.Domain) ([]net.IP, error) {
privateClient, err := nbdns.GetClientPrivate(r.wgInterface.Address().IP, r.wgInterface.Name(), dialTimeout)
if err != nil {
return nil, fmt.Errorf("error while creating private client: %s", err)
}
msg := new(dns.Msg)
msg.SetQuestion(dns.Fqdn(string(domain)), dns.TypeA)
startTime := time.Now()
response, _, err := privateClient.Exchange(msg, r.resolverAddr)
if err != nil {
return nil, fmt.Errorf("DNS query for %s failed after %s: %s ", domain.SafeString(), time.Since(startTime), err)
}
if response.Rcode != dns.RcodeSuccess {
return nil, fmt.Errorf("dns response code: %s", dns.RcodeToString[response.Rcode])
}
ips := make([]net.IP, 0)
for _, answ := range response.Answer {
if aRecord, ok := answ.(*dns.A); ok {
ips = append(ips, aRecord.A)
}
if aaaaRecord, ok := answ.(*dns.AAAA); ok {
ips = append(ips, aaaaRecord.AAAA)
}
}
if len(ips) == 0 {
return nil, fmt.Errorf("no A or AAAA records found for %s", domain.SafeString())
}
return ips, nil
}

View File

@@ -16,6 +16,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
@@ -50,7 +51,7 @@ type DefaultManager struct {
statusRecorder *peer.Status
wgInterface *iface.WGIface
pubKey string
notifier *notifier
notifier *notifier.Notifier
routeRefCounter *refcounter.RouteRefCounter
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter
dnsRouteInterval time.Duration
@@ -65,7 +66,8 @@ func NewManager(
initialRoutes []*route.Route,
) *DefaultManager {
mCTX, cancel := context.WithCancel(ctx)
sysOps := systemops.NewSysOps(wgInterface)
notifier := notifier.NewNotifier()
sysOps := systemops.NewSysOps(wgInterface, notifier)
dm := &DefaultManager{
ctx: mCTX,
@@ -77,7 +79,7 @@ func NewManager(
statusRecorder: statusRecorder,
wgInterface: wgInterface,
pubKey: pubKey,
notifier: newNotifier(),
notifier: notifier,
}
dm.routeRefCounter = refcounter.New(
@@ -107,7 +109,7 @@ func NewManager(
if runtime.GOOS == "android" {
cr := dm.clientRoutes(initialRoutes)
dm.notifier.setInitialClientRoutes(cr)
dm.notifier.SetInitialClientRoutes(cr)
}
return dm
}
@@ -186,7 +188,7 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro
filteredClientRoutes := m.routeSelector.FilterSelected(newClientRoutesIDMap)
m.updateClientNetworks(updateSerial, filteredClientRoutes)
m.notifier.onNewRoutes(filteredClientRoutes)
m.notifier.OnNewRoutes(filteredClientRoutes)
if m.serverRouter != nil {
err := m.serverRouter.updateRoutes(newServerRoutesMap)
@@ -199,14 +201,14 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro
}
}
// SetRouteChangeListener set RouteListener for route change notifier
// SetRouteChangeListener set RouteListener for route change Notifier
func (m *DefaultManager) SetRouteChangeListener(listener listener.NetworkChangeListener) {
m.notifier.setListener(listener)
m.notifier.SetListener(listener)
}
// InitialRouteRange return the list of initial routes. It used by mobile systems
func (m *DefaultManager) InitialRouteRange() []string {
return m.notifier.getInitialRouteRanges()
return m.notifier.GetInitialRouteRanges()
}
// GetRouteSelector returns the route selector
@@ -226,7 +228,7 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
networks = m.routeSelector.FilterSelected(networks)
m.notifier.onNewRoutes(networks)
m.notifier.OnNewRoutes(networks)
m.stopObsoleteClients(networks)

View File

@@ -1,6 +1,7 @@
package routemanager
package notifier
import (
"net/netip"
"runtime"
"sort"
"strings"
@@ -10,7 +11,7 @@ import (
"github.com/netbirdio/netbird/route"
)
type notifier struct {
type Notifier struct {
initialRouteRanges []string
routeRanges []string
@@ -18,17 +19,17 @@ type notifier struct {
listenerMux sync.Mutex
}
func newNotifier() *notifier {
return &notifier{}
func NewNotifier() *Notifier {
return &Notifier{}
}
func (n *notifier) setListener(listener listener.NetworkChangeListener) {
func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
n.listenerMux.Lock()
defer n.listenerMux.Unlock()
n.listener = listener
}
func (n *notifier) setInitialClientRoutes(clientRoutes []*route.Route) {
func (n *Notifier) SetInitialClientRoutes(clientRoutes []*route.Route) {
nets := make([]string, 0)
for _, r := range clientRoutes {
nets = append(nets, r.Network.String())
@@ -37,7 +38,10 @@ func (n *notifier) setInitialClientRoutes(clientRoutes []*route.Route) {
n.initialRouteRanges = nets
}
func (n *notifier) onNewRoutes(idMap route.HAMap) {
func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
if runtime.GOOS != "android" {
return
}
newNets := make([]string, 0)
for _, routes := range idMap {
for _, r := range routes {
@@ -62,7 +66,30 @@ func (n *notifier) onNewRoutes(idMap route.HAMap) {
n.notify()
}
func (n *notifier) notify() {
func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
newNets := make([]string, 0)
for _, prefix := range prefixes {
newNets = append(newNets, prefix.String())
}
sort.Strings(newNets)
switch runtime.GOOS {
case "android":
if !n.hasDiff(n.initialRouteRanges, newNets) {
return
}
default:
if !n.hasDiff(n.routeRanges, newNets) {
return
}
}
n.routeRanges = newNets
n.notify()
}
func (n *Notifier) notify() {
n.listenerMux.Lock()
defer n.listenerMux.Unlock()
if n.listener == nil {
@@ -74,7 +101,7 @@ func (n *notifier) notify() {
}(n.listener)
}
func (n *notifier) hasDiff(a []string, b []string) bool {
func (n *Notifier) hasDiff(a []string, b []string) bool {
if len(a) != len(b) {
return true
}
@@ -86,7 +113,7 @@ func (n *notifier) hasDiff(a []string, b []string) bool {
return false
}
func (n *notifier) getInitialRouteRanges() []string {
func (n *Notifier) GetInitialRouteRanges() []string {
return addIPv6RangeIfNeeded(n.initialRouteRanges)
}

View File

@@ -3,7 +3,9 @@ package systemops
import (
"net"
"net/netip"
"sync"
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/iface"
)
@@ -18,10 +20,19 @@ type ExclusionCounter = refcounter.Counter[any, Nexthop]
type SysOps struct {
refCounter *ExclusionCounter
wgInterface *iface.WGIface
// prefixes is tracking all the current added prefixes im memory
// (this is used in iOS as all route updates require a full table update)
//nolint
prefixes map[netip.Prefix]struct{}
//nolint
mu sync.Mutex
// notifier is used to notify the system of route changes (also used on mobile)
notifier *notifier.Notifier
}
func NewSysOps(wgInterface *iface.WGIface) *SysOps {
func NewSysOps(wgInterface *iface.WGIface, notifier *notifier.Notifier) *SysOps {
return &SysOps{
wgInterface: wgInterface,
notifier: notifier,
}
}

View File

@@ -1,4 +1,4 @@
//go:build ios || android
//go:build android
package systemops

View File

@@ -22,7 +22,7 @@ type Route struct {
Interface *net.Interface
}
func getRoutesFromTable() ([]netip.Prefix, error) {
func GetRoutesFromTable() ([]netip.Prefix, error) {
tab, err := retryFetchRIB()
if err != nil {
return nil, fmt.Errorf("fetch RIB: %v", err)

View File

@@ -36,7 +36,7 @@ func TestConcurrentRoutes(t *testing.T) {
baseIP := netip.MustParseAddr("192.0.2.0")
intf := &net.Interface{Name: "lo0"}
r := NewSysOps(nil)
r := NewSysOps(nil, nil)
var wg sync.WaitGroup
for i := 0; i < 1024; i++ {

View File

@@ -50,7 +50,7 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP) (nbnet.AddHookFunc, nbn
nexthop, err := r.addRouteToNonVPNIntf(prefix, r.wgInterface, initialNexthop)
if errors.Is(err, vars.ErrRouteNotAllowed) || errors.Is(err, vars.ErrRouteNotFound) {
log.Tracef("Adding for prefix %s: %v", prefix, err)
// These errors are not critical but also we should not track and try to remove the routes either.
// These errors are not critical, but also we should not track and try to remove the routes either.
return nexthop, refcounter.ErrIgnore
}
return nexthop, err
@@ -135,6 +135,11 @@ func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf *iface.WGIfac
return Nexthop{}, vars.ErrRouteNotAllowed
}
// Check if the prefix is part of any local subnets
if isLocal, subnet := r.isPrefixInLocalSubnets(prefix); isLocal {
return Nexthop{}, fmt.Errorf("prefix %s is part of local subnet %s: %w", prefix, subnet, vars.ErrRouteNotAllowed)
}
// Determine the exit interface and next hop for the prefix, so we can add a specific route
nexthop, err := GetNextHop(addr)
if err != nil {
@@ -167,6 +172,36 @@ func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf *iface.WGIfac
return exitNextHop, nil
}
func (r *SysOps) isPrefixInLocalSubnets(prefix netip.Prefix) (bool, *net.IPNet) {
localInterfaces, err := net.Interfaces()
if err != nil {
log.Errorf("Failed to get local interfaces: %v", err)
return false, nil
}
for _, intf := range localInterfaces {
addrs, err := intf.Addrs()
if err != nil {
log.Errorf("Failed to get addresses for interface %s: %v", intf.Name, err)
continue
}
for _, addr := range addrs {
ipnet, ok := addr.(*net.IPNet)
if !ok {
log.Errorf("Failed to convert address to IPNet: %v", addr)
continue
}
if ipnet.Contains(prefix.Addr().AsSlice()) {
return true, ipnet
}
}
}
return false, nil
}
// genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix
// in two /1 prefixes to avoid replacing the existing default route
func (r *SysOps) genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
@@ -392,7 +427,7 @@ func ipToAddr(ip net.IP, intf *net.Interface) (netip.Addr, error) {
}
func existsInRouteTable(prefix netip.Prefix) (bool, error) {
routes, err := getRoutesFromTable()
routes, err := GetRoutesFromTable()
if err != nil {
return false, fmt.Errorf("get routes from table: %w", err)
}
@@ -405,7 +440,7 @@ func existsInRouteTable(prefix netip.Prefix) (bool, error) {
}
func isSubRange(prefix netip.Prefix) (bool, error) {
routes, err := getRoutesFromTable()
routes, err := GetRoutesFromTable()
if err != nil {
return false, fmt.Errorf("get routes from table: %w", err)
}

View File

@@ -68,7 +68,7 @@ func TestAddRemoveRoutes(t *testing.T) {
err = wgInterface.Create()
require.NoError(t, err, "should create testing wireguard interface")
r := NewSysOps(wgInterface)
r := NewSysOps(wgInterface, nil)
_, _, err = r.SetupRouting(nil)
require.NoError(t, err)
@@ -224,7 +224,7 @@ func TestAddExistAndRemoveRoute(t *testing.T) {
require.NoError(t, err, "InterfaceByName should not return err")
intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()}
r := NewSysOps(wgInterface)
r := NewSysOps(wgInterface, nil)
// Prepare the environment
if testCase.preExistingPrefix.IsValid() {
@@ -379,7 +379,7 @@ func setupTestEnv(t *testing.T) {
assert.NoError(t, wgInterface.Close())
})
r := NewSysOps(wgInterface)
r := NewSysOps(wgInterface, nil)
_, _, err := r.SetupRouting(nil)
require.NoError(t, err, "setupRouting should not return err")
t.Cleanup(func() {

View File

@@ -0,0 +1,64 @@
//go:build ios
package systemops
import (
"net"
"net/netip"
"runtime"
log "github.com/sirupsen/logrus"
nbnet "github.com/netbirdio/netbird/util/net"
)
func (r *SysOps) SetupRouting([]net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
r.mu.Lock()
defer r.mu.Unlock()
r.prefixes = make(map[netip.Prefix]struct{})
return nil, nil, nil
}
func (r *SysOps) CleanupRouting() error {
r.mu.Lock()
defer r.mu.Unlock()
r.prefixes = make(map[netip.Prefix]struct{})
r.notify()
return nil
}
func (r *SysOps) AddVPNRoute(prefix netip.Prefix, _ *net.Interface) error {
r.mu.Lock()
defer r.mu.Unlock()
r.prefixes[prefix] = struct{}{}
r.notify()
return nil
}
func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, _ *net.Interface) error {
r.mu.Lock()
defer r.mu.Unlock()
delete(r.prefixes, prefix)
r.notify()
return nil
}
func EnableIPForwarding() error {
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
return nil
}
func IsAddrRouted(netip.Addr, []netip.Prefix) (bool, netip.Prefix) {
return false, netip.Prefix{}
}
func (r *SysOps) notify() {
prefixes := make([]netip.Prefix, 0, len(r.prefixes))
for prefix := range r.prefixes {
prefixes = append(prefixes, prefix)
}
r.notifier.OnNewPrefixes(prefixes)
}

View File

@@ -206,7 +206,7 @@ func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error
return nil
}
func getRoutesFromTable() ([]netip.Prefix, error) {
func GetRoutesFromTable() ([]netip.Prefix, error) {
v4Routes, err := getRoutes(syscall.RT_TABLE_MAIN, netlink.FAMILY_V4)
if err != nil {
return nil, fmt.Errorf("get v4 routes: %w", err)
@@ -504,7 +504,7 @@ func getAddressFamily(prefix netip.Prefix) int {
func hasSeparateRouting() ([]netip.Prefix, error) {
if isLegacy() {
return getRoutesFromTable()
return GetRoutesFromTable()
}
return nil, ErrRoutingIsSeparate
}

View File

@@ -24,5 +24,5 @@ func EnableIPForwarding() error {
}
func hasSeparateRouting() ([]netip.Prefix, error) {
return getRoutesFromTable()
return GetRoutesFromTable()
}

View File

@@ -94,7 +94,7 @@ func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) erro
return nil
}
func getRoutesFromTable() ([]netip.Prefix, error) {
func GetRoutesFromTable() ([]netip.Prefix, error) {
mux.Lock()
defer mux.Unlock()

View File

@@ -73,7 +73,7 @@ var testCases = []testCase{
{
name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence
destination: "10.0.0.2:53",
expectedSourceIP: "10.0.0.1",
expectedSourceIP: "127.0.0.1",
expectedDestPrefix: "10.0.0.0/8",
expectedNextHop: "0.0.0.0",
expectedInterface: "Loopback Pseudo-Interface 1",
@@ -110,7 +110,7 @@ var testCases = []testCase{
{
name: "To more specific route (local) without custom dialer via physical interface",
destination: "127.0.10.2:53",
expectedSourceIP: "10.0.0.1",
expectedSourceIP: "127.0.0.1",
expectedDestPrefix: "127.0.0.0/8",
expectedNextHop: "0.0.0.0",
expectedInterface: "Loopback Pseudo-Interface 1",
@@ -181,31 +181,6 @@ func testRoute(t *testing.T, destination string, dialer dialer) *FindNetRouteOut
return combinedOutput
}
func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) string {
t.Helper()
ip, ipNet, err := net.ParseCIDR(ipAddressCIDR)
require.NoError(t, err)
subnetMaskSize, _ := ipNet.Mask.Size()
script := fmt.Sprintf(`New-NetIPAddress -InterfaceAlias "%s" -IPAddress "%s" -PrefixLength %d -PolicyStore ActiveStore -Confirm:$False`, interfaceName, ip.String(), subnetMaskSize)
_, err = exec.Command("powershell", "-Command", script).CombinedOutput()
require.NoError(t, err, "Failed to assign IP address to loopback adapter")
// Wait for the IP address to be applied
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
err = waitForIPAddress(ctx, interfaceName, ip.String())
require.NoError(t, err, "IP address not applied within timeout")
t.Cleanup(func() {
script = fmt.Sprintf(`Remove-NetIPAddress -InterfaceAlias "%s" -IPAddress "%s" -Confirm:$False`, interfaceName, ip.String())
_, err = exec.Command("powershell", "-Command", script).CombinedOutput()
require.NoError(t, err, "Failed to remove IP address from loopback adapter")
})
return interfaceName
}
func fetchOriginalGateway() (*RouteInfo, error) {
cmd := exec.Command("powershell", "-Command", "Get-NetRoute -DestinationPrefix 0.0.0.0/0 | Select-Object Nexthop, RouteMetric, InterfaceAlias | ConvertTo-Json")
output, err := cmd.CombinedOutput()
@@ -231,30 +206,6 @@ func verifyOutput(t *testing.T, output *FindNetRouteOutput, sourceIP, destPrefix
assert.Equal(t, intf, output.InterfaceAlias, "Interface mismatch")
}
func waitForIPAddress(ctx context.Context, interfaceAlias, expectedIPAddress string) error {
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
out, err := exec.Command("powershell", "-Command", fmt.Sprintf(`Get-NetIPAddress -InterfaceAlias "%s" | Select-Object -ExpandProperty IPAddress`, interfaceAlias)).CombinedOutput()
if err != nil {
return err
}
ipAddresses := strings.Split(strings.TrimSpace(string(out)), "\n")
for _, ip := range ipAddresses {
if strings.TrimSpace(ip) == expectedIPAddress {
return nil
}
}
}
}
}
func combineOutputs(outputs []FindNetRouteOutput) *FindNetRouteOutput {
var combined FindNetRouteOutput
@@ -285,5 +236,25 @@ func combineOutputs(outputs []FindNetRouteOutput) *FindNetRouteOutput {
func setupDummyInterfacesAndRoutes(t *testing.T) {
t.Helper()
createAndSetupDummyInterface(t, "Loopback Pseudo-Interface 1", "10.0.0.1/8")
addDummyRoute(t, "10.0.0.0/8")
}
func addDummyRoute(t *testing.T, dstCIDR string) {
t.Helper()
script := fmt.Sprintf(`New-NetRoute -DestinationPrefix "%s" -InterfaceIndex 1 -PolicyStore ActiveStore`, dstCIDR)
output, err := exec.Command("powershell", "-Command", script).CombinedOutput()
if err != nil {
t.Logf("Failed to add dummy route: %v\nOutput: %s", err, output)
t.FailNow()
}
t.Cleanup(func() {
script = fmt.Sprintf(`Remove-NetRoute -DestinationPrefix "%s" -InterfaceIndex 1 -Confirm:$false`, dstCIDR)
output, err := exec.Command("powershell", "-Command", script).CombinedOutput()
if err != nil {
t.Logf("Failed to remove dummy route: %v\nOutput: %s", err, output)
}
})
}

View File

@@ -19,6 +19,7 @@ import (
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/route"
)
@@ -47,6 +48,7 @@ type CustomLogger interface {
type selectRoute struct {
NetID string
Network netip.Prefix
Domains domain.List
Selected bool
}
@@ -279,6 +281,7 @@ func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) {
route := &selectRoute{
NetID: string(id),
Network: rt[0].Network,
Domains: rt[0].Domains,
Selected: routeSelector.IsSelected(id),
}
routes = append(routes, route)
@@ -299,17 +302,40 @@ func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) {
return iPrefix < jPrefix
})
resolvedDomains := c.recorder.GetResolvedDomainsStates()
return prepareRouteSelectionDetails(routes, resolvedDomains), nil
}
func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[domain.Domain][]netip.Prefix) *RoutesSelectionDetails {
var routeSelection []RoutesSelectionInfo
for _, r := range routes {
domainList := make([]DomainInfo, 0)
for _, d := range r.Domains {
domainResp := DomainInfo{
Domain: d.SafeString(),
}
if prefixes, exists := resolvedDomains[d]; exists {
var ipStrings []string
for _, prefix := range prefixes {
ipStrings = append(ipStrings, prefix.Addr().String())
}
domainResp.ResolvedIPs = strings.Join(ipStrings, ", ")
}
domainList = append(domainList, domainResp)
}
domainDetails := DomainDetails{items: domainList}
routeSelection = append(routeSelection, RoutesSelectionInfo{
ID: r.NetID,
Network: r.Network.String(),
Domains: &domainDetails,
Selected: r.Selected,
})
}
routeSelectionDetails := RoutesSelectionDetails{items: routeSelection}
return &routeSelectionDetails, nil
return &routeSelectionDetails
}
func (c *Client) SelectRoute(id string) error {

View File

@@ -16,9 +16,25 @@ type RoutesSelectionDetails struct {
type RoutesSelectionInfo struct {
ID string
Network string
Domains *DomainDetails
Selected bool
}
type DomainCollection interface {
Add(s DomainInfo) DomainCollection
Get(i int) *DomainInfo
Size() int
}
type DomainDetails struct {
items []DomainInfo
}
type DomainInfo struct {
Domain string
ResolvedIPs string
}
// Add new PeerInfo to the collection
func (array RoutesSelectionDetails) Add(s RoutesSelectionInfo) RoutesSelectionDetails {
array.items = append(array.items, s)
@@ -34,3 +50,16 @@ func (array RoutesSelectionDetails) Get(i int) *RoutesSelectionInfo {
func (array RoutesSelectionDetails) Size() int {
return len(array.items)
}
func (array DomainDetails) Add(s DomainInfo) DomainCollection {
array.items = append(array.items, s)
return array
}
func (array DomainDetails) Get(i int) *DomainInfo {
return &array.items[i]
}
func (array DomainDetails) Size() int {
return len(array.items)
}

View File

@@ -1828,8 +1828,9 @@ type DebugBundleRequest struct {
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Anonymize bool `protobuf:"varint,1,opt,name=anonymize,proto3" json:"anonymize,omitempty"`
Status string `protobuf:"bytes,2,opt,name=status,proto3" json:"status,omitempty"`
Anonymize bool `protobuf:"varint,1,opt,name=anonymize,proto3" json:"anonymize,omitempty"`
Status string `protobuf:"bytes,2,opt,name=status,proto3" json:"status,omitempty"`
SystemInfo bool `protobuf:"varint,3,opt,name=systemInfo,proto3" json:"systemInfo,omitempty"`
}
func (x *DebugBundleRequest) Reset() {
@@ -1878,6 +1879,13 @@ func (x *DebugBundleRequest) GetStatus() string {
return ""
}
func (x *DebugBundleRequest) GetSystemInfo() bool {
if x != nil {
return x.SystemInfo
}
return false
}
type DebugBundleResponse struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
@@ -2370,11 +2378,13 @@ var file_daemon_proto_rawDesc = []byte{
0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x24, 0x0a, 0x05, 0x76, 0x61, 0x6c,
0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0e, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f,
0x6e, 0x2e, 0x49, 0x50, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a,
0x02, 0x38, 0x01, 0x22, 0x4a, 0x0a, 0x12, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64,
0x02, 0x38, 0x01, 0x22, 0x6a, 0x0a, 0x12, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64,
0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1c, 0x0a, 0x09, 0x61, 0x6e, 0x6f,
0x6e, 0x79, 0x6d, 0x69, 0x7a, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x61, 0x6e,
0x6f, 0x6e, 0x79, 0x6d, 0x69, 0x7a, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75,
0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x22,
0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12,
0x1e, 0x0a, 0x0a, 0x73, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x49, 0x6e, 0x66, 0x6f, 0x18, 0x03, 0x20,
0x01, 0x28, 0x08, 0x52, 0x0a, 0x73, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x49, 0x6e, 0x66, 0x6f, 0x22,
0x29, 0x0a, 0x13, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65,
0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x61, 0x74, 0x68, 0x18, 0x01,
0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x22, 0x14, 0x0a, 0x12, 0x47, 0x65,

View File

@@ -263,6 +263,7 @@ message Route {
message DebugBundleRequest {
bool anonymize = 1;
string status = 2;
bool systemInfo = 3;
}
message DebugBundleResponse {

View File

@@ -1,3 +1,5 @@
//go:build !android && !ios
package server
import (
@@ -6,16 +8,70 @@ import (
"context"
"fmt"
"io"
"net"
"net/netip"
"os"
"sort"
"strings"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/proto"
)
const readmeContent = `Netbird debug bundle
This debug bundle contains the following files:
status.txt: Anonymized status information of the NetBird client.
client.log: Most recent, anonymized log file of the NetBird client.
routes.txt: Anonymized system routes, if --system-info flag was provided.
interfaces.txt: Anonymized network interface information, if --system-info flag was provided.
config.txt: Anonymized configuration information of the NetBird client.
Anonymization Process
The files in this bundle have been anonymized to protect sensitive information. Here's how the anonymization was applied:
IP Addresses
IPv4 addresses are replaced with addresses starting from 192.51.100.0
IPv6 addresses are replaced with addresses starting from 100::
IP addresses from non public ranges and well known addresses are not anonymized (e.g. 8.8.8.8, 100.64.0.0/10, addresses starting with 192.168., 172.16., 10., etc.).
Reoccuring IP addresses are replaced with the same anonymized address.
Note: The anonymized IP addresses in the status file do not match those in the log and routes files. However, the anonymized IP addresses are consistent within the status file and across the routes and log files.
Domains
All domain names (except for the netbird domains) are replaced with randomly generated strings ending in ".domain". Anonymized domains are consistent across all files in the bundle.
Reoccuring domain names are replaced with the same anonymized domain.
Routes
For anonymized routes, the IP addresses are replaced as described above. The prefix length remains unchanged. Note that for prefixes, the anonymized IP might not be a network address, but the prefix length is still correct.
Network Interfaces
The interfaces.txt file contains information about network interfaces, including:
- Interface name
- Interface index
- MTU (Maximum Transmission Unit)
- Flags
- IP addresses associated with each interface
The IP addresses in the interfaces file are anonymized using the same process as described above. Interface names, indexes, MTUs, and flags are not anonymized.
Configuration
The config.txt file contains anonymized configuration information of the NetBird client. Sensitive information such as private keys and SSH keys are excluded. The following fields are anonymized:
- ManagementURL
- AdminURL
- NATExternalIPs
- CustomDNSAddress
Other non-sensitive configuration options are included without anonymization.
`
// DebugBundle creates a debug bundle and returns the location.
func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (resp *proto.DebugBundleResponse, err error) {
s.mutex.Lock()
@@ -30,93 +86,211 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
return nil, fmt.Errorf("create zip file: %w", err)
}
defer func() {
if err := bundlePath.Close(); err != nil {
log.Errorf("failed to close zip file: %v", err)
if closeErr := bundlePath.Close(); closeErr != nil && err == nil {
err = fmt.Errorf("close zip file: %w", closeErr)
}
if err != nil {
if err2 := os.Remove(bundlePath.Name()); err2 != nil {
log.Errorf("Failed to remove zip file: %v", err2)
if removeErr := os.Remove(bundlePath.Name()); removeErr != nil {
log.Errorf("Failed to remove zip file: %v", removeErr)
}
}
}()
archive := zip.NewWriter(bundlePath)
defer func() {
if err := archive.Close(); err != nil {
log.Errorf("failed to close archive writer: %v", err)
}
}()
if status := req.GetStatus(); status != "" {
filename := "status.txt"
if req.GetAnonymize() {
filename = "status.anon.txt"
}
statusReader := strings.NewReader(status)
if err := addFileToZip(archive, statusReader, filename); err != nil {
return nil, fmt.Errorf("add status file to zip: %w", err)
}
}
logFile, err := os.Open(s.logFile)
if err != nil {
return nil, fmt.Errorf("open log file: %w", err)
}
defer func() {
if err := logFile.Close(); err != nil {
log.Errorf("failed to close original log file: %v", err)
}
}()
filename := "client.log.txt"
var logReader io.Reader
errChan := make(chan error, 1)
if req.GetAnonymize() {
filename = "client.anon.log.txt"
var writer io.WriteCloser
logReader, writer = io.Pipe()
go s.anonymize(logFile, writer, errChan)
} else {
logReader = logFile
}
if err := addFileToZip(archive, logReader, filename); err != nil {
return nil, fmt.Errorf("add log file to zip: %w", err)
}
select {
case err := <-errChan:
if err != nil {
return nil, err
}
default:
if err := s.createArchive(bundlePath, req); err != nil {
return nil, err
}
return &proto.DebugBundleResponse{Path: bundlePath.Name()}, nil
}
func (s *Server) anonymize(reader io.Reader, writer io.WriteCloser, errChan chan<- error) {
scanner := bufio.NewScanner(reader)
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
func (s *Server) createArchive(bundlePath *os.File, req *proto.DebugBundleRequest) error {
archive := zip.NewWriter(bundlePath)
if err := s.addReadme(req, archive); err != nil {
return fmt.Errorf("add readme: %w", err)
}
if err := s.addStatus(req, archive); err != nil {
return fmt.Errorf("add status: %w", err)
}
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
status := s.statusRecorder.GetFullStatus()
seedFromStatus(anonymizer, &status)
if err := s.addConfig(req, anonymizer, archive); err != nil {
return fmt.Errorf("add config: %w", err)
}
if req.GetSystemInfo() {
if err := s.addRoutes(req, anonymizer, archive); err != nil {
return fmt.Errorf("add routes: %w", err)
}
if err := s.addInterfaces(req, anonymizer, archive); err != nil {
return fmt.Errorf("add interfaces: %w", err)
}
}
if err := s.addLogfile(req, anonymizer, archive); err != nil {
return fmt.Errorf("add log file: %w", err)
}
if err := archive.Close(); err != nil {
return fmt.Errorf("close archive writer: %w", err)
}
return nil
}
func (s *Server) addReadme(req *proto.DebugBundleRequest, archive *zip.Writer) error {
if req.GetAnonymize() {
readmeReader := strings.NewReader(readmeContent)
if err := addFileToZip(archive, readmeReader, "README.txt"); err != nil {
return fmt.Errorf("add README file to zip: %w", err)
}
}
return nil
}
func (s *Server) addStatus(req *proto.DebugBundleRequest, archive *zip.Writer) error {
if status := req.GetStatus(); status != "" {
statusReader := strings.NewReader(status)
if err := addFileToZip(archive, statusReader, "status.txt"); err != nil {
return fmt.Errorf("add status file to zip: %w", err)
}
}
return nil
}
func (s *Server) addConfig(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error {
var configContent strings.Builder
s.addCommonConfigFields(&configContent)
if req.GetAnonymize() {
if s.config.ManagementURL != nil {
configContent.WriteString(fmt.Sprintf("ManagementURL: %s\n", anonymizer.AnonymizeURI(s.config.ManagementURL.String())))
}
if s.config.AdminURL != nil {
configContent.WriteString(fmt.Sprintf("AdminURL: %s\n", anonymizer.AnonymizeURI(s.config.AdminURL.String())))
}
configContent.WriteString(fmt.Sprintf("NATExternalIPs: %v\n", anonymizeNATExternalIPs(s.config.NATExternalIPs, anonymizer)))
if s.config.CustomDNSAddress != "" {
configContent.WriteString(fmt.Sprintf("CustomDNSAddress: %s\n", anonymizer.AnonymizeString(s.config.CustomDNSAddress)))
}
} else {
if s.config.ManagementURL != nil {
configContent.WriteString(fmt.Sprintf("ManagementURL: %s\n", s.config.ManagementURL.String()))
}
if s.config.AdminURL != nil {
configContent.WriteString(fmt.Sprintf("AdminURL: %s\n", s.config.AdminURL.String()))
}
configContent.WriteString(fmt.Sprintf("NATExternalIPs: %v\n", s.config.NATExternalIPs))
if s.config.CustomDNSAddress != "" {
configContent.WriteString(fmt.Sprintf("CustomDNSAddress: %s\n", s.config.CustomDNSAddress))
}
}
// Add config content to zip file
configReader := strings.NewReader(configContent.String())
if err := addFileToZip(archive, configReader, "config.txt"); err != nil {
return fmt.Errorf("add config file to zip: %w", err)
}
return nil
}
func (s *Server) addCommonConfigFields(configContent *strings.Builder) {
configContent.WriteString("NetBird Client Configuration:\n\n")
// Add non-sensitive fields
configContent.WriteString(fmt.Sprintf("WgIface: %s\n", s.config.WgIface))
configContent.WriteString(fmt.Sprintf("WgPort: %d\n", s.config.WgPort))
if s.config.NetworkMonitor != nil {
configContent.WriteString(fmt.Sprintf("NetworkMonitor: %v\n", *s.config.NetworkMonitor))
}
configContent.WriteString(fmt.Sprintf("IFaceBlackList: %v\n", s.config.IFaceBlackList))
configContent.WriteString(fmt.Sprintf("DisableIPv6Discovery: %v\n", s.config.DisableIPv6Discovery))
configContent.WriteString(fmt.Sprintf("RosenpassEnabled: %v\n", s.config.RosenpassEnabled))
configContent.WriteString(fmt.Sprintf("RosenpassPermissive: %v\n", s.config.RosenpassPermissive))
if s.config.ServerSSHAllowed != nil {
configContent.WriteString(fmt.Sprintf("ServerSSHAllowed: %v\n", *s.config.ServerSSHAllowed))
}
configContent.WriteString(fmt.Sprintf("DisableAutoConnect: %v\n", s.config.DisableAutoConnect))
configContent.WriteString(fmt.Sprintf("DNSRouteInterval: %s\n", s.config.DNSRouteInterval))
}
func (s *Server) addRoutes(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error {
if routes, err := systemops.GetRoutesFromTable(); err != nil {
log.Errorf("Failed to get routes: %v", err)
} else {
// TODO: get routes including nexthop
routesContent := formatRoutes(routes, req.GetAnonymize(), anonymizer)
routesReader := strings.NewReader(routesContent)
if err := addFileToZip(archive, routesReader, "routes.txt"); err != nil {
return fmt.Errorf("add routes file to zip: %w", err)
}
}
return nil
}
func (s *Server) addInterfaces(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error {
interfaces, err := net.Interfaces()
if err != nil {
return fmt.Errorf("get interfaces: %w", err)
}
interfacesContent := formatInterfaces(interfaces, req.GetAnonymize(), anonymizer)
interfacesReader := strings.NewReader(interfacesContent)
if err := addFileToZip(archive, interfacesReader, "interfaces.txt"); err != nil {
return fmt.Errorf("add interfaces file to zip: %w", err)
}
return nil
}
func (s *Server) addLogfile(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) (err error) {
logFile, err := os.Open(s.logFile)
if err != nil {
return fmt.Errorf("open log file: %w", err)
}
defer func() {
if err := writer.Close(); err != nil {
log.Errorf("Failed to close writer: %v", err)
if err := logFile.Close(); err != nil {
log.Errorf("Failed to close original log file: %v", err)
}
}()
var logReader io.Reader
if req.GetAnonymize() {
var writer *io.PipeWriter
logReader, writer = io.Pipe()
go s.anonymize(logFile, writer, anonymizer)
} else {
logReader = logFile
}
if err := addFileToZip(archive, logReader, "client.log"); err != nil {
return fmt.Errorf("add log file to zip: %w", err)
}
return nil
}
func (s *Server) anonymize(reader io.Reader, writer *io.PipeWriter, anonymizer *anonymize.Anonymizer) {
defer func() {
// always nil
_ = writer.Close()
}()
scanner := bufio.NewScanner(reader)
for scanner.Scan() {
line := anonymizer.AnonymizeString(scanner.Text())
if _, err := writer.Write([]byte(line + "\n")); err != nil {
errChan <- fmt.Errorf("write line to writer: %w", err)
writer.CloseWithError(fmt.Errorf("anonymize write: %w", err))
return
}
}
if err := scanner.Err(); err != nil {
errChan <- fmt.Errorf("read line from scanner: %w", err)
writer.CloseWithError(fmt.Errorf("anonymize scan: %w", err))
return
}
}
@@ -141,8 +315,22 @@ func (s *Server) SetLogLevel(_ context.Context, req *proto.SetLogLevelRequest) (
func addFileToZip(archive *zip.Writer, reader io.Reader, filename string) error {
header := &zip.FileHeader{
Name: filename,
Method: zip.Deflate,
Name: filename,
Method: zip.Deflate,
Modified: time.Now(),
CreatorVersion: 20, // Version 2.0
ReaderVersion: 20, // Version 2.0
Flags: 0x800, // UTF-8 filename
}
// If the reader is a file, we can get more accurate information
if f, ok := reader.(*os.File); ok {
if stat, err := f.Stat(); err != nil {
log.Tracef("Failed to get file stat for %s: %v", filename, err)
} else {
header.Modified = stat.ModTime()
}
}
writer, err := archive.CreateHeader(header)
@@ -165,6 +353,13 @@ func seedFromStatus(a *anonymize.Anonymizer, status *peer.FullStatus) {
for _, peer := range status.Peers {
a.AnonymizeDomain(peer.FQDN)
for route := range peer.GetRoutes() {
a.AnonymizeRoute(route)
}
}
for route := range status.LocalPeerState.Routes {
a.AnonymizeRoute(route)
}
for _, nsGroup := range status.NSGroupStates {
@@ -179,3 +374,113 @@ func seedFromStatus(a *anonymize.Anonymizer, status *peer.FullStatus) {
}
}
}
func formatRoutes(routes []netip.Prefix, anonymize bool, anonymizer *anonymize.Anonymizer) string {
var ipv4Routes, ipv6Routes []netip.Prefix
// Separate IPv4 and IPv6 routes
for _, route := range routes {
if route.Addr().Is4() {
ipv4Routes = append(ipv4Routes, route)
} else {
ipv6Routes = append(ipv6Routes, route)
}
}
// Sort IPv4 and IPv6 routes separately
sort.Slice(ipv4Routes, func(i, j int) bool {
return ipv4Routes[i].Bits() > ipv4Routes[j].Bits()
})
sort.Slice(ipv6Routes, func(i, j int) bool {
return ipv6Routes[i].Bits() > ipv6Routes[j].Bits()
})
var builder strings.Builder
// Format IPv4 routes
builder.WriteString("IPv4 Routes:\n")
for _, route := range ipv4Routes {
formatRoute(&builder, route, anonymize, anonymizer)
}
// Format IPv6 routes
builder.WriteString("\nIPv6 Routes:\n")
for _, route := range ipv6Routes {
formatRoute(&builder, route, anonymize, anonymizer)
}
return builder.String()
}
func formatRoute(builder *strings.Builder, route netip.Prefix, anonymize bool, anonymizer *anonymize.Anonymizer) {
if anonymize {
anonymizedIP := anonymizer.AnonymizeIP(route.Addr())
builder.WriteString(fmt.Sprintf("%s/%d\n", anonymizedIP, route.Bits()))
} else {
builder.WriteString(fmt.Sprintf("%s\n", route))
}
}
func formatInterfaces(interfaces []net.Interface, anonymize bool, anonymizer *anonymize.Anonymizer) string {
sort.Slice(interfaces, func(i, j int) bool {
return interfaces[i].Name < interfaces[j].Name
})
var builder strings.Builder
builder.WriteString("Network Interfaces:\n")
for _, iface := range interfaces {
builder.WriteString(fmt.Sprintf("\nInterface: %s\n", iface.Name))
builder.WriteString(fmt.Sprintf(" Index: %d\n", iface.Index))
builder.WriteString(fmt.Sprintf(" MTU: %d\n", iface.MTU))
builder.WriteString(fmt.Sprintf(" Flags: %v\n", iface.Flags))
addrs, err := iface.Addrs()
if err != nil {
builder.WriteString(fmt.Sprintf(" Addresses: Error retrieving addresses: %v\n", err))
} else {
builder.WriteString(" Addresses:\n")
for _, addr := range addrs {
prefix, err := netip.ParsePrefix(addr.String())
if err != nil {
builder.WriteString(fmt.Sprintf(" Error parsing address: %v\n", err))
continue
}
ip := prefix.Addr()
if anonymize {
ip = anonymizer.AnonymizeIP(ip)
}
builder.WriteString(fmt.Sprintf(" %s/%d\n", ip, prefix.Bits()))
}
}
}
return builder.String()
}
func anonymizeNATExternalIPs(ips []string, anonymizer *anonymize.Anonymizer) []string {
anonymizedIPs := make([]string, len(ips))
for i, ip := range ips {
parts := strings.SplitN(ip, "/", 2)
ip1, err := netip.ParseAddr(parts[0])
if err != nil {
anonymizedIPs[i] = ip
continue
}
ip1anon := anonymizer.AnonymizeIP(ip1)
if len(parts) == 2 {
ip2, err := netip.ParseAddr(parts[1])
if err != nil {
anonymizedIPs[i] = fmt.Sprintf("%s/%s", ip1anon, parts[1])
} else {
ip2anon := anonymizer.AnonymizeIP(ip2)
anonymizedIPs[i] = fmt.Sprintf("%s/%s", ip1anon, ip2anon)
}
} else {
anonymizedIPs[i] = ip1anon.String()
}
}
return anonymizedIPs
}

View File

@@ -582,7 +582,7 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes
}
// Down engine work in the daemon.
func (s *Server) Down(_ context.Context, _ *proto.DownRequest) (*proto.DownResponse, error) {
func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
@@ -593,7 +593,25 @@ func (s *Server) Down(_ context.Context, _ *proto.DownRequest) (*proto.DownRespo
state := internal.CtxGetState(s.rootCtx)
state.Set(internal.StatusIdle)
return &proto.DownResponse{}, nil
maxWaitTime := 5 * time.Second
timeout := time.After(maxWaitTime)
engine := s.connectClient.Engine()
for {
if !engine.IsWGIfaceUp() {
return &proto.DownResponse{}, nil
}
select {
case <-ctx.Done():
return &proto.DownResponse{}, nil
case <-timeout:
return nil, fmt.Errorf("failed to shut down properly")
default:
time.Sleep(100 * time.Millisecond)
}
}
}
// Status returns the daemon status

View File

@@ -1,11 +1,12 @@
package system
import (
log "github.com/sirupsen/logrus"
"testing"
log "github.com/sirupsen/logrus"
)
func Test_sysInfo(t *testing.T) {
func Test_sysInfoMac(t *testing.T) {
t.Skip("skipping darwin test")
serialNum, prodName, manufacturer := sysInfo()
if serialNum == "" {

View File

@@ -8,6 +8,7 @@ import (
"context"
"os"
"os/exec"
"regexp"
"runtime"
"strings"
"time"
@@ -20,6 +21,26 @@ import (
"github.com/netbirdio/netbird/version"
)
type SysInfoGetter interface {
GetSysInfo() SysInfo
}
type SysInfoWrapper struct {
si sysinfo.SysInfo
}
func (s SysInfoWrapper) GetSysInfo() SysInfo {
s.si.GetSysInfo()
return SysInfo{
ChassisSerial: s.si.Chassis.Serial,
ProductSerial: s.si.Product.Serial,
BoardSerial: s.si.Board.Serial,
ProductName: s.si.Product.Name,
BoardName: s.si.Board.Name,
ProductVendor: s.si.Product.Vendor,
}
}
// GetInfo retrieves and parses the system information
func GetInfo(ctx context.Context) *Info {
info := _getInfo()
@@ -44,7 +65,8 @@ func GetInfo(ctx context.Context) *Info {
log.Warnf("failed to discover network addresses: %s", err)
}
serialNum, prodName, manufacturer := sysInfo()
si := SysInfoWrapper{}
serialNum, prodName, manufacturer := sysInfo(si.GetSysInfo())
env := Environment{
Cloud: detect_cloud.Detect(ctx),
@@ -86,12 +108,36 @@ func _getInfo() string {
return out.String()
}
func sysInfo() (serialNumber string, productName string, manufacturer string) {
var si sysinfo.SysInfo
si.GetSysInfo()
serial := si.Chassis.Serial
if (serial == "Default string" || serial == "") && si.Product.Serial != "" {
serial = si.Product.Serial
func sysInfo(si SysInfo) (string, string, string) {
isascii := regexp.MustCompile("^[[:ascii:]]+$")
serials := []string{si.ChassisSerial, si.ProductSerial}
serial := ""
for _, s := range serials {
if isascii.MatchString(s) {
serial = s
if s != "Default string" {
break
}
}
}
return serial, si.Product.Name, si.Product.Vendor
if serial == "" && isascii.MatchString(si.BoardSerial) {
serial = si.BoardSerial
}
var name string
for _, n := range []string{si.ProductName, si.BoardName} {
if isascii.MatchString(n) {
name = n
break
}
}
var manufacturer string
if isascii.MatchString(si.ProductVendor) {
manufacturer = si.ProductVendor
}
return serial, name, manufacturer
}

View File

@@ -0,0 +1,12 @@
package system
// SysInfo used to moc out the sysinfo getter
type SysInfo struct {
ChassisSerial string
ProductSerial string
BoardSerial string
ProductName string
BoardName string
ProductVendor string
}

View File

@@ -0,0 +1,198 @@
package system
import "testing"
func Test_sysInfo(t *testing.T) {
tests := []struct {
name string
sysInfo SysInfo
wantSerialNum string
wantProdName string
wantManufacturer string
}{
{
name: "Test Case 1",
sysInfo: SysInfo{
ChassisSerial: "Default string",
ProductSerial: "Default string",
BoardSerial: "M80-G8013200245",
ProductName: "B650M-HDV/M.2",
BoardName: "B650M-HDV/M.2",
ProductVendor: "ASRock",
},
wantSerialNum: "Default string",
wantProdName: "B650M-HDV/M.2",
wantManufacturer: "ASRock",
},
{
name: "Empty Chassis Serial",
sysInfo: SysInfo{
ChassisSerial: "",
ProductSerial: "Default string",
BoardSerial: "M80-G8013200245",
ProductName: "B650M-HDV/M.2",
BoardName: "B650M-HDV/M.2",
ProductVendor: "ASRock",
},
wantSerialNum: "Default string",
wantProdName: "B650M-HDV/M.2",
wantManufacturer: "ASRock",
},
{
name: "Empty Chassis Serial",
sysInfo: SysInfo{
ChassisSerial: "",
ProductSerial: "Default string",
BoardSerial: "M80-G8013200245",
ProductName: "B650M-HDV/M.2",
BoardName: "B650M-HDV/M.2",
ProductVendor: "ASRock",
},
wantSerialNum: "Default string",
wantProdName: "B650M-HDV/M.2",
wantManufacturer: "ASRock",
},
{
name: "Fallback to Product Serial",
sysInfo: SysInfo{
ChassisSerial: "Default string",
ProductSerial: "Product serial",
BoardSerial: "M80-G8013200245",
ProductName: "B650M-HDV/M.2",
BoardName: "B650M-HDV/M.2",
ProductVendor: "ASRock",
},
wantSerialNum: "Product serial",
wantProdName: "B650M-HDV/M.2",
wantManufacturer: "ASRock",
},
{
name: "Fallback to Product Serial with default string",
sysInfo: SysInfo{
ChassisSerial: "Default string",
ProductSerial: "Default string",
BoardSerial: "M80-G8013200245",
ProductName: "B650M-HDV/M.2",
BoardName: "B650M-HDV/M.2",
ProductVendor: "ASRock",
},
wantSerialNum: "Default string",
wantProdName: "B650M-HDV/M.2",
wantManufacturer: "ASRock",
},
{
name: "Non UTF-8 in Chassis Serial",
sysInfo: SysInfo{
ChassisSerial: "\x80",
ProductSerial: "Product serial",
BoardSerial: "M80-G8013200245",
ProductName: "B650M-HDV/M.2",
BoardName: "B650M-HDV/M.2",
ProductVendor: "ASRock",
},
wantSerialNum: "Product serial",
wantProdName: "B650M-HDV/M.2",
wantManufacturer: "ASRock",
},
{
name: "Non UTF-8 in Chassis Serial and Product Serial",
sysInfo: SysInfo{
ChassisSerial: "\x80",
ProductSerial: "\x80",
BoardSerial: "M80-G8013200245",
ProductName: "B650M-HDV/M.2",
BoardName: "B650M-HDV/M.2",
ProductVendor: "ASRock",
},
wantSerialNum: "M80-G8013200245",
wantProdName: "B650M-HDV/M.2",
wantManufacturer: "ASRock",
},
{
name: "Non UTF-8 in Chassis Serial and Product Serial and BoardSerial",
sysInfo: SysInfo{
ChassisSerial: "\x80",
ProductSerial: "\x80",
BoardSerial: "\x80",
ProductName: "B650M-HDV/M.2",
BoardName: "B650M-HDV/M.2",
ProductVendor: "ASRock",
},
wantSerialNum: "",
wantProdName: "B650M-HDV/M.2",
wantManufacturer: "ASRock",
},
{
name: "Empty Product Name",
sysInfo: SysInfo{
ChassisSerial: "Default string",
ProductSerial: "Default string",
BoardSerial: "M80-G8013200245",
ProductName: "",
BoardName: "boardname",
ProductVendor: "ASRock",
},
wantSerialNum: "Default string",
wantProdName: "boardname",
wantManufacturer: "ASRock",
},
{
name: "Invalid Product Name",
sysInfo: SysInfo{
ChassisSerial: "Default string",
ProductSerial: "Default string",
BoardSerial: "M80-G8013200245",
ProductName: "\x80",
BoardName: "boardname",
ProductVendor: "ASRock",
},
wantSerialNum: "Default string",
wantProdName: "boardname",
wantManufacturer: "ASRock",
},
{
name: "Invalid BoardName Name",
sysInfo: SysInfo{
ChassisSerial: "Default string",
ProductSerial: "Default string",
BoardSerial: "M80-G8013200245",
ProductName: "\x80",
BoardName: "\x80",
ProductVendor: "ASRock",
},
wantSerialNum: "Default string",
wantProdName: "",
wantManufacturer: "ASRock",
},
{
name: "Invalid chars",
sysInfo: SysInfo{
ChassisSerial: "\x80",
ProductSerial: "\x80",
BoardSerial: "\x80",
ProductName: "\x80",
BoardName: "\x80",
ProductVendor: "\x80",
},
wantSerialNum: "",
wantProdName: "",
wantManufacturer: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotSerialNum, gotProdName, gotManufacturer := sysInfo(tt.sysInfo)
if gotSerialNum != tt.wantSerialNum {
t.Errorf("sysInfo() gotSerialNum = %v, want %v", gotSerialNum, tt.wantSerialNum)
}
if gotProdName != tt.wantProdName {
t.Errorf("sysInfo() gotProdName = %v, want %v", gotProdName, tt.wantProdName)
}
if gotManufacturer != tt.wantManufacturer {
t.Errorf("sysInfo() gotManufacturer = %v, want %v", gotManufacturer, tt.wantManufacturer)
}
})
}
}

View File

@@ -15,7 +15,6 @@ import (
"strconv"
"strings"
"sync"
"syscall"
"time"
"unicode"
@@ -34,6 +33,7 @@ import (
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/version"
)
@@ -62,8 +62,25 @@ func main() {
var errorMSG string
flag.StringVar(&errorMSG, "error-msg", "", "displays a error message window")
tmpDir := "/tmp"
if runtime.GOOS == "windows" {
tmpDir = os.TempDir()
}
var saveLogsInFile bool
flag.BoolVar(&saveLogsInFile, "use-log-file", false, fmt.Sprintf("save logs in a file: %s/netbird-ui-PID.log", tmpDir))
flag.Parse()
if saveLogsInFile {
logFile := path.Join(tmpDir, fmt.Sprintf("netbird-ui-%d.log", os.Getpid()))
err := util.InitLog("trace", logFile)
if err != nil {
log.Errorf("error while initializing log: %v", err)
return
}
}
a := app.NewWithID("NetBird")
a.SetIcon(fyne.NewStaticResource("netbird", iconDisconnectedPNG))
@@ -76,8 +93,12 @@ func main() {
if showSettings || showRoutes {
a.Run()
} else {
if err := checkPIDFile(); err != nil {
log.Errorf("check PID file: %v", err)
running, err := isAnotherProcessRunning()
if err != nil {
log.Errorf("error while checking process: %v", err)
}
if running {
log.Warn("another process is running")
return
}
client.setDefaultFonts()
@@ -861,104 +882,3 @@ func openURL(url string) error {
}
return err
}
// checkPIDFile exists and return error, or write new.
func checkPIDFile() error {
pidFile := path.Join(os.TempDir(), "wiretrustee-ui.pid")
if piddata, err := os.ReadFile(pidFile); err == nil {
if pid, err := strconv.Atoi(string(piddata)); err == nil {
if process, err := os.FindProcess(pid); err == nil {
if err := process.Signal(syscall.Signal(0)); err == nil {
return fmt.Errorf("process already exists: %d", pid)
}
}
}
}
return os.WriteFile(pidFile, []byte(fmt.Sprintf("%d", os.Getpid())), 0o664) //nolint:gosec
}
func (s *serviceClient) setDefaultFonts() {
var (
defaultFontPath string
)
//TODO: Linux Multiple Language Support
switch runtime.GOOS {
case "darwin":
defaultFontPath = "/Library/Fonts/Arial Unicode.ttf"
case "windows":
fontPath := s.getWindowsFontFilePath()
defaultFontPath = fontPath
}
_, err := os.Stat(defaultFontPath)
if err == nil {
os.Setenv("FYNE_FONT", defaultFontPath)
}
}
func (s *serviceClient) getWindowsFontFilePath() (fontPath string) {
/*
https://learn.microsoft.com/en-us/windows/apps/design/globalizing/loc-international-fonts
https://learn.microsoft.com/en-us/typography/fonts/windows_11_font_list
*/
var (
fontFolder string = "C:/Windows/Fonts"
fontMapping = map[string]string{
"default": "Segoeui.ttf",
"zh-CN": "Msyh.ttc",
"am-ET": "Ebrima.ttf",
"nirmala": "Nirmala.ttf",
"chr-CHER-US": "Gadugi.ttf",
"zh-HK": "Msjh.ttc",
"zh-TW": "Msjh.ttc",
"ja-JP": "Yugothm.ttc",
"km-KH": "Leelawui.ttf",
"ko-KR": "Malgun.ttf",
"th-TH": "Leelawui.ttf",
"ti-ET": "Ebrima.ttf",
}
nirMalaLang = []string{
"as-IN",
"bn-BD",
"bn-IN",
"gu-IN",
"hi-IN",
"kn-IN",
"kok-IN",
"ml-IN",
"mr-IN",
"ne-NP",
"or-IN",
"pa-IN",
"si-LK",
"ta-IN",
"te-IN",
}
)
cmd := exec.Command("powershell", "-Command", "(Get-Culture).Name")
output, err := cmd.Output()
if err != nil {
log.Errorf("Failed to get Windows default language setting: %v", err)
fontPath = path.Join(fontFolder, fontMapping["default"])
return
}
defaultLanguage := strings.TrimSpace(string(output))
for _, lang := range nirMalaLang {
if defaultLanguage == lang {
fontPath = path.Join(fontFolder, fontMapping["nirmala"])
return
}
}
if font, ok := fontMapping[defaultLanguage]; ok {
fontPath = path.Join(fontFolder, font)
} else {
fontPath = path.Join(fontFolder, fontMapping["default"])
}
return
}

26
client/ui/font_bsd.go Normal file
View File

@@ -0,0 +1,26 @@
//go:build darwin
package main
import (
"os"
"runtime"
log "github.com/sirupsen/logrus"
)
const defaultFontPath = "/Library/Fonts/Arial Unicode.ttf"
func (s *serviceClient) setDefaultFonts() {
// TODO: add other bsd paths
if runtime.GOOS != "darwin" {
return
}
if _, err := os.Stat(defaultFontPath); err != nil {
log.Errorf("Failed to find default font file: %v", err)
return
}
os.Setenv("FYNE_FONT", defaultFontPath)
}

7
client/ui/font_linux.go Normal file
View File

@@ -0,0 +1,7 @@
//go:build !386
package main
func (s *serviceClient) setDefaultFonts() {
//TODO: Linux Multiple Language Support
}

91
client/ui/font_windows.go Normal file
View File

@@ -0,0 +1,91 @@
package main
import (
"os"
"path"
"unsafe"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
)
func (s *serviceClient) setDefaultFonts() {
defaultFontPath := s.getWindowsFontFilePath()
if _, err := os.Stat(defaultFontPath); err != nil {
log.Errorf("Failed to find default font file: %v", err)
return
}
os.Setenv("FYNE_FONT", defaultFontPath)
}
func (s *serviceClient) getWindowsFontFilePath() string {
var (
fontFolder = "C:/Windows/Fonts"
fontMapping = map[string]string{
"default": "Segoeui.ttf",
"zh-CN": "Msyh.ttc",
"am-ET": "Ebrima.ttf",
"nirmala": "Nirmala.ttf",
"chr-CHER-US": "Gadugi.ttf",
"zh-HK": "Msjh.ttc",
"zh-TW": "Msjh.ttc",
"ja-JP": "Yugothm.ttc",
"km-KH": "Leelawui.ttf",
"ko-KR": "Malgun.ttf",
"th-TH": "Leelawui.ttf",
"ti-ET": "Ebrima.ttf",
}
nirMalaLang = []string{
"as-IN",
"bn-BD",
"bn-IN",
"gu-IN",
"hi-IN",
"kn-IN",
"kok-IN",
"ml-IN",
"mr-IN",
"ne-NP",
"or-IN",
"pa-IN",
"si-LK",
"ta-IN",
"te-IN",
}
)
// getUserDefaultLocaleName.Call() panics if the func is not found
defer func() {
if r := recover(); r != nil {
log.Errorf("Recovered from panic: %v", r)
}
}()
kernel32 := windows.NewLazySystemDLL("kernel32.dll")
getUserDefaultLocaleName := kernel32.NewProc("GetUserDefaultLocaleName")
buf := make([]uint16, 85) // LOCALE_NAME_MAX_LENGTH is usually 85
r, _, err := getUserDefaultLocaleName.Call(uintptr(unsafe.Pointer(&buf[0])), uintptr(len(buf)))
// returns 0 on failure, err is always non-nil
// https://learn.microsoft.com/en-us/windows/win32/api/winnls/nf-winnls-getuserdefaultlocalename
if r == 0 {
log.Errorf("GetUserDefaultLocaleName call failed: %v", err)
return path.Join(fontFolder, fontMapping["default"])
}
defaultLanguage := windows.UTF16ToString(buf)
for _, lang := range nirMalaLang {
if defaultLanguage == lang {
return path.Join(fontFolder, fontMapping["nirmala"])
}
}
if font, ok := fontMapping[defaultLanguage]; ok {
return path.Join(fontFolder, font)
}
return path.Join(fontFolder, fontMapping["default"])
}

37
client/ui/process.go Normal file
View File

@@ -0,0 +1,37 @@
package main
import (
"os"
"path/filepath"
"strings"
"github.com/shirou/gopsutil/v3/process"
)
func isAnotherProcessRunning() (bool, error) {
processes, err := process.Processes()
if err != nil {
return false, err
}
pid := os.Getpid()
processName := strings.ToLower(filepath.Base(os.Args[0]))
for _, p := range processes {
if int(p.Pid) == pid {
continue
}
runningProcessPath, err := p.Exe()
// most errors are related to short-lived processes
if err != nil {
continue
}
if strings.Contains(strings.ToLower(runningProcessPath), processName) && isProcessOwnedByCurrentUser(p) {
return true, nil
}
}
return false, nil
}

View File

@@ -0,0 +1,26 @@
//go:build !windows
package main
import (
"os"
"github.com/shirou/gopsutil/v3/process"
log "github.com/sirupsen/logrus"
)
func isProcessOwnedByCurrentUser(p *process.Process) bool {
currentUserID := os.Getuid()
uids, err := p.Uids()
if err != nil {
log.Errorf("get process uids: %v", err)
return false
}
for _, id := range uids {
log.Debugf("checking process uid: %d", id)
if int(id) == currentUserID {
return true
}
}
return false
}

View File

@@ -0,0 +1,24 @@
package main
import (
"os/user"
"github.com/shirou/gopsutil/v3/process"
log "github.com/sirupsen/logrus"
)
func isProcessOwnedByCurrentUser(p *process.Process) bool {
processUsername, err := p.Username()
if err != nil {
log.Errorf("get process username error: %v", err)
return false
}
currUser, err := user.Current()
if err != nil {
log.Errorf("get current user error: %v", err)
return false
}
return processUsername == currUser.Username
}

View File

@@ -10,7 +10,7 @@ import (
func EncryptMessage(remotePubKey wgtypes.Key, ourPrivateKey wgtypes.Key, message pb.Message) ([]byte, error) {
byteResp, err := pb.Marshal(message)
if err != nil {
log.Errorf("failed marshalling message %v", err)
log.Errorf("failed marshalling message %v, %+v", err, message.String())
return nil, err
}

View File

@@ -14,14 +14,29 @@ type TextFormatter struct {
levelDesc []string
}
// SyslogFormatter formats logs into text
type SyslogFormatter struct {
levelDesc []string
}
var validLevelDesc = []string{"PANC", "FATL", "ERRO", "WARN", "INFO", "DEBG", "TRAC"}
// NewTextFormatter create new MyTextFormatter instance
func NewTextFormatter() *TextFormatter {
return &TextFormatter{
levelDesc: []string{"PANC", "FATL", "ERRO", "WARN", "INFO", "DEBG", "TRAC"},
levelDesc: validLevelDesc,
timestampFormat: time.RFC3339, // or RFC3339
}
}
// NewSyslogFormatter create new MySyslogFormatter instance
func NewSyslogFormatter() *SyslogFormatter {
return &SyslogFormatter{
levelDesc: validLevelDesc,
}
}
// Format renders a single log entry
func (f *TextFormatter) Format(entry *logrus.Entry) ([]byte, error) {
var fields string
@@ -49,3 +64,20 @@ func (f *TextFormatter) parseLevel(level logrus.Level) string {
return f.levelDesc[level]
}
// Format renders a single log entry
func (f *SyslogFormatter) Format(entry *logrus.Entry) ([]byte, error) {
var fields string
keys := make([]string, 0, len(entry.Data))
for k, v := range entry.Data {
if k == "source" {
continue
}
keys = append(keys, fmt.Sprintf("%s: %v", k, v))
}
if len(keys) > 0 {
fields = fmt.Sprintf("[%s] ", strings.Join(keys, ", "))
}
return []byte(fmt.Sprintf("%s%s\n", fields, entry.Message)), nil
}

View File

@@ -8,7 +8,7 @@ import (
"github.com/stretchr/testify/assert"
)
func TestLogMessageFormat(t *testing.T) {
func TestLogTextFormat(t *testing.T) {
someEntry := &logrus.Entry{
Data: logrus.Fields{"att1": 1, "att2": 2, "source": "some/fancy/path.go:46"},
@@ -24,3 +24,20 @@ func TestLogMessageFormat(t *testing.T) {
expectedString := "^2021-02-21T01:10:30Z WARN \\[(att1: 1, att2: 2|att2: 2, att1: 1)\\] some/fancy/path.go:46: Some Message\\s+$"
assert.Regexp(t, expectedString, parsedString)
}
func TestLogSyslogFormat(t *testing.T) {
someEntry := &logrus.Entry{
Data: logrus.Fields{"att1": 1, "att2": 2, "source": "some/fancy/path.go:46"},
Time: time.Date(2021, time.Month(2), 21, 1, 10, 30, 0, time.UTC),
Level: 3,
Message: "Some Message",
}
formatter := NewSyslogFormatter()
result, _ := formatter.Format(someEntry)
parsedString := string(result)
expectedString := "^\\[(att1: 1, att2: 2|att2: 2, att1: 1)\\] Some Message\\s+$"
assert.Regexp(t, expectedString, parsedString)
}

View File

@@ -10,6 +10,12 @@ func SetTextFormatter(logger *logrus.Logger) {
logger.ReportCaller = true
logger.AddHook(NewContextHook())
}
// SetSyslogFormatter set the text formatter for given logger.
func SetSyslogFormatter(logger *logrus.Logger) {
logger.Formatter = NewSyslogFormatter()
logger.ReportCaller = true
logger.AddHook(NewContextHook())
}
// SetJSONFormatter set the JSON formatter for given logger.
func SetJSONFormatter(logger *logrus.Logger) {

12
go.mod
View File

@@ -19,12 +19,12 @@ require (
github.com/spf13/cobra v1.7.0
github.com/spf13/pflag v1.0.5
github.com/vishvananda/netlink v1.2.1-beta.2
golang.org/x/crypto v0.23.0
golang.org/x/sys v0.20.0
golang.org/x/crypto v0.24.0
golang.org/x/sys v0.21.0
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
golang.zx2c4.com/wireguard/windows v0.5.3
google.golang.org/grpc v1.64.0
google.golang.org/grpc v1.64.1
google.golang.org/protobuf v1.34.1
gopkg.in/natefinch/lumberjack.v2 v2.0.0
)
@@ -83,10 +83,10 @@ require (
goauthentik.io/api/v3 v3.2023051.3
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842
golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028
golang.org/x/net v0.25.0
golang.org/x/net v0.26.0
golang.org/x/oauth2 v0.19.0
golang.org/x/sync v0.7.0
golang.org/x/term v0.20.0
golang.org/x/term v0.21.0
google.golang.org/api v0.177.0
gopkg.in/yaml.v3 v3.0.1
gorm.io/driver/postgres v1.5.7
@@ -115,7 +115,7 @@ require (
github.com/dgraph-io/ristretto v0.1.1 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/distribution/reference v0.6.0 // indirect
github.com/docker/docker v26.1.3+incompatible // indirect
github.com/docker/docker v26.1.4+incompatible // indirect
github.com/docker/go-connections v0.5.0 // indirect
github.com/docker/go-units v0.5.0 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect

24
go.sum
View File

@@ -81,8 +81,8 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
github.com/docker/docker v26.1.3+incompatible h1:lLCzRbrVZrljpVNobJu1J2FHk8V0s4BawoZippkc+xo=
github.com/docker/docker v26.1.3+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
github.com/docker/docker v26.1.4+incompatible h1:vuTpXDuoga+Z38m1OZHzl7NKisKWaWlhjQk7IDPSLsU=
github.com/docker/docker v26.1.4+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj1Br63c=
github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc=
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
@@ -532,8 +532,8 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y
golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE=
golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw=
golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg=
golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI=
golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM=
@@ -577,8 +577,8 @@ golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI=
golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY=
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ=
golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.8.0/go.mod h1:yr7u4HXZRm1R1kBWqr/xKNqewf0plRYoB7sla+BCIXE=
golang.org/x/oauth2 v0.19.0 h1:9+E/EZBCbTLNrbN35fHv/a/d/mOBatymz1zbtQrXpIg=
@@ -634,8 +634,8 @@ golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws=
golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
@@ -643,8 +643,8 @@ golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY=
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU=
golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY=
golang.org/x/term v0.20.0 h1:VnkxpohqXaOBYJtBmEppKUG6mXpi+4O6purfc2+sMhw=
golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
golang.org/x/term v0.21.0 h1:WVXCp+/EBEHOj53Rvu+7KiT/iElMrO8ACK16SMZ3jaA=
golang.org/x/term v0.21.0/go.mod h1:ooXLefLobQVslOqselCNF4SxFAaoS6KujMbsGzSDmX0=
golang.org/x/text v0.0.0-20160726164857-2910a502d2bf/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
@@ -701,8 +701,8 @@ google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyac
google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY=
google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk=
google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc=
google.golang.org/grpc v1.64.0 h1:KH3VH9y/MgNQg1dE7b3XfVK0GsPSIzJwdF617gUSbvY=
google.golang.org/grpc v1.64.0/go.mod h1:oxjF8E3FBnjp+/gVFYdWacaLDx9na1aqy9oovLpxQYg=
google.golang.org/grpc v1.64.1 h1:LKtvyfbX3UGVPFcGqJ9ItpVWW6oN/2XqTxfAnwRRXiA=
google.golang.org/grpc v1.64.1/go.mod h1:hiQF4LFZelK2WKaP6W0L92zGHtiQdZxk8CrSdvyjeP0=
google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0=
google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM=

View File

@@ -64,7 +64,7 @@ func (t *wgTunDevice) Create(routes []string, dns string, searchDomains []string
t.wrapper = newDeviceWrapper(tunDevice)
log.Debugf("attaching to interface %v", name)
t.device = device.NewDevice(t.wrapper, t.iceBind, device.NewLogger(device.LogLevelSilent, "[wiretrustee] "))
t.device = device.NewDevice(t.wrapper, t.iceBind, device.NewLogger(wgLogLevel(), "[wiretrustee] "))
// without this property mobile devices can discover remote endpoints if the configured one was wrong.
// this helps with support for the older NetBird clients that had a hardcoded direct mode
// t.device.DisableSomeRoamingForBrokenMobileSemantics()

View File

@@ -49,7 +49,7 @@ func (t *tunDevice) Create() (wgConfigurer, error) {
t.device = device.NewDevice(
t.wrapper,
t.iceBind,
device.NewLogger(device.LogLevelSilent, "[netbird] "),
device.NewLogger(wgLogLevel(), "[netbird] "),
)
err = t.assignAddr()

View File

@@ -64,7 +64,7 @@ func (t *tunDevice) Create() (wgConfigurer, error) {
t.wrapper = newDeviceWrapper(tunDevice)
log.Debug("Attaching to interface")
t.device = device.NewDevice(t.wrapper, t.iceBind, device.NewLogger(device.LogLevelSilent, "[wiretrustee] "))
t.device = device.NewDevice(t.wrapper, t.iceBind, device.NewLogger(wgLogLevel(), "[wiretrustee] "))
// without this property mobile devices can discover remote endpoints if the configured one was wrong.
// this helps with support for the older NetBird clients that had a hardcoded direct mode
// t.device.DisableSomeRoamingForBrokenMobileSemantics()

View File

@@ -54,7 +54,7 @@ func (t *tunNetstackDevice) Create() (wgConfigurer, error) {
t.device = device.NewDevice(
t.wrapper,
t.iceBind,
device.NewLogger(device.LogLevelSilent, "[netbird] "),
device.NewLogger(wgLogLevel(), "[netbird] "),
)
t.configurer = newWGUSPConfigurer(t.device, t.name)

View File

@@ -57,7 +57,7 @@ func (t *tunUSPDevice) Create() (wgConfigurer, error) {
t.device = device.NewDevice(
t.wrapper,
t.iceBind,
device.NewLogger(device.LogLevelSilent, "[netbird] "),
device.NewLogger(wgLogLevel(), "[netbird] "),
)
err = t.assignAddr()

View File

@@ -41,6 +41,7 @@ func newTunDevice(name string, address WGAddress, port int, key string, mtu int,
}
func (t *tunDevice) Create() (wgConfigurer, error) {
log.Info("create tun interface")
tunDevice, err := tun.CreateTUN(t.name, t.mtu)
if err != nil {
return nil, err
@@ -52,7 +53,7 @@ func (t *tunDevice) Create() (wgConfigurer, error) {
t.device = device.NewDevice(
t.wrapper,
t.iceBind,
device.NewLogger(device.LogLevelSilent, "[netbird] "),
device.NewLogger(wgLogLevel(), "[netbird] "),
)
luid := winipcfg.LUID(t.nativeTunDevice.LUID())

15
iface/wg_log.go Normal file
View File

@@ -0,0 +1,15 @@
package iface
import (
"os"
"golang.zx2c4.com/wireguard/device"
)
func wgLogLevel() int {
if os.Getenv("NB_WG_DEBUG") == "true" {
return device.LogLevelVerbose
} else {
return device.LogLevelSilent
}
}

View File

@@ -2,7 +2,6 @@ package client
import (
"context"
"crypto/tls"
"fmt"
"io"
"sync"
@@ -11,15 +10,11 @@ import (
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/keepalive"
"github.com/cenkalti/backoff/v4"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/encryption"
@@ -51,26 +46,21 @@ type GrpcClient struct {
// NewClient creates a new client to Management service
func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsEnabled bool) (*GrpcClient, error) {
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
var conn *grpc.ClientConn
if tlsEnabled {
transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{}))
operation := func() error {
var err error
conn, err = nbgrpc.CreateConnection(addr, tlsEnabled)
if err != nil {
log.Printf("createConnection error: %v", err)
return err
}
return nil
}
mgmCtx, cancel := context.WithTimeout(ctx, ConnectTimeout)
defer cancel()
conn, err := grpc.DialContext(
mgmCtx,
addr,
transportOption,
nbgrpc.WithCustomDialer(),
grpc.WithBlock(),
grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 30 * time.Second,
Timeout: 10 * time.Second,
}))
err := backoff.Retry(operation, nbgrpc.Backoff(ctx))
if err != nil {
log.Errorf("failed creating connection to Management Service %v", err)
log.Errorf("failed creating connection to Management Service: %v", err)
return nil, err
}
@@ -326,25 +316,44 @@ func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*pro
if !c.ready() {
return nil, fmt.Errorf(errMsgNoMgmtConnection)
}
loginReq, err := encryption.EncryptMessage(serverKey, c.key, req)
if err != nil {
log.Errorf("failed to encrypt message: %s", err)
return nil, err
}
mgmCtx, cancel := context.WithTimeout(c.ctx, ConnectTimeout)
defer cancel()
resp, err := c.realClient.Login(mgmCtx, &proto.EncryptedMessage{
WgPubKey: c.key.PublicKey().String(),
Body: loginReq,
})
var resp *proto.EncryptedMessage
operation := func() error {
mgmCtx, cancel := context.WithTimeout(context.Background(), ConnectTimeout)
defer cancel()
var err error
resp, err = c.realClient.Login(mgmCtx, &proto.EncryptedMessage{
WgPubKey: c.key.PublicKey().String(),
Body: loginReq,
})
if err != nil {
// retry only on context canceled
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.Canceled {
return err
}
return backoff.Permanent(err)
}
return nil
}
err = backoff.Retry(operation, nbgrpc.Backoff(c.ctx))
if err != nil {
log.Errorf("failed to login to Management Service: %v", err)
return nil, err
}
loginResp := &proto.LoginResponse{}
err = encryption.DecryptMessage(serverKey, c.key, resp.Body, loginResp)
if err != nil {
log.Errorf("failed to decrypt registration message: %s", err)
log.Errorf("failed to decrypt login response: %s", err)
return nil, err
}

View File

@@ -69,6 +69,7 @@ type AccountManager interface {
ListSetupKeys(ctx context.Context, accountID, userID string) ([]*SetupKey, error)
SaveUser(ctx context.Context, accountID, initiatorUserID string, update *User) (*UserInfo, error)
SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error)
SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error)
GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error)
GetAccountByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (*Account, error)
GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, *User, error)
@@ -95,6 +96,7 @@ type AccountManager interface {
GetAllGroups(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error)
GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error)
SaveGroup(ctx context.Context, accountID, userID string, group *nbgroup.Group) error
SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error
DeleteGroup(ctx context.Context, accountId, userId, groupID string) error
ListGroups(ctx context.Context, accountId string) ([]*nbgroup.Group, error)
GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error
@@ -133,8 +135,8 @@ type AccountManager interface {
UpdateIntegratedValidatorGroups(ctx context.Context, accountID string, userID string, groups []string) error
GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error)
GetValidatedPeers(account *Account) (map[string]struct{}, error)
SyncAndMarkPeer(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error)
CancelPeerRoutines(ctx context.Context, peer *nbpeer.Peer) error
SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error)
OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error)
@@ -768,10 +770,6 @@ func (a *Account) GetPeer(peerID string) *nbpeer.Peer {
// SetJWTGroups updates the user's auto groups by synchronizing JWT groups.
// Returns true if there are changes in the JWT group membership.
func (a *Account) SetJWTGroups(userID string, groupsNames []string) bool {
if len(groupsNames) == 0 {
return false
}
user, ok := a.Users[userID]
if !ok {
return false
@@ -976,7 +974,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour")
}
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -1027,7 +1025,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) {
return func() (time.Duration, bool) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -1126,7 +1124,7 @@ func (am *DefaultAccountManager) warmupIDPCache(ctx context.Context) error {
// DeleteAccount deletes an account and all its users from local store and from the remote IDP if the requester is an admin and account owner
func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, userID string) error {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
@@ -1586,7 +1584,7 @@ func (am *DefaultAccountManager) MarkPATUsed(ctx context.Context, tokenID string
return err
}
unlock := am.Store.AcquireAccountWriteLock(ctx, account.Id)
unlock := am.Store.AcquireWriteLockByUID(ctx, account.Id)
defer unlock()
account, err = am.Store.GetAccountByUser(ctx, user.Id)
@@ -1669,7 +1667,7 @@ func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims
if err != nil {
return nil, nil, err
}
unlock := am.Store.AcquireAccountWriteLock(ctx, newAcc.Id)
unlock := am.Store.AcquireWriteLockByUID(ctx, newAcc.Id)
alreadyUnlocked := false
defer func() {
if !alreadyUnlocked {
@@ -1726,7 +1724,6 @@ func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims
log.WithContext(ctx).Errorf("failed to save account: %v", err)
} else {
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
// todo: optimize this as part of the group optimizations
am.updateAccountPeers(ctx, account)
unlock()
alreadyUnlocked = true
@@ -1826,7 +1823,7 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.C
account, err := am.Store.GetAccountByUser(ctx, claims.UserId)
if err == nil {
unlockAccount := am.Store.AcquireAccountWriteLock(ctx, account.Id)
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, account.Id)
defer unlockAccount()
account, err = am.Store.GetAccountByUser(ctx, claims.UserId)
if err != nil {
@@ -1846,7 +1843,7 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.C
return account, nil
} else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
if domainAccount != nil {
unlockAccount := am.Store.AcquireAccountWriteLock(ctx, domainAccount.Id)
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccount.Id)
defer unlockAccount()
domainAccount, err = am.Store.GetAccountByPrivateDomain(ctx, claims.Domain)
if err != nil {
@@ -1860,17 +1857,11 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.C
}
}
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
accountID, err := am.Store.GetAccountIDByPeerPubKey(ctx, peerPubKey)
if err != nil {
if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound {
return nil, nil, nil, status.Errorf(status.Unauthenticated, "peer not registered")
}
return nil, nil, nil, err
}
unlock := am.Store.AcquireAccountReadLock(ctx, accountID)
defer unlock()
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
accountUnlock := am.Store.AcquireReadLockByUID(ctx, accountID)
defer accountUnlock()
peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peerPubKey)
defer peerUnlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
@@ -1890,26 +1881,20 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, peerPubKey
return peer, netMap, postureChecks, nil
}
func (am *DefaultAccountManager) CancelPeerRoutines(ctx context.Context, peer *nbpeer.Peer) error {
accountID, err := am.Store.GetAccountIDByPeerPubKey(ctx, peer.Key)
if err != nil {
if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound {
return status.Errorf(status.Unauthenticated, "peer not registered")
}
return err
}
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock()
func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error {
accountUnlock := am.Store.AcquireReadLockByUID(ctx, accountID)
defer accountUnlock()
peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peerPubKey)
defer peerUnlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
err = am.MarkPeerConnected(ctx, peer.Key, false, nil, account)
err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, account)
if err != nil {
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peer.Key, err)
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err)
}
return nil
@@ -1922,7 +1907,7 @@ func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey st
return err
}
unlock := am.Store.AcquireAccountReadLock(ctx, accountID)
unlock := am.Store.AcquireReadLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)

View File

@@ -2219,6 +2219,13 @@ func TestAccount_SetJWTGroups(t *testing.T) {
assert.Len(t, account.Users["user2"].AutoGroups, 1, "new group should be added")
assert.Contains(t, account.Groups, account.Users["user2"].AutoGroups[0], "groups must contain group3 from user groups")
})
t.Run("remove all JWT groups", func(t *testing.T) {
updated := account.SetJWTGroups("user1", []string{})
assert.True(t, updated, "account should be updated")
assert.Len(t, account.Users["user1"].AutoGroups, 1, "only non-JWT groups should remain")
assert.Contains(t, account.Users["user1"].AutoGroups, "group1", " group1 should still be present")
})
}
func TestAccount_UserGroupsAddToPeers(t *testing.T) {

View File

@@ -56,6 +56,10 @@ type Config struct {
func (c Config) GetAuthAudiences() []string {
audiences := []string{c.HttpConfig.AuthAudience}
if c.HttpConfig.ExtraAuthAudience != "" {
audiences = append(audiences, c.HttpConfig.ExtraAuthAudience)
}
if c.DeviceAuthorizationFlow != nil && c.DeviceAuthorizationFlow.ProviderConfig.Audience != "" {
audiences = append(audiences, c.DeviceAuthorizationFlow.ProviderConfig.Audience)
}
@@ -90,6 +94,8 @@ type HttpServerConfig struct {
OIDCConfigEndpoint string
// IdpSignKeyRefreshEnabled identifies the signing key is currently being rotated or not
IdpSignKeyRefreshEnabled bool
// Extra audience
ExtraAuthAudience string
}
// Host represents a Wiretrustee host (e.g. STUN, TURN, Signal)

View File

@@ -36,7 +36,7 @@ func (d DNSSettings) Copy() DNSSettings {
// GetDNSSettings validates a user role and returns the DNS settings for the provided account ID
func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*DNSSettings, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -58,7 +58,7 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s
// SaveDNSSettings validates a user role and updates the account's DNS settings
func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -108,7 +108,6 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta)
}
// todo: check if before/after groups are in use by dns, acl, routes and if it has peers
am.updateAccountPeers(ctx, account)
return nil

View File

@@ -13,7 +13,7 @@ import (
// GetEvents returns a list of activity events of an account
func (am *DefaultAccountManager) GetEvents(ctx context.Context, accountID, userID string) ([]*activity.Event, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)

View File

@@ -39,8 +39,8 @@ type FileStore struct {
mux sync.Mutex `json:"-"`
storeFile string `json:"-"`
// sync.Mutex indexed by accountID
accountLocks sync.Map `json:"-"`
// sync.Mutex indexed by resource ID
resourceLocks sync.Map `json:"-"`
globalAccountLock sync.Mutex `json:"-"`
metrics telemetry.AppMetrics `json:"-"`
@@ -281,26 +281,26 @@ func (s *FileStore) AcquireGlobalLock(ctx context.Context) (unlock func()) {
return unlock
}
// AcquireAccountWriteLock acquires account lock for writing to a resource and returns a function that releases the lock
func (s *FileStore) AcquireAccountWriteLock(ctx context.Context, accountID string) (unlock func()) {
log.WithContext(ctx).Debugf("acquiring lock for account %s", accountID)
// AcquireWriteLockByUID acquires an ID lock for writing to a resource and returns a function that releases the lock
func (s *FileStore) AcquireWriteLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
log.WithContext(ctx).Debugf("acquiring lock for ID %s", uniqueID)
start := time.Now()
value, _ := s.accountLocks.LoadOrStore(accountID, &sync.Mutex{})
value, _ := s.resourceLocks.LoadOrStore(uniqueID, &sync.Mutex{})
mtx := value.(*sync.Mutex)
mtx.Lock()
unlock = func() {
mtx.Unlock()
log.WithContext(ctx).Debugf("released lock for account %s in %v", accountID, time.Since(start))
log.WithContext(ctx).Debugf("released lock for ID %s in %v", uniqueID, time.Since(start))
}
return unlock
}
// AcquireAccountReadLock AcquireAccountWriteLock acquires account lock for reading a resource and returns a function that releases the lock
// AcquireReadLockByUID acquires an ID lock for reading a resource and returns a function that releases the lock
// This method is still returns a write lock as file store can't handle read locks
func (s *FileStore) AcquireAccountReadLock(ctx context.Context, accountID string) (unlock func()) {
return s.AcquireAccountWriteLock(ctx, accountID)
func (s *FileStore) AcquireReadLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
return s.AcquireWriteLockByUID(ctx, uniqueID)
}
func (s *FileStore) SaveAccount(ctx context.Context, account *Account) error {
@@ -666,6 +666,26 @@ func (s *FileStore) SaveInstallationID(ctx context.Context, ID string) error {
return s.persist(ctx, s.storeFile)
}
// SavePeer saves the peer in the account
func (s *FileStore) SavePeer(_ context.Context, accountID string, peer *nbpeer.Peer) error {
s.mux.Lock()
defer s.mux.Unlock()
account, err := s.getAccount(accountID)
if err != nil {
return err
}
newPeer := peer.Copy()
account.Peers[peer.ID] = newPeer
s.PeerKeyID2AccountID[peer.Key] = accountID
s.PeerID2AccountID[peer.ID] = accountID
return nil
}
// SavePeerStatus stores the PeerStatus in memory. It doesn't attempt to persist data to speed up things.
// PeerStatus will be saved eventually when some other changes occur.
func (s *FileStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error {
@@ -746,3 +766,11 @@ func (s *FileStore) Close(ctx context.Context) error {
func (s *FileStore) GetStoreEngine() StoreEngine {
return FileStoreEngine
}
func (s *FileStore) SaveUsers(accountID string, users map[string]*User) error {
return status.Errorf(status.Internal, "SaveUsers is not implemented")
}
func (s *FileStore) SaveGroups(accountID string, groups map[string]*nbgroup.Group) error {
return status.Errorf(status.Internal, "SaveGroups is not implemented")
}

View File

@@ -9,6 +9,7 @@ import (
"path"
"strconv"
log "github.com/sirupsen/logrus"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
@@ -30,6 +31,8 @@ func loadGeolocationDatabases(dataDir string) error {
continue
}
log.Infof("geo location file %s not found , file will be downloaded", file)
switch file {
case MMDBFileName:
extractFunc := func(src string, dst string) error {

View File

@@ -23,7 +23,7 @@ func (e *GroupLinkError) Error() string {
// GetGroup object of the peers
func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupID, userID string) (*nbgroup.Group, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -50,7 +50,7 @@ func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupI
// GetAllGroups returns all groups in an account
func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID string, userID string) ([]*nbgroup.Group, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -77,7 +77,7 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID str
// GetGroupByName filters all groups in an account by name and returns the one with the most peers
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -110,64 +110,87 @@ func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName,
// SaveGroup object of the peers
func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userID string, newGroup *nbgroup.Group) error {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
return am.SaveGroups(ctx, accountID, userID, []*nbgroup.Group{newGroup})
}
// SaveGroups adds new groups to the account.
// Note: This function does not acquire the global lock.
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error {
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI {
return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued)
}
var eventsToStore []func()
if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI {
for _, newGroup := range newGroups {
if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI {
return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued)
}
existingGroup, err := account.FindGroupByName(newGroup.Name)
if err != nil {
s, ok := status.FromError(err)
if !ok || s.ErrorType != status.NotFound {
return err
if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI {
existingGroup, err := account.FindGroupByName(newGroup.Name)
if err != nil {
s, ok := status.FromError(err)
if !ok || s.ErrorType != status.NotFound {
return err
}
}
// Avoid duplicate groups only for the API issued groups.
// Integration or JWT groups can be duplicated as they are coming from the IdP that we don't have control of.
if existingGroup != nil {
return status.Errorf(status.AlreadyExists, "group with name %s already exists", newGroup.Name)
}
newGroup.ID = xid.New().String()
}
for _, peerID := range newGroup.Peers {
if account.Peers[peerID] == nil {
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
}
}
// avoid duplicate groups only for the API issued groups. Integration or JWT groups can be duplicated as they are
// coming from the IdP that we don't have control of.
if existingGroup != nil {
return status.Errorf(status.AlreadyExists, "group with name %s already exists", newGroup.Name)
}
oldGroup := account.Groups[newGroup.ID]
account.Groups[newGroup.ID] = newGroup
newGroup.ID = xid.New().String()
events := am.prepareGroupEvents(ctx, userID, accountID, newGroup, oldGroup, account)
eventsToStore = append(eventsToStore, events...)
}
for _, peerID := range newGroup.Peers {
if account.Peers[peerID] == nil {
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
}
}
oldGroup, exists := account.Groups[newGroup.ID]
account.Groups[newGroup.ID] = newGroup
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
// todo: check if groups is in use by dns, acl, routes and before/after peers
am.updateAccountPeers(ctx, account)
// the following snippet tracks the activity and stores the group events in the event store.
// It has to happen after all the operations have been successfully performed.
for _, storeEvent := range eventsToStore {
storeEvent()
}
return nil
}
// prepareGroupEvents prepares a list of event functions to be stored.
func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID string, accountID string, newGroup, oldGroup *nbgroup.Group, account *Account) []func() {
var eventsToStore []func()
addedPeers := make([]string, 0)
removedPeers := make([]string, 0)
if exists {
if oldGroup != nil {
addedPeers = difference(newGroup.Peers, oldGroup.Peers)
removedPeers = difference(oldGroup.Peers, newGroup.Peers)
} else {
addedPeers = append(addedPeers, newGroup.Peers...)
am.StoreEvent(ctx, userID, newGroup.ID, accountID, activity.GroupCreated, newGroup.EventMeta())
eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, userID, newGroup.ID, accountID, activity.GroupCreated, newGroup.EventMeta())
})
}
for _, p := range addedPeers {
@@ -176,11 +199,14 @@ func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userI
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID)
continue
}
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupAddedToPeer,
map[string]any{
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peer.IP.String(),
"peer_fqdn": peer.FQDN(am.GetDNSDomain()),
})
peerCopy := peer // copy to avoid closure issues
eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupAddedToPeer,
map[string]any{
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peerCopy.IP.String(),
"peer_fqdn": peerCopy.FQDN(am.GetDNSDomain()),
})
})
}
for _, p := range removedPeers {
@@ -189,14 +215,17 @@ func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userI
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID)
continue
}
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupRemovedFromPeer,
map[string]any{
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peer.IP.String(),
"peer_fqdn": peer.FQDN(am.GetDNSDomain()),
})
peerCopy := peer // copy to avoid closure issues
eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupRemovedFromPeer,
map[string]any{
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peerCopy.IP.String(),
"peer_fqdn": peerCopy.FQDN(am.GetDNSDomain()),
})
})
}
return nil
return eventsToStore
}
// difference returns the elements in `a` that aren't in `b`.
@@ -216,7 +245,7 @@ func difference(a, b []string) []string {
// DeleteGroup object of the peers
func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, userId, groupID string) error {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountId)
unlock := am.Store.AcquireWriteLockByUID(ctx, accountId)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountId)
@@ -323,7 +352,6 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, use
am.StoreEvent(ctx, userId, groupID, accountId, activity.GroupDeleted, g.EventMeta())
// todo: check if groups is in use by dns, acl, routes and if it has peers
am.updateAccountPeers(ctx, account)
return nil
@@ -331,7 +359,7 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, use
// ListGroups objects of the peers
func (am *DefaultAccountManager) ListGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -349,7 +377,7 @@ func (am *DefaultAccountManager) ListGroups(ctx context.Context, accountID strin
// GroupAddPeer appends peer to the group
func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -378,7 +406,6 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
return err
}
// todo: check if groups is in use by dns, acl, routes
am.updateAccountPeers(ctx, account)
return nil
@@ -386,7 +413,7 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
// GroupDeletePeer removes peer from the group
func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -409,7 +436,6 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
}
}
// todo: check if groups is in use by dns, acl, routes
am.updateAccountPeers(ctx, account)
return nil

View File

@@ -156,7 +156,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
log.WithContext(ctx).Tracef("peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP)
}
peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, peerKey.String(), extractPeerMeta(ctx, syncReq.GetMeta()), realIP)
peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), extractPeerMeta(ctx, syncReq.GetMeta()), realIP)
if err != nil {
return mapError(ctx, err)
}
@@ -179,11 +179,11 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart))
}
return s.handleUpdates(ctx, peerKey, peer, updates, srv)
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv)
}
// handleUpdates sends updates to the connected peer until the updates channel is closed.
func (s *GRPCServer) handleUpdates(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error {
func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error {
for {
select {
// condition when there are some updates
@@ -194,12 +194,12 @@ func (s *GRPCServer) handleUpdates(ctx context.Context, peerKey wgtypes.Key, pee
if !open {
log.WithContext(ctx).Debugf("updates channel for peer %s was closed", peerKey.String())
s.cancelPeerRoutines(ctx, peer)
s.cancelPeerRoutines(ctx, accountID, peer)
return nil
}
log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String())
if err := s.sendUpdate(ctx, peerKey, peer, update, srv); err != nil {
if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv); err != nil {
return err
}
@@ -207,7 +207,7 @@ func (s *GRPCServer) handleUpdates(ctx context.Context, peerKey wgtypes.Key, pee
case <-srv.Context().Done():
// happens when connection drops, e.g. client disconnects
log.WithContext(ctx).Debugf("stream of peer %s has been closed", peerKey.String())
s.cancelPeerRoutines(ctx, peer)
s.cancelPeerRoutines(ctx, accountID, peer)
return srv.Context().Err()
}
}
@@ -215,10 +215,10 @@ func (s *GRPCServer) handleUpdates(ctx context.Context, peerKey wgtypes.Key, pee
// sendUpdate encrypts the update message using the peer key and the server's wireguard key,
// then sends the encrypted message to the connected peer via the sync server.
func (s *GRPCServer) sendUpdate(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, update *UpdateMessage, srv proto.ManagementService_SyncServer) error {
func (s *GRPCServer) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *UpdateMessage, srv proto.ManagementService_SyncServer) error {
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, update.Update)
if err != nil {
s.cancelPeerRoutines(ctx, peer)
s.cancelPeerRoutines(ctx, accountID, peer)
return status.Errorf(codes.Internal, "failed processing update message")
}
err = srv.SendMsg(&proto.EncryptedMessage{
@@ -226,17 +226,17 @@ func (s *GRPCServer) sendUpdate(ctx context.Context, peerKey wgtypes.Key, peer *
Body: encryptedResp,
})
if err != nil {
s.cancelPeerRoutines(ctx, peer)
s.cancelPeerRoutines(ctx, accountID, peer)
return status.Errorf(codes.Internal, "failed sending update message")
}
log.WithContext(ctx).Debugf("sent an update to peer %s", peerKey.String())
return nil
}
func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, peer *nbpeer.Peer) {
func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) {
s.peersUpdateManager.CloseChannel(ctx, peer.ID)
s.turnCredentialsManager.CancelRefresh(peer.ID)
_ = s.accountManager.CancelPeerRoutines(ctx, peer)
_ = s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key)
s.ephemeralManager.OnPeerDisconnected(ctx, peer)
}

View File

@@ -526,6 +526,43 @@ components:
- revoked
- auto_groups
- usage_limit
CreateSetupKeyRequest:
type: object
properties:
name:
description: Setup Key name
type: string
example: Default key
type:
description: Setup key type, one-off for single time usage and reusable
type: string
example: reusable
expires_in:
description: Expiration time in seconds
type: integer
minimum: 86400
maximum: 31536000
example: 86400
auto_groups:
description: List of group IDs to auto-assign to peers registered with this key
type: array
items:
type: string
example: "ch8i4ug6lnn4g9hqv7m0"
usage_limit:
description: A number of times this key can be used. The value of 0 indicates the unlimited usage.
type: integer
example: 0
ephemeral:
description: Indicate that the peer will be ephemeral or not
type: boolean
example: true
required:
- name
- type
- expires_in
- auto_groups
- usage_limit
PersonalAccessToken:
type: object
properties:
@@ -1806,7 +1843,7 @@ paths:
content:
'application/json':
schema:
$ref: '#/components/schemas/SetupKeyRequest'
$ref: '#/components/schemas/CreateSetupKeyRequest'
responses:
'200':
description: A Setup Keys Object

View File

@@ -254,6 +254,27 @@ type Country struct {
// CountryCode 2-letter ISO 3166-1 alpha-2 code that represents the country
type CountryCode = string
// CreateSetupKeyRequest defines model for CreateSetupKeyRequest.
type CreateSetupKeyRequest struct {
// AutoGroups List of group IDs to auto-assign to peers registered with this key
AutoGroups []string `json:"auto_groups"`
// Ephemeral Indicate that the peer will be ephemeral or not
Ephemeral *bool `json:"ephemeral,omitempty"`
// ExpiresIn Expiration time in seconds
ExpiresIn int `json:"expires_in"`
// Name Setup Key name
Name string `json:"name"`
// Type Setup key type, one-off for single time usage and reusable
Type string `json:"type"`
// UsageLimit A number of times this key can be used. The value of 0 indicates the unlimited usage.
UsageLimit int `json:"usage_limit"`
}
// DNSSettings defines model for DNSSettings.
type DNSSettings struct {
// DisabledManagementGroups Groups whose DNS management is disabled
@@ -1241,7 +1262,7 @@ type PostApiRoutesJSONRequestBody = RouteRequest
type PutApiRoutesRouteIdJSONRequestBody = RouteRequest
// PostApiSetupKeysJSONRequestBody defines body for PostApiSetupKeys for application/json ContentType.
type PostApiSetupKeysJSONRequestBody = SetupKeyRequest
type PostApiSetupKeysJSONRequestBody = CreateSetupKeyRequest
// PutApiSetupKeysKeyIdJSONRequestBody defines body for PutApiSetupKeysKeyId for application/json ContentType.
type PutApiSetupKeysKeyIdJSONRequestBody = SetupKeyRequest

View File

@@ -32,7 +32,7 @@ func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(ctx context.Con
return errors.New("invalid groups")
}
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
a, err := am.Store.GetAccountByUser(ctx, userID)

View File

@@ -2,6 +2,7 @@ package server
import (
"context"
"fmt"
"net"
"os"
"path/filepath"
@@ -16,6 +17,7 @@ import (
"google.golang.org/grpc/keepalive"
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/formatter"
mgmtProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/util"
@@ -83,7 +85,7 @@ func Test_SyncProtocol(t *testing.T) {
defer func() {
os.Remove(filepath.Join(dir, "store.json")) //nolint
}()
mgmtServer, mgmtAddr, err := startManagement(t, &Config{
mgmtServer, _, mgmtAddr, err := startManagement(t, &Config{
Stuns: []*Host{{
Proto: "udp",
URI: "stun:stun.wiretrustee.com:3468",
@@ -399,32 +401,35 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
}
}
func startManagement(t *testing.T, config *Config) (*grpc.Server, string, error) {
func startManagement(t *testing.T, config *Config) (*grpc.Server, *DefaultAccountManager, string, error) {
t.Helper()
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
return nil, "", err
return nil, nil, "", err
}
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
store, cleanUp, err := NewTestStoreFromJson(context.Background(), config.Datadir)
if err != nil {
return nil, "", err
return nil, nil, "", err
}
t.Cleanup(cleanUp)
peersUpdateManager := NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{}
accountManager, err := BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted",
ctx := context.WithValue(context.Background(), formatter.ExecutionContextKey, formatter.SystemSource) //nolint:staticcheck
accountManager, err := BuildManager(ctx, store, peersUpdateManager, nil, "", "netbird.selfhosted",
eventStore, nil, false, MocIntegratedValidator{})
if err != nil {
return nil, "", err
return nil, nil, "", err
}
turnManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
ephemeralMgr := NewEphemeralManager(store, accountManager)
mgmtServer, err := NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, ephemeralMgr)
if err != nil {
return nil, "", err
return nil, nil, "", err
}
mgmtProto.RegisterManagementServiceServer(s, mgmtServer)
@@ -434,7 +439,7 @@ func startManagement(t *testing.T, config *Config) (*grpc.Server, string, error)
}
}()
return s, lis.Addr().String(), nil
return s, accountManager, lis.Addr().String(), nil
}
func createRawClient(addr string) (mgmtProto.ManagementServiceClient, *grpc.ClientConn, error) {
@@ -454,3 +459,165 @@ func createRawClient(addr string) (mgmtProto.ManagementServiceClient, *grpc.Clie
return mgmtProto.NewManagementServiceClient(conn), conn, nil
}
func Test_SyncStatusRace(t *testing.T) {
if os.Getenv("CI") == "true" && os.Getenv("NETBIRD_STORE_ENGINE") == "postgres" {
t.Skip("Skipping on CI and Postgres store")
}
for i := 0; i < 500; i++ {
t.Run(fmt.Sprintf("TestRun-%d", i), func(t *testing.T) {
testSyncStatusRace(t)
})
}
}
func testSyncStatusRace(t *testing.T) {
t.Helper()
dir := t.TempDir()
err := util.CopyFileContents("testdata/store_with_expired_peers.json", filepath.Join(dir, "store.json"))
if err != nil {
t.Fatal(err)
}
defer func() {
os.Remove(filepath.Join(dir, "store.json")) //nolint
}()
mgmtServer, am, mgmtAddr, err := startManagement(t, &Config{
Stuns: []*Host{{
Proto: "udp",
URI: "stun:stun.wiretrustee.com:3468",
}},
TURNConfig: &TURNConfig{
TimeBasedCredentials: false,
CredentialsTTL: util.Duration{},
Secret: "whatever",
Turns: []*Host{{
Proto: "udp",
URI: "turn:stun.wiretrustee.com:3468",
}},
},
Signal: &Host{
Proto: "http",
URI: "signal.wiretrustee.com:10000",
},
Datadir: dir,
HttpConfig: nil,
})
if err != nil {
t.Fatal(err)
return
}
defer mgmtServer.GracefulStop()
client, clientConn, err := createRawClient(mgmtAddr)
if err != nil {
t.Fatal(err)
return
}
defer clientConn.Close()
// there are two peers already in the store, add two more
peers, err := registerPeers(2, client)
if err != nil {
t.Fatal(err)
return
}
serverKey, err := getServerKey(client)
if err != nil {
t.Fatal(err)
return
}
concurrentPeerKey2 := peers[1]
t.Log("Public key of concurrent peer: ", concurrentPeerKey2.PublicKey().String())
syncReq2 := &mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}}
message2, err := encryption.EncryptMessage(*serverKey, *concurrentPeerKey2, syncReq2)
if err != nil {
t.Fatal(err)
return
}
ctx2, cancelFunc2 := context.WithCancel(context.Background())
//client.
sync2, err := client.Sync(ctx2, &mgmtProto.EncryptedMessage{
WgPubKey: concurrentPeerKey2.PublicKey().String(),
Body: message2,
})
if err != nil {
t.Fatal(err)
return
}
resp2 := &mgmtProto.EncryptedMessage{}
err = sync2.RecvMsg(resp2)
if err != nil {
t.Fatal(err)
return
}
peerWithInvalidStatus := peers[0]
t.Log("Public key of peer with invalid status: ", peerWithInvalidStatus.PublicKey().String())
syncReq := &mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}}
message, err := encryption.EncryptMessage(*serverKey, *peerWithInvalidStatus, syncReq)
if err != nil {
t.Fatal(err)
return
}
ctx, cancelFunc := context.WithCancel(context.Background())
//client.
sync, err := client.Sync(ctx, &mgmtProto.EncryptedMessage{
WgPubKey: peerWithInvalidStatus.PublicKey().String(),
Body: message,
})
if err != nil {
t.Fatal(err)
return
}
// take the first registered peer as a base for the test. Total four.
resp := &mgmtProto.EncryptedMessage{}
err = sync.RecvMsg(resp)
if err != nil {
t.Fatal(err)
return
}
cancelFunc2()
time.Sleep(1 * time.Millisecond)
cancelFunc()
time.Sleep(10 * time.Millisecond)
ctx, cancelFunc = context.WithCancel(context.Background())
defer cancelFunc()
sync, err = client.Sync(ctx, &mgmtProto.EncryptedMessage{
WgPubKey: peerWithInvalidStatus.PublicKey().String(),
Body: message,
})
if err != nil {
t.Fatal(err)
return
}
resp = &mgmtProto.EncryptedMessage{}
err = sync.RecvMsg(resp)
if err != nil {
t.Fatal(err)
return
}
time.Sleep(10 * time.Millisecond)
peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), peerWithInvalidStatus.PublicKey().String())
if err != nil {
t.Fatal(err)
return
}
if !peer.Status.Connected {
t.Fatal("Peer should be connected")
}
}

View File

@@ -31,7 +31,7 @@ type MockAccountManager struct {
ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error)
GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error
SyncAndMarkPeerFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error)
SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error)
DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error
GetNetworkMapFunc func(ctx context.Context, peerKey string) (*server.NetworkMap, error)
GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*server.Network, error)
@@ -40,6 +40,7 @@ type MockAccountManager struct {
GetAllGroupsFunc func(ctx context.Context, accountID, userID string) ([]*group.Group, error)
GetGroupByNameFunc func(ctx context.Context, accountID, groupName string) (*group.Group, error)
SaveGroupFunc func(ctx context.Context, accountID, userID string, group *group.Group) error
SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*group.Group) error
DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error
ListGroupsFunc func(ctx context.Context, accountID string) ([]*group.Group, error)
GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
@@ -64,6 +65,7 @@ type MockAccountManager struct {
ListSetupKeysFunc func(ctx context.Context, accountID, userID string) ([]*server.SetupKey, error)
SaveUserFunc func(ctx context.Context, accountID, userID string, user *server.User) (*server.UserInfo, error)
SaveOrAddUserFunc func(ctx context.Context, accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error)
SaveOrAddUsersFunc func(ctx context.Context, accountID, initiatorUserID string, update []*server.User, addIfNotExists bool) ([]*server.UserInfo, error)
DeleteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error
CreatePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error)
DeletePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) error
@@ -103,14 +105,14 @@ type MockAccountManager struct {
GetAccountIDForPeerKeyFunc func(ctx context.Context, peerKey string) (string, error)
}
func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) {
func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) {
if am.SyncAndMarkPeerFunc != nil {
return am.SyncAndMarkPeerFunc(ctx, peerPubKey, meta, realIP)
return am.SyncAndMarkPeerFunc(ctx, accountID, peerPubKey, meta, realIP)
}
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
}
func (am *MockAccountManager) CancelPeerRoutines(_ context.Context, peer *nbpeer.Peer) error {
func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID string, peerPubKey string) error {
// TODO implement me
panic("implement me")
}
@@ -308,6 +310,14 @@ func (am *MockAccountManager) SaveGroup(ctx context.Context, accountID, userID s
return status.Errorf(codes.Unimplemented, "method SaveGroup is not implemented")
}
// SaveGroups mock implementation of SaveGroups from server.AccountManager interface
func (am *MockAccountManager) SaveGroups(ctx context.Context, accountID, userID string, groups []*group.Group) error {
if am.SaveGroupsFunc != nil {
return am.SaveGroupsFunc(ctx, accountID, userID, groups)
}
return status.Errorf(codes.Unimplemented, "method SaveGroups is not implemented")
}
// DeleteGroup mock implementation of DeleteGroup from server.AccountManager interface
func (am *MockAccountManager) DeleteGroup(ctx context.Context, accountId, userId, groupID string) error {
if am.DeleteGroupFunc != nil {
@@ -502,6 +512,14 @@ func (am *MockAccountManager) SaveOrAddUser(ctx context.Context, accountID, user
return nil, status.Errorf(codes.Unimplemented, "method SaveOrAddUser is not implemented")
}
// SaveOrAddUsers mocks SaveOrAddUsers of the AccountManager interface
func (am *MockAccountManager) SaveOrAddUsers(ctx context.Context, accountID, userID string, users []*server.User, addIfNotExists bool) ([]*server.UserInfo, error) {
if am.SaveOrAddUsersFunc != nil {
return am.SaveOrAddUsersFunc(ctx, accountID, userID, users, addIfNotExists)
}
return nil, status.Errorf(codes.Unimplemented, "method SaveOrAddUsers is not implemented")
}
// DeleteUser mocks DeleteUser of the AccountManager interface
func (am *MockAccountManager) DeleteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error {
if am.DeleteUserFunc != nil {

View File

@@ -20,7 +20,7 @@ const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]{2,}$`
// GetNameServerGroup gets a nameserver group object from account and nameserver group IDs
func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -48,7 +48,7 @@ func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, account
// CreateNameServerGroup creates and saves a new nameserver group
func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainEnabled bool) (*nbdns.NameServerGroup, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -85,7 +85,6 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
return nil, err
}
// todo: check distribution groups if they have any peers
am.updateAccountPeers(ctx, account)
am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())
@@ -96,7 +95,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
// SaveNameServerGroup saves nameserver group
func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
if nsGroupToSave == nil {
@@ -121,7 +120,6 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
return err
}
// todo: check before and after distribution groups if they have any peers
am.updateAccountPeers(ctx, account)
am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())
@@ -132,7 +130,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
// DeleteNameServerGroup deletes nameserver group with nsGroupID
func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -152,7 +150,6 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
return err
}
// todo: check distribution groups if they have any peers
am.updateAccountPeers(ctx, account)
am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())
@@ -163,7 +160,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
// ListNameServerGroups returns a list of nameserver groups from account
func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)

View File

@@ -150,7 +150,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
// UpdatePeer updates peer. Only Peer.Name, Peer.SSHEnabled, and Peer.LoginExpirationEnabled can be updated.
func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -218,7 +218,6 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
return nil, err
}
// todo: don't call it if peer is not expired and Peer.LoginExpirationEnabled was set to false
am.updateAccountPeers(ctx, account)
return peer, nil
@@ -273,7 +272,7 @@ func (am *DefaultAccountManager) deletePeers(ctx context.Context, account *Accou
// DeletePeer removes peer from the account by its IP
func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peerID, userID string) error {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -291,7 +290,6 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
return err
}
// todo: evaluate if peer was part of a group that has is used in a active dns, route, acl
am.updateAccountPeers(ctx, account)
return nil
@@ -358,7 +356,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
return nil, nil, nil, status.Errorf(status.NotFound, "failed adding new peer: account not found")
}
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer func() {
if unlock != nil {
unlock()
@@ -382,7 +380,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
}
// This is a handling for the case when the same machine (with the same WireGuard pub key) tries to register twice.
// Such case is possible when AddPeer function takes long time to finish after AcquireAccountWriteLock (e.g., database is slow)
// Such case is possible when AddPeer function takes long time to finish after AcquireWriteLockByUID (e.g., database is slow)
// and the peer disconnects with a timeout and tries to register again.
// We just check if this machine has been registered before and reject the second registration.
// The connecting peer should be able to recover with a retry.
@@ -455,6 +453,17 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
Location: peer.Location,
}
if am.geo != nil && newPeer.Location.ConnectionIP != nil {
location, err := am.geo.Lookup(newPeer.Location.ConnectionIP)
if err != nil {
log.WithContext(ctx).Warnf("failed to get location for new peer realip: [%s]: %v", newPeer.Location.ConnectionIP.String(), err)
} else {
newPeer.Location.CountryCode = location.Country.ISOCode
newPeer.Location.CityName = location.City.Names.En
newPeer.Location.GeoNameID = location.City.GeonameID
}
}
// add peer to 'All' group
group, err := account.GetGroupAll()
if err != nil {
@@ -512,7 +521,6 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
// todo: evaluate if peer is part of a group that has is used in a active dns, route, acl
am.updateAccountPeers(ctx, account)
approvedPeersMap, err := am.GetValidatedPeers(account)
@@ -538,17 +546,16 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
}
if peerLoginExpired(ctx, peer, account.Settings) {
return nil, nil, nil, status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more")
return nil, nil, nil, status.NewPeerLoginExpiredError()
}
peer, updated := updatePeerMeta(peer, sync.Meta, account)
if updated {
err = am.Store.SaveAccount(ctx, account)
err = am.Store.SavePeer(ctx, account.Id, peer)
if err != nil {
return nil, nil, nil, err
}
// todo: review this logic
if sync.UpdateAccountPeers {
am.updateAccountPeers(ctx, account)
}
@@ -568,7 +575,6 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
return peer, emptyMap, postureChecks, nil
}
// todo: review this logic and combine with the previous
if isStatusChanged {
am.updateAccountPeers(ctx, account)
}
@@ -591,21 +597,10 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
// we couldn't find this peer by its public key which can mean that peer hasn't been registered yet.
// Try registering it.
newPeer := &nbpeer.Peer{
Key: login.WireGuardPubKey,
Meta: login.Meta,
SSHKey: login.SSHKey,
}
if am.geo != nil && login.ConnectionIP != nil {
location, err := am.geo.Lookup(login.ConnectionIP)
if err != nil {
log.WithContext(ctx).Warnf("failed to get location for new peer realip: [%s]: %v", login.ConnectionIP.String(), err)
} else {
newPeer.Location.ConnectionIP = login.ConnectionIP
newPeer.Location.CountryCode = location.Country.ISOCode
newPeer.Location.CityName = location.City.Names.En
newPeer.Location.GeoNameID = location.City.GeonameID
}
Key: login.WireGuardPubKey,
Meta: login.Meta,
SSHKey: login.SSHKey,
Location: nbpeer.Location{ConnectionIP: login.ConnectionIP},
}
return am.AddPeer(ctx, login.SetupKey, login.UserID, newPeer)
@@ -615,44 +610,17 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
return nil, nil, nil, status.Errorf(status.Internal, "failed while logging in peer")
}
peer, err := am.Store.GetPeerByPeerPubKey(ctx, login.WireGuardPubKey)
if err != nil {
return nil, nil, nil, status.NewPeerNotRegisteredError()
}
accSettings, err := am.Store.GetAccountSettings(ctx, accountID)
if err != nil {
return nil, nil, nil, status.Errorf(status.Internal, "failed to get account settings: %s", err)
}
var isWriteLock bool
// duplicated logic from after the lock to have an early exit
expired := peerLoginExpired(ctx, peer, accSettings)
switch {
case expired:
if err := checkAuth(ctx, login.UserID, peer); err != nil {
// when the client sends a login request with a JWT which is used to get the user ID,
// it means that the client has already checked if it needs login and had been through the SSO flow
// so, we can skip this check and directly proceed with the login
if login.UserID == "" {
err = am.checkIFPeerNeedsLoginWithoutLock(ctx, accountID, login)
if err != nil {
return nil, nil, nil, err
}
isWriteLock = true
log.WithContext(ctx).Debugf("peer login expired, acquiring write lock")
case peer.UpdateMetaIfNew(login.Meta):
isWriteLock = true
log.WithContext(ctx).Debugf("peer changed meta, acquiring write lock")
default:
isWriteLock = false
log.WithContext(ctx).Debugf("peer meta is the same, acquiring read lock")
}
var unlock func()
if isWriteLock {
unlock = am.Store.AcquireAccountWriteLock(ctx, accountID)
} else {
unlock = am.Store.AcquireAccountReadLock(ctx, accountID)
}
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer func() {
if unlock != nil {
unlock()
@@ -665,7 +633,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
return nil, nil, nil, err
}
peer, err = account.FindPeerByPubKey(login.WireGuardPubKey)
peer, err := account.FindPeerByPubKey(login.WireGuardPubKey)
if err != nil {
return nil, nil, nil, status.NewPeerNotRegisteredError()
}
@@ -676,53 +644,39 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
}
// this flag prevents unnecessary calls to the persistent store.
shouldStoreAccount := false
shouldStorePeer := false
updateRemotePeers := false
if peerLoginExpired(ctx, peer, account.Settings) {
err = checkAuth(ctx, login.UserID, peer)
err = am.handleExpiredPeer(ctx, login, account, peer)
if err != nil {
return nil, nil, nil, err
}
// If peer was expired before and if it reached this point, it is re-authenticated.
// UserID is present, meaning that JWT validation passed successfully in the API layer.
updatePeerLastLogin(peer, account)
updateRemotePeers = true
shouldStoreAccount = true
// sync user last login with peer last login
user, err := account.FindUser(login.UserID)
if err != nil {
return nil, nil, nil, status.Errorf(status.Internal, "couldn't find user")
}
user.updateLastLogin(peer.LastLogin)
am.StoreEvent(ctx, login.UserID, peer.ID, account.Id, activity.UserLoggedInPeer, peer.EventMeta(am.GetDNSDomain()))
shouldStorePeer = true
}
isRequiresApproval, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
if err != nil {
return nil, nil, nil, err
}
peer, updated := updatePeerMeta(peer, login.Meta, account)
if updated {
shouldStoreAccount = true
shouldStorePeer = true
}
peer, err = am.checkAndUpdatePeerSSHKey(ctx, peer, account, login.SSHKey)
if err != nil {
return nil, nil, nil, err
if peer.SSHKey != login.SSHKey {
peer.SSHKey = login.SSHKey
shouldStorePeer = true
}
if shouldStoreAccount {
if !isWriteLock {
log.WithContext(ctx).Errorf("account %s should be stored but is not write locked", accountID)
return nil, nil, nil, status.Errorf(status.Internal, "account should be stored but is not write locked")
}
err = am.Store.SaveAccount(ctx, account)
if shouldStorePeer {
err = am.Store.SavePeer(ctx, accountID, peer)
if err != nil {
return nil, nil, nil, err
}
}
unlock()
unlock = nil
@@ -730,13 +684,46 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
am.updateAccountPeers(ctx, account)
}
return am.getValidatedPeerWithMap(ctx, isRequiresApproval, account, peer)
}
// checkIFPeerNeedsLoginWithoutLock checks if the peer needs login without acquiring the account lock. The check validate if the peer was not added via SSO
// and if the peer login is expired.
// The NetBird client doesn't have a way to check if the peer needs login besides sending a login request
// with no JWT token and usually no setup-key. As the client can send up to two login request to check if it is expired
// and before starting the engine, we do the checks without an account lock to avoid piling up requests.
func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Context, accountID string, login PeerLogin) error {
peer, err := am.Store.GetPeerByPeerPubKey(ctx, login.WireGuardPubKey)
if err != nil {
return err
}
// if the peer was not added with SSO login we can exit early because peers activated with setup-key
// doesn't expire, and we avoid extra databases calls.
if !peer.AddedWithSSOLogin() {
return nil
}
settings, err := am.Store.GetAccountSettings(ctx, accountID)
if err != nil {
return err
}
if peerLoginExpired(ctx, peer, settings) {
return status.NewPeerLoginExpiredError()
}
return nil
}
func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, account *Account, peer *nbpeer.Peer) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
var postureChecks []*posture.Checks
if isRequiresApproval {
emptyMap := &NetworkMap{
Network: account.Network.Copy(),
}
return peer, emptyMap, postureChecks, nil
return peer, emptyMap, nil, nil
}
approvedPeersMap, err := am.GetValidatedPeers(account)
@@ -748,6 +735,30 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
return peer, account.GetPeerNetworkMap(ctx, peer.ID, am.dnsDomain, approvedPeersMap), postureChecks, nil
}
func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, login PeerLogin, account *Account, peer *nbpeer.Peer) error {
err := checkAuth(ctx, login.UserID, peer)
if err != nil {
return err
}
// If peer was expired before and if it reached this point, it is re-authenticated.
// UserID is present, meaning that JWT validation passed successfully in the API layer.
updatePeerLastLogin(peer, account)
// sync user last login with peer last login
user, err := account.FindUser(login.UserID)
if err != nil {
return status.Errorf(status.Internal, "couldn't find user")
}
err = am.Store.SaveUserLastLogin(account.Id, user.Id, peer.LastLogin)
if err != nil {
return err
}
am.StoreEvent(ctx, login.UserID, peer.ID, account.Id, activity.UserLoggedInPeer, peer.EventMeta(am.GetDNSDomain()))
return nil
}
func checkIfPeerOwnerIsBlocked(peer *nbpeer.Peer, account *Account) error {
if peer.AddedWithSSOLogin() {
user, err := account.FindUser(peer.UserID)
@@ -764,11 +775,11 @@ func checkIfPeerOwnerIsBlocked(peer *nbpeer.Peer, account *Account) error {
func checkAuth(ctx context.Context, loginUserID string, peer *nbpeer.Peer) error {
if loginUserID == "" {
// absence of a user ID indicates that JWT wasn't provided.
return status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more")
return status.NewPeerLoginExpiredError()
}
if peer.UserID != loginUserID {
log.WithContext(ctx).Warnf("user mismatch when logging in peer %s: peer user %s, login user %s ", peer.ID, peer.UserID, loginUserID)
return status.Errorf(status.Unauthenticated, "can't login")
return status.Errorf(status.Unauthenticated, "can't login with this credentials")
}
return nil
}
@@ -788,33 +799,6 @@ func updatePeerLastLogin(peer *nbpeer.Peer, account *Account) {
account.UpdatePeer(peer)
}
func (am *DefaultAccountManager) checkAndUpdatePeerSSHKey(ctx context.Context, peer *nbpeer.Peer, account *Account, newSSHKey string) (*nbpeer.Peer, error) {
if len(newSSHKey) == 0 {
log.WithContext(ctx).Debugf("no new SSH key provided for peer %s, skipping update", peer.ID)
return peer, nil
}
if peer.SSHKey == newSSHKey {
log.WithContext(ctx).Debugf("same SSH key provided for peer %s, skipping update", peer.ID)
return peer, nil
}
peer.SSHKey = newSSHKey
account.UpdatePeer(peer)
err := am.Store.SaveAccount(ctx, account)
if err != nil {
return nil, err
}
// trigger network map update
// todo: remove this since it is called by the caller function
am.updateAccountPeers(ctx, account)
return peer, nil
}
// todo: not in use, remove it
// UpdatePeerSSHKey updates peer's public SSH key
func (am *DefaultAccountManager) UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error {
if sshKey == "" {
@@ -827,7 +811,7 @@ func (am *DefaultAccountManager) UpdatePeerSSHKey(ctx context.Context, peerID st
return err
}
unlock := am.Store.AcquireAccountWriteLock(ctx, account.Id)
unlock := am.Store.AcquireWriteLockByUID(ctx, account.Id)
defer unlock()
// ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
@@ -862,7 +846,7 @@ func (am *DefaultAccountManager) UpdatePeerSSHKey(ctx context.Context, peerID st
// GetPeer for a given accountID, peerID and userID error if not found.
func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)

View File

@@ -315,7 +315,7 @@ func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule,
// GetPolicy from the store
func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -343,7 +343,7 @@ func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, polic
// SavePolicy in the store
func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) error {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -364,7 +364,6 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
}
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
// todo: call if before and after source and destination groups are not empty
am.updateAccountPeers(ctx, account)
return nil
@@ -372,7 +371,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
// DeletePolicy from the store
func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -392,7 +391,6 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
am.StoreEvent(ctx, userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta())
// todo: call if source and destination groups are not empty
am.updateAccountPeers(ctx, account)
return nil
@@ -400,7 +398,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
// ListPolicies from the store
func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)

View File

@@ -15,7 +15,7 @@ const (
)
func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -42,7 +42,7 @@ func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID
}
func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -82,7 +82,6 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta())
if exists {
// todo: check if posture check is linked to a policy
am.updateAccountPeers(ctx, account)
}
@@ -90,7 +89,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
}
func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -122,7 +121,7 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun
}
func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)

View File

@@ -17,7 +17,7 @@ import (
// GetRoute gets a route object from account and route IDs
func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -126,7 +126,7 @@ func getRouteDescriptor(prefix netip.Prefix, domains domain.List) string {
// CreateRoute creates and saves a new route
func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -204,10 +204,7 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
if err = am.Store.SaveAccount(ctx, account); err != nil {
return nil, err
}
// todo: call if one of the three is true:
// 1. distribution groups are not empty
// 2. routing groups are not empy
// 3. there is a routing peer
am.updateAccountPeers(ctx, account)
am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta())
@@ -217,7 +214,7 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
// SaveRoute saves route
func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userID string, routeToSave *route.Route) error {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
if routeToSave == nil {
@@ -276,10 +273,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
// todo: call if one of the three is true:
// 1. before and after distribution groups are not empty
// 2. before and after routing groups are not empy
// 3. there is a routing peer
am.updateAccountPeers(ctx, account)
am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
@@ -289,7 +283,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
// DeleteRoute deletes route with routeID
func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -309,10 +303,7 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
}
am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta())
// todo: call if one of the three is true:
// 1. distribution groups are not empty
// 2. routing groups are not empy
// 3. there is a routing peer
am.updateAccountPeers(ctx, account)
return nil
@@ -320,7 +311,7 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
// ListRoutes returns a list of routes from account
func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)

View File

@@ -210,7 +210,7 @@ func Hash(s string) uint32 {
// and adds it to the specified account. A list of autoGroups IDs can be empty.
func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType SetupKeyType,
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*SetupKey, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
keyDuration := DefaultSetupKeyDuration
@@ -256,7 +256,7 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s
// (e.g. the key itself, creation date, ID, etc).
// These properties are overwritten: Name, AutoGroups, Revoked. The rest is copied from the existing key.
func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID string, keyToSave *SetupKey, userID string) (*SetupKey, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
if keyToSave == nil {
@@ -320,7 +320,7 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
}
}
}()
// todo: remove it, not needed here since we don't update anything else
am.updateAccountPeers(ctx, account)
return newKey, nil
@@ -328,7 +328,7 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
// ListSetupKeys returns a list of all setup keys of the account
func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, userID string) ([]*SetupKey, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
@@ -360,7 +360,7 @@ func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, u
// GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found.
func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)

View File

@@ -31,14 +31,16 @@ import (
)
const (
storeSqliteFileName = "store.db"
idQueryCondition = "id = ?"
storeSqliteFileName = "store.db"
idQueryCondition = "id = ?"
accountAndIDQueryCondition = "account_id = ? and id = ?"
peerNotFoundFMT = "peer %s not found"
)
// SqlStore represents an account storage backed by a Sql DB persisted to disk
type SqlStore struct {
db *gorm.DB
accountLocks sync.Map
resourceLocks sync.Map
globalAccountLock sync.Mutex
metrics telemetry.AppMetrics
installationPK int
@@ -96,33 +98,35 @@ func (s *SqlStore) AcquireGlobalLock(ctx context.Context) (unlock func()) {
return unlock
}
func (s *SqlStore) AcquireAccountWriteLock(ctx context.Context, accountID string) (unlock func()) {
log.WithContext(ctx).Tracef("acquiring write lock for account %s", accountID)
// AcquireWriteLockByUID acquires an ID lock for writing to a resource and returns a function that releases the lock
func (s *SqlStore) AcquireWriteLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
log.WithContext(ctx).Tracef("acquiring write lock for ID %s", uniqueID)
start := time.Now()
value, _ := s.accountLocks.LoadOrStore(accountID, &sync.RWMutex{})
value, _ := s.resourceLocks.LoadOrStore(uniqueID, &sync.RWMutex{})
mtx := value.(*sync.RWMutex)
mtx.Lock()
unlock = func() {
mtx.Unlock()
log.WithContext(ctx).Tracef("released write lock for account %s in %v", accountID, time.Since(start))
log.WithContext(ctx).Tracef("released write lock for ID %s in %v", uniqueID, time.Since(start))
}
return unlock
}
func (s *SqlStore) AcquireAccountReadLock(ctx context.Context, accountID string) (unlock func()) {
log.WithContext(ctx).Tracef("acquiring read lock for account %s", accountID)
// AcquireReadLockByUID acquires an ID lock for writing to a resource and returns a function that releases the lock
func (s *SqlStore) AcquireReadLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
log.WithContext(ctx).Tracef("acquiring read lock for ID %s", uniqueID)
start := time.Now()
value, _ := s.accountLocks.LoadOrStore(accountID, &sync.RWMutex{})
value, _ := s.resourceLocks.LoadOrStore(uniqueID, &sync.RWMutex{})
mtx := value.(*sync.RWMutex)
mtx.RLock()
unlock = func() {
mtx.RUnlock()
log.WithContext(ctx).Tracef("released read lock for account %s in %v", accountID, time.Since(start))
log.WithContext(ctx).Tracef("released read lock for ID %s in %v", uniqueID, time.Since(start))
}
return unlock
@@ -271,19 +275,56 @@ func (s *SqlStore) GetInstallationID() string {
return installation.InstallationIDValue
}
func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error {
// To maintain data integrity, we create a copy of the peer's to prevent unintended updates to other fields.
peerCopy := peer.Copy()
peerCopy.AccountID = accountID
err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// check if peer exists before saving
var peerID string
result := tx.Model(&nbpeer.Peer{}).Select("id").Find(&peerID, accountAndIDQueryCondition, accountID, peer.ID)
if result.Error != nil {
return result.Error
}
if peerID == "" {
return status.Errorf(status.NotFound, peerNotFoundFMT, peer.ID)
}
result = tx.Model(&nbpeer.Peer{}).Where(accountAndIDQueryCondition, accountID, peer.ID).Save(peerCopy)
if result.Error != nil {
return result.Error
}
return nil
})
if err != nil {
return err
}
return nil
}
func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error {
var peerCopy nbpeer.Peer
peerCopy.Status = &peerStatus
result := s.db.Model(&nbpeer.Peer{}).
Where("account_id = ? AND id = ?", accountID, peerID).
Updates(peerCopy)
fieldsToUpdate := []string{
"peer_status_last_seen", "peer_status_connected",
"peer_status_login_expired", "peer_status_required_approval",
}
result := s.db.Model(&nbpeer.Peer{}).
Select(fieldsToUpdate).
Where(accountAndIDQueryCondition, accountID, peerID).
Updates(&peerCopy)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "peer %s not found", peerID)
return status.Errorf(status.NotFound, peerNotFoundFMT, peerID)
}
return nil
@@ -297,7 +338,7 @@ func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.P
peerCopy.Location = peerWithLocation.Location
result := s.db.Model(&nbpeer.Peer{}).
Where("account_id = ? and id = ?", accountID, peerWithLocation.ID).
Where(accountAndIDQueryCondition, accountID, peerWithLocation.ID).
Updates(peerCopy)
if result.Error != nil {
@@ -305,12 +346,40 @@ func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.P
}
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "peer %s not found", peerWithLocation.ID)
return status.Errorf(status.NotFound, peerNotFoundFMT, peerWithLocation.ID)
}
return nil
}
// SaveUsers saves the given list of users to the database.
// It updates existing users if a conflict occurs.
func (s *SqlStore) SaveUsers(accountID string, users map[string]*User) error {
usersToSave := make([]User, 0, len(users))
for _, user := range users {
user.AccountID = accountID
for id, pat := range user.PATs {
pat.ID = id
user.PATsG = append(user.PATsG, *pat)
}
usersToSave = append(usersToSave, *user)
}
return s.db.Session(&gorm.Session{FullSaveAssociations: true}).
Clauses(clause.OnConflict{UpdateAll: true}).
Create(&usersToSave).Error
}
// SaveGroups saves the given list of groups to the database.
// It updates existing groups if a conflict occurs.
func (s *SqlStore) SaveGroups(accountID string, groups map[string]*nbgroup.Group) error {
groupsToSave := make([]nbgroup.Group, 0, len(groups))
for _, group := range groups {
group.AccountID = accountID
groupsToSave = append(groupsToSave, *group)
}
return s.db.Clauses(clause.OnConflict{UpdateAll: true}).Create(&groupsToSave).Error
}
// DeleteHashedPAT2TokenIDIndex is noop in SqlStore
func (s *SqlStore) DeleteHashedPAT2TokenIDIndex(hashedToken string) error {
return nil
@@ -611,7 +680,7 @@ func (s *SqlStore) GetAccountSettings(ctx context.Context, accountID string) (*S
func (s *SqlStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error {
var user User
result := s.db.First(&user, "account_id = ? and id = ?", accountID, userID)
result := s.db.First(&user, accountAndIDQueryCondition, accountID, userID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.Errorf(status.NotFound, "user %s not found", userID)
@@ -662,11 +731,7 @@ func NewSqliteStore(ctx context.Context, dataDir string, metrics telemetry.AppMe
}
file := filepath.Join(dataDir, storeStr)
db, err := gorm.Open(sqlite.Open(file), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
CreateBatchSize: 400,
PrepareStmt: true,
})
db, err := gorm.Open(sqlite.Open(file), getGormConfig())
if err != nil {
return nil, err
}
@@ -676,10 +741,7 @@ func NewSqliteStore(ctx context.Context, dataDir string, metrics telemetry.AppMe
// NewPostgresqlStore creates a new Postgres store.
func NewPostgresqlStore(ctx context.Context, dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) {
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
PrepareStmt: true,
})
db, err := gorm.Open(postgres.Open(dsn), getGormConfig())
if err != nil {
return nil, err
}
@@ -687,6 +749,14 @@ func NewPostgresqlStore(ctx context.Context, dsn string, metrics telemetry.AppMe
return NewSqlStore(ctx, db, PostgresStoreEngine, metrics)
}
func getGormConfig() *gorm.Config {
return &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
CreateBatchSize: 400,
PrepareStmt: true,
}
}
// newPostgresStore initializes a new Postgres store.
func newPostgresStore(ctx context.Context, metrics telemetry.AppMetrics) (Store, error) {
dsn, ok := os.LookupEnv(postgresDsnEnv)

View File

@@ -41,11 +41,22 @@ func TestSqlite_NewStore(t *testing.T) {
}
func TestSqlite_SaveAccount_Large(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
if runtime.GOOS != "linux" && os.Getenv("CI") == "true" || runtime.GOOS == "windows" {
t.Skip("skip large test on non-linux OS due to environment restrictions")
}
t.Run("SQLite", func(t *testing.T) {
store := newSqliteStore(t)
runLargeTest(t, store)
})
// create store outside to have a better time counter for the test
store := newPostgresqlStore(t)
t.Run("PostgreSQL", func(t *testing.T) {
runLargeTest(t, store)
})
}
store := newSqliteStore(t)
func runLargeTest(t *testing.T, store Store) {
t.Helper()
account := newAccountWithId(context.Background(), "account_id", "testuser", "")
groupALL, err := account.GetGroupAll()
@@ -54,7 +65,7 @@ func TestSqlite_SaveAccount_Large(t *testing.T) {
}
setupKey := GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey
const numPerAccount = 2000
const numPerAccount = 6000
for n := 0; n < numPerAccount; n++ {
netIP := randomIPv4()
peerID := fmt.Sprintf("%s-peer-%d", account.Id, n)
@@ -351,6 +362,54 @@ func TestSqlite_GetAccount(t *testing.T) {
require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error")
}
func TestSqlite_SavePeer(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store := newSqliteStoreFromFile(t, "testdata/store.json")
account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b")
require.NoError(t, err)
// save status of non-existing peer
peer := &nbpeer.Peer{
Key: "peerkey",
ID: "testpeer",
SetupKey: "peerkeysetupkey",
IP: net.IP{127, 0, 0, 1},
Meta: nbpeer.PeerSystemMeta{Hostname: "testingpeer"},
Name: "peer name",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
}
ctx := context.Background()
err = store.SavePeer(ctx, account.Id, peer)
assert.Error(t, err)
parsedErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error")
// save new status of existing peer
account.Peers[peer.ID] = peer
err = store.SaveAccount(context.Background(), account)
require.NoError(t, err)
updatedPeer := peer.Copy()
updatedPeer.Status.Connected = false
updatedPeer.Meta.Hostname = "updatedpeer"
err = store.SavePeer(ctx, account.Id, updatedPeer)
require.NoError(t, err)
account, err = store.GetAccount(context.Background(), account.Id)
require.NoError(t, err)
actual := account.Peers[peer.ID]
assert.Equal(t, updatedPeer.Status, actual.Status)
assert.Equal(t, updatedPeer.Meta, actual.Meta)
}
func TestSqlite_SavePeerStatus(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
@@ -362,7 +421,7 @@ func TestSqlite_SavePeerStatus(t *testing.T) {
require.NoError(t, err)
// save status of non-existing peer
newStatus := nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}
newStatus := nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().UTC()}
err = store.SavePeerStatus(account.Id, "non-existing-peer", newStatus)
assert.Error(t, err)
parsedErr, ok := status.FromError(err)
@@ -377,7 +436,7 @@ func TestSqlite_SavePeerStatus(t *testing.T) {
IP: net.IP{127, 0, 0, 1},
Meta: nbpeer.PeerSystemMeta{},
Name: "peer name",
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().UTC()},
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
}
err = store.SaveAccount(context.Background(), account)
@@ -391,7 +450,19 @@ func TestSqlite_SavePeerStatus(t *testing.T) {
actual := account.Peers["testpeer"].Status
assert.Equal(t, newStatus, *actual)
newStatus.Connected = true
err = store.SavePeerStatus(account.Id, "testpeer", newStatus)
require.NoError(t, err)
account, err = store.GetAccount(context.Background(), account.Id)
require.NoError(t, err)
actual = account.Peers["testpeer"].Status
assert.Equal(t, newStatus, *actual)
}
func TestSqlite_SavePeerLocation(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")

View File

@@ -95,3 +95,8 @@ func NewUserNotFoundError(userKey string) error {
func NewPeerNotRegisteredError() error {
return Errorf(Unauthenticated, "peer is not registered")
}
// NewPeerLoginExpiredError creates a new Error with PermissionDenied type for an expired peer
func NewPeerLoginExpiredError() error {
return Errorf(PermissionDenied, "peer login has expired, please log in once more")
}

View File

@@ -15,6 +15,8 @@ import (
log "github.com/sirupsen/logrus"
"gorm.io/gorm"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/util"
@@ -41,16 +43,19 @@ type Store interface {
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
SaveAccount(ctx context.Context, account *Account) error
SaveUsers(accountID string, users map[string]*User) error
SaveGroups(accountID string, groups map[string]*nbgroup.Group) error
DeleteHashedPAT2TokenIDIndex(hashedToken string) error
DeleteTokenID2UserIDIndex(tokenID string) error
GetInstallationID() string
SaveInstallationID(ctx context.Context, ID string) error
// AcquireAccountWriteLock should attempt to acquire account lock for write purposes and return a function that releases the lock
AcquireAccountWriteLock(ctx context.Context, accountID string) func()
// AcquireAccountReadLock should attempt to acquire account lock for read purposes and return a function that releases the lock
AcquireAccountReadLock(ctx context.Context, accountID string) func()
// AcquireWriteLockByUID should attempt to acquire a lock for write purposes and return a function that releases the lock
AcquireWriteLockByUID(ctx context.Context, uniqueID string) func()
// AcquireReadLockByUID should attempt to acquire lock for read purposes and return a function that releases the lock
AcquireReadLockByUID(ctx context.Context, uniqueID string) func()
// AcquireGlobalLock should attempt to acquire a global lock and return a function that releases the lock
AcquireGlobalLock(ctx context.Context) func()
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
SavePeerLocation(accountID string, peer *nbpeer.Peer) error
SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error

View File

@@ -211,7 +211,7 @@ func NewOwnerUser(id string) *User {
// createServiceUser creates a new service user under the given account.
func (am *DefaultAccountManager) createServiceUser(ctx context.Context, accountID string, initiatorUserID string, role UserRole, serviceUserName string, nonDeletable bool, autoGroups []string) (*UserInfo, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -267,7 +267,7 @@ func (am *DefaultAccountManager) CreateUser(ctx context.Context, accountID, user
// inviteNewUser Invites a USer to a given account and creates reference in datastore
func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, userID string, invite *UserInfo) (*UserInfo, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
if am.idpManager == nil {
@@ -368,7 +368,7 @@ func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.A
return nil, fmt.Errorf("failed to get account with token claims %v", err)
}
unlock := am.Store.AcquireAccountWriteLock(ctx, account.Id)
unlock := am.Store.AcquireWriteLockByUID(ctx, account.Id)
defer unlock()
account, err = am.Store.GetAccount(ctx, account.Id)
@@ -401,7 +401,7 @@ func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.A
// ListUsers returns lists of all users under the account.
// It doesn't populate user information such as email or name.
func (am *DefaultAccountManager) ListUsers(ctx context.Context, accountID string) ([]*User, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -428,7 +428,7 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init
if initiatorUserID == targetUserID {
return status.Errorf(status.InvalidArgument, "self deletion is not allowed")
}
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -517,7 +517,6 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, account
meta := map[string]any{"name": tuName, "email": tuEmail, "created_at": tuCreatedAt}
am.StoreEvent(ctx, initiatorUserID, targetUserID, account.Id, activity.UserDeleted, meta)
// todo: call only if user had a peer linked to it and peer propagation is enabled
am.updateAccountPeers(ctx, account)
return nil
@@ -539,7 +538,7 @@ func (am *DefaultAccountManager) deleteUserPeers(ctx context.Context, initiatorU
// InviteUser resend invitations to users who haven't activated their accounts prior to the expiration period.
func (am *DefaultAccountManager) InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
if am.idpManager == nil {
@@ -579,7 +578,7 @@ func (am *DefaultAccountManager) InviteUser(ctx context.Context, accountID strin
// CreatePAT creates a new PAT for the given user
func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
if tokenName == "" {
@@ -629,7 +628,7 @@ func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string
// DeletePAT deletes a specific PAT from a user
func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -679,7 +678,7 @@ func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string
// GetPAT returns a specific PAT from a user
func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -711,7 +710,7 @@ func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, i
// GetAllPATs returns all PATs for a user
func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -741,7 +740,7 @@ func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID strin
return pats, nil
}
// SaveUser saves updates to the given user. If the user doesn't exit it will throw status.NotFound error.
// SaveUser saves updates to the given user. If the user doesn't exist, it will throw status.NotFound error.
func (am *DefaultAccountManager) SaveUser(ctx context.Context, accountID, initiatorUserID string, update *User) (*UserInfo, error) {
return am.SaveOrAddUser(ctx, accountID, initiatorUserID, update, false) // false means do not create user and throw status.NotFound
}
@@ -749,13 +748,33 @@ func (am *DefaultAccountManager) SaveUser(ctx context.Context, accountID, initia
// SaveOrAddUser updates the given user. If addIfNotExists is set to true it will add user when no exist
// Only User.AutoGroups, User.Role, and User.Blocked fields are allowed to be updated for now.
func (am *DefaultAccountManager) SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock()
if update == nil {
return nil, status.Errorf(status.InvalidArgument, "provided user update is nil")
}
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
updatedUsers, err := am.SaveOrAddUsers(ctx, accountID, initiatorUserID, []*User{update}, addIfNotExists)
if err != nil {
return nil, err
}
if len(updatedUsers) == 0 {
return nil, status.Errorf(status.Internal, "user was not updated")
}
return updatedUsers[0], nil
}
// SaveOrAddUsers updates existing users or adds new users to the account.
// Note: This function does not acquire the global lock.
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error) {
if len(updates) == 0 {
return nil, nil //nolint:nilnil
}
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, err
@@ -770,145 +789,200 @@ func (am *DefaultAccountManager) SaveOrAddUser(ctx context.Context, accountID, i
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are authorized to perform user update operations")
}
oldUser := account.Users[update.Id]
if oldUser == nil {
if !addIfNotExists {
return nil, status.Errorf(status.NotFound, "user to update doesn't exist")
updatedUsers := make([]*UserInfo, 0, len(updates))
var (
expiredPeers []*nbpeer.Peer
eventsToStore []func()
)
for _, update := range updates {
if update == nil {
return nil, status.Errorf(status.InvalidArgument, "provided user update is nil")
}
// when addIfNotExists is set to true the newUser will use all fields from the update input
oldUser = update
}
if initiatorUser.HasAdminPower() && initiatorUserID == update.Id && oldUser.Blocked != update.Blocked {
return nil, status.Errorf(status.PermissionDenied, "admins can't block or unblock themselves")
}
if initiatorUser.HasAdminPower() && initiatorUserID == update.Id && update.Role != initiatorUser.Role {
return nil, status.Errorf(status.PermissionDenied, "admins can't change their role")
}
if initiatorUser.Role == UserRoleAdmin && oldUser.Role == UserRoleOwner && update.Role != oldUser.Role {
return nil, status.Errorf(status.PermissionDenied, "only owners can remove owner role from their user")
}
if initiatorUser.Role == UserRoleAdmin && oldUser.Role == UserRoleOwner && update.IsBlocked() && !oldUser.IsBlocked() {
return nil, status.Errorf(status.PermissionDenied, "unable to block owner user")
}
if initiatorUser.Role == UserRoleAdmin && update.Role == UserRoleOwner && update.Role != oldUser.Role {
return nil, status.Errorf(status.PermissionDenied, "only owners can add owner role to other users")
}
if oldUser.IsServiceUser && update.Role == UserRoleOwner {
return nil, status.Errorf(status.PermissionDenied, "can't update a service user with owner role")
}
transferedOwnerRole := false
if initiatorUser.Role == UserRoleOwner && initiatorUserID != update.Id && update.Role == UserRoleOwner {
newInitiatorUser := initiatorUser.Copy()
newInitiatorUser.Role = UserRoleAdmin
account.Users[initiatorUserID] = newInitiatorUser
transferedOwnerRole = true
}
// only auto groups, revoked status, and integration reference can be updated for now
newUser := oldUser.Copy()
newUser.Role = update.Role
newUser.Blocked = update.Blocked
// these two fields can't be set via API, only via direct call to the method
newUser.Issued = update.Issued
newUser.IntegrationReference = update.IntegrationReference
for _, newGroupID := range update.AutoGroups {
if _, ok := account.Groups[newGroupID]; !ok {
return nil, status.Errorf(status.InvalidArgument, "provided group ID %s in the user %s update doesn't exist",
newGroupID, update.Id)
oldUser := account.Users[update.Id]
if oldUser == nil {
if !addIfNotExists {
return nil, status.Errorf(status.NotFound, "user to update doesn't exist: %s", update.Id)
}
// when addIfNotExists is set to true, the newUser will use all fields from the update input
oldUser = update
}
}
newUser.AutoGroups = update.AutoGroups
account.Users[newUser.Id] = newUser
if !oldUser.IsBlocked() && update.IsBlocked() {
// expire peers that belong to the user who's getting blocked
blockedPeers, err := account.FindUserPeers(update.Id)
if err != nil {
if err := validateUserUpdate(account, initiatorUser, oldUser, update); err != nil {
return nil, err
}
if err := am.expireAndUpdatePeers(ctx, account, blockedPeers); err != nil {
// only auto groups, revoked status, and integration reference can be updated for now
newUser := oldUser.Copy()
newUser.Role = update.Role
newUser.Blocked = update.Blocked
newUser.AutoGroups = update.AutoGroups
// these two fields can't be set via API, only via direct call to the method
newUser.Issued = update.Issued
newUser.IntegrationReference = update.IntegrationReference
transferredOwnerRole := handleOwnerRoleTransfer(account, initiatorUser, update)
account.Users[newUser.Id] = newUser
if !oldUser.IsBlocked() && update.IsBlocked() {
// expire peers that belong to the user who's getting blocked
blockedPeers, err := account.FindUserPeers(update.Id)
if err != nil {
return nil, err
}
expiredPeers = append(expiredPeers, blockedPeers...)
}
if update.AutoGroups != nil && account.Settings.GroupsPropagationEnabled {
removedGroups := difference(oldUser.AutoGroups, update.AutoGroups)
// need force update all auto groups in any case they will not be duplicated
account.UserGroupsAddToPeers(oldUser.Id, update.AutoGroups...)
account.UserGroupsRemoveFromPeers(oldUser.Id, removedGroups...)
}
events := am.prepareUserUpdateEvents(ctx, initiatorUser.Id, oldUser, newUser, account, transferredOwnerRole)
eventsToStore = append(eventsToStore, events...)
updatedUserInfo, err := getUserInfo(ctx, am, newUser, account)
if err != nil {
return nil, err
}
updatedUsers = append(updatedUsers, updatedUserInfo)
}
if len(expiredPeers) > 0 {
if err := am.expireAndUpdatePeers(ctx, account, expiredPeers); err != nil {
log.WithContext(ctx).Errorf("failed update expired peers: %s", err)
return nil, err
}
}
if update.AutoGroups != nil && account.Settings.GroupsPropagationEnabled {
removedGroups := difference(oldUser.AutoGroups, update.AutoGroups)
// need force update all auto groups in any case they will not be duplicated
account.UserGroupsAddToPeers(oldUser.Id, update.AutoGroups...)
account.UserGroupsRemoveFromPeers(oldUser.Id, removedGroups...)
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return nil, err
}
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return nil, err
}
// todo: call only if is existing user, it has a peer linked to it and peer propagation is enabled
// new users don't need to call this
if account.Settings.GroupsPropagationEnabled {
am.updateAccountPeers(ctx, account)
} else {
if err = am.Store.SaveAccount(ctx, account); err != nil {
return nil, err
}
for _, storeEvent := range eventsToStore {
storeEvent()
}
return updatedUsers, nil
}
// prepareUserUpdateEvents prepares a list user update events based on the changes between the old and new user data.
func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, initiatorUserID string, oldUser, newUser *User, account *Account, transferredOwnerRole bool) []func() {
var eventsToStore []func()
if oldUser.IsBlocked() != newUser.IsBlocked() {
if newUser.IsBlocked() {
eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.UserBlocked, nil)
})
} else {
eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.UserUnblocked, nil)
})
}
}
defer func() {
if oldUser.IsBlocked() != update.IsBlocked() {
if update.IsBlocked() {
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.UserBlocked, nil)
switch {
case transferredOwnerRole:
eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.TransferredOwnerRole, nil)
})
case oldUser.Role != newUser.Role:
eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.UserRoleUpdated, map[string]any{"role": newUser.Role})
})
}
if newUser.AutoGroups != nil {
removedGroups := difference(oldUser.AutoGroups, newUser.AutoGroups)
addedGroups := difference(newUser.AutoGroups, oldUser.AutoGroups)
for _, g := range removedGroups {
group := account.GetGroup(g)
if group != nil {
eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.GroupRemovedFromUser,
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
})
} else {
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.UserUnblocked, nil)
log.WithContext(ctx).Errorf("group %s not found while saving user activity event of account %s", g, account.Id)
}
}
switch {
case transferedOwnerRole:
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.TransferredOwnerRole, nil)
case oldUser.Role != newUser.Role:
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.UserRoleUpdated, map[string]any{"role": newUser.Role})
default:
}
if update.AutoGroups != nil {
removedGroups := difference(oldUser.AutoGroups, update.AutoGroups)
addedGroups := difference(newUser.AutoGroups, oldUser.AutoGroups)
for _, g := range removedGroups {
group := account.GetGroup(g)
if group != nil {
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.GroupRemovedFromUser,
for _, g := range addedGroups {
group := account.GetGroup(g)
if group != nil {
eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.GroupAddedToUser,
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
} else {
log.WithContext(ctx).Errorf("group %s not found while saving user activity event of account %s", g, account.Id)
}
}
for _, g := range addedGroups {
group := account.GetGroup(g)
if group != nil {
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.GroupAddedToUser,
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
}
})
}
}
}()
}
if !isNil(am.idpManager) && !newUser.IsServiceUser {
userData, err := am.lookupUserInCache(ctx, newUser.Id, account)
return eventsToStore
}
func handleOwnerRoleTransfer(account *Account, initiatorUser, update *User) bool {
if initiatorUser.Role == UserRoleOwner && initiatorUser.Id != update.Id && update.Role == UserRoleOwner {
newInitiatorUser := initiatorUser.Copy()
newInitiatorUser.Role = UserRoleAdmin
account.Users[initiatorUser.Id] = newInitiatorUser
return true
}
return false
}
// getUserInfo retrieves the UserInfo for a given User and Account.
// If the AccountManager has a non-nil idpManager and the User is not a service user,
// it will attempt to look up the UserData from the cache.
func getUserInfo(ctx context.Context, am *DefaultAccountManager, user *User, account *Account) (*UserInfo, error) {
if !isNil(am.idpManager) && !user.IsServiceUser {
userData, err := am.lookupUserInCache(ctx, user.Id, account)
if err != nil {
return nil, err
}
return newUser.ToUserInfo(userData, account.Settings)
return user.ToUserInfo(userData, account.Settings)
}
return newUser.ToUserInfo(nil, account.Settings)
return user.ToUserInfo(nil, account.Settings)
}
// validateUserUpdate validates the update operation for a user.
func validateUserUpdate(account *Account, initiatorUser, oldUser, update *User) error {
if initiatorUser.HasAdminPower() && initiatorUser.Id == update.Id && oldUser.Blocked != update.Blocked {
return status.Errorf(status.PermissionDenied, "admins can't block or unblock themselves")
}
if initiatorUser.HasAdminPower() && initiatorUser.Id == update.Id && update.Role != initiatorUser.Role {
return status.Errorf(status.PermissionDenied, "admins can't change their role")
}
if initiatorUser.Role == UserRoleAdmin && oldUser.Role == UserRoleOwner && update.Role != oldUser.Role {
return status.Errorf(status.PermissionDenied, "only owners can remove owner role from their user")
}
if initiatorUser.Role == UserRoleAdmin && oldUser.Role == UserRoleOwner && update.IsBlocked() && !oldUser.IsBlocked() {
return status.Errorf(status.PermissionDenied, "unable to block owner user")
}
if initiatorUser.Role == UserRoleAdmin && update.Role == UserRoleOwner && update.Role != oldUser.Role {
return status.Errorf(status.PermissionDenied, "only owners can add owner role to other users")
}
if oldUser.IsServiceUser && update.Role == UserRoleOwner {
return status.Errorf(status.PermissionDenied, "can't update a service user with owner role")
}
for _, newGroupID := range update.AutoGroups {
if _, ok := account.Groups[newGroupID]; !ok {
return status.Errorf(status.InvalidArgument, "provided group ID %s in the user %s update doesn't exist",
newGroupID, update.Id)
}
}
return nil
}
// GetOrCreateAccountByUser returns an existing account for a given user id or creates a new one if doesn't exist
@@ -939,7 +1013,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(ctx context.Context, u
userObj := account.Users[userID]
if account.Domain != lowerDomain && userObj.Role == UserRoleOwner {
if lowerDomain != "" && account.Domain != lowerDomain && userObj.Role == UserRoleOwner {
account.Domain = lowerDomain
err = am.Store.SaveAccount(ctx, account)
if err != nil {

View File

@@ -18,6 +18,8 @@ Flags:
--letsencrypt-domain string a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS
--port int Server port to listen on (e.g. 10000) (default 10000)
--ssl-dir string server ssl directory location. *Required only for Let's Encrypt certificates. (default "/var/lib/netbird/")
--cert-file string Location of your SSL certificate. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect
--cert-key string Location of your SSL certificate private key. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect
Global Flags:
--log-file string sets Netbird log path. If console is specified the the log will be output to stdout (default "/var/log/netbird/signal.log")
@@ -90,6 +92,9 @@ The Signal Server exposes the following metrics in Prometheus format:
- **registration_delay_milliseconds**: A Histogram metric that measures the time
it took to register a peer in
milliseconds.
- **get_registration_delay_milliseconds**: A Histogram metric that measures the time
it took to get a peer registration in
milliseconds.
- **messages_forwarded_total**: A Counter metric that counts the total number of
messages forwarded between peers.
- **message_forward_failures_total**: A Counter metric that counts the total

View File

@@ -2,7 +2,6 @@ package client
import (
"context"
"crypto/tls"
"fmt"
"io"
"sync"
@@ -14,9 +13,6 @@ import (
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
@@ -64,28 +60,21 @@ func (c *GrpcClient) Close() error {
// NewClient creates a new Signal client
func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled bool) (*GrpcClient, error) {
var conn *grpc.ClientConn
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
if tlsEnabled {
transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{}))
operation := func() error {
var err error
conn, err = nbgrpc.CreateConnection(addr, tlsEnabled)
if err != nil {
log.Printf("createConnection error: %v", err)
return err
}
return nil
}
sigCtx, cancel := context.WithTimeout(ctx, client.ConnectTimeout)
defer cancel()
conn, err := grpc.DialContext(
sigCtx,
addr,
transportOption,
nbgrpc.WithCustomDialer(),
grpc.WithBlock(),
grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 30 * time.Second,
Timeout: 10 * time.Second,
}))
err := backoff.Retry(operation, nbgrpc.Backoff(ctx))
if err != nil {
log.Errorf("failed to connect to the signalling server %v", err)
log.Errorf("failed to connect to the signalling server: %v", err)
return nil, err
}
@@ -408,7 +397,7 @@ func (c *GrpcClient) receive(stream proto.SignalExchange_ConnectStreamClient,
if err != nil {
log.Errorf("error while handling message of Peer [key: %s] error: [%s]", msg.Key, err.Error())
//todo send something??
// todo send something??
}
}
}

View File

@@ -2,15 +2,12 @@ package cmd
import (
"context"
"crypto/tls"
"errors"
"flag"
"fmt"
"io"
"io/fs"
"net"
"net/http"
"os"
"path"
"strings"
"time"
@@ -41,7 +38,8 @@ var (
signalLetsencryptDomain string
signalSSLDir string
defaultSignalSSLDir string
tlsEnabled bool
signalCertFile string
signalCertKey string
signalKaep = grpc.KeepaliveEnforcementPolicy(keepalive.EnforcementPolicy{
MinTime: 5 * time.Second,
@@ -56,12 +54,22 @@ var (
})
runCmd = &cobra.Command{
Use: "run",
Short: "start NetBird Signal Server daemon",
Use: "run",
Short: "start NetBird Signal Server daemon",
SilenceUsage: true,
PreRun: func(cmd *cobra.Command, args []string) {
err := util.InitLog(logLevel, logFile)
if err != nil {
log.Fatalf("failed initializing log %v", err)
}
flag.Parse()
// detect whether user specified a port
userPort := cmd.Flag("port").Changed
if signalLetsencryptDomain != "" {
tlsEnabled := false
if signalLetsencryptDomain != "" || (signalCertFile != "" && signalCertKey != "") {
tlsEnabled = true
}
@@ -77,33 +85,12 @@ var (
RunE: func(cmd *cobra.Command, args []string) error {
flag.Parse()
err := util.InitLog(logLevel, logFile)
opts, certManager, err := getTLSConfigurations()
if err != nil {
log.Fatalf("failed initializing log %v", err)
return err
}
if signalSSLDir == "" {
oldPath := "/var/lib/wiretrustee"
if migrateToNetbird(oldPath, defaultSignalSSLDir) {
if err := cpDir(oldPath, defaultSignalSSLDir); err != nil {
log.Fatal(err)
}
}
}
var opts []grpc.ServerOption
var certManager *autocert.Manager
if tlsEnabled {
// Let's encrypt enabled -> generate certificate automatically
certManager, err = encryption.CreateCertManager(signalSSLDir, signalLetsencryptDomain)
if err != nil {
return err
}
transportCredentials := credentials.NewTLS(certManager.TLSConfig())
opts = append(opts, grpc.Creds(transportCredentials))
}
metricsServer := metrics.NewServer(metricsPort, "")
metricsServer, err := metrics.NewServer(metricsPort, "")
if err != nil {
return fmt.Errorf("setup metrics: %v", err)
}
@@ -124,7 +111,25 @@ var (
}
proto.RegisterSignalExchangeServer(grpcServer, srv)
grpcRootHandler := grpcHandlerFunc(grpcServer)
if certManager != nil {
startServerWithCertManager(certManager, grpcRootHandler)
}
var compatListener net.Listener
var grpcListener net.Listener
var httpListener net.Listener
// If certManager is configured and signalPort == 443, then the gRPC server has already been started
if certManager == nil || signalPort != 443 {
grpcListener, err = serveGRPC(grpcServer, signalPort)
if err != nil {
return err
}
log.Infof("running gRPC server: %s", grpcListener.Addr().String())
}
if signalPort != 10000 {
// The Signal gRPC server was running on port 10000 previously. Old agents that are already connected to Signal
// are using port 10000. For compatibility purposes we keep running a 2nd gRPC server on port 10000.
@@ -135,28 +140,6 @@ var (
log.Infof("running gRPC backward compatibility server: %s", compatListener.Addr().String())
}
var grpcListener net.Listener
var httpListener net.Listener
if tlsEnabled {
httpListener = certManager.Listener()
if signalPort == 443 {
// running gRPC and HTTP cert manager on the same port
serveHTTP(httpListener, certManager.HTTPHandler(grpcHandlerFunc(grpcServer)))
log.Infof("running HTTP server (LetsEncrypt challenge handler) and gRPC server on the same port: %s", httpListener.Addr().String())
} else {
serveHTTP(httpListener, certManager.HTTPHandler(nil))
log.Infof("running HTTP server (LetsEncrypt challenge handler): %s", httpListener.Addr().String())
}
}
if signalPort != 443 || !tlsEnabled {
grpcListener, err = serveGRPC(grpcServer, signalPort)
if err != nil {
return err
}
log.Infof("running gRPC server: %s", grpcListener.Addr().String())
}
log.Infof("signal server version %s", version.NetbirdVersion())
log.Infof("started Signal Service")
@@ -190,6 +173,58 @@ var (
}
)
func getTLSConfigurations() ([]grpc.ServerOption, *autocert.Manager, error) {
var (
err error
certManager *autocert.Manager
tlsConfig *tls.Config
)
if signalLetsencryptDomain == "" && signalCertFile == "" && signalCertKey == "" {
log.Infof("running without TLS")
return nil, nil, nil
}
if signalLetsencryptDomain != "" {
certManager, err = encryption.CreateCertManager(signalSSLDir, signalLetsencryptDomain)
if err != nil {
return nil, certManager, err
}
tlsConfig = certManager.TLSConfig()
log.Infof("setting up TLS with LetsEncrypt.")
} else {
if signalCertFile == "" || signalCertKey == "" {
log.Errorf("both cert-file and cert-key must be provided when not using LetsEncrypt")
return nil, certManager, errors.New("both cert-file and cert-key must be provided when not using LetsEncrypt")
}
tlsConfig, err = loadTLSConfig(signalCertFile, signalCertKey)
if err != nil {
log.Errorf("cannot load TLS credentials: %v", err)
return nil, certManager, err
}
log.Infof("setting up TLS with custom certificates.")
}
transportCredentials := credentials.NewTLS(tlsConfig)
return []grpc.ServerOption{grpc.Creds(transportCredentials)}, certManager, err
}
func startServerWithCertManager(certManager *autocert.Manager, grpcRootHandler http.Handler) {
// a call to certManager.Listener() always creates a new listener so we do it once
httpListener := certManager.Listener()
if signalPort == 443 {
// running gRPC and HTTP cert manager on the same port
serveHTTP(httpListener, certManager.HTTPHandler(grpcRootHandler))
log.Infof("running HTTP server (LetsEncrypt challenge handler) and gRPC server on the same port: %s", httpListener.Addr().String())
} else {
// Start the HTTP cert manager server separately
serveHTTP(httpListener, certManager.HTTPHandler(nil))
log.Infof("running HTTP server (LetsEncrypt challenge handler): %s", httpListener.Addr().String())
}
}
func grpcHandlerFunc(grpcServer *grpc.Server) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
grpcHeader := strings.HasPrefix(r.Header.Get("Content-Type"), "application/grpc") ||
@@ -232,95 +267,29 @@ func serveGRPC(grpcServer *grpc.Server, port int) (net.Listener, error) {
return listener, nil
}
func cpFile(src, dst string) error {
var err error
var srcfd *os.File
var dstfd *os.File
var srcinfo os.FileInfo
if srcfd, err = os.Open(src); err != nil {
return err
}
defer srcfd.Close()
if dstfd, err = os.Create(dst); err != nil {
return err
}
defer dstfd.Close()
if _, err = io.Copy(dstfd, srcfd); err != nil {
return err
}
if srcinfo, err = os.Stat(src); err != nil {
return err
}
return os.Chmod(dst, srcinfo.Mode())
}
func copySymLink(source, dest string) error {
link, err := os.Readlink(source)
func loadTLSConfig(certFile string, certKey string) (*tls.Config, error) {
// Load server's certificate and private key
serverCert, err := tls.LoadX509KeyPair(certFile, certKey)
if err != nil {
return err
}
return os.Symlink(link, dest)
}
func cpDir(src string, dst string) error {
var err error
var fds []os.DirEntry
var srcinfo os.FileInfo
if srcinfo, err = os.Stat(src); err != nil {
return err
return nil, err
}
if err = os.MkdirAll(dst, srcinfo.Mode()); err != nil {
return err
// NewDefaultAppMetrics the credentials and return it
config := &tls.Config{
Certificates: []tls.Certificate{serverCert},
ClientAuth: tls.NoClientCert,
NextProtos: []string{
"h2", "http/1.1", // enable HTTP/2
},
}
if fds, err = os.ReadDir(src); err != nil {
return err
}
for _, fd := range fds {
srcfp := path.Join(src, fd.Name())
dstfp := path.Join(dst, fd.Name())
fileInfo, err := os.Stat(srcfp)
if err != nil {
log.Fatalf("Couldn't get fileInfo; %v", err)
}
switch fileInfo.Mode() & os.ModeType {
case os.ModeSymlink:
if err = copySymLink(srcfp, dstfp); err != nil {
log.Fatalf("Failed to copy from %s to %s; %v", srcfp, dstfp, err)
}
case os.ModeDir:
if err = cpDir(srcfp, dstfp); err != nil {
log.Fatalf("Failed to copy from %s to %s; %v", srcfp, dstfp, err)
}
default:
if err = cpFile(srcfp, dstfp); err != nil {
log.Fatalf("Failed to copy from %s to %s; %v", srcfp, dstfp, err)
}
}
}
return nil
}
func migrateToNetbird(oldPath, newPath string) bool {
_, errOld := os.Stat(oldPath)
_, errNew := os.Stat(newPath)
if errors.Is(errOld, fs.ErrNotExist) || errNew == nil {
return false
}
return true
return config, nil
}
func init() {
runCmd.PersistentFlags().IntVar(&signalPort, "port", 80, "Server port to listen on (defaults to 443 if TLS is enabled, 80 otherwise")
runCmd.Flags().StringVar(&signalSSLDir, "ssl-dir", defaultSignalSSLDir, "server ssl directory location. *Required only for Let's Encrypt certificates.")
runCmd.Flags().StringVar(&signalLetsencryptDomain, "letsencrypt-domain", "", "a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS")
runCmd.Flags().StringVar(&signalCertFile, "cert-file", "", "Location of your SSL certificate. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect")
runCmd.Flags().StringVar(&signalCertKey, "cert-key", "", "Location of your SSL certificate private key. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect")
}

View File

@@ -15,6 +15,7 @@ type AppMetrics struct {
Deregistrations metric.Int64Counter
RegistrationFailures metric.Int64Counter
RegistrationDelay metric.Float64Histogram
GetRegistrationDelay metric.Float64Histogram
MessagesForwarded metric.Int64Counter
MessageForwardFailures metric.Int64Counter
@@ -54,6 +55,12 @@ func NewAppMetrics(meter metric.Meter) (*AppMetrics, error) {
return nil, err
}
getRegistrationDelay, err := meter.Float64Histogram("get_registration_delay_milliseconds",
metric.WithExplicitBucketBoundaries(getStandardBucketBoundaries()...))
if err != nil {
return nil, err
}
messagesForwarded, err := meter.Int64Counter("messages_forwarded_total")
if err != nil {
return nil, err
@@ -80,6 +87,7 @@ func NewAppMetrics(meter metric.Meter) (*AppMetrics, error) {
Deregistrations: deregistrations,
RegistrationFailures: registrationFailures,
RegistrationDelay: registrationDelay,
GetRegistrationDelay: getRegistrationDelay,
MessagesForwarded: messagesForwarded,
MessageForwardFailures: messageForwardFailures,

View File

@@ -26,10 +26,10 @@ type Metrics struct {
}
// NewServer initializes and returns a new Metrics instance
func NewServer(port int, endpoint string) *Metrics {
func NewServer(port int, endpoint string) (*Metrics, error) {
exporter, err := prometheus.New()
if err != nil {
return nil
return nil, err
}
provider := metric.NewMeterProvider(metric.WithReader(exporter))
@@ -57,7 +57,7 @@ func NewServer(port int, endpoint string) *Metrics {
provider: provider,
Endpoint: endpoint,
Server: server,
}
}, nil
}
// Shutdown stops the metrics server

Some files were not shown because too many files have changed in this diff Show More