Compare commits

..

247 Commits

Author SHA1 Message Date
Diego Noguês
76fb153d76 fix: small conflicts 2026-02-12 18:41:57 +01:00
Diego Noguês
eee4d75932 chore: switching env var for NB_PROXY_DOMAIN 2026-02-12 18:35:30 +01:00
Diego Noguês
62b8875f67 chore: build 2026-02-12 17:42:03 +01:00
Diego Noguês
47a5478964 chore: remove unnecessary comment 2026-02-12 17:18:23 +01:00
Diego Noguês
9922d6f953 feat: switch to labels for traefik instead of static conf files 2026-02-12 17:18:23 +01:00
Diego Noguês
f9bab22f61 fix: remove change to peers group all 2026-02-12 17:18:23 +01:00
Diego Noguês
3d8fdb7a89 feat: adding IPAM settings to docker compose and setting static ip to traefik 2026-02-12 17:18:23 +01:00
Diego Noguês
fb10153ab8 feat: adding traefik and proxy component to getting-started 2026-02-12 17:17:30 +01:00
mlsmaycon
57d3ee5aac optimize the DeriveClusterFromDomain function
1. validate domain only for proxy urls
2. use registered target cluster for custom domain extraction
2026-02-12 17:10:32 +01:00
pascal
cfdfdecc14 return error if unable to derive cluster on service creation 2026-02-12 16:57:16 +01:00
Diego Noguês
b00babb8b1 feat: switch to labels for traefik instead of static conf files 2026-02-12 16:50:43 +01:00
mlsmaycon
ac995bae6d rename url flag to domain and update validation 2026-02-12 16:28:29 +01:00
Alisdair MacLeod
41a5509ce0 fix nil pointer error in roundtripper 2026-02-12 15:19:19 +00:00
pascal
db5e26db94 rename domain type 2026-02-12 16:15:02 +01:00
Viktor Liu
fe975fb834 Fix missing lang attribute 2026-02-12 23:03:50 +08:00
Viktor Liu
e368d2995b Fix test 2026-02-12 22:57:28 +08:00
Viktor Liu
a3241d8376 Fix swallowed response codes 2026-02-12 22:54:17 +08:00
Alisdair MacLeod
6dfc5772ba fix nil pointer error in roundtripper 2026-02-12 14:44:07 +00:00
Viktor Liu
f70925178c Handle TCP port reuse for TIME-WAIT connections 2026-02-12 22:06:29 +08:00
Viktor Liu
9554934b92 Validate trusted proxies in OAuth callback getClientIP 2026-02-12 22:06:29 +08:00
Viktor Liu
7fdb824a37 Remove write permissions from /var/lib/netbird in proxy Dockerfile 2026-02-12 22:06:29 +08:00
Viktor Liu
412407adc0 Add .dockerignore to exclude sensitive files from build context 2026-02-12 22:06:29 +08:00
Viktor Liu
e0874d7de7 Add noopener to window.open in ErrorPage 2026-02-12 22:06:29 +08:00
pascal
8df1536cbb Merge branch 'main' into prototype/reverse-proxy 2026-02-12 15:05:14 +01:00
pascal
fcbacc62ec clear userID from access logs if not oidc 2026-02-12 14:50:35 +01:00
pascal
ee2ae45653 add permissions validation to domain manager 2026-02-12 14:31:23 +01:00
Diego Noguês
3bc8cbb13f fix: remove change to peers group all 2026-02-12 14:06:23 +01:00
Diego Noguês
bf7bdf6c4f feat: adding IPAM settings to docker compose and setting static ip to traefik 2026-02-12 14:02:21 +01:00
pascal
6f2f0f9ae4 exclude proxy peers on peers api 2026-02-12 13:49:05 +01:00
Alisdair MacLeod
c37ebc6fb3 add more metrics, improve metrics, reduce metrics impact on other packages 2026-02-12 12:36:54 +00:00
Viktor Liu
23abb5743c Treated tombstoned conns as new 2026-02-12 20:11:12 +08:00
Diego Noguês
0a895ffc22 Merge branch 'dn-reverse-proxy' of github.com:netbirdio/netbird into dn-reverse-proxy 2026-02-12 12:00:17 +01:00
Viktor Liu
b87aa0bc15 Address linter issues 2026-02-12 18:41:20 +08:00
Viktor Liu
f1a65d732d Add proxy to license boundary check 2026-02-12 18:31:18 +08:00
Viktor Liu
a3c0ea3e71 Add proxy unit test workflow 2026-02-12 18:31:18 +08:00
Viktor Liu
abaf061c2a Skip nil client for health 2026-02-12 18:31:18 +08:00
pascal
e531fb54b1 ignore error 2026-02-12 11:20:22 +01:00
mlsmaycon
5fcfed5b16 add proxy tests 2026-02-12 11:19:10 +01:00
Diego Noguês
b81837a364 feat: adding traefik and proxy component to getting-started 2026-02-12 11:16:37 +01:00
pascal
5f43449f67 move linter exceptions 2026-02-12 10:45:21 +01:00
mlsmaycon
6796601aa6 Generate a random nonce to ensure each OIDC request gets a unique state 2026-02-12 10:45:13 +01:00
pascal
1fc25c301b move linter exceptions 2026-02-12 10:11:49 +01:00
Viktor Liu
08ae281b2d Fix network monitor restarting the client in netstack mode 2026-02-12 16:48:31 +08:00
Viktor Liu
bd47f44c63 Preload services targets 2026-02-12 16:04:55 +08:00
Viktor Liu
381260911b Create unique token per proxy 2026-02-12 15:48:35 +08:00
Viktor Liu
38db42e7d6 Fix initial sync complete on empty service list 2026-02-12 15:48:35 +08:00
Viktor Liu
5d606d909d Add TTL-based expiry and cleanup for PKCE verifiers to prevent unbounded memory growth 2026-02-12 15:12:41 +08:00
Viktor Liu
d689718b50 Improve logging and error handling 2026-02-12 15:12:41 +08:00
pascal
54a73c6649 move linter exceptions 2026-02-12 02:10:00 +01:00
pascal
418377842e fix tests 2026-02-12 02:00:22 +01:00
pascal
15ef56e03d fix typos 2026-02-12 01:54:14 +01:00
pascal
917035f8e8 fix tests 2026-02-12 01:52:30 +01:00
pascal
963e3f5457 fix linter issues 2026-02-12 01:15:36 +01:00
pascal
e20b969188 fix linter issues 2026-02-12 01:02:13 +01:00
pascal
1c7059ee67 fix some tests 2026-02-12 00:16:33 +01:00
pascal
22a3365658 fix rename errors and tests 2026-02-11 22:34:50 +01:00
pascal
08ab1e3478 rename reverse proxy to services 2026-02-11 21:39:51 +01:00
pascal
ebb1f4007d add id to request log search 2026-02-11 19:25:23 +01:00
pascal
acb53ece93 Merge branch 'prototype/reverse-proxy-logs-pagination' into prototype/reverse-proxy 2026-02-11 18:51:28 +01:00
pascal
e020950cfd concat host and path for search and add a status filter 2026-02-11 17:54:29 +01:00
pascal
9dba262a20 add index to access log entries 2026-02-11 17:07:15 +01:00
pascal
5bcdf36377 fix source_ip 2026-02-11 16:50:27 +01:00
pascal
1ffe8deb10 add general search filter 2026-02-11 16:38:31 +01:00
pascal
d069145bd1 add more filters 2026-02-11 16:23:52 +01:00
Alisdair MacLeod
f3493ee042 add basic metrics for stress testing 2026-02-11 14:56:39 +00:00
Diego Noguês
b782ac6f56 feat: adding traefik and proxy component to getting-started 2026-02-11 15:34:18 +01:00
pascal
bf48044e5c push filter files 2026-02-11 14:52:44 +01:00
pascal
fb4cc37a4a add pagination for access logs 2026-02-11 14:41:52 +01:00
pascal
55b8d89a79 add rate limiting for callback endpoint 2026-02-11 13:42:54 +01:00
pascal
6968a32a5a move to argon2id 2026-02-11 13:26:40 +01:00
pascal
cfe6753349 hash pin and password 2026-02-11 11:48:15 +01:00
Alisdair MacLeod
5ae15b3af3 add hotpath proxy and roundtripper benchmarks 2026-02-11 09:47:40 +00:00
pascal
b79adb706c add services to permissions list 2026-02-11 10:38:20 +01:00
mlsmaycon
f22497d5da remove query parameters on refresh 2026-02-10 21:53:18 +01:00
mlsmaycon
95d672c9df fix: capture auth method in access logs for failed authentication
- Add wasCredentialSubmitted helper to detect when credentials were
  submitted but authentication failed
- Set auth method in CapturedData when wrong PIN/password is entered
- Set auth method for OAuth callback errors and token validation errors
- Add tests for failed auth method capture
2026-02-10 21:33:15 +01:00
mlsmaycon
7d08a609e6 fix: capture account/service/user IDs in access logs for auth requests
- Add accountID and serviceID to auth middleware DomainConfig
- Set account/service IDs in CapturedData when domain is matched
- Update AddDomain to accept accountID and serviceID parameters
- Skip access logging for internal proxy assets (/__netbird__/*)
- Return validationResult struct from validateSessionToken to preserve
  user ID even when access is denied
- Capture user ID and auth method in access logs for denied requests
2026-02-10 20:55:07 +01:00
mlsmaycon
eea6120cd0 refactor: add ValidateSession gRPC and streamline test setup
- Add ValidateSession gRPC method for proxy-side user validation
- Move group access validation from REST callback to gRPC layer
- Capture user info in access logs via CapturedData mutable pointer
- Create validate_session_test.go for gRPC validation tests
- Simplify auth_callback_integration_test.go to create accounts
  programmatically instead of using SQL file
- SQL test data file now only used by validate_session_test.go
2026-02-10 20:31:03 +01:00
pascal
0cb02bd906 fix path handling + extract targets to separate table + guard resource/peer deletion 2026-02-10 17:12:34 +01:00
mlsmaycon
08d3867f41 update error page 2026-02-10 16:54:05 +01:00
mlsmaycon
b16d63643c Add group-based access control for SSO reverse proxy authentication
Implement user group validation during OAuth callback to ensure users
belong to allowed distribution groups before granting access to reverse
proxies. This provides account isolation and fine-grained access control.

Key changes:
- Add ValidateUserGroupAccess to ProxyServiceServer for group membership checks
- Redirect denied users to error page with access_denied parameter
- Handle OAuth error responses in proxy middleware
- Add comprehensive integration tests for auth callback flow
2026-02-10 16:25:00 +01:00
Eduard Gert
940d01bdea Merge remote-tracking branch 'origin/prototype/reverse-proxy' into prototype/reverse-proxy 2026-02-10 14:39:48 +01:00
Eduard Gert
ba9158d159 Remove peer card from proxy error page 2026-02-10 14:39:25 +01:00
pascal
ca9a7e11ef continue on host lookup failure 2026-02-10 14:38:15 +01:00
pascal
a803f47685 add network map support for clustering 2026-02-10 14:29:20 +01:00
Viktor Liu
79fed32f01 Add wg port configuration 2026-02-10 19:55:48 +08:00
Viktor Liu
6b00bb0a66 Strip session_token on redirect 2026-02-10 18:27:31 +08:00
mlsmaycon
e2adef1eea add back notBefore and now to cert log 2026-02-09 20:37:20 +01:00
pascal
9e5fa11792 handle multiple path 2026-02-09 19:25:30 +01:00
pascal
1ff75acb31 handle default ports 2026-02-09 19:23:39 +01:00
pascal
1754160686 handle default ports 2026-02-09 19:21:43 +01:00
pascal
423f6266fb handle default ports 2026-02-09 18:18:53 +01:00
pascal
16d1b4a14a handle default ports 2026-02-09 18:15:26 +01:00
pascal
7c14056faf fix resource lookup 2026-02-09 17:58:28 +01:00
pascal
62e37dc2e2 fix host resolution 2026-02-09 17:56:38 +01:00
pascal
6a08695ee8 Merge branch 'main' into prototype/reverse-proxy 2026-02-09 17:16:00 +01:00
pascal
9a67a8e427 send updates on changes 2026-02-09 17:06:04 +01:00
Viktor Liu
73aa0785ba Add cert health info to checks 2026-02-09 22:55:12 +08:00
Viktor Liu
53c1016a8e Add graceful shutdown for Kubernetes 2026-02-09 22:55:12 +08:00
Viktor Liu
fd442138e6 Add cert hot reload and cert file locking
Adds file-watching certificate hot reload, cross-replica ACME
certificate lock coordination via flock (Unix) and Kubernetes lease
objects.
2026-02-09 22:55:12 +08:00
pascal
be5f30225a fix embedded exception 2026-02-09 15:28:48 +01:00
pascal
7467e9fb8c use portrange 2026-02-09 14:46:23 +01:00
pascal
2390c2e46e change network map calc to inject proxy policies 2026-02-09 14:41:22 +01:00
mlsmaycon
778c223176 fix api handler path 2026-02-09 02:30:06 +01:00
mlsmaycon
36cd0dd85c temp fix import cycle 2026-02-09 02:10:21 +01:00
mlsmaycon
09a1d5a02d rename endpoint 2026-02-09 01:48:51 +01:00
mlsmaycon
7c996ac9b5 add AuthCallbackURL 2026-02-09 01:18:49 +01:00
mlsmaycon
cf9fd5d960 add AuthClientID 2026-02-08 19:41:52 +01:00
mlsmaycon
1c5ab7cb8f add logger support to acme manager 2026-02-08 19:11:27 +01:00
Viktor Liu
aaad3b25a7 Increase client startup timeout
The client has to start mgmt, signal, relay and wireguard/netstack.
If this times out, the client shuts down and never manages to start.
2026-02-09 02:02:18 +08:00
Viktor Liu
9904235a2f Improve embed client error detection and reporting 2026-02-09 01:51:53 +08:00
Viktor Liu
780e9f57a5 Improve mgmt backoff 2026-02-09 01:51:53 +08:00
mlsmaycon
a8db73285b add issued time log and CT timestamp logs 2026-02-08 18:13:50 +01:00
Viktor Liu
3b43c00d12 Use unique static path for auth assets to avoid collision with routes 2026-02-09 01:10:50 +08:00
Viktor Liu
2f390e1794 Conflate default ports 2026-02-09 00:57:08 +08:00
Viktor Liu
3630ebb3ae Add option to rewrite redirects 2026-02-09 00:44:47 +08:00
Viktor Liu
260c46df04 Fix broken auth redirect 2026-02-09 00:02:54 +08:00
Viktor Liu
7f11e3205d Validate target id 2026-02-08 23:44:31 +08:00
Viktor Liu
1c8f92a96f Fix management nil pointer 2026-02-08 23:29:16 +08:00
Viktor Liu
7b6294b624 Refuse to service a service if auth setup failed 2026-02-08 23:24:43 +08:00
Viktor Liu
156d0b1fef Fix duplicate path 2026-02-08 21:41:32 +08:00
Viktor Liu
2cf00dba58 Fix missing route 2026-02-08 21:36:55 +08:00
Viktor Liu
d2a7f3ae36 Fix pass host header 2026-02-08 21:33:48 +08:00
Viktor Liu
6a64d4e4dd Remove test deployment specs 2026-02-08 21:13:22 +08:00
Viktor Liu
51e63c246b Add health status to debug 2026-02-08 21:04:46 +08:00
mlsmaycon
99e6b1eda4 attempt to trigger ssl before first request
1. When AddDomain() is called (when proxy receives a new mapping), it now spawns a goroutine to prefetch the certificate
  2. prefetchCertificate() creates a synthetic tls.ClientHelloInfo and calls GetCertificate() to trigger the ACME flow
  3. The certificate is cached by autocert.DirCache, so subsequent real requests will use the cached cert
  4. If the cert is already cached (e.g., proxy restart), GetCertificate just returns it without making ACME requests
2026-02-08 10:59:36 +01:00
Viktor Liu
dc26a5a436 Merge branch 'main' into prototype/reverse-proxy 2026-02-08 17:50:16 +08:00
Viktor Liu
3883b2fb41 Fix netbird_test.go 2026-02-08 17:49:03 +08:00
Viktor Liu
ed58659a01 Set forwarded headers from trusted proxies only 2026-02-08 17:49:03 +08:00
Viktor Liu
5190923c70 Improve logging requests 2026-02-08 17:49:03 +08:00
Viktor Liu
7c647dd160 Add peer firewall to the receiving peer 2026-02-08 17:49:03 +08:00
Viktor Liu
07e59b2708 Add reverse proxy header security and forwarding
- Rewrite Host header to backend target (configurable via pass_host_header per mapping)
- Strip and set X-Forwarded-For/X-Real-IP from direct connection (trust boundary)
- Set X-Forwarded-Host and X-Forwarded-Proto headers
- Strip nb_session cookie and session_token query param before forwarding
- Add --forwarded-proto flag (auto/http/https) for proto detection
- Fix OIDC redirect hardcoded https scheme
- Add pass_host_header to proto, API, and management model
2026-02-08 15:00:35 +08:00
Viktor Liu
0a3a9f977d Add proxy <-> management authentication 2026-02-08 14:33:27 +08:00
mlsmaycon
2f263bf7e6 fix cluster logic for domains and reverse proxy 2026-02-07 11:43:01 +01:00
mlsmaycon
f65f4fc280 fix some conflicts regression 2026-02-06 20:39:17 +01:00
pascal
adbd7ab4c3 send account updates on proxy change 2026-02-06 17:03:18 +01:00
pascal
0419834482 add routed exposed services support in nmap 2026-02-06 15:42:13 +01:00
pascal
f797d2d9cb fix cert dir name in docker file 2026-02-05 15:46:07 +01:00
pascal
5ae7efe8f7 Merge remote-tracking branch 'origin/prototype/reverse-proxy' into prototype/reverse-proxy 2026-02-05 15:22:39 +01:00
pascal
d6e35bd0fe fix merge conflicts 2026-02-05 15:22:23 +01:00
pascal
0e00f1c8f7 Merge remote-tracking branch 'origin/prototype/reverse-proxy-clusters' into prototype/reverse-proxy
# Conflicts:
#	management/internals/modules/reverseproxy/manager/manager.go
#	management/internals/modules/reverseproxy/reverseproxy.go
#	management/internals/server/modules.go
#	management/internals/shared/grpc/proxy.go
#	management/server/http/handler.go
#	management/server/http/testing/testing_tools/channel/channel.go
2026-02-05 15:19:57 +01:00
Eduard Gert
4433f44a12 Add some other errors 2026-02-05 14:30:55 +01:00
Eduard Gert
7504e718d7 Add better error page 2026-02-05 14:00:51 +01:00
Viktor Liu
9b0387e7ee Add /cert dir 2026-02-05 19:22:31 +08:00
mlsmaycon
5ccce1ab3f add debug logging for proxy connections and domain resolution
- Log proxy address and cluster info when proxy connects
  - Log connected proxy URLs when GetConnectedProxyURLs is called
  - Log proxy allow list when GetDomains is called
  - Helps debug issues with free domains not appearing in API response
2026-02-05 02:18:38 +01:00
pascal
e366fe340e add log when listener is ready 2026-02-04 23:32:19 +01:00
pascal
b01809f8e3 use logger 2026-02-04 23:10:01 +01:00
pascal
790ef39187 log on debug 2026-02-04 22:43:40 +01:00
pascal
3af16cf333 add trace logs 2026-02-04 22:26:29 +01:00
pascal
d09c69f303 fix scan sql 2026-02-04 21:05:25 +01:00
pascal
096d4ac529 rewrite peer creation and network map calc [WIP] 2026-02-04 20:01:00 +01:00
Alisdair MacLeod
8fafde614a Merge remote-tracking branch 'origin/prototype/reverse-proxy' into prototype/reverse-proxy 2026-02-04 16:52:42 +00:00
Alisdair MacLeod
694ae13418 add stateless proxy sessions 2026-02-04 16:52:35 +00:00
Eduard Gert
b5b7dd4f53 Add other error pages 2026-02-04 17:12:26 +01:00
Viktor Liu
476785b122 Remove health check addr override 2026-02-04 22:32:46 +08:00
Viktor Liu
907677f835 Set readiness false on disconnect right away 2026-02-04 22:28:53 +08:00
Viktor Liu
7d844b9410 Add health checks 2026-02-04 22:18:45 +08:00
Eduard Gert
eeabc64a73 Merge remote-tracking branch 'origin/prototype/reverse-proxy' into prototype/reverse-proxy 2026-02-04 15:11:33 +01:00
Eduard Gert
5da2b0fdcc Add error page 2026-02-04 15:11:22 +01:00
Alisdair MacLeod
a0005a604e fix minor potential security issues with OIDC 2026-02-04 12:25:19 +00:00
Alisdair MacLeod
a89bb807a6 fix protos after merge 2026-02-04 11:56:34 +00:00
Alisdair MacLeod
28f3354ffa Merge remote-tracking branch 'origin/prototype/reverse-proxy' into prototype/reverse-proxy
# Conflicts:
#	management/internals/modules/reverseproxy/reverseproxy.go
#	management/internals/server/boot.go
#	management/internals/shared/grpc/proxy.go
#	proxy/internal/auth/middleware.go
#	shared/management/proto/proxy_service.pb.go
#	shared/management/proto/proxy_service.proto
#	shared/management/proto/proxy_service_grpc.pb.go
2026-02-04 11:56:04 +00:00
Alisdair MacLeod
562923c600 management OIDC implementation using pkce 2026-02-04 11:51:46 +00:00
Alisdair MacLeod
0dd0c67b3b Revert "add management oidc configuration for proxies"
This reverts commit 146774860b.
2026-02-04 09:28:54 +00:00
Viktor Liu
ca33849f31 Use a 1:1 mapping of netbird client to netbird account
- Add debug endpoint for monitoring netbird clients
- Add types package with AccountID type
- Refactor netbird roundtrip to key clients by AccountID
- Multiple domains can share the same client per account
- Add status notifier for tunnel connection updates
- Add OIDC flags to CLI
- Add tests for netbird client management
2026-02-04 14:48:20 +08:00
Viktor Liu
18cd0f1480 Fix netstack detection and add wireguard port option
- Add WireguardPort option to embed.Options for custom port configuration
- Fix KernelInterface detection to account for netstack mode
- Skip SSH config updates when running in netstack mode
- Skip interface removal wait when running in netstack mode
- Use BindListener for netstack to avoid port conflicts on same host
2026-02-04 14:39:19 +08:00
mlsmaycon
b02982f6b1 add logs 2026-02-04 03:14:26 +01:00
mlsmaycon
4d89ae27ef add clusters logic 2026-02-04 02:16:57 +01:00
Eduard Gert
733ea77c5c Add proxy auth ui 2026-02-03 19:05:55 +01:00
pascal
92f72bfce6 add reverse proxy meta to api resp 2026-02-03 17:37:55 +01:00
pascal
bffb25bea7 add status confirmation for certs and tunnel creation 2026-02-03 16:58:14 +01:00
Alisdair MacLeod
3af4543e80 check for domain ownership via subdomain rather than naked domain 2026-02-03 12:50:25 +00:00
Alisdair MacLeod
146774860b add management oidc configuration for proxies 2026-02-03 12:39:16 +00:00
Alisdair MacLeod
5243481316 get OIDC configuration from proxy flags/env 2026-02-03 12:10:23 +00:00
Alisdair MacLeod
76a39c1dcb Revert "add management side of OIDC authentication"
This reverts commit 02ce918114.
2026-02-03 10:03:38 +00:00
Alisdair MacLeod
02ce918114 add management side of OIDC authentication 2026-02-03 09:42:40 +00:00
Alisdair MacLeod
30cfc22cb6 correct proto and proxy authentication for oidc 2026-02-03 09:01:39 +00:00
Alisdair MacLeod
3168afbfcb clean up proxy reported urls when using them for validation 2026-02-02 15:59:24 +00:00
Alisdair MacLeod
a73ee47557 ignore ports when performing proxy mapping lookups 2026-02-02 14:39:13 +00:00
Alisdair MacLeod
fa6ff005f2 add validation logging 2026-02-02 10:53:46 +00:00
Alisdair MacLeod
095379fa60 add logging to domain validation 2026-02-02 10:27:20 +00:00
Alisdair MacLeod
30572fe1b8 add domain validation using values from proxies 2026-02-02 09:53:49 +00:00
Alisdair MacLeod
3a6f364b03 use a defined logger
this should avoid issues with the embedded
client also attempting to use the same global logger
2026-01-30 16:31:32 +00:00
Alisdair MacLeod
5345d716ee Merge branch 'main' into prototype/reverse-proxy 2026-01-30 14:46:08 +00:00
Alisdair MacLeod
f882c36e0a simplify authentication 2026-01-30 14:08:52 +00:00
Alisdair MacLeod
e95cfa1a00 add support for some basic authentication methods 2026-01-29 16:34:52 +00:00
pascal
0d480071b6 pass accountID 2026-01-29 14:47:22 +01:00
pascal
8e0b7b6c25 add api for access log events 2026-01-29 14:27:57 +01:00
Alisdair MacLeod
f204da0d68 fix management reverseproxy proto mapping 2026-01-29 12:29:21 +00:00
Alisdair MacLeod
7d74904d62 add roundtripper debug log 2026-01-29 12:03:14 +00:00
Alisdair MacLeod
760ac5e07d use the netbird client transport directly 2026-01-29 11:11:28 +00:00
Alisdair MacLeod
4352228797 allow setting the proxy url for autocert server name 2026-01-29 10:03:48 +00:00
Alisdair MacLeod
74c770609c fix access log context cancelled 2026-01-29 09:33:23 +00:00
Alisdair MacLeod
f4ca36ed7e fix non-nil path assignment 2026-01-29 08:40:03 +00:00
mlsmaycon
c86da92fc6 update log init 2026-01-28 23:18:54 +01:00
mlsmaycon
3f0c577456 use util.InitLog 2026-01-28 22:59:08 +01:00
mlsmaycon
717da8c7b7 fix nil path 2026-01-28 22:40:39 +01:00
mlsmaycon
a0a61d4f47 add extra debug logs 2026-01-28 21:26:57 +01:00
Alisdair MacLeod
5b1fced872 Merge remote-tracking branch 'origin/prototype/reverse-proxy' into prototype/reverse-proxy 2026-01-28 16:55:12 +00:00
Alisdair MacLeod
c98dcf5ef9 get all proxy endpoints when a proxy connects 2026-01-28 16:55:05 +00:00
pascal
57cb6bfccb add log on broadcasting update 2026-01-28 17:52:38 +01:00
Alisdair MacLeod
95bf97dc3c add env var for debug logs 2026-01-28 16:38:24 +00:00
Alisdair MacLeod
3d116c9d33 add debug logs and switch to logrus for logs 2026-01-28 15:44:35 +00:00
Alisdair MacLeod
a9ce9f8d5a add grpc TLS with selection inferred from management URL 2026-01-28 13:44:47 +00:00
Alisdair MacLeod
10b981a855 fix gorm id failures 2026-01-28 13:16:47 +00:00
Alisdair MacLeod
7700b4333d correctly interpret custom domains from the database 2026-01-28 12:45:32 +00:00
Alisdair MacLeod
7d0131111e Merge remote-tracking branch 'origin/prototype/reverse-proxy' into prototype/reverse-proxy 2026-01-28 12:36:23 +00:00
Alisdair MacLeod
1daea35e4b remove scheme information from management address when connecting via grpc 2026-01-28 12:36:13 +00:00
pascal
f97544af0d go mod tidy 2026-01-28 13:02:22 +01:00
pascal
231e80cc15 Merge branch 'main' into prototype/reverse-proxy 2026-01-28 12:52:42 +01:00
Alisdair MacLeod
a4c1362bff pass proxy information to management on grpc connection 2026-01-28 10:50:35 +00:00
Alisdair MacLeod
b611d4a751 pass account manager in to proxy grpc server for setup key generation 2026-01-28 08:39:09 +00:00
Alisdair MacLeod
2c9decfa55 fix domain store slice retrieval 2026-01-27 17:27:16 +00:00
Alisdair MacLeod
3c5ac17e2f fix domain store nil pointer 2026-01-27 17:06:20 +00:00
Alisdair MacLeod
ae42bbb898 Merge remote-tracking branch 'origin/prototype/reverse-proxy' into prototype/reverse-proxy 2026-01-27 17:02:02 +00:00
Alisdair MacLeod
b86722394b fix domain api registration 2026-01-27 17:01:55 +00:00
pascal
a103f69767 remove basic auth scheme 2026-01-27 17:53:59 +01:00
pascal
73fbb3fc62 fix reverse proxy put and post 2026-01-27 17:38:55 +01:00
Alisdair MacLeod
7b3523e25e return empty domain list when none in database 2026-01-27 16:34:56 +00:00
pascal
6e4e1386e7 fix path variables 2026-01-27 17:13:42 +01:00
pascal
671e9af6eb create setup key and policy to send to reverse proxies 2026-01-27 17:05:32 +01:00
Alisdair MacLeod
50f42caf94 connect api to store and manager for domains 2026-01-27 15:43:54 +00:00
pascal
b7eeefc102 send proxy mapping updates 2026-01-27 16:34:00 +01:00
pascal
8dd22f3a4f move to reverse proxy and update api 2026-01-27 15:34:01 +01:00
pascal
4b89427447 Merge remote-tracking branch 'origin/prototype/reverse-proxy' into prototype/reverse-proxy
# Conflicts:
#	shared/management/http/api/types.gen.go
2026-01-27 15:31:15 +01:00
pascal
b71e2860cf Merge branch 'refs/heads/main' into prototype/reverse-proxy
# Conflicts:
#	management/server/activity/codes.go
#	management/server/http/handler.go
#	management/server/store/sql_store.go
#	management/server/store/store.go
#	shared/management/http/api/openapi.yml
#	shared/management/http/api/types.gen.go
#	shared/management/proto/management.pb.go
2026-01-27 15:21:55 +01:00
Alisdair MacLeod
160b27bc60 create reverse proxy domain manager and api 2026-01-27 14:18:52 +00:00
pascal
c084386b88 add docker file 2026-01-27 11:42:51 +01:00
Alisdair MacLeod
6889047350 Merge remote-tracking branch 'origin/prototype/reverse-proxy' into prototype/reverse-proxy 2026-01-27 09:58:28 +00:00
Alisdair MacLeod
245bbb4acf move domain validation to management 2026-01-27 09:58:14 +00:00
pascal
2b2fc02d83 update openapi specs 2026-01-27 10:42:19 +01:00
Alisdair MacLeod
703ef29199 start and stop netbird embedded clients in proxy 2026-01-27 08:33:44 +00:00
Alisdair MacLeod
b0b60b938a add initial setup key provisioning 2026-01-26 16:15:24 +00:00
Alisdair MacLeod
e3a026bf1c connect proxy grpc server to database 2026-01-26 15:28:50 +00:00
Alisdair MacLeod
94503465ee stub out management proxy server database connection 2026-01-26 14:47:49 +00:00
Alisdair MacLeod
8d959b0abc update management proxy gRPC server 2026-01-26 14:02:27 +00:00
Alisdair MacLeod
1d8390b935 refactor layout and structure 2026-01-26 09:28:46 +00:00
pascal
2851e38a1f add management API to store 2026-01-16 16:16:29 +01:00
pascal
51261fe7a9 proxy service proto 2026-01-16 14:48:33 +01:00
pascal
304321d019 put grpc endpoint on management and send test exposed service 2026-01-16 14:24:39 +01:00
pascal
f8c3295645 Merge branch 'main' into prototype/reverse-proxy 2026-01-16 13:07:52 +01:00
pascal
183619d1e1 cleanup 2026-01-16 12:01:52 +01:00
pascal
3b832d1f21 discard client logs 2026-01-15 17:59:07 +01:00
pascal
fcb849698f add cert manager with self signed cert support 2026-01-15 17:54:16 +01:00
pascal
7527e0ebdb use embedded netbird agent for tunneling 2026-01-15 17:03:27 +01:00
pascal
ed5f98da5b cleanup 2026-01-15 14:54:33 +01:00
pascal
12b38e25da using go http reverseproxy with OIDC auth 2026-01-14 23:53:55 +01:00
pascal
626e892e3b trying embedded caddy reverse proxy 2026-01-14 17:16:42 +01:00
422 changed files with 7701 additions and 53104 deletions

View File

@@ -1,14 +0,0 @@
blank_issues_enabled: true
contact_links:
- name: Community Support
url: https://forum.netbird.io/
about: Community support forum
- name: Cloud Support
url: https://docs.netbird.io/help/report-bug-issues
about: Contact us for support
- name: Client/Connection Troubleshooting
url: https://docs.netbird.io/help/troubleshooting-client
about: See our client troubleshooting guide for help addressing common issues
- name: Self-host Troubleshooting
url: https://docs.netbird.io/selfhosted/troubleshooting
about: See our self-host troubleshooting guide for help addressing common issues

View File

@@ -39,7 +39,7 @@ jobs:
else
echo "✓ No problematic dependencies found"
fi
done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name "proxy" -not -name "combined" -not -name ".git*" | sort)
done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name "proxy" -not -name ".git*" | sort)
echo ""
if [ $FOUND_ISSUES -eq 1 ]; then
@@ -88,7 +88,7 @@ jobs:
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
# Check if any importer is NOT in management/signal/relay
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\|combined\)" | head -1)
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\)" | head -1)
if [ -n "$BSD_IMPORTER" ]; then
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"

View File

@@ -43,5 +43,5 @@ jobs:
run: git --no-pager diff --exit-code
- name: Test
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy)

View File

@@ -97,16 +97,6 @@ jobs:
working-directory: relay
run: CGO_ENABLED=1 GOARCH=386 go build -o relay-386 .
- name: Build combined
if: steps.cache.outputs.cache-hit != 'true'
working-directory: combined
run: CGO_ENABLED=1 go build .
- name: Build combined 386
if: steps.cache.outputs.cache-hit != 'true'
working-directory: combined
run: CGO_ENABLED=1 GOARCH=386 go build -o combined-386 .
test:
name: "Client / Unit"
needs: [build-cache]
@@ -154,7 +144,7 @@ jobs:
run: git --no-pager diff --exit-code
- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy)
test_client_on_docker:
name: "Client (Docker) / Unit"
@@ -214,7 +204,7 @@ jobs:
sh -c ' \
apk update; apk add --no-cache \
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui -e /upload-server)
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /client/ui -e /upload-server)
'
test_relay:
@@ -409,19 +399,12 @@ jobs:
run: git --no-pager diff --exit-code
- name: Login to Docker hub
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
uses: docker/login-action@v3
if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref)
uses: docker/login-action@v1
with:
username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: docker login for root user
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
env:
DOCKER_USER: ${{ secrets.DOCKER_USER }}
DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }}
run: echo "$DOCKER_TOKEN" | sudo docker login --username "$DOCKER_USER" --password-stdin
- name: download mysql image
if: matrix.store == 'mysql'
run: docker pull mlsmaycon/warmed-mysql:8
@@ -504,18 +487,15 @@ jobs:
run: git --no-pager diff --exit-code
- name: Login to Docker hub
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
uses: docker/login-action@v3
if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref)
uses: docker/login-action@v1
with:
username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: docker login for root user
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
env:
DOCKER_USER: ${{ secrets.DOCKER_USER }}
DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }}
run: echo "$DOCKER_TOKEN" | sudo docker login --username "$DOCKER_USER" --password-stdin
- name: download mysql image
if: matrix.store == 'mysql'
run: docker pull mlsmaycon/warmed-mysql:8
- name: Test
run: |
@@ -596,18 +576,15 @@ jobs:
run: git --no-pager diff --exit-code
- name: Login to Docker hub
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
uses: docker/login-action@v3
if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref)
uses: docker/login-action@v1
with:
username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: docker login for root user
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
env:
DOCKER_USER: ${{ secrets.DOCKER_USER }}
DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }}
run: echo "$DOCKER_TOKEN" | sudo docker login --username "$DOCKER_USER" --password-stdin
- name: download mysql image
if: matrix.store == 'mysql'
run: docker pull mlsmaycon/warmed-mysql:8
- name: Test
run: |

View File

@@ -63,7 +63,7 @@ jobs:
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }}
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }}
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' } | Where-Object { $_ -notmatch '/proxy' } | Where-Object { $_ -notmatch '/combined' })" >> $env:GITHUB_ENV
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' } | Where-Object { $_ -notmatch '/proxy' })" >> $env:GITHUB_ENV
- name: test
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"

View File

@@ -19,7 +19,7 @@ jobs:
- name: codespell
uses: codespell-project/actions-codespell@v2
with:
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans
skip: go.mod,go.sum,**/proxy/web/**
golangci:
strategy:

View File

@@ -1,51 +0,0 @@
name: PR Title Check
on:
pull_request:
types: [opened, edited, synchronize, reopened]
jobs:
check-title:
runs-on: ubuntu-latest
steps:
- name: Validate PR title prefix
uses: actions/github-script@v7
with:
script: |
const title = context.payload.pull_request.title;
const allowedTags = [
'management',
'client',
'signal',
'proxy',
'relay',
'misc',
'infrastructure',
'self-hosted',
'doc',
];
const pattern = /^\[([^\]]+)\]\s+.+/;
const match = title.match(pattern);
if (!match) {
core.setFailed(
`PR title must start with a tag in brackets.\n` +
`Example: [client] fix something\n` +
`Allowed tags: ${allowedTags.join(', ')}`
);
return;
}
const tags = match[1].split(',').map(t => t.trim().toLowerCase());
const invalid = tags.filter(t => !allowedTags.includes(t));
if (invalid.length > 0) {
core.setFailed(
`Invalid tag(s): ${invalid.join(', ')}\n` +
`Allowed tags: ${allowedTags.join(', ')}`
);
return;
}
console.log(`Valid PR title tags: [${tags.join(', ')}]`);

View File

@@ -10,7 +10,7 @@ on:
env:
SIGN_PIPE_VER: "v0.1.1"
GORELEASER_VER: "v2.14.3"
GORELEASER_VER: "v2.3.2"
PRODUCT_NAME: "NetBird"
COPYRIGHT: "NetBird GmbH"
@@ -160,7 +160,7 @@ jobs:
username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Log in to the GitHub container registry
if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository
if: github.event_name != 'pull_request'
uses: docker/login-action@v3
with:
registry: ghcr.io
@@ -169,13 +169,6 @@ jobs:
- name: Install OS build dependencies
run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu
- name: Decode GPG signing key
env:
GPG_RPM_PRIVATE_KEY: ${{ secrets.GPG_RPM_PRIVATE_KEY }}
run: |
echo "$GPG_RPM_PRIVATE_KEY" | base64 -d > /tmp/gpg-rpm-signing-key.asc
echo "GPG_RPM_KEY_FILE=/tmp/gpg-rpm-signing-key.asc" >> $GITHUB_ENV
- name: Install goversioninfo
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
- name: Generate windows syso amd64
@@ -183,7 +176,6 @@ jobs:
- name: Generate windows syso arm64
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso
- name: Run GoReleaser
id: goreleaser
uses: goreleaser/goreleaser-action@v4
with:
version: ${{ env.GORELEASER_VER }}
@@ -193,55 +185,6 @@ jobs:
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
GPG_RPM_KEY_FILE: ${{ env.GPG_RPM_KEY_FILE }}
NFPM_NETBIRD_RPM_PASSPHRASE: ${{ secrets.GPG_RPM_PASSPHRASE }}
- name: Verify RPM signatures
run: |
docker run --rm -v $(pwd)/dist:/dist fedora:41 bash -c '
dnf install -y -q rpm-sign curl >/dev/null 2>&1
curl -sSL https://pkgs.netbird.io/yum/repodata/repomd.xml.key -o /tmp/rpm-pub.key
rpm --import /tmp/rpm-pub.key
echo "=== Verifying RPM signatures ==="
for rpm_file in /dist/*amd64*.rpm; do
[ -f "$rpm_file" ] || continue
echo "--- $(basename $rpm_file) ---"
rpm -K "$rpm_file"
done
'
- name: Clean up GPG key
if: always()
run: rm -f /tmp/gpg-rpm-signing-key.asc
- name: Tag and push images (amd64 only)
if: |
(github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository) ||
(github.event_name == 'push' && github.ref == 'refs/heads/main')
run: |
resolve_tags() {
if [[ "${{ github.event_name }}" == "pull_request" ]]; then
echo "pr-${{ github.event.pull_request.number }}"
else
echo "main sha-$(git rev-parse --short HEAD)"
fi
}
tag_and_push() {
local src="$1" img_name tag dst
img_name="${src%%:*}"
for tag in $(resolve_tags); do
dst="${img_name}:${tag}"
echo "Tagging ${src} -> ${dst}"
docker tag "$src" "$dst"
docker push "$dst"
done
}
export -f tag_and_push resolve_tags
echo '${{ steps.goreleaser.outputs.artifacts }}' | \
jq -r '.[] | select(.type == "Docker Image") | select(.goarch == "amd64") | .name' | \
grep '^ghcr.io/' | while read -r SRC; do
tag_and_push "$SRC"
done
- name: upload non tags for debug purposes
uses: actions/upload-artifact@v4
with:
@@ -308,13 +251,6 @@ jobs:
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libappindicator3-dev gir1.2-appindicator3-0.1 libxxf86vm-dev gcc-mingw-w64-x86-64
- name: Decode GPG signing key
env:
GPG_RPM_PRIVATE_KEY: ${{ secrets.GPG_RPM_PRIVATE_KEY }}
run: |
echo "$GPG_RPM_PRIVATE_KEY" | base64 -d > /tmp/gpg-rpm-signing-key.asc
echo "GPG_RPM_KEY_FILE=/tmp/gpg-rpm-signing-key.asc" >> $GITHUB_ENV
- name: Install LLVM-MinGW for ARM64 cross-compilation
run: |
cd /tmp
@@ -339,24 +275,6 @@ jobs:
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
GPG_RPM_KEY_FILE: ${{ env.GPG_RPM_KEY_FILE }}
NFPM_NETBIRD_UI_RPM_PASSPHRASE: ${{ secrets.GPG_RPM_PASSPHRASE }}
- name: Verify RPM signatures
run: |
docker run --rm -v $(pwd)/dist:/dist fedora:41 bash -c '
dnf install -y -q rpm-sign curl >/dev/null 2>&1
curl -sSL https://pkgs.netbird.io/yum/repodata/repomd.xml.key -o /tmp/rpm-pub.key
rpm --import /tmp/rpm-pub.key
echo "=== Verifying RPM signatures ==="
for rpm_file in /dist/*.rpm; do
[ -f "$rpm_file" ] || continue
echo "--- $(basename $rpm_file) ---"
rpm -K "$rpm_file"
done
'
- name: Clean up GPG key
if: always()
run: rm -f /tmp/gpg-rpm-signing-key.asc
- name: upload non tags for debug purposes
uses: actions/upload-artifact@v4
with:

View File

@@ -106,26 +106,6 @@ builds:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
- id: netbird-server
dir: combined
env:
- CGO_ENABLED=1
- >-
{{- if eq .Runtime.Goos "linux" }}
{{- if eq .Arch "arm64"}}CC=aarch64-linux-gnu-gcc{{- end }}
{{- if eq .Arch "arm"}}CC=arm-linux-gnueabihf-gcc{{- end }}
{{- end }}
binary: netbird-server
goos:
- linux
goarch:
- amd64
- arm64
- arm
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
- id: netbird-upload
dir: upload-server
env: [CGO_ENABLED=0]
@@ -140,20 +120,6 @@ builds:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
- id: netbird-proxy
dir: proxy/cmd/proxy
env: [CGO_ENABLED=0]
binary: netbird-proxy
goos:
- linux
goarch:
- amd64
- arm64
- arm
ldflags:
- -s -w -X main.Version={{.Version}} -X main.Commit={{.Commit}} -X main.BuildDate={{.CommitDate}}
mod_timestamp: "{{ .CommitTimestamp }}"
universal_binaries:
- id: netbird
@@ -171,12 +137,13 @@ nfpms:
- maintainer: Netbird <dev@netbird.io>
description: Netbird client.
homepage: https://netbird.io/
id: netbird_deb
id: netbird-deb
bindir: /usr/bin
builds:
- netbird
formats:
- deb
scripts:
postinstall: "release_files/post_install.sh"
preremove: "release_files/pre_remove.sh"
@@ -184,18 +151,16 @@ nfpms:
- maintainer: Netbird <dev@netbird.io>
description: Netbird client.
homepage: https://netbird.io/
id: netbird_rpm
id: netbird-rpm
bindir: /usr/bin
builds:
- netbird
formats:
- rpm
scripts:
postinstall: "release_files/post_install.sh"
preremove: "release_files/pre_remove.sh"
rpm:
signature:
key_file: '{{ if index .Env "GPG_RPM_KEY_FILE" }}{{ .Env.GPG_RPM_KEY_FILE }}{{ end }}'
dockers:
- image_templates:
- netbirdio/netbird:{{ .Version }}-amd64
@@ -555,104 +520,6 @@ dockers:
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird-server:{{ .Version }}-amd64
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
ids:
- netbird-server
goarch: amd64
use: buildx
dockerfile: combined/Dockerfile
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird-server:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm64v8
ids:
- netbird-server
goarch: arm64
use: buildx
dockerfile: combined/Dockerfile
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird-server:{{ .Version }}-arm
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
ids:
- netbird-server
goarch: arm
goarm: 6
use: buildx
dockerfile: combined/Dockerfile
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/reverse-proxy:{{ .Version }}-amd64
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-amd64
ids:
- netbird-proxy
goarch: amd64
use: buildx
dockerfile: proxy/Dockerfile
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/reverse-proxy:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm64v8
ids:
- netbird-proxy
goarch: arm64
use: buildx
dockerfile: proxy/Dockerfile
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/reverse-proxy:{{ .Version }}-arm
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm
ids:
- netbird-proxy
goarch: arm
goarm: 6
use: buildx
dockerfile: proxy/Dockerfile
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
docker_manifests:
- name_template: netbirdio/netbird:{{ .Version }}
image_templates:
@@ -731,18 +598,6 @@ docker_manifests:
- netbirdio/upload:{{ .Version }}-arm
- netbirdio/upload:{{ .Version }}-amd64
- name_template: netbirdio/netbird-server:{{ .Version }}
image_templates:
- netbirdio/netbird-server:{{ .Version }}-arm64v8
- netbirdio/netbird-server:{{ .Version }}-arm
- netbirdio/netbird-server:{{ .Version }}-amd64
- name_template: netbirdio/netbird-server:latest
image_templates:
- netbirdio/netbird-server:{{ .Version }}-arm64v8
- netbirdio/netbird-server:{{ .Version }}-arm
- netbirdio/netbird-server:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/netbird:{{ .Version }}
image_templates:
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
@@ -820,43 +675,6 @@ docker_manifests:
- ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/upload:{{ .Version }}-arm
- ghcr.io/netbirdio/upload:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/netbird-server:{{ .Version }}
image_templates:
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/netbird-server:latest
image_templates:
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
- name_template: netbirdio/reverse-proxy:{{ .Version }}
image_templates:
- netbirdio/reverse-proxy:{{ .Version }}-arm64v8
- netbirdio/reverse-proxy:{{ .Version }}-arm
- netbirdio/reverse-proxy:{{ .Version }}-amd64
- name_template: netbirdio/reverse-proxy:latest
image_templates:
- netbirdio/reverse-proxy:{{ .Version }}-arm64v8
- netbirdio/reverse-proxy:{{ .Version }}-arm
- netbirdio/reverse-proxy:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/reverse-proxy:{{ .Version }}
image_templates:
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/reverse-proxy:latest
image_templates:
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-amd64
brews:
- ids:
- default
@@ -877,7 +695,7 @@ brews:
uploads:
- name: debian
ids:
- netbird_deb
- netbird-deb
mode: archive
target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package=
username: dev@wiretrustee.com
@@ -885,7 +703,7 @@ uploads:
- name: yum
ids:
- netbird_rpm
- netbird-rpm
mode: archive
target: https://pkgs.wiretrustee.com/yum/{{ .Arch }}{{ if .Arm }}{{ .Arm }}{{ end }}
username: dev@wiretrustee.com

View File

@@ -61,7 +61,7 @@ nfpms:
- maintainer: Netbird <dev@netbird.io>
description: Netbird client UI.
homepage: https://netbird.io/
id: netbird_ui_deb
id: netbird-ui-deb
package_name: netbird-ui
builds:
- netbird-ui
@@ -80,7 +80,7 @@ nfpms:
- maintainer: Netbird <dev@netbird.io>
description: Netbird client UI.
homepage: https://netbird.io/
id: netbird_ui_rpm
id: netbird-ui-rpm
package_name: netbird-ui
builds:
- netbird-ui
@@ -95,14 +95,11 @@ nfpms:
dst: /usr/share/pixmaps/netbird.png
dependencies:
- netbird
rpm:
signature:
key_file: '{{ if index .Env "GPG_RPM_KEY_FILE" }}{{ .Env.GPG_RPM_KEY_FILE }}{{ end }}'
uploads:
- name: debian
ids:
- netbird_ui_deb
- netbird-ui-deb
mode: archive
target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package=
username: dev@wiretrustee.com
@@ -110,7 +107,7 @@ uploads:
- name: yum
ids:
- netbird_ui_rpm
- netbird-ui-rpm
mode: archive
target: https://pkgs.wiretrustee.com/yum/{{ .Arch }}{{ if .Arm }}{{ .Arm }}{{ end }}
username: dev@wiretrustee.com

View File

@@ -1,4 +1,4 @@
This BSD3Clause license applies to all parts of the repository except for the directories management/, signal/, relay/ and combined/.
This BSD3Clause license applies to all parts of the repository except for the directories management/, signal/ and relay/.
Those directories are licensed under the GNU Affero General Public License version 3.0 (AGPLv3). See the respective LICENSE files inside each directory.
BSD 3-Clause License

View File

@@ -126,7 +126,6 @@ See a complete [architecture overview](https://docs.netbird.io/about-netbird/how
### Community projects
- [NetBird installer script](https://github.com/physk/netbird-installer)
- [NetBird ansible collection by Dominion Solutions](https://galaxy.ansible.com/ui/repo/published/dominion_solutions/netbird/)
- [netbird-tui](https://github.com/n0pashkov/netbird-tui) — terminal UI for managing NetBird peers, routes, and settings
**Note**: The `main` branch may be in an *unstable or even broken state* during development.
For stable versions, see [releases](https://github.com/netbirdio/netbird/releases).

View File

@@ -4,7 +4,7 @@
# sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client .
# sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest
FROM alpine:3.23.3
FROM alpine:3.23.2
# iproute2: busybox doesn't display ip rules properly
RUN apk add --no-cache \
bash \

View File

@@ -124,7 +124,7 @@ func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroid
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile)
}
@@ -157,7 +157,7 @@ func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsR
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile)
}

View File

@@ -1,19 +1,10 @@
package android
import (
"github.com/netbirdio/netbird/client/internal/lazyconn"
"github.com/netbirdio/netbird/client/internal/peer"
)
import "github.com/netbirdio/netbird/client/internal/peer"
var (
// EnvKeyNBForceRelay Exported for Android java client to force relay connections
// EnvKeyNBForceRelay Exported for Android java client
EnvKeyNBForceRelay = peer.EnvKeyNBForceRelay
// EnvKeyNBLazyConn Exported for Android java client to configure lazy connection
EnvKeyNBLazyConn = lazyconn.EnvEnableLazyConn
// EnvKeyNBInactivityThreshold Exported for Android java client to configure connection inactivity threshold
EnvKeyNBInactivityThreshold = lazyconn.EnvInactivityThreshold
)
// EnvList wraps a Go map for export to Java

View File

@@ -1,280 +0,0 @@
package cmd
import (
"context"
"errors"
"fmt"
"io"
"os"
"os/signal"
"regexp"
"strconv"
"strings"
"syscall"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/util"
)
var pinRegexp = regexp.MustCompile(`^\d{6}$`)
var (
exposePin string
exposePassword string
exposeUserGroups []string
exposeDomain string
exposeNamePrefix string
exposeProtocol string
exposeExternalPort uint16
)
var exposeCmd = &cobra.Command{
Use: "expose <port>",
Short: "Expose a local port via the NetBird reverse proxy",
Args: cobra.ExactArgs(1),
Example: ` netbird expose --with-password safe-pass 8080
netbird expose --protocol tcp 5432
netbird expose --protocol tcp --with-external-port 5433 5432
netbird expose --protocol tls --with-custom-domain tls.example.com 4443`,
RunE: exposeFn,
}
func init() {
exposeCmd.Flags().StringVar(&exposePin, "with-pin", "", "Protect the exposed service with a 6-digit PIN (e.g. --with-pin 123456)")
exposeCmd.Flags().StringVar(&exposePassword, "with-password", "", "Protect the exposed service with a password (e.g. --with-password my-secret)")
exposeCmd.Flags().StringSliceVar(&exposeUserGroups, "with-user-groups", nil, "Restrict access to specific user groups with SSO (e.g. --with-user-groups devops,Backend)")
exposeCmd.Flags().StringVar(&exposeDomain, "with-custom-domain", "", "Custom domain for the exposed service, must be configured to your account (e.g. --with-custom-domain myapp.example.com)")
exposeCmd.Flags().StringVar(&exposeNamePrefix, "with-name-prefix", "", "Prefix for the generated service name (e.g. --with-name-prefix my-app)")
exposeCmd.Flags().StringVar(&exposeProtocol, "protocol", "http", "Protocol to use: http, https, tcp, udp, or tls (e.g. --protocol tcp)")
exposeCmd.Flags().Uint16Var(&exposeExternalPort, "with-external-port", 0, "Public-facing external port on the proxy cluster (defaults to the target port for L4)")
}
// isClusterProtocol returns true for L4/TLS protocols that reject HTTP-style auth flags.
func isClusterProtocol(protocol string) bool {
switch strings.ToLower(protocol) {
case "tcp", "udp", "tls":
return true
default:
return false
}
}
// isPortBasedProtocol returns true for pure port-based protocols (TCP/UDP)
// where domain display doesn't apply. TLS uses SNI so it has a domain.
func isPortBasedProtocol(protocol string) bool {
switch strings.ToLower(protocol) {
case "tcp", "udp":
return true
default:
return false
}
}
// extractPort returns the port portion of a URL like "tcp://host:12345", or
// falls back to the given default formatted as a string.
func extractPort(serviceURL string, fallback uint16) string {
u := serviceURL
if idx := strings.Index(u, "://"); idx != -1 {
u = u[idx+3:]
}
if i := strings.LastIndex(u, ":"); i != -1 {
if p := u[i+1:]; p != "" {
return p
}
}
return strconv.FormatUint(uint64(fallback), 10)
}
// resolveExternalPort returns the effective external port, defaulting to the target port.
func resolveExternalPort(targetPort uint64) uint16 {
if exposeExternalPort != 0 {
return exposeExternalPort
}
return uint16(targetPort)
}
func validateExposeFlags(cmd *cobra.Command, portStr string) (uint64, error) {
port, err := strconv.ParseUint(portStr, 10, 32)
if err != nil {
return 0, fmt.Errorf("invalid port number: %s", portStr)
}
if port == 0 || port > 65535 {
return 0, fmt.Errorf("invalid port number: must be between 1 and 65535")
}
if !isProtocolValid(exposeProtocol) {
return 0, fmt.Errorf("unsupported protocol %q: must be http, https, tcp, udp, or tls", exposeProtocol)
}
if isClusterProtocol(exposeProtocol) {
if exposePin != "" || exposePassword != "" || len(exposeUserGroups) > 0 {
return 0, fmt.Errorf("auth flags (--with-pin, --with-password, --with-user-groups) are not supported for %s protocol", exposeProtocol)
}
} else if cmd.Flags().Changed("with-external-port") {
return 0, fmt.Errorf("--with-external-port is not supported for %s protocol", exposeProtocol)
}
if exposePin != "" && !pinRegexp.MatchString(exposePin) {
return 0, fmt.Errorf("invalid pin: must be exactly 6 digits")
}
if cmd.Flags().Changed("with-password") && exposePassword == "" {
return 0, fmt.Errorf("password cannot be empty")
}
if cmd.Flags().Changed("with-user-groups") && len(exposeUserGroups) == 0 {
return 0, fmt.Errorf("user groups cannot be empty")
}
return port, nil
}
func isProtocolValid(exposeProtocol string) bool {
switch strings.ToLower(exposeProtocol) {
case "http", "https", "tcp", "udp", "tls":
return true
default:
return false
}
}
func exposeFn(cmd *cobra.Command, args []string) error {
SetFlagsFromEnvVars(rootCmd)
if err := util.InitLog(logLevel, util.LogConsole); err != nil {
log.Errorf("failed initializing log %v", err)
return err
}
cmd.Root().SilenceUsage = false
port, err := validateExposeFlags(cmd, args[0])
if err != nil {
return err
}
cmd.Root().SilenceUsage = true
ctx, cancel := context.WithCancel(cmd.Context())
defer cancel()
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigCh
cancel()
}()
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
return fmt.Errorf("connect to daemon: %w", err)
}
defer func() {
if err := conn.Close(); err != nil {
log.Debugf("failed to close daemon connection: %v", err)
}
}()
client := proto.NewDaemonServiceClient(conn)
protocol, err := toExposeProtocol(exposeProtocol)
if err != nil {
return err
}
req := &proto.ExposeServiceRequest{
Port: uint32(port),
Protocol: protocol,
Pin: exposePin,
Password: exposePassword,
UserGroups: exposeUserGroups,
Domain: exposeDomain,
NamePrefix: exposeNamePrefix,
}
if isClusterProtocol(exposeProtocol) {
req.ListenPort = uint32(resolveExternalPort(port))
}
stream, err := client.ExposeService(ctx, req)
if err != nil {
return fmt.Errorf("expose service: %w", err)
}
if err := handleExposeReady(cmd, stream, port); err != nil {
return err
}
return waitForExposeEvents(cmd, ctx, stream)
}
func toExposeProtocol(exposeProtocol string) (proto.ExposeProtocol, error) {
switch strings.ToLower(exposeProtocol) {
case "http":
return proto.ExposeProtocol_EXPOSE_HTTP, nil
case "https":
return proto.ExposeProtocol_EXPOSE_HTTPS, nil
case "tcp":
return proto.ExposeProtocol_EXPOSE_TCP, nil
case "udp":
return proto.ExposeProtocol_EXPOSE_UDP, nil
case "tls":
return proto.ExposeProtocol_EXPOSE_TLS, nil
default:
return 0, fmt.Errorf("unsupported protocol %q: must be http, https, tcp, udp, or tls", exposeProtocol)
}
}
func handleExposeReady(cmd *cobra.Command, stream proto.DaemonService_ExposeServiceClient, port uint64) error {
event, err := stream.Recv()
if err != nil {
return fmt.Errorf("receive expose event: %w", err)
}
ready, ok := event.Event.(*proto.ExposeServiceEvent_Ready)
if !ok {
return fmt.Errorf("unexpected expose event: %T", event.Event)
}
printExposeReady(cmd, ready.Ready, port)
return nil
}
func printExposeReady(cmd *cobra.Command, r *proto.ExposeServiceReady, port uint64) {
cmd.Println("Service exposed successfully!")
cmd.Printf(" Name: %s\n", r.ServiceName)
if r.ServiceUrl != "" {
cmd.Printf(" URL: %s\n", r.ServiceUrl)
}
if r.Domain != "" && !isPortBasedProtocol(exposeProtocol) {
cmd.Printf(" Domain: %s\n", r.Domain)
}
cmd.Printf(" Protocol: %s\n", exposeProtocol)
cmd.Printf(" Internal: %d\n", port)
if isClusterProtocol(exposeProtocol) {
cmd.Printf(" External: %s\n", extractPort(r.ServiceUrl, resolveExternalPort(port)))
}
if r.PortAutoAssigned && exposeExternalPort != 0 {
cmd.Printf("\n Note: requested port %d was reassigned\n", exposeExternalPort)
}
cmd.Println()
cmd.Println("Press Ctrl+C to stop exposing.")
}
func waitForExposeEvents(cmd *cobra.Command, ctx context.Context, stream proto.DaemonService_ExposeServiceClient) error {
for {
_, err := stream.Recv()
if err != nil {
if ctx.Err() != nil {
cmd.Println("\nService stopped.")
//nolint:nilerr
return nil
}
if errors.Is(err, io.EOF) {
return fmt.Errorf("connection to daemon closed unexpectedly")
}
return fmt.Errorf("stream error: %w", err)
}
}
}

View File

@@ -22,7 +22,6 @@ import (
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
daddr "github.com/netbirdio/netbird/client/internal/daemonaddr"
"github.com/netbirdio/netbird/client/internal/profilemanager"
)
@@ -81,15 +80,6 @@ var (
Short: "",
Long: "",
SilenceUsage: true,
PersistentPreRunE: func(cmd *cobra.Command, args []string) error {
SetFlagsFromEnvVars(cmd.Root())
// Don't resolve for service commands — they create the socket, not connect to it.
if !isServiceCmd(cmd) {
daemonAddr = daddr.ResolveUnixDaemonAddr(daemonAddr)
}
return nil
},
}
)
@@ -154,7 +144,6 @@ func init() {
rootCmd.AddCommand(forwardingRulesCmd)
rootCmd.AddCommand(debugCmd)
rootCmd.AddCommand(profileCmd)
rootCmd.AddCommand(exposeCmd)
networksCMD.AddCommand(routesListCmd)
networksCMD.AddCommand(routesSelectCmd, routesDeselectCmd)
@@ -396,6 +385,7 @@ func migrateToNetbird(oldPath, newPath string) bool {
}
func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) {
SetFlagsFromEnvVars(rootCmd)
cmd.SetOut(cmd.OutOrStdout())
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
@@ -408,13 +398,3 @@ func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) {
return conn, nil
}
// isServiceCmd returns true if cmd is the "service" command or a child of it.
func isServiceCmd(cmd *cobra.Command) bool {
for c := cmd; c != nil; c = c.Parent() {
if c.Name() == "service" {
return true
}
}
return false
}

View File

@@ -7,7 +7,7 @@ import (
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal/updater/reposign"
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
)
var (

View File

@@ -6,7 +6,7 @@ import (
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal/updater/reposign"
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
)
const (

View File

@@ -7,7 +7,7 @@ import (
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal/updater/reposign"
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
)
const (

View File

@@ -7,7 +7,7 @@ import (
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal/updater/reposign"
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
)
var (

View File

@@ -197,7 +197,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr
r := peer.NewRecorder(config.ManagementURL.String())
r.GetFullStatus()
connectClient := internal.NewConnectClient(ctx, config, r)
connectClient := internal.NewConnectClient(ctx, config, r, false)
SetupDebugHandler(ctx, config, r, connectClient, "")
return connectClient.Run(nil, util.FindFirstLogPath(logFiles))

View File

@@ -11,7 +11,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal/updater/installer"
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
"github.com/netbirdio/netbird/util"
)

View File

@@ -14,7 +14,6 @@ import (
"github.com/sirupsen/logrus"
wgnetstack "golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/auth"
@@ -82,12 +81,6 @@ type Options struct {
BlockInbound bool
// WireguardPort is the port for the WireGuard interface. Use 0 for a random port.
WireguardPort *int
// MTU is the MTU for the WireGuard interface.
// Valid values are in the range 576..8192 bytes.
// If non-nil, this value overrides any value stored in the config file.
// If nil, the existing config MTU (if non-zero) is preserved; otherwise it defaults to 1280.
// Set to a higher value (e.g. 1400) if carrying QUIC or other protocols that require larger datagrams.
MTU *uint16
}
// validateCredentials checks that exactly one credential type is provided
@@ -119,12 +112,6 @@ func New(opts Options) (*Client, error) {
return nil, err
}
if opts.MTU != nil {
if err := iface.ValidateMTU(*opts.MTU); err != nil {
return nil, fmt.Errorf("invalid MTU: %w", err)
}
}
if opts.LogOutput != nil {
logrus.SetOutput(opts.LogOutput)
}
@@ -164,7 +151,6 @@ func New(opts Options) (*Client, error) {
DisableClientRoutes: &opts.DisableClientRoutes,
BlockInbound: &opts.BlockInbound,
WireguardPort: opts.WireguardPort,
MTU: opts.MTU,
}
if opts.ConfigPath != "" {
config, err = profilemanager.UpdateOrCreateConfig(input)
@@ -216,7 +202,7 @@ func (c *Client) Start(startCtx context.Context) error {
if err, _ := authClient.Login(ctx, c.setupKey, c.jwtToken); err != nil {
return fmt.Errorf("login: %w", err)
}
client := internal.NewConnectClient(ctx, c.config, c.recorder)
client := internal.NewConnectClient(ctx, c.config, c.recorder, false)
client.SetSyncResponsePersistence(true)
// either startup error (permanent backoff err) or nil err (successful engine up)

View File

@@ -358,9 +358,9 @@ func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
// Fast path for IPv4 addresses (4 bytes) - most common case
if len(oldBytes) == 4 && len(newBytes) == 4 {
sum += uint32(^binary.BigEndian.Uint16(oldBytes[0:2]))
sum += uint32(^binary.BigEndian.Uint16(oldBytes[2:4])) //nolint:gosec // length checked above
sum += uint32(^binary.BigEndian.Uint16(oldBytes[2:4]))
sum += uint32(binary.BigEndian.Uint16(newBytes[0:2]))
sum += uint32(binary.BigEndian.Uint16(newBytes[2:4])) //nolint:gosec // length checked above
sum += uint32(binary.BigEndian.Uint16(newBytes[2:4]))
} else {
// Fallback for other lengths
for i := 0; i < len(oldBytes)-1; i += 2 {

View File

@@ -5,18 +5,20 @@ package configurer
import (
"net"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/ipc"
)
func openUAPI(deviceName string) (net.Listener, error) {
uapiSock, err := ipc.UAPIOpen(deviceName)
if err != nil {
log.Errorf("failed to open uapi socket: %v", err)
return nil, err
}
listener, err := ipc.UAPIListen(deviceName, uapiSock)
if err != nil {
_ = uapiSock.Close()
log.Errorf("failed to listen on uapi socket: %v", err)
return nil, err
}

View File

@@ -54,14 +54,6 @@ func NewUSPConfigurer(device *device.Device, deviceName string, activityRecorder
return wgCfg
}
func NewUSPConfigurerNoUAPI(device *device.Device, deviceName string, activityRecorder *bind.ActivityRecorder) *WGUSPConfigurer {
return &WGUSPConfigurer{
device: device,
deviceName: deviceName,
activityRecorder: activityRecorder,
}
}
func (c *WGUSPConfigurer) ConfigureInterface(privateKey string, port int) error {
log.Debugf("adding Wireguard private key")
key, err := wgtypes.ParseKey(privateKey)

View File

@@ -79,7 +79,7 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) {
device.NewLogger(wgLogLevel(), "[netbird] "),
)
t.configurer = configurer.NewUSPConfigurerNoUAPI(t.device, t.name, t.bind.ActivityRecorder())
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.bind.ActivityRecorder())
err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
if cErr := tunIface.Close(); cErr != nil {

View File

@@ -27,8 +27,8 @@ import (
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/internal/updater"
"github.com/netbirdio/netbird/client/internal/updater/installer"
"github.com/netbirdio/netbird/client/internal/updatemanager"
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
nbnet "github.com/netbirdio/netbird/client/net"
cProto "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/ssh"
@@ -44,13 +44,13 @@ import (
)
type ConnectClient struct {
ctx context.Context
config *profilemanager.Config
statusRecorder *peer.Status
ctx context.Context
config *profilemanager.Config
statusRecorder *peer.Status
doInitialAutoUpdate bool
engine *Engine
engineMutex sync.Mutex
updateManager *updater.Manager
engine *Engine
engineMutex sync.Mutex
persistSyncResponse bool
}
@@ -59,19 +59,17 @@ func NewConnectClient(
ctx context.Context,
config *profilemanager.Config,
statusRecorder *peer.Status,
doInitalAutoUpdate bool,
) *ConnectClient {
return &ConnectClient{
ctx: ctx,
config: config,
statusRecorder: statusRecorder,
engineMutex: sync.Mutex{},
ctx: ctx,
config: config,
statusRecorder: statusRecorder,
doInitialAutoUpdate: doInitalAutoUpdate,
engineMutex: sync.Mutex{},
}
}
func (c *ConnectClient) SetUpdateManager(um *updater.Manager) {
c.updateManager = um
}
// Run with main logic.
func (c *ConnectClient) Run(runningChan chan struct{}, logPath string) error {
return c.run(MobileDependency{}, runningChan, logPath)
@@ -189,13 +187,14 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
stateManager := statemanager.New(path)
stateManager.RegisterState(&sshconfig.ShutdownState{})
if c.updateManager != nil {
c.updateManager.CheckUpdateSuccess(c.ctx)
}
updateManager, err := updatemanager.NewManager(c.statusRecorder, stateManager)
if err == nil {
updateManager.CheckUpdateSuccess(c.ctx)
inst := installer.New()
if err := inst.CleanUpInstallerFiles(); err != nil {
log.Errorf("failed to clean up temporary installer file: %v", err)
inst := installer.New()
if err := inst.CleanUpInstallerFiles(); err != nil {
log.Errorf("failed to clean up temporary installer file: %v", err)
}
}
defer c.statusRecorder.ClientStop()
@@ -309,15 +308,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
checks := loginResp.GetChecks()
c.engineMutex.Lock()
engine := NewEngine(engineCtx, cancel, engineConfig, EngineServices{
SignalClient: signalClient,
MgmClient: mgmClient,
RelayManager: relayManager,
StatusRecorder: c.statusRecorder,
Checks: checks,
StateManager: stateManager,
UpdateManager: c.updateManager,
}, mobileDependency)
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks, stateManager)
engine.SetSyncResponsePersistence(c.persistSyncResponse)
c.engine = engine
c.engineMutex.Unlock()
@@ -327,15 +318,21 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
return wrapErr(err)
}
if loginResp.PeerConfig != nil && loginResp.PeerConfig.AutoUpdate != nil {
// AutoUpdate will be true when the user click on "Connect" menu on the UI
if c.doInitialAutoUpdate {
log.Infof("start engine by ui, run auto-update check")
c.engine.InitialUpdateHandling(loginResp.PeerConfig.AutoUpdate)
c.doInitialAutoUpdate = false
}
}
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
state.Set(StatusConnected)
if runningChan != nil {
select {
case <-runningChan:
default:
close(runningChan)
}
close(runningChan)
runningChan = nil
}
<-engineCtx.Done()

View File

@@ -1,60 +0,0 @@
//go:build !windows && !ios && !android
package daemonaddr
import (
"os"
"path/filepath"
"strings"
log "github.com/sirupsen/logrus"
)
var scanDir = "/var/run/netbird"
// setScanDir overrides the scan directory (used by tests).
func setScanDir(dir string) {
scanDir = dir
}
// ResolveUnixDaemonAddr checks whether the default Unix socket exists and, if not,
// scans /var/run/netbird/ for a single .sock file to use instead. This handles the
// mismatch between the netbird@.service template (which places the socket under
// /var/run/netbird/<instance>.sock) and the CLI default (/var/run/netbird.sock).
func ResolveUnixDaemonAddr(addr string) string {
if !strings.HasPrefix(addr, "unix://") {
return addr
}
sockPath := strings.TrimPrefix(addr, "unix://")
if _, err := os.Stat(sockPath); err == nil {
return addr
}
entries, err := os.ReadDir(scanDir)
if err != nil {
return addr
}
var found []string
for _, e := range entries {
if e.IsDir() {
continue
}
if strings.HasSuffix(e.Name(), ".sock") {
found = append(found, filepath.Join(scanDir, e.Name()))
}
}
switch len(found) {
case 1:
resolved := "unix://" + found[0]
log.Debugf("Default daemon socket not found, using discovered socket: %s", resolved)
return resolved
case 0:
return addr
default:
log.Warnf("Default daemon socket not found and multiple sockets discovered in %s; pass --daemon-addr explicitly", scanDir)
return addr
}
}

View File

@@ -1,8 +0,0 @@
//go:build windows || ios || android
package daemonaddr
// ResolveUnixDaemonAddr is a no-op on platforms that don't use Unix sockets.
func ResolveUnixDaemonAddr(addr string) string {
return addr
}

View File

@@ -1,121 +0,0 @@
//go:build !windows && !ios && !android
package daemonaddr
import (
"os"
"path/filepath"
"testing"
)
// createSockFile creates a regular file with a .sock extension.
// ResolveUnixDaemonAddr uses os.Stat (not net.Dial), so a regular file is
// sufficient and avoids Unix socket path-length limits on macOS.
func createSockFile(t *testing.T, path string) {
t.Helper()
if err := os.WriteFile(path, nil, 0o600); err != nil {
t.Fatalf("failed to create test sock file at %s: %v", path, err)
}
}
func TestResolveUnixDaemonAddr_DefaultExists(t *testing.T) {
tmp := t.TempDir()
sock := filepath.Join(tmp, "netbird.sock")
createSockFile(t, sock)
addr := "unix://" + sock
got := ResolveUnixDaemonAddr(addr)
if got != addr {
t.Errorf("expected %s, got %s", addr, got)
}
}
func TestResolveUnixDaemonAddr_SingleDiscovered(t *testing.T) {
tmp := t.TempDir()
// Default socket does not exist
defaultAddr := "unix://" + filepath.Join(tmp, "netbird.sock")
// Create a scan dir with one socket
sd := filepath.Join(tmp, "netbird")
if err := os.MkdirAll(sd, 0o755); err != nil {
t.Fatal(err)
}
instanceSock := filepath.Join(sd, "main.sock")
createSockFile(t, instanceSock)
origScanDir := scanDir
setScanDir(sd)
t.Cleanup(func() { setScanDir(origScanDir) })
got := ResolveUnixDaemonAddr(defaultAddr)
expected := "unix://" + instanceSock
if got != expected {
t.Errorf("expected %s, got %s", expected, got)
}
}
func TestResolveUnixDaemonAddr_MultipleDiscovered(t *testing.T) {
tmp := t.TempDir()
defaultAddr := "unix://" + filepath.Join(tmp, "netbird.sock")
sd := filepath.Join(tmp, "netbird")
if err := os.MkdirAll(sd, 0o755); err != nil {
t.Fatal(err)
}
createSockFile(t, filepath.Join(sd, "main.sock"))
createSockFile(t, filepath.Join(sd, "other.sock"))
origScanDir := scanDir
setScanDir(sd)
t.Cleanup(func() { setScanDir(origScanDir) })
got := ResolveUnixDaemonAddr(defaultAddr)
if got != defaultAddr {
t.Errorf("expected original %s, got %s", defaultAddr, got)
}
}
func TestResolveUnixDaemonAddr_NoSocketsFound(t *testing.T) {
tmp := t.TempDir()
defaultAddr := "unix://" + filepath.Join(tmp, "netbird.sock")
sd := filepath.Join(tmp, "netbird")
if err := os.MkdirAll(sd, 0o755); err != nil {
t.Fatal(err)
}
origScanDir := scanDir
setScanDir(sd)
t.Cleanup(func() { setScanDir(origScanDir) })
got := ResolveUnixDaemonAddr(defaultAddr)
if got != defaultAddr {
t.Errorf("expected original %s, got %s", defaultAddr, got)
}
}
func TestResolveUnixDaemonAddr_NonUnixAddr(t *testing.T) {
addr := "tcp://127.0.0.1:41731"
got := ResolveUnixDaemonAddr(addr)
if got != addr {
t.Errorf("expected %s, got %s", addr, got)
}
}
func TestResolveUnixDaemonAddr_ScanDirMissing(t *testing.T) {
tmp := t.TempDir()
defaultAddr := "unix://" + filepath.Join(tmp, "netbird.sock")
origScanDir := scanDir
setScanDir(filepath.Join(tmp, "nonexistent"))
t.Cleanup(func() { setScanDir(origScanDir) })
got := ResolveUnixDaemonAddr(defaultAddr)
if got != defaultAddr {
t.Errorf("expected original %s, got %s", defaultAddr, got)
}
}

View File

@@ -27,7 +27,7 @@ import (
"github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/updater/installer"
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
nbstatus "github.com/netbirdio/netbird/client/status"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/util"

View File

@@ -14,8 +14,6 @@ import (
"strings"
"sync"
"github.com/hashicorp/go-multierror"
nberrors "github.com/netbirdio/netbird/client/errors"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
@@ -24,7 +22,6 @@ import (
const (
netbirdDNSStateKeyFormat = "State:/Network/Service/NetBird-%s/DNS"
netbirdDNSStateKeyIndexedFormat = "State:/Network/Service/NetBird-%s-%d/DNS"
globalIPv4State = "State:/Network/Global/IPv4"
primaryServiceStateKeyFormat = "State:/Network/Service/%s/DNS"
keySupplementalMatchDomains = "SupplementalMatchDomains"
@@ -38,14 +35,6 @@ const (
searchSuffix = "Search"
matchSuffix = "Match"
localSuffix = "Local"
// maxDomainsPerResolverEntry is the max number of domains per scutil resolver key.
// scutil's d.add has maxArgs=101 (key + * + 99 values), so 99 is the hard cap.
maxDomainsPerResolverEntry = 50
// maxDomainBytesPerResolverEntry is the max total bytes of domain strings per key.
// scutil has an undocumented ~2048 byte value buffer; we stay well under it.
maxDomainBytesPerResolverEntry = 1500
)
type systemConfigurator struct {
@@ -95,23 +84,28 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
searchDomains = append(searchDomains, strings.TrimSuffix(""+dConf.Domain, "."))
}
if err := s.removeKeysContaining(matchSuffix); err != nil {
log.Warnf("failed to remove old match keys: %v", err)
}
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
var err error
if len(matchDomains) != 0 {
if err := s.addBatchedDomains(matchSuffix, matchDomains, config.ServerIP, config.ServerPort, false); err != nil {
return fmt.Errorf("add match domains: %w", err)
}
err = s.addMatchDomains(matchKey, strings.Join(matchDomains, " "), config.ServerIP, config.ServerPort)
} else {
log.Infof("removing match domains from the system")
err = s.removeKeyFromSystemConfig(matchKey)
}
if err != nil {
return fmt.Errorf("add match domains: %w", err)
}
s.updateState(stateManager)
if err := s.removeKeysContaining(searchSuffix); err != nil {
log.Warnf("failed to remove old search keys: %v", err)
}
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
if len(searchDomains) != 0 {
if err := s.addBatchedDomains(searchSuffix, searchDomains, config.ServerIP, config.ServerPort, true); err != nil {
return fmt.Errorf("add search domains: %w", err)
}
err = s.addSearchDomains(searchKey, strings.Join(searchDomains, " "), config.ServerIP, config.ServerPort)
} else {
log.Infof("removing search domains from the system")
err = s.removeKeyFromSystemConfig(searchKey)
}
if err != nil {
return fmt.Errorf("add search domains: %w", err)
}
s.updateState(stateManager)
@@ -155,7 +149,8 @@ func (s *systemConfigurator) restoreHostDNS() error {
func (s *systemConfigurator) getRemovableKeysWithDefaults() []string {
if len(s.createdKeys) == 0 {
return s.discoverExistingKeys()
// return defaults for startup calls
return []string{getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix), getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)}
}
keys := make([]string, 0, len(s.createdKeys))
@@ -165,47 +160,6 @@ func (s *systemConfigurator) getRemovableKeysWithDefaults() []string {
return keys
}
// discoverExistingKeys probes scutil for all NetBird DNS keys that may exist.
// This handles the case where createdKeys is empty (e.g., state file lost after unclean shutdown).
func (s *systemConfigurator) discoverExistingKeys() []string {
dnsKeys, err := getSystemDNSKeys()
if err != nil {
log.Errorf("failed to get system DNS keys: %v", err)
return nil
}
var keys []string
for _, suffix := range []string{searchSuffix, matchSuffix, localSuffix} {
key := getKeyWithInput(netbirdDNSStateKeyFormat, suffix)
if strings.Contains(dnsKeys, key) {
keys = append(keys, key)
}
}
for _, suffix := range []string{searchSuffix, matchSuffix} {
for i := 0; ; i++ {
key := fmt.Sprintf(netbirdDNSStateKeyIndexedFormat, suffix, i)
if !strings.Contains(dnsKeys, key) {
break
}
keys = append(keys, key)
}
}
return keys
}
// getSystemDNSKeys gets all DNS keys
func getSystemDNSKeys() (string, error) {
command := "list .*DNS\nquit\n"
out, err := runSystemConfigCommand(command)
if err != nil {
return "", err
}
return string(out), nil
}
func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error {
line := buildRemoveKeyOperation(key)
_, err := runSystemConfigCommand(wrapCommand(line))
@@ -230,11 +184,12 @@ func (s *systemConfigurator) addLocalDNS() error {
return nil
}
domainsStr := strings.Join(s.systemDNSSettings.Domains, " ")
if err := s.addDNSState(localKey, domainsStr, s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort, true); err != nil {
return fmt.Errorf("add local dns state: %w", err)
if err := s.addSearchDomains(
localKey,
strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort,
); err != nil {
return fmt.Errorf("add search domains: %w", err)
}
s.createdKeys[localKey] = struct{}{}
return nil
}
@@ -325,77 +280,28 @@ func (s *systemConfigurator) getOriginalNameservers() []netip.Addr {
return slices.Clone(s.origNameservers)
}
// splitDomainsIntoBatches splits domains into batches respecting both element count and byte size limits.
func splitDomainsIntoBatches(domains []string) [][]string {
if len(domains) == 0 {
return nil
func (s *systemConfigurator) addSearchDomains(key, domains string, ip netip.Addr, port int) error {
err := s.addDNSState(key, domains, ip, port, true)
if err != nil {
return fmt.Errorf("add dns state: %w", err)
}
var batches [][]string
var current []string
currentBytes := 0
log.Infof("added %d search domains to the state. Domain list: %s", len(strings.Split(domains, " ")), domains)
for _, d := range domains {
domainLen := len(d)
newBytes := currentBytes + domainLen
if currentBytes > 0 {
newBytes++ // space separator
}
s.createdKeys[key] = struct{}{}
if len(current) > 0 && (len(current) >= maxDomainsPerResolverEntry || newBytes > maxDomainBytesPerResolverEntry) {
batches = append(batches, current)
current = nil
currentBytes = 0
}
current = append(current, d)
if currentBytes > 0 {
currentBytes += 1 + domainLen
} else {
currentBytes = domainLen
}
}
if len(current) > 0 {
batches = append(batches, current)
}
return batches
return nil
}
// removeKeysContaining removes all created keys that contain the given substring.
func (s *systemConfigurator) removeKeysContaining(suffix string) error {
var toRemove []string
for key := range s.createdKeys {
if strings.Contains(key, suffix) {
toRemove = append(toRemove, key)
}
}
var multiErr *multierror.Error
for _, key := range toRemove {
if err := s.removeKeyFromSystemConfig(key); err != nil {
multiErr = multierror.Append(multiErr, fmt.Errorf("couldn't remove key %s: %w", key, err))
}
}
return nberrors.FormatErrorOrNil(multiErr)
}
// addBatchedDomains splits domains into batches and creates indexed scutil keys for each batch.
func (s *systemConfigurator) addBatchedDomains(suffix string, domains []string, ip netip.Addr, port int, enableSearch bool) error {
batches := splitDomainsIntoBatches(domains)
for i, batch := range batches {
key := fmt.Sprintf(netbirdDNSStateKeyIndexedFormat, suffix, i)
domainsStr := strings.Join(batch, " ")
if err := s.addDNSState(key, domainsStr, ip, port, enableSearch); err != nil {
return fmt.Errorf("add dns state for batch %d: %w", i, err)
}
s.createdKeys[key] = struct{}{}
func (s *systemConfigurator) addMatchDomains(key, domains string, dnsServer netip.Addr, port int) error {
err := s.addDNSState(key, domains, dnsServer, port, false)
if err != nil {
return fmt.Errorf("add dns state: %w", err)
}
log.Infof("added %d %s domains across %d resolver entries", len(domains), suffix, len(batches))
log.Infof("added %d match domains to the state. Domain list: %s", len(strings.Split(domains, " ")), domains)
s.createdKeys[key] = struct{}{}
return nil
}
@@ -458,6 +364,7 @@ func (s *systemConfigurator) flushDNSCache() error {
if out, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("restart mDNSResponder: %w, output: %s", err, out)
}
log.Info("flushed DNS cache")
return nil
}

View File

@@ -3,10 +3,7 @@
package dns
import (
"bufio"
"bytes"
"context"
"fmt"
"net/netip"
"os/exec"
"path/filepath"
@@ -52,22 +49,17 @@ func TestDarwinDNSUncleanShutdownCleanup(t *testing.T) {
require.NoError(t, sm.PersistState(context.Background()))
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
// Collect all created keys for cleanup verification
createdKeys := make([]string, 0, len(configurator.createdKeys))
for key := range configurator.createdKeys {
createdKeys = append(createdKeys, key)
}
defer func() {
for _, key := range createdKeys {
for _, key := range []string{searchKey, matchKey, localKey} {
_ = removeTestDNSKey(key)
}
_ = removeTestDNSKey(localKey)
}()
for _, key := range createdKeys {
for _, key := range []string{searchKey, matchKey, localKey} {
exists, err := checkDNSKeyExists(key)
require.NoError(t, err)
if exists {
@@ -91,223 +83,13 @@ func TestDarwinDNSUncleanShutdownCleanup(t *testing.T) {
err = shutdownState.Cleanup()
require.NoError(t, err)
for _, key := range createdKeys {
for _, key := range []string{searchKey, matchKey, localKey} {
exists, err := checkDNSKeyExists(key)
require.NoError(t, err)
assert.False(t, exists, "Key %s should NOT exist after cleanup", key)
}
}
// generateShortDomains generates domains like a.com, b.com, ..., aa.com, ab.com, etc.
func generateShortDomains(count int) []string {
domains := make([]string, 0, count)
for i := range count {
label := ""
n := i
for {
label = string(rune('a'+n%26)) + label
n = n/26 - 1
if n < 0 {
break
}
}
domains = append(domains, label+".com")
}
return domains
}
// generateLongDomains generates domains like subdomain-000.department.organization-name.example.com
func generateLongDomains(count int) []string {
domains := make([]string, 0, count)
for i := range count {
domains = append(domains, fmt.Sprintf("subdomain-%03d.department.organization-name.example.com", i))
}
return domains
}
// readDomainsFromKey reads the SupplementalMatchDomains array back from scutil for a given key.
func readDomainsFromKey(t *testing.T, key string) []string {
t.Helper()
cmd := exec.Command(scutilPath)
cmd.Stdin = strings.NewReader(fmt.Sprintf("open\nshow %s\nquit\n", key))
out, err := cmd.Output()
require.NoError(t, err, "scutil show should succeed")
var domains []string
inArray := false
scanner := bufio.NewScanner(bytes.NewReader(out))
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if strings.HasPrefix(line, "SupplementalMatchDomains") && strings.Contains(line, "<array>") {
inArray = true
continue
}
if inArray {
if line == "}" {
break
}
// lines look like: "0 : a.com"
parts := strings.SplitN(line, " : ", 2)
if len(parts) == 2 {
domains = append(domains, parts[1])
}
}
}
require.NoError(t, scanner.Err())
return domains
}
func TestSplitDomainsIntoBatches(t *testing.T) {
tests := []struct {
name string
domains []string
expectedCount int
checkAllPresent bool
}{
{
name: "empty",
domains: nil,
expectedCount: 0,
},
{
name: "under_limit",
domains: generateShortDomains(10),
expectedCount: 1,
checkAllPresent: true,
},
{
name: "at_element_limit",
domains: generateShortDomains(50),
expectedCount: 1,
checkAllPresent: true,
},
{
name: "over_element_limit",
domains: generateShortDomains(51),
expectedCount: 2,
checkAllPresent: true,
},
{
name: "triple_element_limit",
domains: generateShortDomains(150),
expectedCount: 3,
checkAllPresent: true,
},
{
name: "long_domains_hit_byte_limit",
domains: generateLongDomains(50),
checkAllPresent: true,
},
{
name: "500_short_domains",
domains: generateShortDomains(500),
expectedCount: 10,
checkAllPresent: true,
},
{
name: "500_long_domains",
domains: generateLongDomains(500),
checkAllPresent: true,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
batches := splitDomainsIntoBatches(tc.domains)
if tc.expectedCount > 0 {
assert.Len(t, batches, tc.expectedCount, "expected %d batches", tc.expectedCount)
}
// Verify each batch respects limits
for i, batch := range batches {
assert.LessOrEqual(t, len(batch), maxDomainsPerResolverEntry,
"batch %d exceeds element limit", i)
totalBytes := 0
for j, d := range batch {
if j > 0 {
totalBytes++
}
totalBytes += len(d)
}
assert.LessOrEqual(t, totalBytes, maxDomainBytesPerResolverEntry,
"batch %d exceeds byte limit (%d bytes)", i, totalBytes)
}
if tc.checkAllPresent {
var all []string
for _, batch := range batches {
all = append(all, batch...)
}
assert.Equal(t, tc.domains, all, "all domains should be present in order")
}
})
}
}
// TestMatchDomainBatching writes increasing numbers of domains via the batching mechanism
// and verifies all domains are readable across multiple scutil keys.
func TestMatchDomainBatching(t *testing.T) {
if testing.Short() {
t.Skip("skipping scutil integration test in short mode")
}
testCases := []struct {
name string
count int
generator func(int) []string
}{
{"short_10", 10, generateShortDomains},
{"short_50", 50, generateShortDomains},
{"short_100", 100, generateShortDomains},
{"short_200", 200, generateShortDomains},
{"short_500", 500, generateShortDomains},
{"long_10", 10, generateLongDomains},
{"long_50", 50, generateLongDomains},
{"long_100", 100, generateLongDomains},
{"long_200", 200, generateLongDomains},
{"long_500", 500, generateLongDomains},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
configurator := &systemConfigurator{
createdKeys: make(map[string]struct{}),
}
defer func() {
for key := range configurator.createdKeys {
_ = removeTestDNSKey(key)
}
}()
domains := tc.generator(tc.count)
err := configurator.addBatchedDomains(matchSuffix, domains, netip.MustParseAddr("100.64.0.1"), 53, false)
require.NoError(t, err)
batches := splitDomainsIntoBatches(domains)
t.Logf("wrote %d domains across %d batched keys", tc.count, len(batches))
// Read back all domains from all batched keys
var got []string
for i := range batches {
key := fmt.Sprintf(netbirdDNSStateKeyIndexedFormat, matchSuffix, i)
exists, err := checkDNSKeyExists(key)
require.NoError(t, err)
require.True(t, exists, "key %s should exist", key)
got = append(got, readDomainsFromKey(t, key)...)
}
t.Logf("read back %d/%d domains from %d keys", len(got), tc.count, len(batches))
assert.Equal(t, tc.count, len(got), "all domains should be readable")
assert.Equal(t, domains, got, "domains should match in order")
})
}
}
func checkDNSKeyExists(key string) (bool, error) {
cmd := exec.Command(scutilPath)
cmd.Stdin = strings.NewReader("show " + key + "\nquit\n")
@@ -376,15 +158,15 @@ func setupTestConfigurator(t *testing.T) (*systemConfigurator, *statemanager.Man
createdKeys: make(map[string]struct{}),
}
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
cleanup := func() {
_ = sm.Stop(context.Background())
for key := range configurator.createdKeys {
for _, key := range []string{searchKey, matchKey, localKey} {
_ = removeTestDNSKey(key)
}
// Also clean up old-format keys and local key in case they exist
_ = removeTestDNSKey(getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix))
_ = removeTestDNSKey(getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix))
_ = removeTestDNSKey(getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix))
}
return configurator, sm, cleanup

View File

@@ -42,8 +42,6 @@ const (
dnsPolicyConfigConfigOptionsKey = "ConfigOptions"
dnsPolicyConfigConfigOptionsValue = 0x8
nrptMaxDomainsPerRule = 50
interfaceConfigPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces`
interfaceConfigNameServerKey = "NameServer"
interfaceConfigSearchListKey = "SearchList"
@@ -200,11 +198,10 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
if len(matchDomains) != 0 {
count, err := r.addDNSMatchPolicy(matchDomains, config.ServerIP)
// Update count even on error to ensure cleanup covers partially created rules
r.nrptEntryCount = count
if err != nil {
return fmt.Errorf("add dns match policy: %w", err)
}
r.nrptEntryCount = count
} else {
r.nrptEntryCount = 0
}
@@ -242,33 +239,23 @@ func (r *registryConfigurator) addDNSSetupForAll(ip netip.Addr) error {
func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr) (int, error) {
// if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored
// see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745
for i, domain := range domains {
localPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)
gpoPath := fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, i)
// We need to batch domains into chunks and create one NRPT rule per batch.
ruleIndex := 0
for i := 0; i < len(domains); i += nrptMaxDomainsPerRule {
end := i + nrptMaxDomainsPerRule
if end > len(domains) {
end = len(domains)
singleDomain := []string{domain}
if err := r.configureDNSPolicy(localPath, singleDomain, ip); err != nil {
return i, fmt.Errorf("configure DNS Local policy for domain %s: %w", domain, err)
}
batchDomains := domains[i:end]
localPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, ruleIndex)
gpoPath := fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, ruleIndex)
if err := r.configureDNSPolicy(localPath, batchDomains, ip); err != nil {
return ruleIndex, fmt.Errorf("configure DNS Local policy for rule %d: %w", ruleIndex, err)
}
// Increment immediately so the caller's cleanup path knows about this rule
ruleIndex++
if r.gpo {
if err := r.configureDNSPolicy(gpoPath, batchDomains, ip); err != nil {
return ruleIndex, fmt.Errorf("configure gpo DNS policy for rule %d: %w", ruleIndex-1, err)
if err := r.configureDNSPolicy(gpoPath, singleDomain, ip); err != nil {
return i, fmt.Errorf("configure gpo DNS policy: %w", err)
}
}
log.Debugf("added NRPT rule %d with %d domains", ruleIndex-1, len(batchDomains))
log.Debugf("added NRPT entry for domain: %s", domain)
}
if r.gpo {
@@ -277,8 +264,8 @@ func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr
}
}
log.Infof("added %d NRPT rules for %d domains", ruleIndex, len(domains))
return ruleIndex, nil
log.Infof("added %d separate NRPT entries. Domain list: %s", len(domains), domains)
return len(domains), nil
}
func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []string, ip netip.Addr) error {

View File

@@ -12,7 +12,6 @@ import (
// TestNRPTEntriesCleanupOnConfigChange tests that old NRPT entries are properly cleaned up
// when the number of match domains decreases between configuration changes.
// With batching enabled (50 domains per rule), we need enough domains to create multiple rules.
func TestNRPTEntriesCleanupOnConfigChange(t *testing.T) {
if testing.Short() {
t.Skip("skipping registry integration test in short mode")
@@ -38,60 +37,51 @@ func TestNRPTEntriesCleanupOnConfigChange(t *testing.T) {
gpo: false,
}
// Create 125 domains which will result in 3 NRPT rules (50+50+25)
domains125 := make([]DomainConfig, 125)
for i := 0; i < 125; i++ {
domains125[i] = DomainConfig{
Domain: fmt.Sprintf("domain%d.com", i+1),
MatchOnly: true,
}
}
config125 := HostDNSConfig{
config5 := HostDNSConfig{
ServerIP: testIP,
Domains: domains125,
Domains: []DomainConfig{
{Domain: "domain1.com", MatchOnly: true},
{Domain: "domain2.com", MatchOnly: true},
{Domain: "domain3.com", MatchOnly: true},
{Domain: "domain4.com", MatchOnly: true},
{Domain: "domain5.com", MatchOnly: true},
},
}
err = cfg.applyDNSConfig(config125, nil)
err = cfg.applyDNSConfig(config5, nil)
require.NoError(t, err)
// Verify 3 NRPT rules exist
assert.Equal(t, 3, cfg.nrptEntryCount, "Should create 3 NRPT rules for 125 domains")
for i := 0; i < 3; i++ {
// Verify all 5 entries exist
for i := 0; i < 5; i++ {
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
require.NoError(t, err)
assert.True(t, exists, "NRPT rule %d should exist after first config", i)
assert.True(t, exists, "Entry %d should exist after first config", i)
}
// Reduce to 75 domains which will result in 2 NRPT rules (50+25)
domains75 := make([]DomainConfig, 75)
for i := 0; i < 75; i++ {
domains75[i] = DomainConfig{
Domain: fmt.Sprintf("domain%d.com", i+1),
MatchOnly: true,
}
}
config75 := HostDNSConfig{
config2 := HostDNSConfig{
ServerIP: testIP,
Domains: domains75,
Domains: []DomainConfig{
{Domain: "domain1.com", MatchOnly: true},
{Domain: "domain2.com", MatchOnly: true},
},
}
err = cfg.applyDNSConfig(config75, nil)
err = cfg.applyDNSConfig(config2, nil)
require.NoError(t, err)
// Verify first 2 NRPT rules exist
assert.Equal(t, 2, cfg.nrptEntryCount, "Should create 2 NRPT rules for 75 domains")
// Verify first 2 entries exist
for i := 0; i < 2; i++ {
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
require.NoError(t, err)
assert.True(t, exists, "NRPT rule %d should exist after second config", i)
assert.True(t, exists, "Entry %d should exist after second config", i)
}
// Verify rule 2 is cleaned up
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, 2))
require.NoError(t, err)
assert.False(t, exists, "NRPT rule 2 should NOT exist after reducing to 75 domains")
// Verify entries 2-4 are cleaned up
for i := 2; i < 5; i++ {
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
require.NoError(t, err)
assert.False(t, exists, "Entry %d should NOT exist after reducing to 2 domains", i)
}
}
func registryKeyExists(path string) (bool, error) {
@@ -107,106 +97,6 @@ func registryKeyExists(path string) (bool, error) {
}
func cleanupRegistryKeys(*testing.T) {
// Clean up more entries to account for batching tests with many domains
cfg := &registryConfigurator{nrptEntryCount: 20}
cfg := &registryConfigurator{nrptEntryCount: 10}
_ = cfg.removeDNSMatchPolicies()
}
// TestNRPTDomainBatching verifies that domains are correctly batched into NRPT rules.
func TestNRPTDomainBatching(t *testing.T) {
if testing.Short() {
t.Skip("skipping registry integration test in short mode")
}
defer cleanupRegistryKeys(t)
cleanupRegistryKeys(t)
testIP := netip.MustParseAddr("100.64.0.1")
// Create a test interface registry key so updateSearchDomains doesn't fail
testGUID := "{12345678-1234-1234-1234-123456789ABC}"
interfacePath := `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\` + testGUID
testKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, interfacePath, registry.SET_VALUE)
require.NoError(t, err, "Should create test interface registry key")
testKey.Close()
defer func() {
_ = registry.DeleteKey(registry.LOCAL_MACHINE, interfacePath)
}()
cfg := &registryConfigurator{
guid: testGUID,
gpo: false,
}
testCases := []struct {
name string
domainCount int
expectedRuleCount int
}{
{
name: "Less than 50 domains (single rule)",
domainCount: 30,
expectedRuleCount: 1,
},
{
name: "Exactly 50 domains (single rule)",
domainCount: 50,
expectedRuleCount: 1,
},
{
name: "51 domains (two rules)",
domainCount: 51,
expectedRuleCount: 2,
},
{
name: "100 domains (two rules)",
domainCount: 100,
expectedRuleCount: 2,
},
{
name: "125 domains (three rules: 50+50+25)",
domainCount: 125,
expectedRuleCount: 3,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Clean up before each subtest
cleanupRegistryKeys(t)
// Generate domains
domains := make([]DomainConfig, tc.domainCount)
for i := 0; i < tc.domainCount; i++ {
domains[i] = DomainConfig{
Domain: fmt.Sprintf("domain%d.com", i+1),
MatchOnly: true,
}
}
config := HostDNSConfig{
ServerIP: testIP,
Domains: domains,
}
err := cfg.applyDNSConfig(config, nil)
require.NoError(t, err)
// Verify that exactly expectedRuleCount rules were created
assert.Equal(t, tc.expectedRuleCount, cfg.nrptEntryCount,
"Should create %d NRPT rules for %d domains", tc.expectedRuleCount, tc.domainCount)
// Verify all expected rules exist
for i := 0; i < tc.expectedRuleCount; i++ {
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
require.NoError(t, err)
assert.True(t, exists, "NRPT rule %d should exist", i)
}
// Verify no extra rules were created
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, tc.expectedRuleCount))
require.NoError(t, err)
assert.False(t, exists, "No NRPT rule should exist at index %d", tc.expectedRuleCount)
})
}
}

View File

@@ -77,7 +77,7 @@ func (d *Resolver) ID() types.HandlerID {
return "local-resolver"
}
func (d *Resolver) ProbeAvailability(context.Context) {}
func (d *Resolver) ProbeAvailability() {}
// ServeDNS handles a DNS request
func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {

View File

@@ -1263,9 +1263,9 @@ func TestLocalResolver_AuthoritativeFlag(t *testing.T) {
})
}
// TestLocalResolver_Stop tests cleanup on GracefullyStop
// TestLocalResolver_Stop tests cleanup on Stop
func TestLocalResolver_Stop(t *testing.T) {
t.Run("GracefullyStop clears all state", func(t *testing.T) {
t.Run("Stop clears all state", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.CustomZone{{
Domain: "example.com.",
@@ -1285,7 +1285,7 @@ func TestLocalResolver_Stop(t *testing.T) {
assert.False(t, resolver.isInManagedZone("host.example.com."))
})
t.Run("GracefullyStop is safe to call multiple times", func(t *testing.T) {
t.Run("Stop is safe to call multiple times", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.CustomZone{{
Domain: "example.com.",
@@ -1299,7 +1299,7 @@ func TestLocalResolver_Stop(t *testing.T) {
resolver.Stop()
})
t.Run("GracefullyStop cancels in-flight external resolution", func(t *testing.T) {
t.Run("Stop cancels in-flight external resolution", func(t *testing.T) {
resolver := NewResolver()
lookupStarted := make(chan struct{})

View File

@@ -376,9 +376,9 @@ func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.Serve
}
}
// Flow receiver domain is intentionally excluded from caching.
// Cloud providers may rotate the IP behind this domain; a stale cached record
// causes TLS certificate verification failures on reconnect.
if serverDomains.Flow != "" {
domains = append(domains, serverDomains.Flow)
}
for _, stun := range serverDomains.Stuns {
if stun != "" {

View File

@@ -391,8 +391,7 @@ func TestResolver_PartialUpdateAddsNewTypePreservesExisting(t *testing.T) {
}
assert.Len(t, resolver.GetCachedDomains(), 3)
// Update with partial ServerDomains (only flow domain - flow is intentionally excluded from
// caching to prevent TLS failures from stale records, so all existing domains are preserved)
// Update with partial ServerDomains (only flow domain - new type, should preserve all existing)
partialDomains := dnsconfig.ServerDomains{
Flow: "github.com",
}
@@ -401,10 +400,10 @@ func TestResolver_PartialUpdateAddsNewTypePreservesExisting(t *testing.T) {
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
}
assert.Len(t, removedDomains, 0, "Should not remove any domains when only flow domain is provided")
assert.Len(t, removedDomains, 0, "Should not remove any domains when adding new type")
finalDomains := resolver.GetCachedDomains()
assert.Len(t, finalDomains, 3, "Flow domain is not cached; all original domains should be preserved")
assert.Len(t, finalDomains, 4, "Should have all original domains plus new flow domain")
domainStrings := make([]string, len(finalDomains))
for i, d := range finalDomains {
@@ -413,5 +412,5 @@ func TestResolver_PartialUpdateAddsNewTypePreservesExisting(t *testing.T) {
assert.Contains(t, domainStrings, "example.org")
assert.Contains(t, domainStrings, "google.com")
assert.Contains(t, domainStrings, "cloudflare.com")
assert.NotContains(t, domainStrings, "github.com")
assert.Contains(t, domainStrings, "github.com")
}

View File

@@ -84,18 +84,3 @@ func (m *MockServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error {
func (m *MockServer) PopulateManagementDomain(mgmtURL *url.URL) error {
return nil
}
// BeginBatch mock implementation of BeginBatch from Server interface
func (m *MockServer) BeginBatch() {
// Mock implementation - no-op
}
// EndBatch mock implementation of EndBatch from Server interface
func (m *MockServer) EndBatch() {
// Mock implementation - no-op
}
// CancelBatch mock implementation of CancelBatch from Server interface
func (m *MockServer) CancelBatch() {
// Mock implementation - no-op
}

View File

@@ -45,9 +45,6 @@ type IosDnsManager interface {
type Server interface {
RegisterHandler(domains domain.List, handler dns.Handler, priority int)
DeregisterHandler(domains domain.List, priority int)
BeginBatch()
EndBatch()
CancelBatch()
Initialize() error
Stop()
DnsIP() netip.Addr
@@ -90,7 +87,6 @@ type DefaultServer struct {
currentConfigHash uint64
handlerChain *HandlerChain
extraDomains map[domain.Domain]int
batchMode bool
mgmtCacheResolver *mgmt.Resolver
@@ -104,16 +100,12 @@ type DefaultServer struct {
statusRecorder *peer.Status
stateManager *statemanager.Manager
probeMu sync.Mutex
probeCancel context.CancelFunc
probeWg sync.WaitGroup
}
type handlerWithStop interface {
dns.Handler
Stop()
ProbeAvailability(context.Context)
ProbeAvailability()
ID() types.HandlerID
}
@@ -242,9 +234,7 @@ func (s *DefaultServer) RegisterHandler(domains domain.List, handler dns.Handler
// convert to zone with simple ref counter
s.extraDomains[toZone(domain)]++
}
if !s.batchMode {
s.applyHostConfig()
}
s.applyHostConfig()
}
func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) {
@@ -273,41 +263,9 @@ func (s *DefaultServer) DeregisterHandler(domains domain.List, priority int) {
delete(s.extraDomains, zone)
}
}
if !s.batchMode {
s.applyHostConfig()
}
}
// BeginBatch starts batch mode for DNS handler registration/deregistration.
// In batch mode, applyHostConfig() is not called after each handler operation,
// allowing multiple handlers to be registered/deregistered efficiently.
// Must be followed by EndBatch() to apply the accumulated changes.
func (s *DefaultServer) BeginBatch() {
s.mux.Lock()
defer s.mux.Unlock()
log.Debugf("DNS batch mode enabled")
s.batchMode = true
}
// EndBatch ends batch mode and applies all accumulated DNS configuration changes.
func (s *DefaultServer) EndBatch() {
s.mux.Lock()
defer s.mux.Unlock()
log.Debugf("DNS batch mode disabled, applying accumulated changes")
s.batchMode = false
s.applyHostConfig()
}
// CancelBatch cancels batch mode without applying accumulated changes.
// This is useful when operations fail partway through and you want to
// discard partial state rather than applying it.
func (s *DefaultServer) CancelBatch() {
s.mux.Lock()
defer s.mux.Unlock()
log.Debugf("DNS batch mode cancelled, discarding accumulated changes")
s.batchMode = false
}
func (s *DefaultServer) deregisterHandler(domains []string, priority int) {
log.Debugf("deregistering handler with priority %d for %v", priority, domains)
@@ -366,13 +324,7 @@ func (s *DefaultServer) DnsIP() netip.Addr {
// Stop stops the server
func (s *DefaultServer) Stop() {
s.probeMu.Lock()
if s.probeCancel != nil {
s.probeCancel()
}
s.ctxCancel()
s.probeMu.Unlock()
s.probeWg.Wait()
s.shutdownWg.Wait()
s.mux.Lock()
@@ -489,8 +441,7 @@ func (s *DefaultServer) SearchDomains() []string {
}
// ProbeAvailability tests each upstream group's servers for availability
// and deactivates the group if no server responds.
// If a previous probe is still running, it will be cancelled before starting a new one.
// and deactivates the group if no server responds
func (s *DefaultServer) ProbeAvailability() {
if val := os.Getenv(envSkipDNSProbe); val != "" {
skipProbe, err := strconv.ParseBool(val)
@@ -503,52 +454,15 @@ func (s *DefaultServer) ProbeAvailability() {
}
}
s.probeMu.Lock()
// don't start probes on a stopped server
if s.ctx.Err() != nil {
s.probeMu.Unlock()
return
}
// cancel any running probe
if s.probeCancel != nil {
s.probeCancel()
s.probeCancel = nil
}
// wait for the previous probe goroutines to finish while holding
// the mutex so no other caller can start a new probe concurrently
s.probeWg.Wait()
// start a new probe
probeCtx, probeCancel := context.WithCancel(s.ctx)
s.probeCancel = probeCancel
s.probeWg.Add(1)
defer s.probeWg.Done()
// Snapshot handlers under s.mux to avoid racing with updateMux/dnsMuxMap writers.
s.mux.Lock()
handlers := make([]handlerWithStop, 0, len(s.dnsMuxMap))
for _, mux := range s.dnsMuxMap {
handlers = append(handlers, mux.handler)
}
s.mux.Unlock()
var wg sync.WaitGroup
for _, handler := range handlers {
for _, mux := range s.dnsMuxMap {
wg.Add(1)
go func(h handlerWithStop) {
go func(mux handlerWithStop) {
defer wg.Done()
h.ProbeAvailability(probeCtx)
}(handler)
mux.ProbeAvailability()
}(mux.handler)
}
s.probeMu.Unlock()
wg.Wait()
probeCancel()
}
func (s *DefaultServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error {
@@ -609,7 +523,6 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
s.currentConfig.RouteAll = false
}
// Always apply host config for management updates, regardless of batch mode
s.applyHostConfig()
s.shutdownWg.Add(1)
@@ -974,7 +887,6 @@ func (s *DefaultServer) upstreamCallbacks(
}
}
// Always apply host config when nameserver goes down, regardless of batch mode
s.applyHostConfig()
go func() {
@@ -1010,7 +922,6 @@ func (s *DefaultServer) upstreamCallbacks(
s.registerHandler([]string{nbdns.RootZone}, handler, priority)
}
// Always apply host config when nameserver reactivates, regardless of batch mode
s.applyHostConfig()
s.updateNSState(nsGroup, nil, true)

View File

@@ -18,12 +18,7 @@ func TestGetServerDns(t *testing.T) {
t.Errorf("invalid dns server instance: %s", err)
}
mockSrvB, ok := srvB.(*MockServer)
if !ok {
t.Errorf("returned server is not a MockServer")
}
if mockSrvB != srv {
if srvB != srv {
t.Errorf("mismatch dns instances")
}
}

View File

@@ -1065,7 +1065,7 @@ type mockHandler struct {
func (m *mockHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {}
func (m *mockHandler) Stop() {}
func (m *mockHandler) ProbeAvailability(context.Context) {}
func (m *mockHandler) ProbeAvailability() {}
func (m *mockHandler) ID() types.HandlerID { return types.HandlerID(m.Id) }
type mockService struct{}

View File

@@ -65,7 +65,6 @@ type upstreamResolverBase struct {
mutex sync.Mutex
reactivatePeriod time.Duration
upstreamTimeout time.Duration
wg sync.WaitGroup
deactivate func(error)
reactivate func()
@@ -116,11 +115,6 @@ func (u *upstreamResolverBase) MatchSubdomains() bool {
func (u *upstreamResolverBase) Stop() {
log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers)
u.cancel()
u.mutex.Lock()
u.wg.Wait()
u.mutex.Unlock()
}
// ServeDNS handles a DNS request
@@ -266,10 +260,16 @@ func formatFailures(failures []upstreamFailure) string {
// ProbeAvailability tests all upstream servers simultaneously and
// disables the resolver if none work
func (u *upstreamResolverBase) ProbeAvailability(ctx context.Context) {
func (u *upstreamResolverBase) ProbeAvailability() {
u.mutex.Lock()
defer u.mutex.Unlock()
select {
case <-u.ctx.Done():
return
default:
}
// avoid probe if upstreams could resolve at least one query
if u.successCount.Load() > 0 {
return
@@ -279,39 +279,31 @@ func (u *upstreamResolverBase) ProbeAvailability(ctx context.Context) {
var mu sync.Mutex
var wg sync.WaitGroup
var errs *multierror.Error
var errors *multierror.Error
for _, upstream := range u.upstreamServers {
upstream := upstream
wg.Add(1)
go func(upstream netip.AddrPort) {
go func() {
defer wg.Done()
err := u.testNameserver(u.ctx, ctx, upstream, 500*time.Millisecond)
err := u.testNameserver(upstream, 500*time.Millisecond)
if err != nil {
mu.Lock()
errs = multierror.Append(errs, err)
mu.Unlock()
errors = multierror.Append(errors, err)
log.Warnf("probing upstream nameserver %s: %s", upstream, err)
return
}
mu.Lock()
defer mu.Unlock()
success = true
mu.Unlock()
}(upstream)
}()
}
wg.Wait()
select {
case <-ctx.Done():
return
case <-u.ctx.Done():
return
default:
}
// didn't find a working upstream server, let's disable and try later
if !success {
u.disable(errs.ErrorOrNil())
u.disable(errors.ErrorOrNil())
if u.statusRecorder == nil {
return
@@ -347,7 +339,7 @@ func (u *upstreamResolverBase) waitUntilResponse() {
}
for _, upstream := range u.upstreamServers {
if err := u.testNameserver(u.ctx, nil, upstream, probeTimeout); err != nil {
if err := u.testNameserver(upstream, probeTimeout); err != nil {
log.Tracef("upstream check for %s: %s", upstream, err)
} else {
// at least one upstream server is available, stop probing
@@ -359,22 +351,16 @@ func (u *upstreamResolverBase) waitUntilResponse() {
return fmt.Errorf("upstream check call error")
}
err := backoff.Retry(operation, backoff.WithContext(exponentialBackOff, u.ctx))
err := backoff.Retry(operation, exponentialBackOff)
if err != nil {
if errors.Is(err, context.Canceled) {
log.Debugf("upstream retry loop exited for upstreams %s", u.upstreamServersString())
} else {
log.Warnf("upstream retry loop exited for upstreams %s: %v", u.upstreamServersString(), err)
}
log.Warn(err)
return
}
log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServersString())
u.successCount.Add(1)
u.reactivate()
u.mutex.Lock()
u.disabled = false
u.mutex.Unlock()
}
// isTimeout returns true if the given error is a network timeout error.
@@ -397,11 +383,7 @@ func (u *upstreamResolverBase) disable(err error) {
u.successCount.Store(0)
u.deactivate(err)
u.disabled = true
u.wg.Add(1)
go func() {
defer u.wg.Done()
u.waitUntilResponse()
}()
go u.waitUntilResponse()
}
func (u *upstreamResolverBase) upstreamServersString() string {
@@ -412,18 +394,13 @@ func (u *upstreamResolverBase) upstreamServersString() string {
return strings.Join(servers, ", ")
}
func (u *upstreamResolverBase) testNameserver(baseCtx context.Context, externalCtx context.Context, server netip.AddrPort, timeout time.Duration) error {
mergedCtx, cancel := context.WithTimeout(baseCtx, timeout)
func (u *upstreamResolverBase) testNameserver(server netip.AddrPort, timeout time.Duration) error {
ctx, cancel := context.WithTimeout(u.ctx, timeout)
defer cancel()
if externalCtx != nil {
stop2 := context.AfterFunc(externalCtx, cancel)
defer stop2()
}
r := new(dns.Msg).SetQuestion(testRecord, dns.TypeSOA)
_, _, err := u.upstreamClient.exchange(mergedCtx, server.String(), r)
_, _, err := u.upstreamClient.exchange(ctx, server.String(), r)
return err
}

View File

@@ -188,7 +188,7 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
reactivated = true
}
resolver.ProbeAvailability(context.TODO())
resolver.ProbeAvailability()
if !failed {
t.Errorf("expected that resolving was deactivated")

View File

@@ -28,15 +28,14 @@ import (
"github.com/netbirdio/netbird/client/firewall"
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/internal/acl"
"github.com/netbirdio/netbird/client/internal/debug"
"github.com/netbirdio/netbird/client/internal/dns"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
"github.com/netbirdio/netbird/client/internal/dnsfwd"
"github.com/netbirdio/netbird/client/internal/expose"
"github.com/netbirdio/netbird/client/internal/ingressgw"
"github.com/netbirdio/netbird/client/internal/netflow"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
@@ -45,21 +44,22 @@ import (
"github.com/netbirdio/netbird/client/internal/peer/guard"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/portforward"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/client/internal/rosenpass"
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/internal/updater"
"github.com/netbirdio/netbird/client/internal/updatemanager"
"github.com/netbirdio/netbird/client/jobexec"
cProto "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/shared/management/domain"
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
"github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/route"
mgm "github.com/netbirdio/netbird/shared/management/client"
"github.com/netbirdio/netbird/shared/management/domain"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
@@ -75,11 +75,13 @@ import (
const (
PeerConnectionTimeoutMax = 45000 // ms
PeerConnectionTimeoutMin = 30000 // ms
connInitLimit = 200
disableAutoUpdate = "disabled"
)
var ErrResetConnection = fmt.Errorf("reset connection")
// EngineConfig is a config for the Engine
type EngineConfig struct {
WgPort int
WgIfaceName string
@@ -141,17 +143,6 @@ type EngineConfig struct {
LogPath string
}
// EngineServices holds the external service dependencies required by the Engine.
type EngineServices struct {
SignalClient signal.Client
MgmClient mgm.Client
RelayManager *relayClient.Manager
StatusRecorder *peer.Status
Checks []*mgmProto.Checks
StateManager *statemanager.Manager
UpdateManager *updater.Manager
}
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
type Engine struct {
// signal is a Signal Service client
@@ -209,19 +200,19 @@ type Engine struct {
// checks are the client-applied posture checks that need to be evaluated on the client
checks []*mgmProto.Checks
relayManager *relayClient.Manager
stateManager *statemanager.Manager
portForwardManager *portforward.Manager
srWatcher *guard.SRWatcher
relayManager *relayClient.Manager
stateManager *statemanager.Manager
srWatcher *guard.SRWatcher
// Sync response persistence (protected by syncRespMux)
syncRespMux sync.RWMutex
persistSyncResponse bool
latestSyncResponse *mgmProto.SyncResponse
connSemaphore *semaphoregroup.SemaphoreGroup
flowManager nftypes.FlowManager
// auto-update
updateManager *updater.Manager
updateManager *updatemanager.Manager
// WireGuard interface monitor
wgIfaceMonitor *WGIfaceMonitor
@@ -233,8 +224,6 @@ type Engine struct {
jobExecutor *jobexec.Executor
jobExecutorWG sync.WaitGroup
exposeManager *expose.Manager
}
// Peer is an instance of the Connection Peer
@@ -251,30 +240,35 @@ type localIpUpdater interface {
func NewEngine(
clientCtx context.Context,
clientCancel context.CancelFunc,
signalClient signal.Client,
mgmClient mgm.Client,
relayManager *relayClient.Manager,
config *EngineConfig,
services EngineServices,
mobileDep MobileDependency,
statusRecorder *peer.Status,
checks []*mgmProto.Checks,
stateManager *statemanager.Manager,
) *Engine {
engine := &Engine{
clientCtx: clientCtx,
clientCancel: clientCancel,
signal: services.SignalClient,
signaler: peer.NewSignaler(services.SignalClient, config.WgPrivateKey),
mgmClient: services.MgmClient,
relayManager: services.RelayManager,
peerStore: peerstore.NewConnStore(),
syncMsgMux: &sync.Mutex{},
config: config,
mobileDep: mobileDep,
STUNs: []*stun.URI{},
TURNs: []*stun.URI{},
networkSerial: 0,
statusRecorder: services.StatusRecorder,
stateManager: services.StateManager,
portForwardManager: portforward.NewManager(),
checks: services.Checks,
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
jobExecutor: jobexec.NewExecutor(), updateManager: services.UpdateManager,
clientCtx: clientCtx,
clientCancel: clientCancel,
signal: signalClient,
signaler: peer.NewSignaler(signalClient, config.WgPrivateKey),
mgmClient: mgmClient,
relayManager: relayManager,
peerStore: peerstore.NewConnStore(),
syncMsgMux: &sync.Mutex{},
config: config,
mobileDep: mobileDep,
STUNs: []*stun.URI{},
TURNs: []*stun.URI{},
networkSerial: 0,
statusRecorder: statusRecorder,
stateManager: stateManager,
checks: checks,
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
jobExecutor: jobexec.NewExecutor(),
}
log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String())
@@ -317,7 +311,7 @@ func (e *Engine) Stop() error {
}
if e.updateManager != nil {
e.updateManager.SetDownloadOnly()
e.updateManager.Stop()
}
log.Info("cleaning up status recorder states")
@@ -425,7 +419,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
e.cancel()
}
e.ctx, e.cancel = context.WithCancel(e.clientCtx)
e.exposeManager = expose.NewManager(e.ctx, e.mgmClient)
wgIface, err := e.newWgIface()
if err != nil {
@@ -517,13 +510,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
// conntrack entries from being created before the rules are in place
e.setupWGProxyNoTrack()
// Start after interface is up since port may have been resolved from 0 or changed if occupied
e.shutdownWg.Add(1)
go func() {
defer e.shutdownWg.Done()
e.portForwardManager.Start(e.ctx, uint16(e.config.WgPort))
}()
// Set the WireGuard interface for rosenpass after interface is up
if e.rpManager != nil {
e.rpManager.SetInterface(e.wgInterface)
@@ -574,6 +560,13 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
return nil
}
func (e *Engine) InitialUpdateHandling(autoUpdateSettings *mgmProto.AutoUpdateSettings) {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
e.handleAutoUpdateVersion(autoUpdateSettings, true)
}
func (e *Engine) createFirewall() error {
if e.config.DisableFirewall {
log.Infof("firewall is disabled")
@@ -801,22 +794,39 @@ func (e *Engine) PopulateNetbirdConfig(netbirdConfig *mgmProto.NetbirdConfig, mg
return nil
}
func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdateSettings) {
if e.updateManager == nil {
return
}
func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdateSettings, initialCheck bool) {
if autoUpdateSettings == nil {
return
}
if autoUpdateSettings.Version == disableAutoUpdate {
log.Infof("auto-update is disabled")
e.updateManager.SetDownloadOnly()
disabled := autoUpdateSettings.Version == disableAutoUpdate
// Stop and cleanup if disabled
if e.updateManager != nil && disabled {
log.Infof("auto-update is disabled, stopping update manager")
e.updateManager.Stop()
e.updateManager = nil
return
}
e.updateManager.SetVersion(autoUpdateSettings.Version, autoUpdateSettings.AlwaysUpdate)
// Skip check unless AlwaysUpdate is enabled or this is the initial check at startup
if !autoUpdateSettings.AlwaysUpdate && !initialCheck {
log.Debugf("skipping auto-update check, AlwaysUpdate is false and this is not the initial check")
return
}
// Start manager if needed
if e.updateManager == nil {
log.Infof("starting auto-update manager")
updateManager, err := updatemanager.NewManager(e.statusRecorder, e.stateManager)
if err != nil {
return
}
e.updateManager = updateManager
e.updateManager.Start(e.ctx)
}
log.Infof("handling auto-update version: %s", autoUpdateSettings.Version)
e.updateManager.SetVersion(autoUpdateSettings.Version)
}
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
@@ -833,7 +843,7 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
}
if update.NetworkMap != nil && update.NetworkMap.PeerConfig != nil {
e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate)
e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate, false)
}
if update.GetNetbirdConfig() != nil {
@@ -1306,7 +1316,8 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
// Test received (upstream) servers for availability right away instead of upon usage.
// If no server of a server group responds this will disable the respective handler and retry later.
go e.dnsServer.ProbeAvailability()
e.dnsServer.ProbeAvailability()
return nil
}
@@ -1523,12 +1534,12 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
}
serviceDependencies := peer.ServiceDependencies{
StatusRecorder: e.statusRecorder,
Signaler: e.signaler,
IFaceDiscover: e.mobileDep.IFaceDiscover,
RelayManager: e.relayManager,
SrWatcher: e.srWatcher,
PortForwardManager: e.portForwardManager,
StatusRecorder: e.statusRecorder,
Signaler: e.signaler,
IFaceDiscover: e.mobileDep.IFaceDiscover,
RelayManager: e.relayManager,
SrWatcher: e.srWatcher,
Semaphore: e.connSemaphore,
}
peerConn, err := peer.NewConn(config, serviceDependencies)
if err != nil {
@@ -1551,10 +1562,8 @@ func (e *Engine) receiveSignalEvents() {
defer e.shutdownWg.Done()
// connect to a stream of messages coming from the signal server
err := e.signal.Receive(e.ctx, func(msg *sProto.Message) error {
start := time.Now()
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
gotLock := time.Since(start)
// Check context INSIDE lock to ensure atomicity with shutdown
if e.ctx.Err() != nil {
@@ -1578,8 +1587,6 @@ func (e *Engine) receiveSignalEvents() {
return err
}
log.Debugf("receiveMSG: took %s to get lock for peer %s with session id %s", gotLock, msg.Key, offerAnswer.SessionID)
if msg.Body.Type == sProto.Body_OFFER {
conn.OnRemoteOffer(*offerAnswer)
} else {
@@ -1685,12 +1692,6 @@ func (e *Engine) close() {
if e.rpManager != nil {
_ = e.rpManager.Close()
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := e.portForwardManager.GracefullyStop(ctx); err != nil {
log.Warnf("failed to gracefully stop port forwarding manager: %s", err)
}
}
func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, error) {
@@ -1819,18 +1820,11 @@ func (e *Engine) GetRouteManager() routemanager.Manager {
return e.routeManager
}
// GetFirewallManager returns the firewall manager.
// GetFirewallManager returns the firewall manager
func (e *Engine) GetFirewallManager() firewallManager.Manager {
return e.firewall
}
// GetExposeManager returns the expose session manager.
func (e *Engine) GetExposeManager() *expose.Manager {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
return e.exposeManager
}
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
iface, err := net.InterfaceByName(ifaceName)
if err != nil {

View File

@@ -251,6 +251,9 @@ func TestEngine_SSH(t *testing.T) {
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
engine := NewEngine(
ctx, cancel,
&signal.MockClient{},
&mgmt.MockClient{},
relayMgr,
&EngineConfig{
WgIfaceName: "utun101",
WgAddr: "100.64.0.1/24",
@@ -260,13 +263,10 @@ func TestEngine_SSH(t *testing.T) {
MTU: iface.DefaultMTU,
SSHKey: sshKey,
},
EngineServices{
SignalClient: &signal.MockClient{},
MgmClient: &mgmt.MockClient{},
RelayManager: relayMgr,
StatusRecorder: peer.NewRecorder("https://mgm"),
},
MobileDependency{},
peer.NewRecorder("https://mgm"),
nil,
nil,
)
engine.dnsServer = &dns.MockServer{
@@ -428,18 +428,13 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
defer cancel()
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
engine := NewEngine(ctx, cancel, &EngineConfig{
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{
WgIfaceName: "utun102",
WgAddr: "100.64.0.1/24",
WgPrivateKey: key,
WgPort: 33100,
MTU: iface.DefaultMTU,
}, EngineServices{
SignalClient: &signal.MockClient{},
MgmClient: &mgmt.MockClient{},
RelayManager: relayMgr,
StatusRecorder: peer.NewRecorder("https://mgm"),
}, MobileDependency{})
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
wgIface := &MockWGIface{
NameFunc: func() string { return "utun102" },
@@ -652,18 +647,13 @@ func TestEngine_Sync(t *testing.T) {
return nil
}
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
engine := NewEngine(ctx, cancel, &EngineConfig{
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{SyncFunc: syncFunc}, relayMgr, &EngineConfig{
WgIfaceName: "utun103",
WgAddr: "100.64.0.1/24",
WgPrivateKey: key,
WgPort: 33100,
MTU: iface.DefaultMTU,
}, EngineServices{
SignalClient: &signal.MockClient{},
MgmClient: &mgmt.MockClient{SyncFunc: syncFunc},
RelayManager: relayMgr,
StatusRecorder: peer.NewRecorder("https://mgm"),
}, MobileDependency{})
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
engine.ctx = ctx
engine.dnsServer = &dns.MockServer{
@@ -822,18 +812,13 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
wgAddr := fmt.Sprintf("100.66.%d.1/24", n)
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
engine := NewEngine(ctx, cancel, &EngineConfig{
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{
WgIfaceName: wgIfaceName,
WgAddr: wgAddr,
WgPrivateKey: key,
WgPort: 33100,
MTU: iface.DefaultMTU,
}, EngineServices{
SignalClient: &signal.MockClient{},
MgmClient: &mgmt.MockClient{},
RelayManager: relayMgr,
StatusRecorder: peer.NewRecorder("https://mgm"),
}, MobileDependency{})
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
engine.ctx = ctx
newNet, err := stdnet.NewNet(context.Background(), nil)
if err != nil {
@@ -1029,18 +1014,13 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
wgAddr := fmt.Sprintf("100.66.%d.1/24", n)
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
engine := NewEngine(ctx, cancel, &EngineConfig{
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{
WgIfaceName: wgIfaceName,
WgAddr: wgAddr,
WgPrivateKey: key,
WgPort: 33100,
MTU: iface.DefaultMTU,
}, EngineServices{
SignalClient: &signal.MockClient{},
MgmClient: &mgmt.MockClient{},
RelayManager: relayMgr,
StatusRecorder: peer.NewRecorder("https://mgm"),
}, MobileDependency{})
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
engine.ctx = ctx
newNet, err := stdnet.NewNet(context.Background(), nil)
@@ -1566,12 +1546,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
}
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
e, err := NewEngine(ctx, cancel, conf, EngineServices{
SignalClient: signalClient,
MgmClient: mgmtClient,
RelayManager: relayMgr,
StatusRecorder: peer.NewRecorder("https://mgm"),
}, MobileDependency{}), nil
e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil), nil
e.ctx = ctx
return e, err
}

View File

@@ -1,97 +0,0 @@
package expose
import (
"context"
"time"
mgm "github.com/netbirdio/netbird/shared/management/client"
log "github.com/sirupsen/logrus"
)
const renewTimeout = 10 * time.Second
// Response holds the response from exposing a service.
type Response struct {
ServiceName string
ServiceURL string
Domain string
PortAutoAssigned bool
}
type Request struct {
NamePrefix string
Domain string
Port uint16
Protocol int
Pin string
Password string
UserGroups []string
ListenPort uint16
}
type ManagementClient interface {
CreateExpose(ctx context.Context, req mgm.ExposeRequest) (*mgm.ExposeResponse, error)
RenewExpose(ctx context.Context, domain string) error
StopExpose(ctx context.Context, domain string) error
}
// Manager handles expose session lifecycle via the management client.
type Manager struct {
mgmClient ManagementClient
ctx context.Context
}
// NewManager creates a new expose Manager using the given management client.
func NewManager(ctx context.Context, mgmClient ManagementClient) *Manager {
return &Manager{mgmClient: mgmClient, ctx: ctx}
}
// Expose creates a new expose session via the management server.
func (m *Manager) Expose(ctx context.Context, req Request) (*Response, error) {
log.Infof("exposing service on port %d", req.Port)
resp, err := m.mgmClient.CreateExpose(ctx, toClientExposeRequest(req))
if err != nil {
return nil, err
}
log.Infof("expose session created for %s", resp.Domain)
return fromClientExposeResponse(resp), nil
}
func (m *Manager) KeepAlive(ctx context.Context, domain string) error {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
defer m.stop(domain)
for {
select {
case <-ctx.Done():
log.Infof("context canceled, stopping keep alive for %s", domain)
return nil
case <-ticker.C:
if err := m.renew(ctx, domain); err != nil {
log.Errorf("renewing expose session for %s: %v", domain, err)
return err
}
}
}
}
// renew extends the TTL of an active expose session.
func (m *Manager) renew(ctx context.Context, domain string) error {
renewCtx, cancel := context.WithTimeout(ctx, renewTimeout)
defer cancel()
return m.mgmClient.RenewExpose(renewCtx, domain)
}
// stop terminates an active expose session.
func (m *Manager) stop(domain string) {
stopCtx, cancel := context.WithTimeout(m.ctx, renewTimeout)
defer cancel()
err := m.mgmClient.StopExpose(stopCtx, domain)
if err != nil {
log.Warnf("Failed stopping expose session for %s: %v", domain, err)
}
}

View File

@@ -1,95 +0,0 @@
package expose
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
daemonProto "github.com/netbirdio/netbird/client/proto"
mgm "github.com/netbirdio/netbird/shared/management/client"
)
func TestManager_Expose_Success(t *testing.T) {
mock := &mgm.MockClient{
CreateExposeFunc: func(ctx context.Context, req mgm.ExposeRequest) (*mgm.ExposeResponse, error) {
return &mgm.ExposeResponse{
ServiceName: "my-service",
ServiceURL: "https://my-service.example.com",
Domain: "my-service.example.com",
}, nil
},
}
m := NewManager(context.Background(), mock)
result, err := m.Expose(context.Background(), Request{Port: 8080})
require.NoError(t, err)
assert.Equal(t, "my-service", result.ServiceName, "service name should match")
assert.Equal(t, "https://my-service.example.com", result.ServiceURL, "service URL should match")
assert.Equal(t, "my-service.example.com", result.Domain, "domain should match")
}
func TestManager_Expose_Error(t *testing.T) {
mock := &mgm.MockClient{
CreateExposeFunc: func(ctx context.Context, req mgm.ExposeRequest) (*mgm.ExposeResponse, error) {
return nil, errors.New("permission denied")
},
}
m := NewManager(context.Background(), mock)
_, err := m.Expose(context.Background(), Request{Port: 8080})
require.Error(t, err)
assert.Contains(t, err.Error(), "permission denied", "error should propagate")
}
func TestManager_Renew_Success(t *testing.T) {
mock := &mgm.MockClient{
RenewExposeFunc: func(ctx context.Context, domain string) error {
assert.Equal(t, "my-service.example.com", domain, "domain should be passed through")
return nil
},
}
m := NewManager(context.Background(), mock)
err := m.renew(context.Background(), "my-service.example.com")
require.NoError(t, err)
}
func TestManager_Renew_Timeout(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
mock := &mgm.MockClient{
RenewExposeFunc: func(ctx context.Context, domain string) error {
return ctx.Err()
},
}
m := NewManager(ctx, mock)
err := m.renew(ctx, "my-service.example.com")
require.Error(t, err)
}
func TestNewRequest(t *testing.T) {
req := &daemonProto.ExposeServiceRequest{
Port: 8080,
Protocol: daemonProto.ExposeProtocol_EXPOSE_HTTPS,
Pin: "123456",
Password: "secret",
UserGroups: []string{"group1", "group2"},
Domain: "custom.example.com",
NamePrefix: "my-prefix",
}
exposeReq := NewRequest(req)
assert.Equal(t, uint16(8080), exposeReq.Port, "port should match")
assert.Equal(t, int(daemonProto.ExposeProtocol_EXPOSE_HTTPS), exposeReq.Protocol, "protocol should match")
assert.Equal(t, "123456", exposeReq.Pin, "pin should match")
assert.Equal(t, "secret", exposeReq.Password, "password should match")
assert.Equal(t, []string{"group1", "group2"}, exposeReq.UserGroups, "user groups should match")
assert.Equal(t, "custom.example.com", exposeReq.Domain, "domain should match")
assert.Equal(t, "my-prefix", exposeReq.NamePrefix, "name prefix should match")
}

View File

@@ -1,42 +0,0 @@
package expose
import (
daemonProto "github.com/netbirdio/netbird/client/proto"
mgm "github.com/netbirdio/netbird/shared/management/client"
)
// NewRequest converts a daemon ExposeServiceRequest to a management ExposeServiceRequest.
func NewRequest(req *daemonProto.ExposeServiceRequest) *Request {
return &Request{
Port: uint16(req.Port),
Protocol: int(req.Protocol),
Pin: req.Pin,
Password: req.Password,
UserGroups: req.UserGroups,
Domain: req.Domain,
NamePrefix: req.NamePrefix,
ListenPort: uint16(req.ListenPort),
}
}
func toClientExposeRequest(req Request) mgm.ExposeRequest {
return mgm.ExposeRequest{
NamePrefix: req.NamePrefix,
Domain: req.Domain,
Port: req.Port,
Protocol: req.Protocol,
Pin: req.Pin,
Password: req.Password,
UserGroups: req.UserGroups,
ListenPort: req.ListenPort,
}
}
func fromClientExposeResponse(response *mgm.ExposeResponse) *Response {
return &Response{
ServiceName: response.ServiceName,
Domain: response.Domain,
ServiceURL: response.ServiceURL,
PortAutoAssigned: response.PortAutoAssigned,
}
}

View File

@@ -22,56 +22,51 @@ func prepareFd() (int, error) {
func routeCheck(ctx context.Context, fd int, nexthopv4, nexthopv6 systemops.Nexthop) error {
for {
// Wait until fd is readable or context is cancelled, to avoid a busy-loop
// when the routing socket returns EAGAIN (e.g. immediately after wakeup).
if err := waitReadable(ctx, fd); err != nil {
return err
}
buf := make([]byte, 2048)
n, err := unix.Read(fd, buf)
if err != nil {
if errors.Is(err, unix.EAGAIN) || errors.Is(err, unix.EINTR) {
continue
}
if errors.Is(err, unix.EBADF) || errors.Is(err, unix.EINVAL) {
return fmt.Errorf("routing socket closed: %w", err)
}
return fmt.Errorf("read routing socket: %w", err)
}
if n < unix.SizeofRtMsghdr {
log.Debugf("Network monitor: read from routing socket returned less than expected: %d bytes", n)
continue
}
msg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0]))
switch msg.Type {
// handle route changes
case unix.RTM_ADD, syscall.RTM_DELETE:
route, err := parseRouteMessage(buf[:n])
select {
case <-ctx.Done():
return ctx.Err()
default:
buf := make([]byte, 2048)
n, err := unix.Read(fd, buf)
if err != nil {
log.Debugf("Network monitor: error parsing routing message: %v", err)
if !errors.Is(err, unix.EBADF) && !errors.Is(err, unix.EINVAL) {
log.Warnf("Network monitor: failed to read from routing socket: %v", err)
}
continue
}
if n < unix.SizeofRtMsghdr {
log.Debugf("Network monitor: read from routing socket returned less than expected: %d bytes", n)
continue
}
if route.Dst.Bits() != 0 {
continue
}
msg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0]))
intf := "<nil>"
if route.Interface != nil {
intf = route.Interface.Name
}
switch msg.Type {
case unix.RTM_ADD:
log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf)
return nil
case unix.RTM_DELETE:
if nexthopv4.Intf != nil && route.Gw.Compare(nexthopv4.IP) == 0 || nexthopv6.Intf != nil && route.Gw.Compare(nexthopv6.IP) == 0 {
log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf)
// handle route changes
case unix.RTM_ADD, syscall.RTM_DELETE:
route, err := parseRouteMessage(buf[:n])
if err != nil {
log.Debugf("Network monitor: error parsing routing message: %v", err)
continue
}
if route.Dst.Bits() != 0 {
continue
}
intf := "<nil>"
if route.Interface != nil {
intf = route.Interface.Name
}
switch msg.Type {
case unix.RTM_ADD:
log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf)
return nil
case unix.RTM_DELETE:
if nexthopv4.Intf != nil && route.Gw.Compare(nexthopv4.IP) == 0 || nexthopv6.Intf != nil && route.Gw.Compare(nexthopv6.IP) == 0 {
log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf)
return nil
}
}
}
}
@@ -95,33 +90,3 @@ func parseRouteMessage(buf []byte) (*systemops.Route, error) {
return systemops.MsgToRoute(msg)
}
// waitReadable blocks until fd has data to read, or ctx is cancelled.
func waitReadable(ctx context.Context, fd int) error {
var fdset unix.FdSet
if fd < 0 || fd/unix.NFDBITS >= len(fdset.Bits) {
return fmt.Errorf("fd %d out of range for FdSet", fd)
}
for {
if err := ctx.Err(); err != nil {
return err
}
fdset = unix.FdSet{}
fdset.Set(fd)
// Use a 1-second timeout so we can re-check ctx periodically.
tv := unix.Timeval{Sec: 1}
n, err := unix.Select(fd+1, &fdset, nil, nil, &tv)
if err != nil {
if errors.Is(err, unix.EINTR) {
continue
}
return fmt.Errorf("select on routing socket: %w", err)
}
if n > 0 {
return nil
}
// timeout — loop back and re-check ctx
}
}

View File

@@ -3,6 +3,7 @@ package peer
import (
"context"
"fmt"
"math/rand"
"net"
"net/netip"
"runtime"
@@ -21,10 +22,10 @@ import (
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/peer/id"
"github.com/netbirdio/netbird/client/internal/peer/worker"
"github.com/netbirdio/netbird/client/internal/portforward"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/route"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
)
type ServiceDependencies struct {
@@ -33,8 +34,8 @@ type ServiceDependencies struct {
IFaceDiscover stdnet.ExternalIFaceDiscover
RelayManager *relayClient.Manager
SrWatcher *guard.SRWatcher
Semaphore *semaphoregroup.SemaphoreGroup
PeerConnDispatcher *dispatcher.ConnectionDispatcher
PortForwardManager *portforward.Manager
}
type WgConfig struct {
@@ -76,17 +77,16 @@ type ConnConfig struct {
}
type Conn struct {
Log *log.Entry
mu sync.Mutex
ctx context.Context
ctxCancel context.CancelFunc
config ConnConfig
statusRecorder *Status
signaler *Signaler
iFaceDiscover stdnet.ExternalIFaceDiscover
relayManager *relayClient.Manager
srWatcher *guard.SRWatcher
portForwardManager *portforward.Manager
Log *log.Entry
mu sync.Mutex
ctx context.Context
ctxCancel context.CancelFunc
config ConnConfig
statusRecorder *Status
signaler *Signaler
iFaceDiscover stdnet.ExternalIFaceDiscover
relayManager *relayClient.Manager
srWatcher *guard.SRWatcher
onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)
onDisconnected func(remotePeer string)
@@ -111,8 +111,9 @@ type Conn struct {
wgProxyRelay wgproxy.Proxy
handshaker *Handshaker
guard *guard.Guard
wg sync.WaitGroup
guard *guard.Guard
semaphore *semaphoregroup.SemaphoreGroup
wg sync.WaitGroup
// debug purpose
dumpState *stateDump
@@ -131,19 +132,19 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
dumpState := newStateDump(config.Key, connLog, services.StatusRecorder)
var conn = &Conn{
Log: connLog,
config: config,
statusRecorder: services.StatusRecorder,
signaler: services.Signaler,
iFaceDiscover: services.IFaceDiscover,
relayManager: services.RelayManager,
srWatcher: services.SrWatcher,
portForwardManager: services.PortForwardManager,
statusRelay: worker.NewAtomicStatus(),
statusICE: worker.NewAtomicStatus(),
dumpState: dumpState,
endpointUpdater: NewEndpointUpdater(connLog, config.WgConfig, isController(config)),
wgWatcher: NewWGWatcher(connLog, config.WgConfig.WgInterface, config.Key, dumpState),
Log: connLog,
config: config,
statusRecorder: services.StatusRecorder,
signaler: services.Signaler,
iFaceDiscover: services.IFaceDiscover,
relayManager: services.RelayManager,
srWatcher: services.SrWatcher,
semaphore: services.Semaphore,
statusRelay: worker.NewAtomicStatus(),
statusICE: worker.NewAtomicStatus(),
dumpState: dumpState,
endpointUpdater: NewEndpointUpdater(connLog, config.WgConfig, isController(config)),
wgWatcher: NewWGWatcher(connLog, config.WgConfig.WgInterface, config.Key, dumpState),
}
return conn, nil
@@ -153,10 +154,15 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
// It will try to establish a connection using ICE and in parallel with relay. The higher priority connection type will
// be used.
func (conn *Conn) Open(engineCtx context.Context) error {
if err := conn.semaphore.Add(engineCtx); err != nil {
return err
}
conn.mu.Lock()
defer conn.mu.Unlock()
if conn.opened {
conn.semaphore.Done()
return nil
}
@@ -167,6 +173,7 @@ func (conn *Conn) Open(engineCtx context.Context) error {
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
if err != nil {
conn.semaphore.Done()
return err
}
conn.workerICE = workerICE
@@ -200,6 +207,10 @@ func (conn *Conn) Open(engineCtx context.Context) error {
conn.wg.Add(1)
go func() {
defer conn.wg.Done()
conn.waitInitialRandomSleepTime(conn.ctx)
conn.semaphore.Done()
conn.guard.Start(conn.ctx, conn.onGuardEvent)
}()
conn.opened = true
@@ -399,7 +410,7 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr)
}
func (conn *Conn) onICEStateDisconnected(sessionChanged bool) {
func (conn *Conn) onICEStateDisconnected() {
conn.mu.Lock()
defer conn.mu.Unlock()
@@ -419,18 +430,14 @@ func (conn *Conn) onICEStateDisconnected(sessionChanged bool) {
if conn.isReadyToUpgrade() {
conn.Log.Infof("ICE disconnected, set Relay to active connection")
conn.dumpState.SwitchToRelay()
if sessionChanged {
conn.resetEndpoint()
}
// todo consider to move after the ConfigureWGEndpoint
conn.wgProxyRelay.Work()
presharedKey := conn.presharedKey(conn.rosenpassRemoteKey)
if err := conn.endpointUpdater.SwitchWGEndpoint(conn.wgProxyRelay.EndpointAddr(), presharedKey); err != nil {
if err := conn.endpointUpdater.ConfigureWGEndpoint(conn.wgProxyRelay.EndpointAddr(), presharedKey); err != nil {
conn.Log.Errorf("failed to switch to relay conn: %v", err)
}
conn.wgProxyRelay.Work()
conn.currentConnPriority = conntype.Relay
} else {
conn.Log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", conntype.None.String())
@@ -492,22 +499,20 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
return
}
controller := isController(conn.config)
wgProxy.Work()
presharedKey := conn.presharedKey(rci.rosenpassPubKey)
if controller {
wgProxy.Work()
}
conn.enableWgWatcherIfNeeded()
if err := conn.endpointUpdater.ConfigureWGEndpoint(wgProxy.EndpointAddr(), conn.presharedKey(rci.rosenpassPubKey)); err != nil {
if err := conn.endpointUpdater.ConfigureWGEndpoint(wgProxy.EndpointAddr(), presharedKey); err != nil {
if err := wgProxy.CloseConn(); err != nil {
conn.Log.Warnf("Failed to close relay connection: %v", err)
}
conn.Log.Errorf("Failed to update WireGuard peer configuration: %v", err)
return
}
if !controller {
wgProxy.Work()
}
wgConfigWorkaround()
conn.rosenpassRemoteKey = rci.rosenpassPubKey
conn.currentConnPriority = conntype.Relay
conn.statusRelay.SetConnected()
@@ -659,6 +664,19 @@ func (conn *Conn) doOnConnected(remoteRosenpassPubKey []byte, remoteRosenpassAdd
}
}
func (conn *Conn) waitInitialRandomSleepTime(ctx context.Context) {
maxWait := 300
duration := time.Duration(rand.Intn(maxWait)) * time.Millisecond
timeout := time.NewTimer(duration)
defer timeout.Stop()
select {
case <-ctx.Done():
case <-timeout.C:
}
}
func (conn *Conn) isRelayed() bool {
switch conn.currentConnPriority {
case conntype.Relay, conntype.ICETurn:
@@ -739,17 +757,6 @@ func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
return wgProxy, nil
}
func (conn *Conn) resetEndpoint() {
if !isController(conn.config) {
return
}
conn.Log.Infof("reset wg endpoint")
conn.wgWatcher.Reset()
if err := conn.endpointUpdater.RemoveEndpointAddress(); err != nil {
conn.Log.Warnf("failed to remove endpoint address before update: %v", err)
}
}
func (conn *Conn) isReadyToUpgrade() bool {
return conn.wgProxyRelay != nil && conn.currentConnPriority != conntype.Relay
}
@@ -855,3 +862,9 @@ func isController(config ConnConfig) bool {
func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool {
return remoteRosenpassPubKey != nil
}
// wgConfigWorkaround is a workaround for the issue with WireGuard configuration update
// When update a peer configuration in near to each other time, the second update can be ignored by WireGuard
func wgConfigWorkaround() {
time.Sleep(100 * time.Millisecond)
}

View File

@@ -15,6 +15,7 @@ import (
"github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/util"
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
)
var testDispatcher = dispatcher.NewConnectionDispatcher()
@@ -52,6 +53,7 @@ func TestConn_GetKey(t *testing.T) {
sd := ServiceDependencies{
SrWatcher: swWatcher,
Semaphore: semaphoregroup.NewSemaphoreGroup(1),
PeerConnDispatcher: testDispatcher,
}
conn, err := NewConn(connConf, sd)
@@ -69,6 +71,7 @@ func TestConn_OnRemoteOffer(t *testing.T) {
sd := ServiceDependencies{
StatusRecorder: NewRecorder("https://mgm"),
SrWatcher: swWatcher,
Semaphore: semaphoregroup.NewSemaphoreGroup(1),
PeerConnDispatcher: testDispatcher,
}
conn, err := NewConn(connConf, sd)
@@ -107,6 +110,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
sd := ServiceDependencies{
StatusRecorder: NewRecorder("https://mgm"),
SrWatcher: swWatcher,
Semaphore: semaphoregroup.NewSemaphoreGroup(1),
PeerConnDispatcher: testDispatcher,
}
conn, err := NewConn(connConf, sd)

View File

@@ -34,27 +34,28 @@ func NewEndpointUpdater(log *logrus.Entry, wgConfig WgConfig, initiator bool) *E
}
}
// ConfigureWGEndpoint sets up the WireGuard endpoint configuration.
// The initiator immediately configures the endpoint, while the non-initiator
// waits for a fallback period before configuring to avoid handshake congestion.
func (e *EndpointUpdater) ConfigureWGEndpoint(addr *net.UDPAddr, presharedKey *wgtypes.Key) error {
e.mu.Lock()
defer e.mu.Unlock()
if e.initiator {
e.log.Debugf("configure up WireGuard as initiator")
return e.configureAsInitiator(addr, presharedKey)
e.log.Debugf("configure up WireGuard as initiatr")
return e.updateWireGuardPeer(addr, presharedKey)
}
e.log.Debugf("configure up WireGuard as responder")
return e.configureAsResponder(addr, presharedKey)
}
func (e *EndpointUpdater) SwitchWGEndpoint(addr *net.UDPAddr, presharedKey *wgtypes.Key) error {
e.mu.Lock()
defer e.mu.Unlock()
// prevent to run new update while cancel the previous update
e.waitForCloseTheDelayedUpdate()
return e.updateWireGuardPeer(addr, presharedKey)
var ctx context.Context
ctx, e.cancelFunc = context.WithCancel(context.Background())
e.updateWg.Add(1)
go e.scheduleDelayedUpdate(ctx, addr, presharedKey)
e.log.Debugf("configure up WireGuard and wait for handshake")
return e.updateWireGuardPeer(nil, presharedKey)
}
func (e *EndpointUpdater) RemoveWgPeer() error {
@@ -65,38 +66,6 @@ func (e *EndpointUpdater) RemoveWgPeer() error {
return e.wgConfig.WgInterface.RemovePeer(e.wgConfig.RemoteKey)
}
func (e *EndpointUpdater) RemoveEndpointAddress() error {
e.mu.Lock()
defer e.mu.Unlock()
e.waitForCloseTheDelayedUpdate()
return e.wgConfig.WgInterface.RemoveEndpointAddress(e.wgConfig.RemoteKey)
}
func (e *EndpointUpdater) configureAsInitiator(addr *net.UDPAddr, presharedKey *wgtypes.Key) error {
if err := e.updateWireGuardPeer(addr, presharedKey); err != nil {
return err
}
return nil
}
func (e *EndpointUpdater) configureAsResponder(addr *net.UDPAddr, presharedKey *wgtypes.Key) error {
// prevent to run new update while cancel the previous update
e.waitForCloseTheDelayedUpdate()
e.log.Debugf("configure up WireGuard and wait for handshake")
var ctx context.Context
ctx, e.cancelFunc = context.WithCancel(context.Background())
e.updateWg.Add(1)
go e.scheduleDelayedUpdate(ctx, addr, presharedKey)
if err := e.updateWireGuardPeer(nil, presharedKey); err != nil {
e.waitForCloseTheDelayedUpdate()
return err
}
return nil
}
func (e *EndpointUpdater) waitForCloseTheDelayedUpdate() {
if e.cancelFunc == nil {
return
@@ -132,9 +101,3 @@ func (e *EndpointUpdater) updateWireGuardPeer(endpoint *net.UDPAddr, presharedKe
presharedKey,
)
}
// wgConfigWorkaround is a workaround for the issue with WireGuard configuration update
// When update a peer configuration in near to each other time, the second update can be ignored by WireGuard
func wgConfigWorkaround() {
time.Sleep(100 * time.Millisecond)
}

View File

@@ -32,8 +32,6 @@ type WGWatcher struct {
enabled bool
muEnabled sync.RWMutex
resetCh chan struct{}
}
func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string, stateDump *stateDump) *WGWatcher {
@@ -42,7 +40,6 @@ func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey strin
wgIfaceStater: wgIfaceStater,
peerKey: peerKey,
stateDump: stateDump,
resetCh: make(chan struct{}, 1),
}
}
@@ -79,15 +76,6 @@ func (w *WGWatcher) IsEnabled() bool {
return w.enabled
}
// Reset signals the watcher that the WireGuard peer has been reset and a new
// handshake is expected. This restarts the handshake timeout from scratch.
func (w *WGWatcher) Reset() {
select {
case w.resetCh <- struct{}{}:
default:
}
}
// wgStateCheck help to check the state of the WireGuard handshake and relay connection
func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, onDisconnectedFn func(), enabledTime time.Time, initialHandshake time.Time) {
w.log.Infof("WireGuard watcher started")
@@ -117,12 +105,6 @@ func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, onDisconnectedFn
w.stateDump.WGcheckSuccess()
w.log.Debugf("WireGuard watcher reset timer: %v", resetTime)
case <-w.resetCh:
w.log.Infof("WireGuard watcher received peer reset, restarting handshake timeout")
lastHandshake = time.Time{}
enabledTime = time.Now()
timer.Stop()
timer.Reset(wgHandshakeOvertime)
case <-ctx.Done():
w.log.Infof("WireGuard watcher stopped")
return

View File

@@ -16,7 +16,6 @@ import (
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/internal/peer/conntype"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/portforward"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/route"
)
@@ -53,18 +52,14 @@ type WorkerICE struct {
// increase by one when disconnecting the agent
// with it the remote peer can discard the already deprecated offer/answer
// Without it the remote peer may recreate a workable ICE connection
sessionID ICESessionID
remoteSessionChanged bool
muxAgent sync.Mutex
sessionID ICESessionID
muxAgent sync.Mutex
localUfrag string
localPwd string
// we record the last known state of the ICE agent to avoid duplicate on disconnected events
lastKnownState ice.ConnectionState
// portForwardAttempted tracks if we've already tried port forwarding this session
portForwardAttempted bool
}
func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, conn *Conn, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool) (*WorkerICE, error) {
@@ -111,7 +106,6 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
return
}
w.log.Debugf("agent already exists, recreate the connection")
w.remoteSessionChanged = true
w.agentDialerCancel()
if w.agent != nil {
if err := w.agent.Close(); err != nil {
@@ -218,8 +212,6 @@ func (w *WorkerICE) Close() {
}
func (w *WorkerICE) reCreateAgent(dialerCancel context.CancelFunc, candidates []ice.CandidateType) (*icemaker.ThreadSafeAgent, error) {
w.portForwardAttempted = false
agent, err := icemaker.NewAgent(w.ctx, w.iFaceDiscover, w.config.ICEConfig, candidates, w.localUfrag, w.localPwd)
if err != nil {
return nil, fmt.Errorf("create agent: %w", err)
@@ -314,17 +306,13 @@ func (w *WorkerICE) connect(ctx context.Context, agent *icemaker.ThreadSafeAgent
w.conn.onICEConnectionIsReady(selectedPriority(pair), ci)
}
func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.CancelFunc) bool {
func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.CancelFunc) {
cancel()
if err := agent.Close(); err != nil {
w.log.Warnf("failed to close ICE agent: %s", err)
}
w.muxAgent.Lock()
defer w.muxAgent.Unlock()
sessionChanged := w.remoteSessionChanged
w.remoteSessionChanged = false
if w.agent == agent {
// consider to remove from here and move to the OnNewOffer
@@ -337,7 +325,7 @@ func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.C
w.agentConnecting = false
w.remoteSessionID = ""
}
return sessionChanged
w.muxAgent.Unlock()
}
func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) {
@@ -376,93 +364,6 @@ func (w *WorkerICE) onICECandidate(candidate ice.Candidate) {
w.log.Errorf("failed signaling candidate to the remote peer %s %s", w.config.Key, err)
}
}()
if candidate.Type() == ice.CandidateTypeServerReflexive {
w.injectPortForwardedCandidate(candidate)
}
}
// injectPortForwardedCandidate signals an additional candidate using the pre-created port mapping.
func (w *WorkerICE) injectPortForwardedCandidate(srflxCandidate ice.Candidate) {
pfManager := w.conn.portForwardManager
if pfManager == nil {
return
}
mapping := pfManager.GetMapping()
if mapping == nil {
return
}
w.muxAgent.Lock()
if w.portForwardAttempted {
w.muxAgent.Unlock()
return
}
w.portForwardAttempted = true
w.muxAgent.Unlock()
forwardedCandidate, err := w.createForwardedCandidate(srflxCandidate, mapping)
if err != nil {
w.log.Warnf("create forwarded candidate: %v", err)
return
}
w.log.Debugf("injecting port-forwarded candidate: %s (mapping: %d -> %d via %s, priority: %d)",
forwardedCandidate.String(), mapping.InternalPort, mapping.ExternalPort, mapping.NATType, forwardedCandidate.Priority())
go func() {
if err := w.signaler.SignalICECandidate(forwardedCandidate, w.config.Key); err != nil {
w.log.Errorf("signal port-forwarded candidate: %v", err)
}
}()
}
// createForwardedCandidate creates a new server reflexive candidate with the forwarded port.
// It uses the NAT gateway's external IP with the forwarded port.
func (w *WorkerICE) createForwardedCandidate(srflxCandidate ice.Candidate, mapping *portforward.Mapping) (ice.Candidate, error) {
var externalIP string
if mapping.ExternalIP != nil && !mapping.ExternalIP.IsUnspecified() {
externalIP = mapping.ExternalIP.String()
} else {
// Fallback to STUN-discovered address if NAT didn't provide external IP
externalIP = srflxCandidate.Address()
}
// Per RFC 8445, the related address for srflx is the base (host candidate address).
// If the original srflx has unspecified related address, use its own address as base.
relAddr := srflxCandidate.RelatedAddress().Address
if relAddr == "" || relAddr == "0.0.0.0" || relAddr == "::" {
relAddr = srflxCandidate.Address()
}
// Arbitrary +1000 boost on top of RFC 8445 priority to favor port-forwarded candidates
// over regular srflx during ICE connectivity checks.
priority := srflxCandidate.Priority() + 1000
candidate, err := ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{
Network: srflxCandidate.NetworkType().String(),
Address: externalIP,
Port: int(mapping.ExternalPort),
Component: srflxCandidate.Component(),
Priority: priority,
RelAddr: relAddr,
RelPort: int(mapping.InternalPort),
})
if err != nil {
return nil, fmt.Errorf("create candidate: %w", err)
}
for _, e := range srflxCandidate.Extensions() {
if e.Key == ice.ExtensionKeyCandidateID {
e.Value = srflxCandidate.ID()
}
if err := candidate.AddExtension(e); err != nil {
return nil, fmt.Errorf("add extension: %w", err)
}
}
return candidate, nil
}
func (w *WorkerICE) onICESelectedCandidatePair(agent *icemaker.ThreadSafeAgent, c1, c2 ice.Candidate) {
@@ -504,10 +405,10 @@ func (w *WorkerICE) logSuccessfulPaths(agent *icemaker.ThreadSafeAgent) {
if !lok || !rok {
continue
}
w.log.Debugf("successful ICE path %s: [%s %s %s:%d] <-> [%s %s %s:%d] rtt=%.3fms",
w.log.Debugf("successful ICE path %s: [%s %s %s] <-> [%s %s %s] rtt=%.3fms",
sessionID,
local.NetworkType(), local.Type(), local.Address(), local.Port(),
remote.NetworkType(), remote.Type(), remote.Address(), remote.Port(),
local.NetworkType(), local.Type(), local.Address(),
remote.NetworkType(), remote.Type(), remote.Address(),
stat.CurrentRoundTripTime*1000)
}
}
@@ -525,11 +426,11 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia
// ice.ConnectionStateClosed happens when we recreate the agent. For the P2P to TURN switch important to
// notify the conn.onICEStateDisconnected changes to update the current used priority
sessionChanged := w.closeAgent(agent, dialerCancel)
w.closeAgent(agent, dialerCancel)
if w.lastKnownState == ice.ConnectionStateConnected {
w.lastKnownState = ice.ConnectionStateDisconnected
w.conn.onICEStateDisconnected(sessionChanged)
w.conn.onICEStateDisconnected()
}
default:
return

View File

@@ -1,35 +0,0 @@
package portforward
import (
"os"
"strconv"
log "github.com/sirupsen/logrus"
)
const (
envDisableNATMapper = "NB_DISABLE_NAT_MAPPER"
envDisablePCPHealthCheck = "NB_DISABLE_PCP_HEALTH_CHECK"
)
func isDisabledByEnv() bool {
return parseBoolEnv(envDisableNATMapper)
}
func isHealthCheckDisabled() bool {
return parseBoolEnv(envDisablePCPHealthCheck)
}
func parseBoolEnv(key string) bool {
val := os.Getenv(key)
if val == "" {
return false
}
disabled, err := strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", key, err)
return false
}
return disabled
}

View File

@@ -1,298 +0,0 @@
//go:build !js
package portforward
import (
"context"
"fmt"
"net"
"sync"
"time"
"github.com/libp2p/go-nat"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/portforward/pcp"
)
const (
defaultMappingTTL = 2 * time.Hour
renewalInterval = defaultMappingTTL / 2
healthCheckInterval = 1 * time.Minute
discoveryTimeout = 10 * time.Second
mappingDescription = "NetBird"
)
type Mapping struct {
Protocol string
InternalPort uint16
ExternalPort uint16
ExternalIP net.IP
NATType string
}
type Manager struct {
cancel context.CancelFunc
mapping *Mapping
mappingLock sync.Mutex
wgPort uint16
done chan struct{}
stopCtx chan context.Context
// protect exported functions
mu sync.Mutex
}
func NewManager() *Manager {
return &Manager{
stopCtx: make(chan context.Context, 1),
}
}
func (m *Manager) Start(ctx context.Context, wgPort uint16) {
m.mu.Lock()
if m.cancel != nil {
m.mu.Unlock()
return
}
if isDisabledByEnv() {
log.Infof("NAT port mapper disabled via %s", envDisableNATMapper)
m.mu.Unlock()
return
}
if wgPort == 0 {
log.Warnf("invalid WireGuard port 0; NAT mapping disabled")
m.mu.Unlock()
return
}
m.wgPort = wgPort
m.done = make(chan struct{})
defer close(m.done)
ctx, m.cancel = context.WithCancel(ctx)
m.mu.Unlock()
gateway, mapping, err := m.setup(ctx)
if err != nil {
log.Errorf("failed to setup NAT port mapping: %v", err)
return
}
m.mappingLock.Lock()
m.mapping = mapping
m.mappingLock.Unlock()
m.renewLoop(ctx, gateway)
select {
case cleanupCtx := <-m.stopCtx:
// block the Start while cleaned up gracefully
m.cleanup(cleanupCtx, gateway)
default:
// return Start immediately and cleanup in background
cleanupCtx, cleanupCancel := context.WithTimeout(context.Background(), 10*time.Second)
go func() {
defer cleanupCancel()
m.cleanup(cleanupCtx, gateway)
}()
}
}
// GetMapping returns the current mapping if ready, nil otherwise
func (m *Manager) GetMapping() *Mapping {
m.mappingLock.Lock()
defer m.mappingLock.Unlock()
if m.mapping == nil {
return nil
}
mapping := *m.mapping
return &mapping
}
// GracefullyStop cancels the manager and attempts to delete the port mapping.
// After GracefullyStop returns, the manager cannot be restarted.
func (m *Manager) GracefullyStop(ctx context.Context) error {
m.mu.Lock()
defer m.mu.Unlock()
if m.cancel == nil {
return nil
}
// Send cleanup context before cancelling, so Start picks it up after renewLoop exits.
m.startTearDown(ctx)
m.cancel()
m.cancel = nil
select {
case <-ctx.Done():
return ctx.Err()
case <-m.done:
return nil
}
}
func (m *Manager) setup(ctx context.Context) (nat.NAT, *Mapping, error) {
discoverCtx, discoverCancel := context.WithTimeout(ctx, discoveryTimeout)
defer discoverCancel()
gateway, err := discoverGateway(discoverCtx)
if err != nil {
log.Infof("NAT gateway discovery failed: %v (port forwarding disabled)", err)
return nil, nil, err
}
log.Infof("discovered NAT gateway: %s", gateway.Type())
mapping, err := m.createMapping(ctx, gateway)
if err != nil {
log.Warnf("failed to create port mapping: %v", err)
return nil, nil, err
}
return gateway, mapping, nil
}
func (m *Manager) createMapping(ctx context.Context, gateway nat.NAT) (*Mapping, error) {
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
externalPort, err := gateway.AddPortMapping(ctx, "udp", int(m.wgPort), mappingDescription, defaultMappingTTL)
if err != nil {
return nil, err
}
externalIP, err := gateway.GetExternalAddress()
if err != nil {
log.Debugf("failed to get external address: %v", err)
}
mapping := &Mapping{
Protocol: "udp",
InternalPort: m.wgPort,
ExternalPort: uint16(externalPort),
ExternalIP: externalIP,
NATType: gateway.Type(),
}
log.Infof("created port mapping: %d -> %d via %s (external IP: %s)",
m.wgPort, externalPort, gateway.Type(), externalIP)
return mapping, nil
}
func (m *Manager) renewLoop(ctx context.Context, gateway nat.NAT) {
renewTicker := time.NewTicker(renewalInterval)
healthTicker := time.NewTicker(healthCheckInterval)
defer renewTicker.Stop()
defer healthTicker.Stop()
for {
select {
case <-ctx.Done():
return
case <-renewTicker.C:
if err := m.renewMapping(ctx, gateway); err != nil {
log.Warnf("failed to renew port mapping: %v", err)
continue
}
case <-healthTicker.C:
if m.checkHealthAndRecreate(ctx, gateway) {
renewTicker.Reset(renewalInterval)
}
}
}
}
func (m *Manager) checkHealthAndRecreate(ctx context.Context, gateway nat.NAT) bool {
if isHealthCheckDisabled() {
return false
}
m.mappingLock.Lock()
hasMapping := m.mapping != nil
m.mappingLock.Unlock()
if !hasMapping {
return false
}
pcpNAT, ok := gateway.(*pcp.NAT)
if !ok {
return false
}
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
epoch, serverRestarted, err := pcpNAT.CheckServerHealth(ctx)
if err != nil {
log.Debugf("PCP health check failed: %v", err)
return false
}
if serverRestarted {
log.Warnf("PCP server restart detected (epoch=%d), recreating port mapping", epoch)
if err := m.renewMapping(ctx, gateway); err != nil {
log.Errorf("failed to recreate port mapping after server restart: %v", err)
return false
}
return true
}
return false
}
func (m *Manager) renewMapping(ctx context.Context, gateway nat.NAT) error {
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
externalPort, err := gateway.AddPortMapping(ctx, m.mapping.Protocol, int(m.mapping.InternalPort), mappingDescription, defaultMappingTTL)
if err != nil {
return fmt.Errorf("add port mapping: %w", err)
}
if uint16(externalPort) != m.mapping.ExternalPort {
log.Warnf("external port changed on renewal: %d -> %d (candidate may be stale)", m.mapping.ExternalPort, externalPort)
m.mappingLock.Lock()
m.mapping.ExternalPort = uint16(externalPort)
m.mappingLock.Unlock()
}
log.Debugf("renewed port mapping: %d -> %d", m.mapping.InternalPort, m.mapping.ExternalPort)
return nil
}
func (m *Manager) cleanup(ctx context.Context, gateway nat.NAT) {
m.mappingLock.Lock()
mapping := m.mapping
m.mapping = nil
m.mappingLock.Unlock()
if mapping == nil {
return
}
if err := gateway.DeletePortMapping(ctx, mapping.Protocol, int(mapping.InternalPort)); err != nil {
log.Warnf("delete port mapping on stop: %v", err)
return
}
log.Infof("deleted port mapping for port %d", mapping.InternalPort)
}
func (m *Manager) startTearDown(ctx context.Context) {
select {
case m.stopCtx <- ctx:
default:
}
}

View File

@@ -1,36 +0,0 @@
package portforward
import (
"context"
"net"
)
// Mapping represents port mapping information.
type Mapping struct {
Protocol string
InternalPort uint16
ExternalPort uint16
ExternalIP net.IP
NATType string
}
// Manager is a stub for js/wasm builds where NAT-PMP/UPnP is not supported.
type Manager struct{}
// NewManager returns a stub manager for js/wasm builds.
func NewManager() *Manager {
return &Manager{}
}
// Start is a no-op on js/wasm: NAT-PMP/UPnP is not available in browser environments.
func (m *Manager) Start(context.Context, uint16) {
// no NAT traversal in wasm
}
// GracefullyStop is a no-op on js/wasm.
func (m *Manager) GracefullyStop(context.Context) error { return nil }
// GetMapping always returns nil on js/wasm.
func (m *Manager) GetMapping() *Mapping {
return nil
}

View File

@@ -1,159 +0,0 @@
//go:build !js
package portforward
import (
"context"
"net"
"testing"
"time"
"github.com/libp2p/go-nat"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type mockNAT struct {
natType string
deviceAddr net.IP
externalAddr net.IP
internalAddr net.IP
mappings map[int]int
addMappingErr error
deleteMappingErr error
}
func newMockNAT() *mockNAT {
return &mockNAT{
natType: "Mock-NAT",
deviceAddr: net.ParseIP("192.168.1.1"),
externalAddr: net.ParseIP("203.0.113.50"),
internalAddr: net.ParseIP("192.168.1.100"),
mappings: make(map[int]int),
}
}
func (m *mockNAT) Type() string {
return m.natType
}
func (m *mockNAT) GetDeviceAddress() (net.IP, error) {
return m.deviceAddr, nil
}
func (m *mockNAT) GetExternalAddress() (net.IP, error) {
return m.externalAddr, nil
}
func (m *mockNAT) GetInternalAddress() (net.IP, error) {
return m.internalAddr, nil
}
func (m *mockNAT) AddPortMapping(ctx context.Context, protocol string, internalPort int, description string, timeout time.Duration) (int, error) {
if m.addMappingErr != nil {
return 0, m.addMappingErr
}
externalPort := internalPort
m.mappings[internalPort] = externalPort
return externalPort, nil
}
func (m *mockNAT) DeletePortMapping(ctx context.Context, protocol string, internalPort int) error {
if m.deleteMappingErr != nil {
return m.deleteMappingErr
}
delete(m.mappings, internalPort)
return nil
}
func TestManager_CreateMapping(t *testing.T) {
m := NewManager()
m.wgPort = 51820
gateway := newMockNAT()
mapping, err := m.createMapping(context.Background(), gateway)
require.NoError(t, err)
require.NotNil(t, mapping)
assert.Equal(t, "udp", mapping.Protocol)
assert.Equal(t, uint16(51820), mapping.InternalPort)
assert.Equal(t, uint16(51820), mapping.ExternalPort)
assert.Equal(t, "Mock-NAT", mapping.NATType)
assert.Equal(t, net.ParseIP("203.0.113.50").To4(), mapping.ExternalIP.To4())
}
func TestManager_GetMapping_ReturnsNilWhenNotReady(t *testing.T) {
m := NewManager()
assert.Nil(t, m.GetMapping())
}
func TestManager_GetMapping_ReturnsCopy(t *testing.T) {
m := NewManager()
m.mapping = &Mapping{
Protocol: "udp",
InternalPort: 51820,
ExternalPort: 51820,
}
mapping := m.GetMapping()
require.NotNil(t, mapping)
assert.Equal(t, uint16(51820), mapping.InternalPort)
// Mutating the returned copy should not affect the manager's mapping.
mapping.ExternalPort = 9999
assert.Equal(t, uint16(51820), m.GetMapping().ExternalPort)
}
func TestManager_Cleanup_DeletesMapping(t *testing.T) {
m := NewManager()
m.mapping = &Mapping{
Protocol: "udp",
InternalPort: 51820,
ExternalPort: 51820,
}
gateway := newMockNAT()
// Seed the mock so we can verify deletion.
gateway.mappings[51820] = 51820
m.cleanup(context.Background(), gateway)
_, exists := gateway.mappings[51820]
assert.False(t, exists, "mapping should be deleted from gateway")
assert.Nil(t, m.GetMapping(), "in-memory mapping should be cleared")
}
func TestManager_Cleanup_NilMapping(t *testing.T) {
m := NewManager()
gateway := newMockNAT()
// Should not panic or call gateway.
m.cleanup(context.Background(), gateway)
}
func TestState_Cleanup(t *testing.T) {
origDiscover := discoverGateway
defer func() { discoverGateway = origDiscover }()
mockGateway := newMockNAT()
mockGateway.mappings[51820] = 51820
discoverGateway = func(ctx context.Context) (nat.NAT, error) {
return mockGateway, nil
}
state := &State{
Protocol: "udp",
InternalPort: 51820,
}
err := state.Cleanup()
assert.NoError(t, err)
_, exists := mockGateway.mappings[51820]
assert.False(t, exists, "mapping should be deleted after cleanup")
}
func TestState_Name(t *testing.T) {
state := &State{}
assert.Equal(t, "port_forward_state", state.Name())
}

View File

@@ -1,408 +0,0 @@
package pcp
import (
"context"
"crypto/rand"
"errors"
"fmt"
"net"
"net/netip"
"sync"
"time"
log "github.com/sirupsen/logrus"
)
const (
defaultTimeout = 3 * time.Second
responseBufferSize = 128
// RFC 6887 Section 8.1.1 retry timing
initialRetryDelay = 3 * time.Second
maxRetryDelay = 1024 * time.Second
maxRetries = 4 // 3s + 6s + 12s + 24s = 45s total worst case
)
// Client is a PCP protocol client.
// All methods are safe for concurrent use.
type Client struct {
gateway netip.Addr
timeout time.Duration
mu sync.Mutex
// localIP caches the resolved local IP address.
localIP netip.Addr
// lastEpoch is the last observed server epoch value.
lastEpoch uint32
// epochTime tracks when lastEpoch was received for state loss detection.
epochTime time.Time
// externalIP caches the external IP from the last successful MAP response.
externalIP netip.Addr
// epochStateLost is set when epoch indicates server restart.
epochStateLost bool
}
// NewClient creates a new PCP client for the gateway at the given IP.
func NewClient(gateway net.IP) *Client {
addr, ok := netip.AddrFromSlice(gateway)
if !ok {
log.Debugf("invalid gateway IP: %v", gateway)
}
return &Client{
gateway: addr.Unmap(),
timeout: defaultTimeout,
}
}
// NewClientWithTimeout creates a new PCP client with a custom timeout.
func NewClientWithTimeout(gateway net.IP, timeout time.Duration) *Client {
addr, ok := netip.AddrFromSlice(gateway)
if !ok {
log.Debugf("invalid gateway IP: %v", gateway)
}
return &Client{
gateway: addr.Unmap(),
timeout: timeout,
}
}
// SetLocalIP sets the local IP address to use in PCP requests.
func (c *Client) SetLocalIP(ip net.IP) {
addr, ok := netip.AddrFromSlice(ip)
if !ok {
log.Debugf("invalid local IP: %v", ip)
}
c.mu.Lock()
c.localIP = addr.Unmap()
c.mu.Unlock()
}
// Gateway returns the gateway IP address.
func (c *Client) Gateway() net.IP {
return c.gateway.AsSlice()
}
// Announce sends a PCP ANNOUNCE request to discover PCP support.
// Returns the server's epoch time on success.
func (c *Client) Announce(ctx context.Context) (epoch uint32, err error) {
localIP, err := c.getLocalIP()
if err != nil {
return 0, fmt.Errorf("get local IP: %w", err)
}
req := buildAnnounceRequest(localIP)
resp, err := c.sendRequest(ctx, req)
if err != nil {
return 0, fmt.Errorf("send announce: %w", err)
}
parsed, err := parseResponse(resp)
if err != nil {
return 0, fmt.Errorf("parse announce response: %w", err)
}
if parsed.ResultCode != ResultSuccess {
return 0, fmt.Errorf("PCP ANNOUNCE failed: %s", ResultCodeString(parsed.ResultCode))
}
c.mu.Lock()
if c.updateEpochLocked(parsed.Epoch) {
log.Warnf("PCP server epoch indicates state loss - mappings may need refresh")
}
c.mu.Unlock()
return parsed.Epoch, nil
}
// AddPortMapping requests a port mapping from the PCP server.
func (c *Client) AddPortMapping(ctx context.Context, protocol string, internalPort int, lifetime time.Duration) (*MapResponse, error) {
return c.addPortMappingWithHint(ctx, protocol, internalPort, internalPort, netip.Addr{}, lifetime)
}
// AddPortMappingWithHint requests a port mapping with suggested external port and IP.
// Use lifetime <= 0 to delete a mapping.
func (c *Client) AddPortMappingWithHint(ctx context.Context, protocol string, internalPort, suggestedExtPort int, suggestedExtIP net.IP, lifetime time.Duration) (*MapResponse, error) {
var extIP netip.Addr
if suggestedExtIP != nil {
var ok bool
extIP, ok = netip.AddrFromSlice(suggestedExtIP)
if !ok {
log.Debugf("invalid suggested external IP: %v", suggestedExtIP)
}
extIP = extIP.Unmap()
}
return c.addPortMappingWithHint(ctx, protocol, internalPort, suggestedExtPort, extIP, lifetime)
}
func (c *Client) addPortMappingWithHint(ctx context.Context, protocol string, internalPort, suggestedExtPort int, suggestedExtIP netip.Addr, lifetime time.Duration) (*MapResponse, error) {
localIP, err := c.getLocalIP()
if err != nil {
return nil, fmt.Errorf("get local IP: %w", err)
}
proto, err := protocolNumber(protocol)
if err != nil {
return nil, fmt.Errorf("parse protocol: %w", err)
}
var nonce [12]byte
if _, err := rand.Read(nonce[:]); err != nil {
return nil, fmt.Errorf("generate nonce: %w", err)
}
// Convert lifetime to seconds. Lifetime 0 means delete, so only apply
// default for positive durations that round to 0 seconds.
var lifetimeSec uint32
if lifetime > 0 {
lifetimeSec = uint32(lifetime.Seconds())
if lifetimeSec == 0 {
lifetimeSec = DefaultLifetime
}
}
req := buildMapRequest(localIP, nonce, proto, uint16(internalPort), uint16(suggestedExtPort), suggestedExtIP, lifetimeSec)
resp, err := c.sendRequest(ctx, req)
if err != nil {
return nil, fmt.Errorf("send map request: %w", err)
}
mapResp, err := parseMapResponse(resp)
if err != nil {
return nil, fmt.Errorf("parse map response: %w", err)
}
if mapResp.Nonce != nonce {
return nil, fmt.Errorf("nonce mismatch in response")
}
if mapResp.Protocol != proto {
return nil, fmt.Errorf("protocol mismatch: requested %d, got %d", proto, mapResp.Protocol)
}
if mapResp.InternalPort != uint16(internalPort) {
return nil, fmt.Errorf("internal port mismatch: requested %d, got %d", internalPort, mapResp.InternalPort)
}
if mapResp.ResultCode != ResultSuccess {
return nil, &Error{
Code: mapResp.ResultCode,
Message: ResultCodeString(mapResp.ResultCode),
}
}
c.mu.Lock()
if c.updateEpochLocked(mapResp.Epoch) {
log.Warnf("PCP server epoch indicates state loss - mappings may need refresh")
}
c.cacheExternalIPLocked(mapResp.ExternalIP)
c.mu.Unlock()
return mapResp, nil
}
// DeletePortMapping removes a port mapping by requesting zero lifetime.
func (c *Client) DeletePortMapping(ctx context.Context, protocol string, internalPort int) error {
if _, err := c.addPortMappingWithHint(ctx, protocol, internalPort, 0, netip.Addr{}, 0); err != nil {
var pcpErr *Error
if errors.As(err, &pcpErr) && pcpErr.Code == ResultNotAuthorized {
return nil
}
return fmt.Errorf("delete mapping: %w", err)
}
return nil
}
// GetExternalAddress returns the external IP address.
// First checks for a cached value from previous MAP responses.
// If not cached, creates a short-lived mapping to discover the external IP.
func (c *Client) GetExternalAddress(ctx context.Context) (net.IP, error) {
c.mu.Lock()
if c.externalIP.IsValid() {
ip := c.externalIP.AsSlice()
c.mu.Unlock()
return ip, nil
}
c.mu.Unlock()
// Use an ephemeral port in the dynamic range (49152-65535).
// Port 0 is not valid with UDP/TCP protocols per RFC 6887.
ephemeralPort := 49152 + int(uint16(time.Now().UnixNano()))%(65535-49152)
// Use minimal lifetime (1 second) for discovery.
resp, err := c.AddPortMapping(ctx, "udp", ephemeralPort, time.Second)
if err != nil {
return nil, fmt.Errorf("create temporary mapping: %w", err)
}
if err := c.DeletePortMapping(ctx, "udp", ephemeralPort); err != nil {
log.Debugf("cleanup temporary PCP mapping: %v", err)
}
return resp.ExternalIP.AsSlice(), nil
}
// LastEpoch returns the last observed server epoch value.
// A decrease in epoch indicates the server may have restarted and mappings may be lost.
func (c *Client) LastEpoch() uint32 {
c.mu.Lock()
defer c.mu.Unlock()
return c.lastEpoch
}
// EpochStateLost returns true if epoch state loss was detected and clears the flag.
func (c *Client) EpochStateLost() bool {
c.mu.Lock()
defer c.mu.Unlock()
lost := c.epochStateLost
c.epochStateLost = false
return lost
}
// updateEpoch updates the epoch tracking and detects potential state loss.
// Returns true if state loss was detected (server likely restarted).
// Caller must hold c.mu.
func (c *Client) updateEpochLocked(newEpoch uint32) bool {
now := time.Now()
stateLost := false
// RFC 6887 Section 8.5: Detect invalid epoch indicating server state loss.
// client_delta = time since last response
// server_delta = epoch change since last response
// Invalid if: client_delta+2 < server_delta - server_delta/16
// OR: server_delta+2 < client_delta - client_delta/16
// The +2 handles quantization, /16 (6.25%) handles clock drift.
if !c.epochTime.IsZero() && c.lastEpoch > 0 {
clientDelta := uint32(now.Sub(c.epochTime).Seconds())
serverDelta := newEpoch - c.lastEpoch
// Check for epoch going backwards or jumping unexpectedly.
// Subtraction is safe: serverDelta/16 is always <= serverDelta.
if clientDelta+2 < serverDelta-(serverDelta/16) ||
serverDelta+2 < clientDelta-(clientDelta/16) {
stateLost = true
c.epochStateLost = true
}
}
c.lastEpoch = newEpoch
c.epochTime = now
return stateLost
}
// cacheExternalIP stores the external IP from a successful MAP response.
// Caller must hold c.mu.
func (c *Client) cacheExternalIPLocked(ip netip.Addr) {
if ip.IsValid() && !ip.IsUnspecified() {
c.externalIP = ip
}
}
// sendRequest sends a PCP request with retries per RFC 6887 Section 8.1.1.
func (c *Client) sendRequest(ctx context.Context, req []byte) ([]byte, error) {
addr := &net.UDPAddr{IP: c.gateway.AsSlice(), Port: Port}
var lastErr error
delay := initialRetryDelay
for range maxRetries {
resp, err := c.sendOnce(ctx, addr, req)
if err == nil {
return resp, nil
}
lastErr = err
if ctx.Err() != nil {
return nil, ctx.Err()
}
// RFC 6887 Section 8.1.1: RT = (1 + RAND) * MIN(2 * RTprev, MRT)
// RAND is random between -0.1 and +0.1
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(retryDelayWithJitter(delay)):
}
delay = min(delay*2, maxRetryDelay)
}
return nil, fmt.Errorf("PCP request failed after %d retries: %w", maxRetries, lastErr)
}
// retryDelayWithJitter applies RFC 6887 jitter: multiply by (1 + RAND) where RAND is [-0.1, +0.1].
func retryDelayWithJitter(d time.Duration) time.Duration {
var b [1]byte
_, _ = rand.Read(b[:])
// Convert byte to range [-0.1, +0.1]: (b/255 * 0.2) - 0.1
jitter := (float64(b[0])/255.0)*0.2 - 0.1
return time.Duration(float64(d) * (1 + jitter))
}
func (c *Client) sendOnce(ctx context.Context, addr *net.UDPAddr, req []byte) ([]byte, error) {
// Use ListenUDP instead of DialUDP to validate response source address per RFC 6887 §8.3.
conn, err := net.ListenUDP("udp", nil)
if err != nil {
return nil, fmt.Errorf("listen: %w", err)
}
defer func() {
if err := conn.Close(); err != nil {
log.Debugf("close UDP connection: %v", err)
}
}()
timeout := c.timeout
if deadline, ok := ctx.Deadline(); ok {
if remaining := time.Until(deadline); remaining < timeout {
timeout = remaining
}
}
if err := conn.SetDeadline(time.Now().Add(timeout)); err != nil {
return nil, fmt.Errorf("set deadline: %w", err)
}
if _, err := conn.WriteToUDP(req, addr); err != nil {
return nil, fmt.Errorf("write: %w", err)
}
resp := make([]byte, responseBufferSize)
n, from, err := conn.ReadFromUDP(resp)
if err != nil {
return nil, fmt.Errorf("read: %w", err)
}
// RFC 6887 §8.3: Validate response came from expected PCP server.
if !from.IP.Equal(addr.IP) {
return nil, fmt.Errorf("response from unexpected source %s (expected %s)", from.IP, addr.IP)
}
return resp[:n], nil
}
func (c *Client) getLocalIP() (netip.Addr, error) {
c.mu.Lock()
defer c.mu.Unlock()
if !c.localIP.IsValid() {
return netip.Addr{}, fmt.Errorf("local IP not set for gateway %s", c.gateway)
}
return c.localIP, nil
}
func protocolNumber(protocol string) (uint8, error) {
switch protocol {
case "udp", "UDP":
return ProtoUDP, nil
case "tcp", "TCP":
return ProtoTCP, nil
default:
return 0, fmt.Errorf("unsupported protocol: %s", protocol)
}
}
// Error represents a PCP error response.
type Error struct {
Code uint8
Message string
}
func (e *Error) Error() string {
return fmt.Sprintf("PCP error: %s (%d)", e.Message, e.Code)
}

View File

@@ -1,187 +0,0 @@
package pcp
import (
"context"
"net"
"net/netip"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestAddrConversion(t *testing.T) {
tests := []struct {
name string
addr netip.Addr
}{
{"IPv4", netip.MustParseAddr("192.168.1.100")},
{"IPv4 loopback", netip.MustParseAddr("127.0.0.1")},
{"IPv6", netip.MustParseAddr("2001:db8::1")},
{"IPv6 loopback", netip.MustParseAddr("::1")},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
b16 := addrTo16(tt.addr)
recovered := addrFrom16(b16)
assert.Equal(t, tt.addr, recovered, "address should round-trip")
})
}
}
func TestBuildAnnounceRequest(t *testing.T) {
clientIP := netip.MustParseAddr("192.168.1.100")
req := buildAnnounceRequest(clientIP)
require.Len(t, req, headerSize)
assert.Equal(t, byte(Version), req[0], "version")
assert.Equal(t, byte(OpAnnounce), req[1], "opcode")
// Check client IP is properly encoded as IPv4-mapped IPv6
assert.Equal(t, byte(0xff), req[18], "IPv4-mapped prefix byte 10")
assert.Equal(t, byte(0xff), req[19], "IPv4-mapped prefix byte 11")
assert.Equal(t, byte(192), req[20], "IP octet 1")
assert.Equal(t, byte(168), req[21], "IP octet 2")
assert.Equal(t, byte(1), req[22], "IP octet 3")
assert.Equal(t, byte(100), req[23], "IP octet 4")
}
func TestBuildMapRequest(t *testing.T) {
clientIP := netip.MustParseAddr("192.168.1.100")
nonce := [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}
req := buildMapRequest(clientIP, nonce, ProtoUDP, 51820, 51820, netip.Addr{}, 3600)
require.Len(t, req, mapRequestSize)
assert.Equal(t, byte(Version), req[0], "version")
assert.Equal(t, byte(OpMap), req[1], "opcode")
// Lifetime at bytes 4-7
assert.Equal(t, uint32(3600), (uint32(req[4])<<24)|(uint32(req[5])<<16)|(uint32(req[6])<<8)|uint32(req[7]), "lifetime")
// Nonce at bytes 24-35
assert.Equal(t, nonce[:], req[24:36], "nonce")
// Protocol at byte 36
assert.Equal(t, byte(ProtoUDP), req[36], "protocol")
// Internal port at bytes 40-41
assert.Equal(t, uint16(51820), (uint16(req[40])<<8)|uint16(req[41]), "internal port")
// External port at bytes 42-43
assert.Equal(t, uint16(51820), (uint16(req[42])<<8)|uint16(req[43]), "external port")
}
func TestParseResponse(t *testing.T) {
// Construct a valid ANNOUNCE response
resp := make([]byte, headerSize)
resp[0] = Version
resp[1] = OpAnnounce | OpReply
// Result code = 0 (success)
// Lifetime = 0
// Epoch = 12345
resp[8] = 0
resp[9] = 0
resp[10] = 0x30
resp[11] = 0x39
parsed, err := parseResponse(resp)
require.NoError(t, err)
assert.Equal(t, uint8(Version), parsed.Version)
assert.Equal(t, uint8(OpAnnounce|OpReply), parsed.Opcode)
assert.Equal(t, uint8(ResultSuccess), parsed.ResultCode)
assert.Equal(t, uint32(12345), parsed.Epoch)
}
func TestParseResponseErrors(t *testing.T) {
t.Run("too short", func(t *testing.T) {
_, err := parseResponse([]byte{1, 2, 3})
assert.Error(t, err)
})
t.Run("wrong version", func(t *testing.T) {
resp := make([]byte, headerSize)
resp[0] = 1 // Wrong version
resp[1] = OpReply
_, err := parseResponse(resp)
assert.Error(t, err)
})
t.Run("missing reply bit", func(t *testing.T) {
resp := make([]byte, headerSize)
resp[0] = Version
resp[1] = OpAnnounce // Missing OpReply bit
_, err := parseResponse(resp)
assert.Error(t, err)
})
}
func TestResultCodeString(t *testing.T) {
assert.Equal(t, "SUCCESS", ResultCodeString(ResultSuccess))
assert.Equal(t, "NOT_AUTHORIZED", ResultCodeString(ResultNotAuthorized))
assert.Equal(t, "ADDRESS_MISMATCH", ResultCodeString(ResultAddressMismatch))
assert.Contains(t, ResultCodeString(255), "UNKNOWN")
}
func TestProtocolNumber(t *testing.T) {
proto, err := protocolNumber("udp")
require.NoError(t, err)
assert.Equal(t, uint8(ProtoUDP), proto)
proto, err = protocolNumber("tcp")
require.NoError(t, err)
assert.Equal(t, uint8(ProtoTCP), proto)
proto, err = protocolNumber("UDP")
require.NoError(t, err)
assert.Equal(t, uint8(ProtoUDP), proto)
_, err = protocolNumber("icmp")
assert.Error(t, err)
}
func TestClientCreation(t *testing.T) {
gateway := netip.MustParseAddr("192.168.1.1").AsSlice()
client := NewClient(gateway)
assert.Equal(t, net.IP(gateway), client.Gateway())
assert.Equal(t, defaultTimeout, client.timeout)
clientWithTimeout := NewClientWithTimeout(gateway, 5*time.Second)
assert.Equal(t, 5*time.Second, clientWithTimeout.timeout)
}
func TestNATType(t *testing.T) {
n := NewNAT(netip.MustParseAddr("192.168.1.1").AsSlice(), netip.MustParseAddr("192.168.1.100").AsSlice())
assert.Equal(t, "PCP", n.Type())
}
// Integration test - skipped unless PCP_TEST_GATEWAY env is set
func TestClientIntegration(t *testing.T) {
t.Skip("Integration test - run manually with PCP_TEST_GATEWAY=<gateway-ip>")
gateway := netip.MustParseAddr("10.0.1.1").AsSlice() // Change to your test gateway
localIP := netip.MustParseAddr("10.0.1.100").AsSlice() // Change to your local IP
client := NewClient(gateway)
client.SetLocalIP(localIP)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
// Test ANNOUNCE
epoch, err := client.Announce(ctx)
require.NoError(t, err)
t.Logf("Server epoch: %d", epoch)
// Test MAP
resp, err := client.AddPortMapping(ctx, "udp", 51820, 1*time.Hour)
require.NoError(t, err)
t.Logf("Mapping: internal=%d external=%d externalIP=%s",
resp.InternalPort, resp.ExternalPort, resp.ExternalIP)
// Cleanup
err = client.DeletePortMapping(ctx, "udp", 51820)
require.NoError(t, err)
}

View File

@@ -1,209 +0,0 @@
package pcp
import (
"context"
"fmt"
"net"
"net/netip"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/libp2p/go-nat"
"github.com/libp2p/go-netroute"
)
var _ nat.NAT = (*NAT)(nil)
// NAT implements the go-nat NAT interface using PCP.
// Supports dual-stack (IPv4 and IPv6) when available.
// All methods are safe for concurrent use.
//
// TODO: IPv6 pinholes use the local IPv6 address. If the address changes
// (e.g., due to SLAAC rotation or network change), the pinhole becomes stale
// and needs to be recreated with the new address.
type NAT struct {
client *Client
mu sync.RWMutex
// client6 is the IPv6 PCP client, nil if IPv6 is unavailable.
client6 *Client
// localIP6 caches the local IPv6 address used for PCP requests.
localIP6 netip.Addr
}
// NewNAT creates a new NAT instance backed by PCP.
func NewNAT(gateway, localIP net.IP) *NAT {
client := NewClient(gateway)
client.SetLocalIP(localIP)
return &NAT{
client: client,
}
}
// Type returns "PCP" as the NAT type.
func (n *NAT) Type() string {
return "PCP"
}
// GetDeviceAddress returns the gateway IP address.
func (n *NAT) GetDeviceAddress() (net.IP, error) {
return n.client.Gateway(), nil
}
// GetExternalAddress returns the external IP address.
func (n *NAT) GetExternalAddress() (net.IP, error) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
return n.client.GetExternalAddress(ctx)
}
// GetInternalAddress returns the local IP address used to communicate with the gateway.
func (n *NAT) GetInternalAddress() (net.IP, error) {
addr, err := n.client.getLocalIP()
if err != nil {
return nil, err
}
return addr.AsSlice(), nil
}
// AddPortMapping creates a port mapping on both IPv4 and IPv6 (if available).
func (n *NAT) AddPortMapping(ctx context.Context, protocol string, internalPort int, _ string, timeout time.Duration) (int, error) {
resp, err := n.client.AddPortMapping(ctx, protocol, internalPort, timeout)
if err != nil {
return 0, fmt.Errorf("add mapping: %w", err)
}
n.mu.RLock()
client6 := n.client6
localIP6 := n.localIP6
n.mu.RUnlock()
if client6 == nil {
return int(resp.ExternalPort), nil
}
if _, err := client6.AddPortMapping(ctx, protocol, internalPort, timeout); err != nil {
log.Warnf("IPv6 PCP mapping failed (continuing with IPv4): %v", err)
return int(resp.ExternalPort), nil
}
log.Infof("created IPv6 PCP pinhole: %s:%d", localIP6, internalPort)
return int(resp.ExternalPort), nil
}
// DeletePortMapping removes a port mapping from both IPv4 and IPv6.
func (n *NAT) DeletePortMapping(ctx context.Context, protocol string, internalPort int) error {
err := n.client.DeletePortMapping(ctx, protocol, internalPort)
n.mu.RLock()
client6 := n.client6
n.mu.RUnlock()
if client6 != nil {
if err6 := client6.DeletePortMapping(ctx, protocol, internalPort); err6 != nil {
log.Warnf("IPv6 PCP delete mapping failed: %v", err6)
}
}
if err != nil {
return fmt.Errorf("delete mapping: %w", err)
}
return nil
}
// CheckServerHealth sends an ANNOUNCE to verify the server is still responsive.
// Returns the current epoch and whether the server may have restarted (epoch state loss detected).
func (n *NAT) CheckServerHealth(ctx context.Context) (epoch uint32, serverRestarted bool, err error) {
epoch, err = n.client.Announce(ctx)
if err != nil {
return 0, false, fmt.Errorf("announce: %w", err)
}
return epoch, n.client.EpochStateLost(), nil
}
// DiscoverPCP attempts to discover a PCP-capable gateway.
// Returns a NAT interface if PCP is supported, or an error otherwise.
// Discovers both IPv4 and IPv6 gateways when available.
func DiscoverPCP(ctx context.Context) (nat.NAT, error) {
gateway, localIP, err := getDefaultGateway()
if err != nil {
return nil, fmt.Errorf("get default gateway: %w", err)
}
client := NewClient(gateway)
client.SetLocalIP(localIP)
if _, err := client.Announce(ctx); err != nil {
return nil, fmt.Errorf("PCP announce: %w", err)
}
result := &NAT{client: client}
discoverIPv6(ctx, result)
return result, nil
}
func discoverIPv6(ctx context.Context, result *NAT) {
gateway6, localIP6, err := getDefaultGateway6()
if err != nil {
log.Debugf("IPv6 gateway discovery failed: %v", err)
return
}
client6 := NewClient(gateway6)
client6.SetLocalIP(localIP6)
if _, err := client6.Announce(ctx); err != nil {
log.Debugf("PCP IPv6 announce failed: %v", err)
return
}
addr, ok := netip.AddrFromSlice(localIP6)
if !ok {
log.Debugf("invalid IPv6 local IP: %v", localIP6)
return
}
result.mu.Lock()
result.client6 = client6
result.localIP6 = addr
result.mu.Unlock()
log.Debugf("PCP IPv6 gateway discovered: %s (local: %s)", gateway6, localIP6)
}
// getDefaultGateway returns the default IPv4 gateway and local IP using the system routing table.
func getDefaultGateway() (gateway net.IP, localIP net.IP, err error) {
router, err := netroute.New()
if err != nil {
return nil, nil, err
}
_, gateway, localIP, err = router.Route(net.IPv4zero)
if err != nil {
return nil, nil, err
}
if gateway == nil {
return nil, nil, nat.ErrNoNATFound
}
return gateway, localIP, nil
}
// getDefaultGateway6 returns the default IPv6 gateway IP address using the system routing table.
func getDefaultGateway6() (gateway net.IP, localIP net.IP, err error) {
router, err := netroute.New()
if err != nil {
return nil, nil, err
}
_, gateway, localIP, err = router.Route(net.IPv6zero)
if err != nil {
return nil, nil, err
}
if gateway == nil {
return nil, nil, nat.ErrNoNATFound
}
return gateway, localIP, nil
}

View File

@@ -1,225 +0,0 @@
// Package pcp implements the Port Control Protocol (RFC 6887).
//
// # Implemented Features
//
// - ANNOUNCE opcode: Discovers PCP server support
// - MAP opcode: Creates/deletes port mappings (IPv4 NAT) and firewall pinholes (IPv6)
// - Dual-stack: Simultaneous IPv4 and IPv6 support via separate clients
// - Nonce validation: Prevents response spoofing
// - Epoch tracking: Detects server restarts per Section 8.5
// - RFC-compliant retry timing: 3s initial, exponential backoff to 1024s max (Section 8.1.1)
//
// # Not Implemented
//
// - PEER opcode: For outbound peer connections (not needed for inbound NAT traversal)
// - THIRD_PARTY option: For managing mappings on behalf of other devices
// - PREFER_FAILURE option: Requires exact external port or fail (IPv4 NAT only, not needed for IPv6 pinholing)
// - FILTER option: To restrict remote peer addresses
//
// These optional features are omitted because the primary use case is simple
// port forwarding for WireGuard, which only requires MAP with default behavior.
package pcp
import (
"encoding/binary"
"fmt"
"net/netip"
)
const (
// Version is the PCP protocol version (RFC 6887).
Version = 2
// Port is the standard PCP server port.
Port = 5351
// DefaultLifetime is the default requested mapping lifetime in seconds.
DefaultLifetime = 7200 // 2 hours
// Header sizes
headerSize = 24
mapPayloadSize = 36
mapRequestSize = headerSize + mapPayloadSize // 60 bytes
)
// Opcodes
const (
OpAnnounce = 0
OpMap = 1
OpPeer = 2
OpReply = 0x80 // OR'd with opcode in responses
)
// Protocol numbers for MAP requests
const (
ProtoUDP = 17
ProtoTCP = 6
)
// Result codes (RFC 6887 Section 7.4)
const (
ResultSuccess = 0
ResultUnsuppVersion = 1
ResultNotAuthorized = 2
ResultMalformedRequest = 3
ResultUnsuppOpcode = 4
ResultUnsuppOption = 5
ResultMalformedOption = 6
ResultNetworkFailure = 7
ResultNoResources = 8
ResultUnsuppProtocol = 9
ResultUserExQuota = 10
ResultCannotProvideExt = 11
ResultAddressMismatch = 12
ResultExcessiveRemotePeers = 13
)
// ResultCodeString returns a human-readable string for a result code.
func ResultCodeString(code uint8) string {
switch code {
case ResultSuccess:
return "SUCCESS"
case ResultUnsuppVersion:
return "UNSUPP_VERSION"
case ResultNotAuthorized:
return "NOT_AUTHORIZED"
case ResultMalformedRequest:
return "MALFORMED_REQUEST"
case ResultUnsuppOpcode:
return "UNSUPP_OPCODE"
case ResultUnsuppOption:
return "UNSUPP_OPTION"
case ResultMalformedOption:
return "MALFORMED_OPTION"
case ResultNetworkFailure:
return "NETWORK_FAILURE"
case ResultNoResources:
return "NO_RESOURCES"
case ResultUnsuppProtocol:
return "UNSUPP_PROTOCOL"
case ResultUserExQuota:
return "USER_EX_QUOTA"
case ResultCannotProvideExt:
return "CANNOT_PROVIDE_EXTERNAL"
case ResultAddressMismatch:
return "ADDRESS_MISMATCH"
case ResultExcessiveRemotePeers:
return "EXCESSIVE_REMOTE_PEERS"
default:
return fmt.Sprintf("UNKNOWN(%d)", code)
}
}
// Response represents a parsed PCP response header.
type Response struct {
Version uint8
Opcode uint8
ResultCode uint8
Lifetime uint32
Epoch uint32
}
// MapResponse contains the full response to a MAP request.
type MapResponse struct {
Response
Nonce [12]byte
Protocol uint8
InternalPort uint16
ExternalPort uint16
ExternalIP netip.Addr
}
// addrTo16 converts an address to its 16-byte IPv4-mapped IPv6 representation.
func addrTo16(addr netip.Addr) [16]byte {
if addr.Is4() {
return netip.AddrFrom4(addr.As4()).As16()
}
return addr.As16()
}
// addrFrom16 extracts an address from a 16-byte representation, unmapping IPv4.
func addrFrom16(b [16]byte) netip.Addr {
return netip.AddrFrom16(b).Unmap()
}
// buildAnnounceRequest creates a PCP ANNOUNCE request packet.
func buildAnnounceRequest(clientIP netip.Addr) []byte {
req := make([]byte, headerSize)
req[0] = Version
req[1] = OpAnnounce
mapped := addrTo16(clientIP)
copy(req[8:24], mapped[:])
return req
}
// buildMapRequest creates a PCP MAP request packet.
func buildMapRequest(clientIP netip.Addr, nonce [12]byte, protocol uint8, internalPort, suggestedExtPort uint16, suggestedExtIP netip.Addr, lifetime uint32) []byte {
req := make([]byte, mapRequestSize)
// Header
req[0] = Version
req[1] = OpMap
binary.BigEndian.PutUint32(req[4:8], lifetime)
mapped := addrTo16(clientIP)
copy(req[8:24], mapped[:])
// MAP payload
copy(req[24:36], nonce[:])
req[36] = protocol
binary.BigEndian.PutUint16(req[40:42], internalPort)
binary.BigEndian.PutUint16(req[42:44], suggestedExtPort)
if suggestedExtIP.IsValid() {
extMapped := addrTo16(suggestedExtIP)
copy(req[44:60], extMapped[:])
}
return req
}
// parseResponse parses the common PCP response header.
func parseResponse(data []byte) (*Response, error) {
if len(data) < headerSize {
return nil, fmt.Errorf("response too short: %d bytes", len(data))
}
resp := &Response{
Version: data[0],
Opcode: data[1],
ResultCode: data[3], // Byte 2 is reserved, byte 3 is result code (RFC 6887 §7.2)
Lifetime: binary.BigEndian.Uint32(data[4:8]),
Epoch: binary.BigEndian.Uint32(data[8:12]),
}
if resp.Version != Version {
return nil, fmt.Errorf("unsupported PCP version: %d", resp.Version)
}
if resp.Opcode&OpReply == 0 {
return nil, fmt.Errorf("response missing reply bit: opcode=0x%02x", resp.Opcode)
}
return resp, nil
}
// parseMapResponse parses a complete MAP response.
func parseMapResponse(data []byte) (*MapResponse, error) {
if len(data) < mapRequestSize {
return nil, fmt.Errorf("MAP response too short: %d bytes", len(data))
}
resp, err := parseResponse(data)
if err != nil {
return nil, fmt.Errorf("parse header: %w", err)
}
mapResp := &MapResponse{
Response: *resp,
Protocol: data[36],
InternalPort: binary.BigEndian.Uint16(data[40:42]),
ExternalPort: binary.BigEndian.Uint16(data[42:44]),
ExternalIP: addrFrom16([16]byte(data[44:60])),
}
copy(mapResp.Nonce[:], data[24:36])
return mapResp, nil
}

View File

@@ -1,63 +0,0 @@
//go:build !js
package portforward
import (
"context"
"fmt"
"github.com/libp2p/go-nat"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/portforward/pcp"
)
// discoverGateway is the function used for NAT gateway discovery.
// It can be replaced in tests to avoid real network operations.
// Tries PCP first, then falls back to NAT-PMP/UPnP.
var discoverGateway = defaultDiscoverGateway
func defaultDiscoverGateway(ctx context.Context) (nat.NAT, error) {
pcpGateway, err := pcp.DiscoverPCP(ctx)
if err == nil {
return pcpGateway, nil
}
log.Debugf("PCP discovery failed: %v, trying NAT-PMP/UPnP", err)
return nat.DiscoverGateway(ctx)
}
// State is persisted only for crash recovery cleanup
type State struct {
InternalPort uint16 `json:"internal_port,omitempty"`
Protocol string `json:"protocol,omitempty"`
}
func (s *State) Name() string {
return "port_forward_state"
}
// Cleanup implements statemanager.CleanableState for crash recovery
func (s *State) Cleanup() error {
if s.InternalPort == 0 {
return nil
}
log.Infof("cleaning up stale port mapping for port %d", s.InternalPort)
ctx, cancel := context.WithTimeout(context.Background(), discoveryTimeout)
defer cancel()
gateway, err := discoverGateway(ctx)
if err != nil {
// Discovery failure is not an error - gateway may not exist
log.Debugf("cleanup: no gateway found: %v", err)
return nil
}
if err := gateway.DeletePortMapping(ctx, s.Protocol, int(s.InternalPort)); err != nil {
return fmt.Errorf("delete port mapping: %w", err)
}
return nil
}

View File

@@ -198,7 +198,7 @@ func getConfigDirForUser(username string) (string, error) {
configDir := filepath.Join(DefaultConfigPathDir, username)
if _, err := os.Stat(configDir); os.IsNotExist(err) {
if err := os.MkdirAll(configDir, 0700); err != nil {
if err := os.MkdirAll(configDir, 0600); err != nil {
return "", err
}
}
@@ -206,15 +206,9 @@ func getConfigDirForUser(username string) (string, error) {
return configDir, nil
}
func fileExists(path string) (bool, error) {
func fileExists(path string) bool {
_, err := os.Stat(path)
if err == nil {
return true, nil
}
if os.IsNotExist(err) {
return false, nil
}
return false, err
return !os.IsNotExist(err)
}
// createNewConfig creates a new config generating a new Wireguard key and saving to file
@@ -641,11 +635,7 @@ func isPreSharedKeyHidden(preSharedKey *string) bool {
// UpdateConfig update existing configuration according to input configuration and return with the configuration
func UpdateConfig(input ConfigInput) (*Config, error) {
configExists, err := fileExists(input.ConfigPath)
if err != nil {
return nil, fmt.Errorf("failed to check if config file exists: %w", err)
}
if !configExists {
if !fileExists(input.ConfigPath) {
return nil, fmt.Errorf("config file %s does not exist", input.ConfigPath)
}
@@ -654,11 +644,7 @@ func UpdateConfig(input ConfigInput) (*Config, error) {
// UpdateOrCreateConfig reads existing config or generates a new one
func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
configExists, err := fileExists(input.ConfigPath)
if err != nil {
return nil, fmt.Errorf("failed to check if config file exists: %w", err)
}
if !configExists {
if !fileExists(input.ConfigPath) {
log.Infof("generating new config %s", input.ConfigPath)
cfg, err := createNewConfig(input)
if err != nil {
@@ -671,7 +657,7 @@ func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
if isPreSharedKeyHidden(input.PreSharedKey) {
input.PreSharedKey = nil
}
err = util.EnforcePermission(input.ConfigPath)
err := util.EnforcePermission(input.ConfigPath)
if err != nil {
log.Errorf("failed to enforce permission on config dir: %v", err)
}
@@ -798,12 +784,7 @@ func ReadConfig(configPath string) (*Config, error) {
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
func readConfig(configPath string, createIfMissing bool) (*Config, error) {
configExists, err := fileExists(configPath)
if err != nil {
return nil, fmt.Errorf("failed to check if config file exists: %w", err)
}
if configExists {
if fileExists(configPath) {
err := util.EnforcePermission(configPath)
if err != nil {
log.Errorf("failed to enforce permission on config dir: %v", err)
@@ -850,11 +831,7 @@ func DirectWriteOutConfig(path string, config *Config) error {
// DirectUpdateOrCreateConfig is like UpdateOrCreateConfig but uses direct (non-atomic) writes.
// Use this on platforms where atomic writes are blocked (e.g., tvOS sandbox).
func DirectUpdateOrCreateConfig(input ConfigInput) (*Config, error) {
configExists, err := fileExists(input.ConfigPath)
if err != nil {
return nil, fmt.Errorf("failed to check if config file exists: %w", err)
}
if !configExists {
if !fileExists(input.ConfigPath) {
log.Infof("generating new config %s", input.ConfigPath)
cfg, err := createNewConfig(input)
if err != nil {

View File

@@ -256,11 +256,7 @@ func (s *ServiceManager) AddProfile(profileName, username string) error {
}
profPath := filepath.Join(configDir, profileName+".json")
profileExists, err := fileExists(profPath)
if err != nil {
return fmt.Errorf("failed to check if profile exists: %w", err)
}
if profileExists {
if fileExists(profPath) {
return ErrProfileAlreadyExists
}
@@ -289,11 +285,7 @@ func (s *ServiceManager) RemoveProfile(profileName, username string) error {
return fmt.Errorf("cannot remove profile with reserved name: %s", defaultProfileName)
}
profPath := filepath.Join(configDir, profileName+".json")
profileExists, err := fileExists(profPath)
if err != nil {
return fmt.Errorf("failed to check if profile exists: %w", err)
}
if !profileExists {
if !fileExists(profPath) {
return ErrProfileNotFound
}

View File

@@ -20,11 +20,7 @@ func (pm *ProfileManager) GetProfileState(profileName string) (*ProfileState, er
}
stateFile := filepath.Join(configDir, profileName+".state.json")
stateFileExists, err := fileExists(stateFile)
if err != nil {
return nil, fmt.Errorf("failed to check if profile state file exists: %w", err)
}
if !stateFileExists {
if !fileExists(stateFile) {
return nil, errors.New("profile state file does not exist")
}

View File

@@ -263,14 +263,8 @@ func (w *Watcher) watchPeerStatusChanges(ctx context.Context, peerKey string, pe
case <-closer:
return
case routerStates := <-subscription.Events():
select {
case peerStateUpdate <- routerStates:
log.Debugf("triggered route state update for Peer: %s", peerKey)
case <-ctx.Done():
return
case <-closer:
return
}
peerStateUpdate <- routerStates
log.Debugf("triggered route state update for Peer: %s", peerKey)
}
}
}

View File

@@ -351,11 +351,6 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg, logger *log.
logger.Errorf("failed to update domain prefixes: %v", err)
}
// Allow time for route changes to be applied before sending
// the DNS response (relevant on iOS where setTunnelNetworkSettings
// is asynchronous).
waitForRouteSettlement(logger)
d.replaceIPsInDNSResponse(r, newPrefixes, logger)
}
}

View File

@@ -1,20 +0,0 @@
//go:build ios
package dnsinterceptor
import (
"time"
log "github.com/sirupsen/logrus"
)
const routeSettleDelay = 500 * time.Millisecond
// waitForRouteSettlement introduces a short delay on iOS to allow
// setTunnelNetworkSettings to apply route changes before the DNS
// response reaches the application. Without this, the first request
// to a newly resolved domain may bypass the tunnel.
func waitForRouteSettlement(logger *log.Entry) {
logger.Tracef("waiting %v for iOS route settlement", routeSettleDelay)
time.Sleep(routeSettleDelay)
}

View File

@@ -1,12 +0,0 @@
//go:build !ios
package dnsinterceptor
import log "github.com/sirupsen/logrus"
func waitForRouteSettlement(_ *log.Entry) {
// No-op on non-iOS platforms: route changes are applied synchronously by
// the kernel, so no settlement delay is needed before the DNS response
// reaches the application. The delay is only required on iOS where
// setTunnelNetworkSettings applies routes asynchronously.
}

View File

@@ -346,23 +346,6 @@ func (m *DefaultManager) updateSystemRoutes(newRoutes route.HAMap) error {
}
var merr *multierror.Error
// Begin batch mode to avoid calling applyHostConfig() after each DNS handler operation
batchStarted := false
if m.dnsServer != nil {
m.dnsServer.BeginBatch()
batchStarted = true
defer func() {
if merr != nil {
// On error, cancel batch to discard partial DNS state
m.dnsServer.CancelBatch()
} else {
// On success, apply accumulated DNS changes
m.dnsServer.EndBatch()
}
}()
}
for id, handler := range toRemove {
if err := handler.RemoveRoute(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove route %s: %w", handler.String(), err))
@@ -393,7 +376,6 @@ func (m *DefaultManager) updateSystemRoutes(newRoutes route.HAMap) error {
m.activeRoutes[id] = handler
}
_ = batchStarted // Mark as used
return nberrors.FormatErrorOrNil(merr)
}

View File

@@ -1,80 +0,0 @@
package handler
import (
"context"
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal"
)
type Agent interface {
Up(ctx context.Context) error
Down(ctx context.Context) error
Status() (internal.StatusType, error)
}
type SleepHandler struct {
agent Agent
mu sync.Mutex
// sleepTriggeredDown indicates whether the sleep handler triggered the last client down, to avoid unnecessary up on wake
sleepTriggeredDown bool
}
func New(agent Agent) *SleepHandler {
return &SleepHandler{
agent: agent,
}
}
func (s *SleepHandler) HandleWakeUp(ctx context.Context) error {
s.mu.Lock()
defer s.mu.Unlock()
if !s.sleepTriggeredDown {
log.Info("skipping up because wasn't sleep down")
return nil
}
// avoid other wakeup runs if sleep didn't make the computer sleep
s.sleepTriggeredDown = false
log.Info("running up after wake up")
err := s.agent.Up(ctx)
if err != nil {
log.Errorf("running up failed: %v", err)
return err
}
log.Info("running up command executed successfully")
return nil
}
func (s *SleepHandler) HandleSleep(ctx context.Context) error {
s.mu.Lock()
defer s.mu.Unlock()
status, err := s.agent.Status()
if err != nil {
return err
}
if status != internal.StatusConnecting && status != internal.StatusConnected {
log.Infof("skipping setting the agent down because status is %s", status)
return nil
}
log.Info("running down after system started sleeping")
if err = s.agent.Down(ctx); err != nil {
log.Errorf("running down failed: %v", err)
return err
}
s.sleepTriggeredDown = true
log.Info("running down executed successfully")
return nil
}

View File

@@ -1,153 +0,0 @@
package handler
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal"
)
type mockAgent struct {
upErr error
downErr error
statusErr error
status internal.StatusType
upCalls int
}
func (m *mockAgent) Up(_ context.Context) error {
m.upCalls++
return m.upErr
}
func (m *mockAgent) Down(_ context.Context) error {
return m.downErr
}
func (m *mockAgent) Status() (internal.StatusType, error) {
return m.status, m.statusErr
}
func newHandler(status internal.StatusType) (*SleepHandler, *mockAgent) {
agent := &mockAgent{status: status}
return New(agent), agent
}
func TestHandleWakeUp_SkipsWhenFlagFalse(t *testing.T) {
h, agent := newHandler(internal.StatusIdle)
err := h.HandleWakeUp(context.Background())
require.NoError(t, err)
assert.Equal(t, 0, agent.upCalls, "Up should not be called when flag is false")
}
func TestHandleWakeUp_ResetsFlagBeforeUp(t *testing.T) {
h, _ := newHandler(internal.StatusIdle)
h.sleepTriggeredDown = true
// Even if Up fails, flag should be reset
_ = h.HandleWakeUp(context.Background())
assert.False(t, h.sleepTriggeredDown, "flag must be reset before calling Up")
}
func TestHandleWakeUp_CallsUpWhenFlagSet(t *testing.T) {
h, agent := newHandler(internal.StatusIdle)
h.sleepTriggeredDown = true
err := h.HandleWakeUp(context.Background())
require.NoError(t, err)
assert.Equal(t, 1, agent.upCalls)
assert.False(t, h.sleepTriggeredDown)
}
func TestHandleWakeUp_ReturnsErrorFromUp(t *testing.T) {
h, agent := newHandler(internal.StatusIdle)
h.sleepTriggeredDown = true
agent.upErr = errors.New("up failed")
err := h.HandleWakeUp(context.Background())
assert.ErrorIs(t, err, agent.upErr)
assert.False(t, h.sleepTriggeredDown, "flag should still be reset even when Up fails")
}
func TestHandleWakeUp_SecondCallIsNoOp(t *testing.T) {
h, agent := newHandler(internal.StatusIdle)
h.sleepTriggeredDown = true
_ = h.HandleWakeUp(context.Background())
err := h.HandleWakeUp(context.Background())
require.NoError(t, err)
assert.Equal(t, 1, agent.upCalls, "second wakeup should be no-op")
}
func TestHandleSleep_SkipsForNonActiveStates(t *testing.T) {
tests := []struct {
name string
status internal.StatusType
}{
{"Idle", internal.StatusIdle},
{"NeedsLogin", internal.StatusNeedsLogin},
{"LoginFailed", internal.StatusLoginFailed},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h, _ := newHandler(tt.status)
err := h.HandleSleep(context.Background())
require.NoError(t, err)
assert.False(t, h.sleepTriggeredDown)
})
}
}
func TestHandleSleep_ProceedsForActiveStates(t *testing.T) {
tests := []struct {
name string
status internal.StatusType
}{
{"Connecting", internal.StatusConnecting},
{"Connected", internal.StatusConnected},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h, _ := newHandler(tt.status)
err := h.HandleSleep(context.Background())
require.NoError(t, err)
assert.True(t, h.sleepTriggeredDown)
})
}
}
func TestHandleSleep_ReturnsErrorFromStatus(t *testing.T) {
agent := &mockAgent{statusErr: errors.New("status error")}
h := New(agent)
err := h.HandleSleep(context.Background())
assert.ErrorIs(t, err, agent.statusErr)
assert.False(t, h.sleepTriggeredDown)
}
func TestHandleSleep_ReturnsErrorFromDown(t *testing.T) {
agent := &mockAgent{status: internal.StatusConnected, downErr: errors.New("down failed")}
h := New(agent)
err := h.HandleSleep(context.Background())
assert.ErrorIs(t, err, agent.downErr)
assert.False(t, h.sleepTriggeredDown, "flag should not be set when Down fails")
}

View File

@@ -1,4 +1,4 @@
// Package updater provides automatic update management for the NetBird client.
// Package updatemanager provides automatic update management for the NetBird client.
// It monitors for new versions, handles update triggers from management server directives,
// and orchestrates the download and installation of client updates.
//
@@ -32,4 +32,4 @@
//
// This enables verification of successful updates and appropriate user notification
// after the client restarts with the new version.
package updater
package updatemanager

View File

@@ -16,8 +16,8 @@ import (
goversion "github.com/hashicorp/go-version"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/updater/downloader"
"github.com/netbirdio/netbird/client/internal/updater/reposign"
"github.com/netbirdio/netbird/client/internal/updatemanager/downloader"
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
)
type Installer struct {

View File

@@ -203,10 +203,7 @@ func (rh *ResultHandler) write(result Result) error {
func (rh *ResultHandler) cleanup() error {
err := os.Remove(rh.resultFile)
if err != nil {
if os.IsNotExist(err) {
return nil
}
if err != nil && !os.IsNotExist(err) {
return err
}
log.Debugf("delete installer result file: %s", rh.resultFile)

View File

@@ -1,9 +1,12 @@
package updater
//go:build windows || darwin
package updatemanager
import (
"context"
"errors"
"fmt"
"runtime"
"sync"
"time"
@@ -12,7 +15,7 @@ import (
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/internal/updater/installer"
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
cProto "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/version"
)
@@ -38,9 +41,6 @@ type Manager struct {
statusRecorder *peer.Status
stateManager *statemanager.Manager
downloadOnly bool // true when no enforcement from management; notifies UI to download latest
forceUpdate bool // true when management sets AlwaysUpdate; skips UI interaction and installs directly
lastTrigger time.Time
mgmUpdateChan chan struct{}
updateChannel chan struct{}
@@ -53,38 +53,37 @@ type Manager struct {
expectedVersion *v.Version
updateToLatestVersion bool
pendingVersion *v.Version
// updateMutex protects update, expectedVersion, updateToLatestVersion,
// downloadOnly, forceUpdate, pendingVersion, and lastTrigger fields
// updateMutex protect update and expectedVersion fields
updateMutex sync.Mutex
// installMutex and installing guard against concurrent installation attempts
installMutex sync.Mutex
installing bool
// protect to start the service multiple times
mu sync.Mutex
autoUpdateSupported func() bool
triggerUpdateFn func(context.Context, string) error
}
// NewManager creates a new update manager. The manager is single-use: once Stop() is called, it cannot be restarted.
func NewManager(statusRecorder *peer.Status, stateManager *statemanager.Manager) *Manager {
manager := &Manager{
statusRecorder: statusRecorder,
stateManager: stateManager,
mgmUpdateChan: make(chan struct{}, 1),
updateChannel: make(chan struct{}, 1),
currentVersion: version.NetbirdVersion(),
update: version.NewUpdate("nb/client"),
downloadOnly: true,
autoUpdateSupported: isAutoUpdateSupported,
func NewManager(statusRecorder *peer.Status, stateManager *statemanager.Manager) (*Manager, error) {
if runtime.GOOS == "darwin" {
isBrew := !installer.TypeOfInstaller(context.Background()).Downloadable()
if isBrew {
log.Warnf("auto-update disabled on Home Brew installation")
return nil, fmt.Errorf("auto-update not supported on Home Brew installation yet")
}
}
return newManager(statusRecorder, stateManager)
}
func newManager(statusRecorder *peer.Status, stateManager *statemanager.Manager) (*Manager, error) {
manager := &Manager{
statusRecorder: statusRecorder,
stateManager: stateManager,
mgmUpdateChan: make(chan struct{}, 1),
updateChannel: make(chan struct{}, 1),
currentVersion: version.NetbirdVersion(),
update: version.NewUpdate("nb/client"),
}
manager.triggerUpdateFn = manager.triggerUpdate
stateManager.RegisterState(&UpdateState{})
return manager
return manager, nil
}
// CheckUpdateSuccess checks if the update was successful and send a notification.
@@ -125,10 +124,8 @@ func (m *Manager) CheckUpdateSuccess(ctx context.Context) {
}
func (m *Manager) Start(ctx context.Context) {
log.Infof("starting update manager")
m.mu.Lock()
defer m.mu.Unlock()
if m.cancel != nil {
log.Errorf("Manager already started")
return
}
@@ -145,32 +142,13 @@ func (m *Manager) Start(ctx context.Context) {
m.cancel = cancel
m.wg.Add(1)
go func() {
defer m.wg.Done()
m.updateLoop(ctx)
}()
go m.updateLoop(ctx)
}
func (m *Manager) SetDownloadOnly() {
m.updateMutex.Lock()
m.downloadOnly = true
m.forceUpdate = false
m.expectedVersion = nil
m.updateToLatestVersion = false
m.lastTrigger = time.Time{}
m.updateMutex.Unlock()
select {
case m.mgmUpdateChan <- struct{}{}:
default:
}
}
func (m *Manager) SetVersion(expectedVersion string, forceUpdate bool) {
log.Infof("expected version changed to %s, force update: %t", expectedVersion, forceUpdate)
if !m.autoUpdateSupported() {
log.Warnf("auto-update not supported on this platform")
func (m *Manager) SetVersion(expectedVersion string) {
log.Infof("set expected agent version for upgrade: %s", expectedVersion)
if m.cancel == nil {
log.Errorf("manager not started")
return
}
@@ -181,7 +159,6 @@ func (m *Manager) SetVersion(expectedVersion string, forceUpdate bool) {
log.Errorf("empty expected version provided")
m.expectedVersion = nil
m.updateToLatestVersion = false
m.downloadOnly = true
return
}
@@ -201,97 +178,12 @@ func (m *Manager) SetVersion(expectedVersion string, forceUpdate bool) {
m.updateToLatestVersion = false
}
m.lastTrigger = time.Time{}
m.downloadOnly = false
m.forceUpdate = forceUpdate
select {
case m.mgmUpdateChan <- struct{}{}:
default:
}
}
// Install triggers the installation of the pending version. It is called when the user clicks the install button in the UI.
func (m *Manager) Install(ctx context.Context) error {
if !m.autoUpdateSupported() {
return fmt.Errorf("auto-update not supported on this platform")
}
m.updateMutex.Lock()
pending := m.pendingVersion
m.updateMutex.Unlock()
if pending == nil {
return fmt.Errorf("no pending version to install")
}
return m.tryInstall(ctx, pending)
}
// tryInstall ensures only one installation runs at a time. Concurrent callers
// receive an error immediately rather than queuing behind a running install.
func (m *Manager) tryInstall(ctx context.Context, targetVersion *v.Version) error {
m.installMutex.Lock()
if m.installing {
m.installMutex.Unlock()
return fmt.Errorf("installation already in progress")
}
m.installing = true
m.installMutex.Unlock()
defer func() {
m.installMutex.Lock()
m.installing = false
m.installMutex.Unlock()
}()
return m.install(ctx, targetVersion)
}
// NotifyUI re-publishes the current update state to a newly connected UI client.
// Only needed for download-only mode where the latest version is already cached
// NotifyUI re-publishes the current update state so a newly connected UI gets the info.
func (m *Manager) NotifyUI() {
m.updateMutex.Lock()
if m.update == nil {
m.updateMutex.Unlock()
return
}
downloadOnly := m.downloadOnly
pendingVersion := m.pendingVersion
latestVersion := m.update.LatestVersion()
m.updateMutex.Unlock()
if downloadOnly {
if latestVersion == nil {
return
}
currentVersion, err := v.NewVersion(m.currentVersion)
if err != nil || currentVersion.GreaterThanOrEqual(latestVersion) {
return
}
m.statusRecorder.PublishEvent(
cProto.SystemEvent_INFO,
cProto.SystemEvent_SYSTEM,
"New version available",
"",
map[string]string{"new_version_available": latestVersion.String()},
)
return
}
if pendingVersion != nil {
m.statusRecorder.PublishEvent(
cProto.SystemEvent_INFO,
cProto.SystemEvent_SYSTEM,
"New version available",
"",
map[string]string{"new_version_available": pendingVersion.String(), "enforced": "true"},
)
}
}
// Stop is not used at the moment because it fully depends on the daemon. In a future refactor it may make sense to use it.
func (m *Manager) Stop() {
if m.cancel == nil {
return
@@ -322,6 +214,8 @@ func (m *Manager) onContextCancel() {
}
func (m *Manager) updateLoop(ctx context.Context) {
defer m.wg.Done()
for {
select {
case <-ctx.Done():
@@ -345,89 +239,55 @@ func (m *Manager) handleUpdate(ctx context.Context) {
return
}
downloadOnly := m.downloadOnly
forceUpdate := m.forceUpdate
expectedVersion := m.expectedVersion
useLatest := m.updateToLatestVersion
curLatestVersion := m.update.LatestVersion()
m.updateMutex.Unlock()
switch {
// Download-only mode or resolve "latest" to actual version
case downloadOnly, m.updateToLatestVersion:
// Resolve "latest" to actual version
case useLatest:
if curLatestVersion == nil {
log.Tracef("latest version not fetched yet")
m.updateMutex.Unlock()
return
}
updateVersion = curLatestVersion
// Install to specific version
case m.expectedVersion != nil:
updateVersion = m.expectedVersion
// Update to specific version
case expectedVersion != nil:
updateVersion = expectedVersion
default:
log.Debugf("no expected version information set")
m.updateMutex.Unlock()
return
}
log.Debugf("checking update option, current version: %s, target version: %s", m.currentVersion, updateVersion)
if !m.shouldUpdate(updateVersion, forceUpdate) {
m.updateMutex.Unlock()
if !m.shouldUpdate(updateVersion) {
return
}
m.lastTrigger = time.Now()
log.Infof("new version available: %s", updateVersion)
if !downloadOnly && !forceUpdate {
m.pendingVersion = updateVersion
}
m.updateMutex.Unlock()
if downloadOnly {
m.statusRecorder.PublishEvent(
cProto.SystemEvent_INFO,
cProto.SystemEvent_SYSTEM,
"New version available",
"",
map[string]string{"new_version_available": updateVersion.String()},
)
return
}
if forceUpdate {
if err := m.tryInstall(ctx, updateVersion); err != nil {
log.Errorf("force update failed: %v", err)
}
return
}
m.statusRecorder.PublishEvent(
cProto.SystemEvent_INFO,
cProto.SystemEvent_SYSTEM,
"New version available",
"",
map[string]string{"new_version_available": updateVersion.String(), "enforced": "true"},
)
}
func (m *Manager) install(ctx context.Context, pendingVersion *v.Version) error {
log.Infof("Auto-update triggered, current version: %s, target version: %s", m.currentVersion, updateVersion)
m.statusRecorder.PublishEvent(
cProto.SystemEvent_CRITICAL,
cProto.SystemEvent_SYSTEM,
"Updating client",
"Installing update now.",
"Automatically updating client",
"Your client version is older than auto-update version set in Management, updating client now.",
nil,
)
m.statusRecorder.PublishEvent(
cProto.SystemEvent_CRITICAL,
cProto.SystemEvent_SYSTEM,
"",
"",
map[string]string{"progress_window": "show", "version": pendingVersion.String()},
map[string]string{"progress_window": "show", "version": updateVersion.String()},
)
updateState := UpdateState{
PreUpdateVersion: m.currentVersion,
TargetVersion: pendingVersion.String(),
TargetVersion: updateVersion.String(),
}
if err := m.stateManager.UpdateState(updateState); err != nil {
log.Warnf("failed to update state: %v", err)
} else {
@@ -436,9 +296,8 @@ func (m *Manager) install(ctx context.Context, pendingVersion *v.Version) error
}
}
inst := installer.New()
if err := inst.RunInstallation(ctx, pendingVersion.String()); err != nil {
log.Errorf("error triggering update: %v", err)
if err := m.triggerUpdateFn(ctx, updateVersion.String()); err != nil {
log.Errorf("Error triggering auto-update: %v", err)
m.statusRecorder.PublishEvent(
cProto.SystemEvent_ERROR,
cProto.SystemEvent_SYSTEM,
@@ -446,9 +305,7 @@ func (m *Manager) install(ctx context.Context, pendingVersion *v.Version) error
fmt.Sprintf("Auto-update failed: %v", err),
nil,
)
return err
}
return nil
}
// loadAndDeleteUpdateState loads the update state, deletes it from storage, and returns it.
@@ -482,7 +339,7 @@ func (m *Manager) loadAndDeleteUpdateState(ctx context.Context) (*UpdateState, e
return updateState, nil
}
func (m *Manager) shouldUpdate(updateVersion *v.Version, forceUpdate bool) bool {
func (m *Manager) shouldUpdate(updateVersion *v.Version) bool {
if m.currentVersion == developmentVersion {
log.Debugf("skipping auto-update, running development version")
return false
@@ -497,8 +354,8 @@ func (m *Manager) shouldUpdate(updateVersion *v.Version, forceUpdate bool) bool
return false
}
if forceUpdate && time.Since(m.lastTrigger) < 3*time.Minute {
log.Infof("skipping auto-update, last update was %s ago", time.Since(m.lastTrigger))
if time.Since(m.lastTrigger) < 5*time.Minute {
log.Debugf("skipping auto-update, last update was %s ago", time.Since(m.lastTrigger))
return false
}
@@ -510,3 +367,8 @@ func (m *Manager) lastResultErrReason() string {
result := installer.NewResultHandler(inst.TempDir())
return result.GetErrorResultReason()
}
func (m *Manager) triggerUpdate(ctx context.Context, targetVersion string) error {
inst := installer.New()
return inst.RunInstallation(ctx, targetVersion)
}

View File

@@ -0,0 +1,214 @@
//go:build windows || darwin
package updatemanager
import (
"context"
"fmt"
"path"
"testing"
"time"
v "github.com/hashicorp/go-version"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
type versionUpdateMock struct {
latestVersion *v.Version
onUpdate func()
}
func (v versionUpdateMock) StopWatch() {}
func (v versionUpdateMock) SetDaemonVersion(newVersion string) bool {
return false
}
func (v *versionUpdateMock) SetOnUpdateListener(updateFn func()) {
v.onUpdate = updateFn
}
func (v versionUpdateMock) LatestVersion() *v.Version {
return v.latestVersion
}
func (v versionUpdateMock) StartFetcher() {}
func Test_LatestVersion(t *testing.T) {
testMatrix := []struct {
name string
daemonVersion string
initialLatestVersion *v.Version
latestVersion *v.Version
shouldUpdateInit bool
shouldUpdateLater bool
}{
{
name: "Should only trigger update once due to time between triggers being < 5 Minutes",
daemonVersion: "1.0.0",
initialLatestVersion: v.Must(v.NewSemver("1.0.1")),
latestVersion: v.Must(v.NewSemver("1.0.2")),
shouldUpdateInit: true,
shouldUpdateLater: false,
},
{
name: "Shouldn't update initially, but should update as soon as latest version is fetched",
daemonVersion: "1.0.0",
initialLatestVersion: nil,
latestVersion: v.Must(v.NewSemver("1.0.1")),
shouldUpdateInit: false,
shouldUpdateLater: true,
},
}
for idx, c := range testMatrix {
mockUpdate := &versionUpdateMock{latestVersion: c.initialLatestVersion}
tmpFile := path.Join(t.TempDir(), fmt.Sprintf("update-test-%d.json", idx))
m, _ := newManager(peer.NewRecorder(""), statemanager.New(tmpFile))
m.update = mockUpdate
targetVersionChan := make(chan string, 1)
m.triggerUpdateFn = func(ctx context.Context, targetVersion string) error {
targetVersionChan <- targetVersion
return nil
}
m.currentVersion = c.daemonVersion
m.Start(context.Background())
m.SetVersion("latest")
var triggeredInit bool
select {
case targetVersion := <-targetVersionChan:
if targetVersion != c.initialLatestVersion.String() {
t.Errorf("%s: Initial update version mismatch, expected %v, got %v", c.name, c.initialLatestVersion.String(), targetVersion)
}
triggeredInit = true
case <-time.After(10 * time.Millisecond):
triggeredInit = false
}
if triggeredInit != c.shouldUpdateInit {
t.Errorf("%s: Initial update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdateInit, triggeredInit)
}
mockUpdate.latestVersion = c.latestVersion
mockUpdate.onUpdate()
var triggeredLater bool
select {
case targetVersion := <-targetVersionChan:
if targetVersion != c.latestVersion.String() {
t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.latestVersion.String(), targetVersion)
}
triggeredLater = true
case <-time.After(10 * time.Millisecond):
triggeredLater = false
}
if triggeredLater != c.shouldUpdateLater {
t.Errorf("%s: Update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdateLater, triggeredLater)
}
m.Stop()
}
}
func Test_HandleUpdate(t *testing.T) {
testMatrix := []struct {
name string
daemonVersion string
latestVersion *v.Version
expectedVersion string
shouldUpdate bool
}{
{
name: "Update to a specific version should update regardless of if latestVersion is available yet",
daemonVersion: "0.55.0",
latestVersion: nil,
expectedVersion: "0.56.0",
shouldUpdate: true,
},
{
name: "Update to specific version should not update if version matches",
daemonVersion: "0.55.0",
latestVersion: nil,
expectedVersion: "0.55.0",
shouldUpdate: false,
},
{
name: "Update to specific version should not update if current version is newer",
daemonVersion: "0.55.0",
latestVersion: nil,
expectedVersion: "0.54.0",
shouldUpdate: false,
},
{
name: "Update to latest version should update if latest is newer",
daemonVersion: "0.55.0",
latestVersion: v.Must(v.NewSemver("0.56.0")),
expectedVersion: "latest",
shouldUpdate: true,
},
{
name: "Update to latest version should not update if latest == current",
daemonVersion: "0.56.0",
latestVersion: v.Must(v.NewSemver("0.56.0")),
expectedVersion: "latest",
shouldUpdate: false,
},
{
name: "Should not update if daemon version is invalid",
daemonVersion: "development",
latestVersion: v.Must(v.NewSemver("1.0.0")),
expectedVersion: "latest",
shouldUpdate: false,
},
{
name: "Should not update if expecting latest and latest version is unavailable",
daemonVersion: "0.55.0",
latestVersion: nil,
expectedVersion: "latest",
shouldUpdate: false,
},
{
name: "Should not update if expected version is invalid",
daemonVersion: "0.55.0",
latestVersion: nil,
expectedVersion: "development",
shouldUpdate: false,
},
}
for idx, c := range testMatrix {
tmpFile := path.Join(t.TempDir(), fmt.Sprintf("update-test-%d.json", idx))
m, _ := newManager(peer.NewRecorder(""), statemanager.New(tmpFile))
m.update = &versionUpdateMock{latestVersion: c.latestVersion}
targetVersionChan := make(chan string, 1)
m.triggerUpdateFn = func(ctx context.Context, targetVersion string) error {
targetVersionChan <- targetVersion
return nil
}
m.currentVersion = c.daemonVersion
m.Start(context.Background())
m.SetVersion(c.expectedVersion)
var updateTriggered bool
select {
case targetVersion := <-targetVersionChan:
if c.expectedVersion == "latest" && targetVersion != c.latestVersion.String() {
t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.latestVersion.String(), targetVersion)
} else if c.expectedVersion != "latest" && targetVersion != c.expectedVersion {
t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.expectedVersion, targetVersion)
}
updateTriggered = true
case <-time.After(10 * time.Millisecond):
updateTriggered = false
}
if updateTriggered != c.shouldUpdate {
t.Errorf("%s: Update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdate, updateTriggered)
}
m.Stop()
}
}

Some files were not shown because too many files have changed in this diff Show More