Compare commits

..

30 Commits

Author SHA1 Message Date
bcmmbaga
59541b5c20 Merge branch 'main' into update-api-spec 2026-03-25 23:56:33 +03:00
bcmmbaga
75125883df add missing types
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2026-03-25 14:38:17 +03:00
bcmmbaga
d8ca49b6de fix tags
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2026-03-25 14:10:10 +03:00
bcmmbaga
50fc47200a refactor
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2026-03-25 13:59:15 +03:00
bcmmbaga
b75274f7e7 fix casing
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2026-03-25 13:49:58 +03:00
bcmmbaga
fa9c1c3b24 add rest client
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2026-03-25 13:36:16 +03:00
bcmmbaga
86f3e9ec78 Refactor
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2026-03-25 13:20:24 +03:00
Viktor Liu
9aaa05e8ea Replace discontinued LocalStack image with MinIO in S3 test (#5680) 2026-03-25 15:51:29 +08:00
Bethuel Mmbaga
0af5a0441f [management] Fix DNS label uniqueness check on peer rename (#5679) 2026-03-24 20:25:29 +03:00
Viktor Liu
0fc63ea0ba [management] Allow multiple header auths with same header name (#5678) 2026-03-24 16:18:21 +01:00
bcmmbaga
c935639a96 Merge branch 'main' into update-api-spec 2026-03-24 16:17:41 +03:00
Bethuel Mmbaga
0b329f7881 [management] Replace JumpCloud SDK with direct HTTP calls (#5591) 2026-03-24 13:21:42 +03:00
Viktor Liu
5b85edb753 [management] Omit proxy_protocol from API response when false (#5656)
The internal Target model uses a plain bool for ProxyProtocol,
which was always serialized to the API response as false even
when not configured. Only set the API field when true so it
gets omitted via omitempty when unset.
2026-03-23 17:53:17 +01:00
Maycon Santos
17cfa5fe1e [misc] Set signing env only if not fork and set license (#5659)
* Add condition to GPG key decoding to handle pull requests

* Add license field to deb and rpm package configurations

* Add condition to GPG key decoding for external pull requests
2026-03-23 17:16:23 +01:00
Viktor Liu
2313494e0e [client] Don't abort debug for command when up/down fails (#5657) 2026-03-23 14:04:03 +01:00
Viktor Liu
fd9d430334 [client] Simplify entrypoint by running netbird up unconditionally (#5652) 2026-03-23 09:39:32 +01:00
Zoltan Papp
91f0d5cefd [client] Feature/client metrics (#5512)
* Add client metrics

* Add client metrics system with OpenTelemetry and VictoriaMetrics support

Implements a comprehensive client metrics system to track peer connection
stages and performance. The system supports multiple backend implementations
(OpenTelemetry, VictoriaMetrics, and no-op) and tracks detailed connection
stage durations from creation through WireGuard handshake.

Key changes:
- Add metrics package with pluggable backend implementations
- Implement OpenTelemetry metrics backend
- Implement VictoriaMetrics metrics backend
- Add no-op metrics implementation for disabled state
- Track connection stages: creation, semaphore, signaling, connection ready, and WireGuard handshake
- Move WireGuard watcher functionality to conn.go
- Refactor engine to integrate metrics tracking
- Add metrics export endpoint in debug server

* Add signaling metrics tracking for initial and reconnection attempts

* Reset connection stage timestamps during reconnections to exclude unnecessary metrics tracking

* Delete otel lib from client

* Update unit tests

* Invoke callback on handshake success in WireGuard watcher

* Add Netbird version tracking to client metrics

Integrate Netbird version into VictoriaMetrics backend and metrics labels. Update `ClientMetrics` constructor and metric name formatting to include version information.

* Add sync duration tracking to client metrics

Introduce `RecordSyncDuration` for measuring sync message processing time. Update all metrics implementations (VictoriaMetrics, no-op) to support the new method. Refactor `ClientMetrics` to use `AgentInfo` for static agent data.

* Remove no-op metrics implementation and simplify ClientMetrics constructor

Eliminate unused `noopMetrics` and refactor `ClientMetrics` to always use the VictoriaMetrics implementation. Update associated logic to reflect these changes.

* Add total duration tracking for connection attempts

Calculate total duration for both initial connections and reconnections, accounting for different timestamp scenarios. Update `Export` method to include Prometheus HELP comments.

* Add metrics push support to VictoriaMetrics integration

* [client] anchor connection metrics to first signal received

* Remove creation_to_semaphore connection stage metric

The semaphore queuing stage (Created → SemaphoreAcquired) is no longer
tracked. Connection metrics now start from SignalingReceived. Updated
docs and Grafana dashboard accordingly.

* [client] Add remote push config for metrics with version-based eligibility

Introduce remoteconfig.Manager that fetches a remote JSON config to control
metrics push interval and restrict pushing to a specific agent version
range. When NB_METRICS_INTERVAL is set, remote config is bypassed
entirely for local override.

* [client] Add WASM-compatible NewClientMetrics implementation

Replace NewClientMetrics in metrics.go with a WASM-specific stub in metrics_js.go, returning nil for compatibility with JS builds. Simplify method usage for WASM targets.

* Add missing file

* Update default case in DeploymentType.String to return "unknown" instead of "selfhosted"

* [client] Rework metrics to use timestamped samples instead of histograms

Replace cumulative Prometheus histograms with timestamped point-in-time
samples that are pushed once and cleared. This fixes metrics for sparse
events (connections/syncs that happen once at startup) where rate() and
increase() produced incorrect or empty results.

Changes:
- Switch from VictoriaMetrics histogram library to raw Prometheus text
  format with explicit millisecond timestamps
- Reset samples after successful push (no resending stale data)
- Rename connection_to_handshake → connection_to_wg_handshake
- Add netbird_peer_connection_count metric for ICE vs Relay tracking
- Simplify dashboard: point-based scatter plots, donut pie chart
- Add maxStalenessInterval=1m to VictoriaMetrics to prevent forward-fill
- Fix deployment_type Unknown returning "selfhosted" instead of "unknown"
- Fix inverted shouldPush condition in push.go

* [client] Add InfluxDB metrics backend alongside VictoriaMetrics

Add influxdb.go with timestamped line protocol export for sparse
one-shot events. Restore victoria.go to use proper Prometheus
histograms. Update Grafana dashboards, add InfluxDB datasource,
and update docs.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* [client] Fix metrics issues and update dev docker setup

- Fix StopPush not clearing push state, preventing restart
- Fix race condition reading currentConnPriority without lock in recordConnectionMetrics
- Fix stale comment referencing old metrics server URL
- Update docker-compose for InfluxDB: add scoped tokens, .env config, init scripts
- Rename docker-compose.victoria.yml to docker-compose.yml

* [client] Add anonymised peer tracking to pushed metrics

Introduce peer_id and connection_pair_id tags to InfluxDB metrics.
Public keys are hashed (truncated SHA-256) for anonymisation. The
connection pair ID is deterministic regardless of which side computes
it, enabling deduplication of reconnections in the ICE vs Relay
dashboard. Also pin Grafana to v11.6.0 for file-based provisioning
and fix datasource UID references.

* Remove unused dependencies from go.mod and go.sum

* Refactor InfluxDB ingest pipeline: extract validation logic

- Move line validation logic to `validateLine` and `validateField` helper functions.
- Improve error handling with structured validation and clearer separation of concerns.
- Add stderr redirection for error messages in `create-tokens.sh`.

* Set non-root user in Dockerfile for Ingest service

* Fix Windows CI: command line too long

* Remove Victoria metrics

* Add hashed peer ID as Authorization header in metrics push

* Revert influxdb in docker compose

* Enable gzip compression and authorization validation for metrics push and ingest

* Reducate code of complexity

* Update debug documentation to include metrics.txt description

* Increase `maxBodySize` limit to 50 MB and update gzip reader wrapping logic

* Refactor deployment type detection to use URL parsing for improved accuracy

* Update readme

* Throttle remote config retries on fetch failure

* Preserve first WG handshake timestamp, ignore rekeys

* Skip adding empty metrics.txt to debug bundle in debug mode

* Update default metrics server URL to https://ingest.netbird.io

* Atomic metrics export-and-reset to prevent sample loss between Export and Reset calls

* Fix doc

* Refactor Push configuration to improve clarity and enforce minimum push interval

* Remove `minPushInterval` and update push interval validation logic

* Revert ExportAndReset, it is acceptable data loss

* Fix metrics review issues: rename env var, remove stale infra, add tests

- Rename NB_METRICS_ENABLED to NB_METRICS_PUSH_ENABLED to clarify that
  collection is always active (for debug bundles) and only push is opt-in
- Change default config URL from staging to production (ingest.netbird.io)
- Delete broken Prometheus dashboard (used non-existent metric names)
- Delete unused VictoriaMetrics datasource config
- Replace committed .env with .env.example containing placeholder values
- Wire Grafana admin credentials through env vars in docker-compose
- Make metricsStages a pointer to prevent reset-vs-write race on reconnect
- Fix typed-nil interface in debug bundle path (GetClientMetrics)
- Use deterministic field order in InfluxDB Export (sorted keys)
- Replace Authorization header with X-Peer-ID for metrics push
- Fix ingest server timeout to use time.Second instead of float
- Fix gzip double-close, stale comments, trim log levels
- Add tests for influxdb.go and MetricsStages

* Add login duration metric, ingest tag validation, and duration bounds

- Add netbird_login measurement recording login/auth duration to management
  server, with success/failure result tag
- Validate InfluxDB tags against per-measurement allowlists in ingest server
  to prevent arbitrary tag injection
- Cap all duration fields (*_seconds) at 300s instead of only total_seconds
- Add ingest server tests for tag/field validation, bounds, and auth

* Add arch tag to all metrics

* Fix Grafana dashboard: add arch to drop columns, add login panels

* Validate NB_METRICS_SERVER_URL is an absolute HTTP(S) URL

* Address review comments: fix README wording, update stale comments

* Clarify env var precedence does not bypass remote config eligibility

* Remove accidentally committed pprof files

---------

Co-authored-by: Viktor Liu <viktor@netbird.io>
2026-03-22 12:45:41 +01:00
Viktor Liu
82762280ee [client] Add health check flag to status command and expose daemon status in output (#5650) 2026-03-22 12:39:40 +01:00
Viktor Liu
b550a2face [management, proxy] Add require_subdomain capability for proxy clusters (#5628) 2026-03-20 11:29:50 +01:00
Viktor Liu
ab77508950 [client] Add env var for management gRPC max receive message size (#5622) 2026-03-19 17:33:50 +01:00
Viktor Liu
b9462f5c6b [client] Make raw table initialization non-fatal in firewall managers (#5621) 2026-03-19 17:33:38 +01:00
Viktor Liu
5ffaa5cdd6 [client] Fix duplicate log lines in containers (#5609) 2026-03-19 15:53:05 +01:00
Pascal Fischer
a1858a9cb7 [management] recover proxies after cleanup if heartbeat is still running (#5617) 2026-03-18 11:48:38 +01:00
Viktor Liu
212b34f639 [management] Add GET /reverse-proxies/clusters endpoint (#5611) 2026-03-18 11:15:56 +08:00
Viktor Liu
af8eaa23e2 [client] Restart engine when peer IP address changes (#5614) 2026-03-17 17:00:24 +01:00
bcmmbaga
afe383fa2d add missing edr and idp sync endpoints
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2026-03-17 18:47:29 +03:00
Viktor Liu
f0eed50678 [management] Accept domain target type for L4 reverse proxy services (#5612) 2026-03-17 16:29:03 +01:00
Wouter van Os
19d94c6158 [client] Allow setting DNSLabels on client embed (#5493) 2026-03-17 16:12:37 +01:00
Viktor Liu
628eb56073 [client] Update go-m1cpu to v0.2.0 to fix SIGSEGV on macOS Tahoe (#5613) 2026-03-17 16:10:38 +01:00
eason
a590c38d8b [client] Fix IPv6 address formatting in DNS address construction (#5603)
Replace fmt.Sprintf("%s:%d", ip, port) with net.JoinHostPort() to
properly handle IPv6 addresses that need bracket wrapping (e.g.,
[2606:4700:4700::1111]:53 instead of 2606:4700:4700::1111:53).

Without this fix, configuring IPv6 nameservers causes "too many colons
in address" errors because Go's net.Dial cannot parse the malformed
address string.

Fixes #5601
Related to #4074

Co-authored-by: easonysliu <easonysliu@tencent.com>
2026-03-17 06:27:47 +01:00
122 changed files with 7774 additions and 2296 deletions

View File

@@ -63,10 +63,15 @@ jobs:
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }}
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }}
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' } | Where-Object { $_ -notmatch '/proxy' } | Where-Object { $_ -notmatch '/combined' })" >> $env:GITHUB_ENV
- name: Generate test script
run: |
$packages = go list ./... | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' } | Where-Object { $_ -notmatch '/proxy' } | Where-Object { $_ -notmatch '/combined' }
$goExe = "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe"
$cmd = "$goExe test -tags=devcert -timeout 10m -p 1 $($packages -join ' ') > test-out.txt 2>&1"
Set-Content -Path "${{ github.workspace }}\run-tests.cmd" -Value $cmd
- name: test
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "${{ github.workspace }}\run-tests.cmd"
- name: test output
if: ${{ always() }}
run: Get-Content test-out.txt

View File

@@ -170,6 +170,7 @@ jobs:
run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu
- name: Decode GPG signing key
if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository
env:
GPG_RPM_PRIVATE_KEY: ${{ secrets.GPG_RPM_PRIVATE_KEY }}
run: |
@@ -309,6 +310,7 @@ jobs:
run: sudo apt update && sudo apt install -y -q libappindicator3-dev gir1.2-appindicator3-0.1 libxxf86vm-dev gcc-mingw-w64-x86-64
- name: Decode GPG signing key
if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository
env:
GPG_RPM_PRIVATE_KEY: ${{ secrets.GPG_RPM_PRIVATE_KEY }}
run: |

View File

@@ -171,6 +171,7 @@ nfpms:
- maintainer: Netbird <dev@netbird.io>
description: Netbird client.
homepage: https://netbird.io/
license: BSD-3-Clause
id: netbird_deb
bindir: /usr/bin
builds:
@@ -184,6 +185,7 @@ nfpms:
- maintainer: Netbird <dev@netbird.io>
description: Netbird client.
homepage: https://netbird.io/
license: BSD-3-Clause
id: netbird_rpm
bindir: /usr/bin
builds:

View File

@@ -17,8 +17,7 @@ ENV \
NETBIRD_BIN="/usr/local/bin/netbird" \
NB_LOG_FILE="console,/var/log/netbird/client.log" \
NB_DAEMON_ADDR="unix:///var/run/netbird.sock" \
NB_ENTRYPOINT_SERVICE_TIMEOUT="5" \
NB_ENTRYPOINT_LOGIN_TIMEOUT="5"
NB_ENTRYPOINT_SERVICE_TIMEOUT="30"
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]

View File

@@ -23,8 +23,7 @@ ENV \
NB_DAEMON_ADDR="unix:///var/lib/netbird/netbird.sock" \
NB_LOG_FILE="console,/var/lib/netbird/client.log" \
NB_DISABLE_DNS="true" \
NB_ENTRYPOINT_SERVICE_TIMEOUT="5" \
NB_ENTRYPOINT_LOGIN_TIMEOUT="1"
NB_ENTRYPOINT_SERVICE_TIMEOUT="30"
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]

View File

@@ -181,10 +181,11 @@ func runForDuration(cmd *cobra.Command, args []string) error {
if stateWasDown {
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
cmd.PrintErrf("Failed to bring service up: %v\n", status.Convert(err).Message())
} else {
cmd.Println("netbird up")
time.Sleep(time.Second * 10)
}
cmd.Println("netbird up")
time.Sleep(time.Second * 10)
}
initialLevelTrace := initialLogLevel.GetLevel() >= proto.LogLevel_TRACE
@@ -199,9 +200,10 @@ func runForDuration(cmd *cobra.Command, args []string) error {
}
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
cmd.PrintErrf("Failed to bring service down: %v\n", status.Convert(err).Message())
} else {
cmd.Println("netbird down")
}
cmd.Println("netbird down")
time.Sleep(1 * time.Second)
@@ -209,13 +211,14 @@ func runForDuration(cmd *cobra.Command, args []string) error {
if _, err := client.SetSyncResponsePersistence(cmd.Context(), &proto.SetSyncResponsePersistenceRequest{
Enabled: true,
}); err != nil {
return fmt.Errorf("failed to enable sync response persistence: %v", status.Convert(err).Message())
cmd.PrintErrf("Failed to enable sync response persistence: %v\n", status.Convert(err).Message())
}
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
cmd.PrintErrf("Failed to bring service up: %v\n", status.Convert(err).Message())
} else {
cmd.Println("netbird up")
}
cmd.Println("netbird up")
time.Sleep(3 * time.Second)
@@ -263,16 +266,18 @@ func runForDuration(cmd *cobra.Command, args []string) error {
if stateWasDown {
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
cmd.PrintErrf("Failed to restore service down state: %v\n", status.Convert(err).Message())
} else {
cmd.Println("netbird down")
}
cmd.Println("netbird down")
}
if !initialLevelTrace {
if _, err := client.SetLogLevel(cmd.Context(), &proto.SetLogLevelRequest{Level: initialLogLevel.GetLevel()}); err != nil {
return fmt.Errorf("failed to restore log level: %v", status.Convert(err).Message())
cmd.PrintErrf("Failed to restore log level: %v\n", status.Convert(err).Message())
} else {
cmd.Println("Log level restored to", initialLogLevel.GetLevel())
}
cmd.Println("Log level restored to", initialLogLevel.GetLevel())
}
cmd.Printf("Local file:\n%s\n", resp.GetPath())

View File

@@ -103,7 +103,7 @@ func (p *program) Stop(srv service.Service) error {
// Common setup for service control commands
func setupServiceControlCommand(cmd *cobra.Command, ctx context.Context, cancel context.CancelFunc) (service.Service, error) {
SetFlagsFromEnvVars(rootCmd)
// rootCmd env vars are already applied by PersistentPreRunE.
SetFlagsFromEnvVars(serviceCmd)
cmd.SetOut(cmd.OutOrStdout())

View File

@@ -28,6 +28,7 @@ var (
ipsFilterMap map[string]struct{}
prefixNamesFilterMap map[string]struct{}
connectionTypeFilter string
checkFlag string
)
var statusCmd = &cobra.Command{
@@ -49,6 +50,7 @@ func init() {
statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected")
statusCmd.PersistentFlags().StringVar(&connectionTypeFilter, "filter-by-connection-type", "", "filters the detailed output by connection type (P2P|Relayed), e.g., --filter-by-connection-type P2P")
statusCmd.PersistentFlags().StringVar(&checkFlag, "check", "", "run a health check and exit with code 0 on success, 1 on failure (live|ready|startup)")
}
func statusFunc(cmd *cobra.Command, args []string) error {
@@ -56,6 +58,10 @@ func statusFunc(cmd *cobra.Command, args []string) error {
cmd.SetOut(cmd.OutOrStdout())
if checkFlag != "" {
return runHealthCheck(cmd)
}
err := parseFilters()
if err != nil {
return err
@@ -68,15 +74,17 @@ func statusFunc(cmd *cobra.Command, args []string) error {
ctx := internal.CtxInitState(cmd.Context())
resp, err := getStatus(ctx, false)
resp, err := getStatus(ctx, true, false)
if err != nil {
return err
}
status := resp.GetStatus()
if status == string(internal.StatusNeedsLogin) || status == string(internal.StatusLoginFailed) ||
status == string(internal.StatusSessionExpired) {
needsAuth := status == string(internal.StatusNeedsLogin) || status == string(internal.StatusLoginFailed) ||
status == string(internal.StatusSessionExpired)
if needsAuth && !jsonFlag && !yamlFlag {
cmd.Printf("Daemon status: %s\n\n"+
"Run UP command to log in with SSO (interactive login):\n\n"+
" netbird up \n\n"+
@@ -99,7 +107,17 @@ func statusFunc(cmd *cobra.Command, args []string) error {
profName = activeProf.Name
}
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp.GetFullStatus(), anonymizeFlag, resp.GetDaemonVersion(), statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap, connectionTypeFilter, profName)
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp.GetFullStatus(), nbstatus.ConvertOptions{
Anonymize: anonymizeFlag,
DaemonVersion: resp.GetDaemonVersion(),
DaemonStatus: nbstatus.ParseDaemonStatus(status),
StatusFilter: statusFilter,
PrefixNamesFilter: prefixNamesFilter,
PrefixNamesFilterMap: prefixNamesFilterMap,
IPsFilter: ipsFilterMap,
ConnectionTypeFilter: connectionTypeFilter,
ProfileName: profName,
})
var statusOutputString string
switch {
case detailFlag:
@@ -121,7 +139,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
return nil
}
func getStatus(ctx context.Context, shouldRunProbes bool) (*proto.StatusResponse, error) {
func getStatus(ctx context.Context, fullPeerStatus bool, shouldRunProbes bool) (*proto.StatusResponse, error) {
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
//nolint
@@ -131,7 +149,7 @@ func getStatus(ctx context.Context, shouldRunProbes bool) (*proto.StatusResponse
}
defer conn.Close()
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: shouldRunProbes})
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: fullPeerStatus, ShouldRunProbes: shouldRunProbes})
if err != nil {
return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message())
}
@@ -185,6 +203,83 @@ func enableDetailFlagWhenFilterFlag() {
}
}
func runHealthCheck(cmd *cobra.Command) error {
check := strings.ToLower(checkFlag)
switch check {
case "live", "ready", "startup":
default:
return fmt.Errorf("unknown check %q, must be one of: live, ready, startup", checkFlag)
}
if err := util.InitLog(logLevel, util.LogConsole); err != nil {
return fmt.Errorf("init log: %w", err)
}
ctx := internal.CtxInitState(cmd.Context())
isStartup := check == "startup"
resp, err := getStatus(ctx, isStartup, false)
if err != nil {
return err
}
switch check {
case "live":
return nil
case "ready":
return checkReadiness(resp)
case "startup":
return checkStartup(resp)
default:
return nil
}
}
func checkReadiness(resp *proto.StatusResponse) error {
daemonStatus := internal.StatusType(resp.GetStatus())
switch daemonStatus {
case internal.StatusIdle, internal.StatusConnecting, internal.StatusConnected:
return nil
case internal.StatusNeedsLogin, internal.StatusLoginFailed, internal.StatusSessionExpired:
return fmt.Errorf("readiness check: daemon status is %s", daemonStatus)
default:
return fmt.Errorf("readiness check: unexpected daemon status %q", daemonStatus)
}
}
func checkStartup(resp *proto.StatusResponse) error {
fullStatus := resp.GetFullStatus()
if fullStatus == nil {
return fmt.Errorf("startup check: no full status available")
}
if !fullStatus.GetManagementState().GetConnected() {
return fmt.Errorf("startup check: management not connected")
}
if !fullStatus.GetSignalState().GetConnected() {
return fmt.Errorf("startup check: signal not connected")
}
var relayCount, relaysConnected int
for _, r := range fullStatus.GetRelays() {
uri := r.GetURI()
if !strings.HasPrefix(uri, "rel://") && !strings.HasPrefix(uri, "rels://") {
continue
}
relayCount++
if r.GetAvailable() {
relaysConnected++
}
}
if relayCount > 0 && relaysConnected == 0 {
return fmt.Errorf("startup check: no relay servers available (0/%d connected)", relayCount)
}
return nil
}
func parseInterfaceIP(interfaceIP string) string {
ip, _, err := net.ParseCIDR(interfaceIP)
if err != nil {

View File

@@ -22,6 +22,7 @@ import (
"github.com/netbirdio/netbird/client/internal/profilemanager"
sshcommon "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/shared/management/domain"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
)
@@ -88,6 +89,8 @@ type Options struct {
// If nil, the existing config MTU (if non-zero) is preserved; otherwise it defaults to 1280.
// Set to a higher value (e.g. 1400) if carrying QUIC or other protocols that require larger datagrams.
MTU *uint16
// DNSLabels defines additional DNS labels configured in the peer.
DNSLabels []string
}
// validateCredentials checks that exactly one credential type is provided
@@ -153,9 +156,14 @@ func New(opts Options) (*Client, error) {
}
}
var err error
var parsedLabels domain.List
if parsedLabels, err = domain.FromStringList(opts.DNSLabels); err != nil {
return nil, fmt.Errorf("invalid dns labels: %w", err)
}
t := true
var config *profilemanager.Config
var err error
input := profilemanager.ConfigInput{
ConfigPath: opts.ConfigPath,
ManagementURL: opts.ManagementURL,
@@ -165,6 +173,7 @@ func New(opts Options) (*Client, error) {
BlockInbound: &opts.BlockInbound,
WireguardPort: opts.WireguardPort,
MTU: opts.MTU,
DNSLabels: parsedLabels,
}
if opts.ConfigPath != "" {
config, err = profilemanager.UpdateOrCreateConfig(input)

View File

@@ -23,9 +23,10 @@ type Manager struct {
wgIface iFaceMapper
ipv4Client *iptables.IPTables
aclMgr *aclManager
router *router
ipv4Client *iptables.IPTables
aclMgr *aclManager
router *router
rawSupported bool
}
// iFaceMapper defines subset methods of interface required for manager
@@ -84,7 +85,7 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
}
if err := m.initNoTrackChain(); err != nil {
return fmt.Errorf("init notrack chain: %w", err)
log.Warnf("raw table not available, notrack rules will be disabled: %v", err)
}
// persist early to ensure cleanup of chains
@@ -318,6 +319,10 @@ func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if !m.rawSupported {
return fmt.Errorf("raw table not available")
}
wgPortStr := fmt.Sprintf("%d", wgPort)
proxyPortStr := fmt.Sprintf("%d", proxyPort)
@@ -375,12 +380,16 @@ func (m *Manager) initNoTrackChain() error {
return fmt.Errorf("add prerouting jump rule: %w", err)
}
m.rawSupported = true
return nil
}
func (m *Manager) cleanupNoTrackChain() error {
exists, err := m.ipv4Client.ChainExists(tableRaw, chainNameRaw)
if err != nil {
if !m.rawSupported {
return nil
}
return fmt.Errorf("check chain exists: %w", err)
}
if !exists {
@@ -401,6 +410,7 @@ func (m *Manager) cleanupNoTrackChain() error {
return fmt.Errorf("clear and delete chain: %w", err)
}
m.rawSupported = false
return nil
}

View File

@@ -95,7 +95,7 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
}
if err := m.initNoTrackChains(workTable); err != nil {
return fmt.Errorf("init notrack chains: %w", err)
log.Warnf("raw priority chains not available, notrack rules will be disabled: %v", err)
}
stateManager.RegisterState(&ShutdownState{})

View File

@@ -28,7 +28,7 @@ func Backoff(ctx context.Context) backoff.BackOff {
// CreateConnection creates a gRPC client connection with the appropriate transport options.
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) {
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string, extraOpts ...grpc.DialOption) (*grpc.ClientConn, error) {
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
// for js, the outer websocket layer takes care of tls
if tlsEnabled && runtime.GOOS != "js" {
@@ -46,9 +46,7 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone
connCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
conn, err := grpc.DialContext(
connCtx,
addr,
opts := []grpc.DialOption{
transportOption,
WithCustomDialer(tlsEnabled, component),
grpc.WithBlock(),
@@ -56,7 +54,10 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone
Time: 30 * time.Second,
Timeout: 10 * time.Second,
}),
)
}
opts = append(opts, extraOpts...)
conn, err := grpc.DialContext(connCtx, addr, opts...)
if err != nil {
return nil, fmt.Errorf("dial context: %w", err)
}

View File

@@ -23,6 +23,7 @@ import (
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/metrics"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/statemanager"
@@ -50,6 +51,7 @@ type ConnectClient struct {
engine *Engine
engineMutex sync.Mutex
clientMetrics *metrics.ClientMetrics
updateManager *updater.Manager
persistSyncResponse bool
@@ -133,10 +135,34 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
}
}()
// Stop metrics push on exit
defer func() {
if c.clientMetrics != nil {
c.clientMetrics.StopPush()
}
}()
log.Infof("starting NetBird client version %s on %s/%s", version.NetbirdVersion(), runtime.GOOS, runtime.GOARCH)
nbnet.Init()
// Initialize metrics once at startup (always active for debug bundles)
if c.clientMetrics == nil {
agentInfo := metrics.AgentInfo{
DeploymentType: metrics.DeploymentTypeUnknown,
Version: version.NetbirdVersion(),
OS: runtime.GOOS,
Arch: runtime.GOARCH,
}
c.clientMetrics = metrics.NewClientMetrics(agentInfo)
log.Debugf("initialized client metrics")
// Start metrics push if enabled (uses daemon context, persists across engine restarts)
if metrics.IsMetricsPushEnabled() {
c.clientMetrics.StartPush(c.ctx, metrics.PushConfigFromEnv())
}
}
backOff := &backoff.ExponentialBackOff{
InitialInterval: time.Second,
RandomizationFactor: 1,
@@ -223,6 +249,16 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
mgmNotifier := statusRecorderToMgmConnStateNotifier(c.statusRecorder)
mgmClient.SetConnStateListener(mgmNotifier)
// Update metrics with actual deployment type after connection
deploymentType := metrics.DetermineDeploymentType(mgmClient.GetServerURL())
agentInfo := metrics.AgentInfo{
DeploymentType: deploymentType,
Version: version.NetbirdVersion(),
OS: runtime.GOOS,
Arch: runtime.GOARCH,
}
c.clientMetrics.UpdateAgentInfo(agentInfo, myPrivateKey.PublicKey().String())
log.Debugf("connected to the Management service %s", c.config.ManagementURL.Host)
defer func() {
if err = mgmClient.Close(); err != nil {
@@ -231,8 +267,10 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
}()
// connect (just a connection, no stream yet) and login to Management Service to get an initial global Netbird config
loginStarted := time.Now()
loginResp, err := loginToManagement(engineCtx, mgmClient, publicSSHKey, c.config)
if err != nil {
c.clientMetrics.RecordLoginDuration(engineCtx, time.Since(loginStarted), false)
log.Debug(err)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
state.Set(StatusNeedsLogin)
@@ -241,6 +279,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
}
return wrapErr(err)
}
c.clientMetrics.RecordLoginDuration(engineCtx, time.Since(loginStarted), true)
c.statusRecorder.MarkManagementConnected()
localPeerState := peer.LocalPeerState{
@@ -317,6 +356,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
Checks: checks,
StateManager: stateManager,
UpdateManager: c.updateManager,
ClientMetrics: c.clientMetrics,
}, mobileDependency)
engine.SetSyncResponsePersistence(c.persistSyncResponse)
c.engine = engine

View File

@@ -31,7 +31,6 @@ import (
nbstatus "github.com/netbirdio/netbird/client/status"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/version"
)
const readmeContent = `Netbird debug bundle
@@ -53,6 +52,7 @@ resolved_domains.txt: Anonymized resolved domain IP addresses from the status re
config.txt: Anonymized configuration information of the NetBird client.
network_map.json: Anonymized sync response containing peer configurations, routes, DNS settings, and firewall rules.
state.json: Anonymized client state dump containing netbird states for the active profile.
metrics.txt: Buffered client metrics in InfluxDB line protocol format. Only present when metrics collection is enabled. Peer identifiers are anonymized.
mutex.prof: Mutex profiling information.
goroutine.prof: Goroutine profiling information.
block.prof: Block profiling information.
@@ -219,6 +219,11 @@ const (
darwinStdoutLogPath = "/var/log/netbird.err.log"
)
// MetricsExporter is an interface for exporting metrics
type MetricsExporter interface {
Export(w io.Writer) error
}
type BundleGenerator struct {
anonymizer *anonymize.Anonymizer
@@ -229,6 +234,7 @@ type BundleGenerator struct {
logPath string
cpuProfile []byte
refreshStatus func() // Optional callback to refresh status before bundle generation
clientMetrics MetricsExporter
anonymize bool
includeSystemInfo bool
@@ -250,6 +256,7 @@ type GeneratorDependencies struct {
LogPath string
CPUProfile []byte
RefreshStatus func() // Optional callback to refresh status before bundle generation
ClientMetrics MetricsExporter
}
func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGenerator {
@@ -268,6 +275,7 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
logPath: deps.LogPath,
cpuProfile: deps.CPUProfile,
refreshStatus: deps.RefreshStatus,
clientMetrics: deps.ClientMetrics,
anonymize: cfg.Anonymize,
includeSystemInfo: cfg.IncludeSystemInfo,
@@ -351,6 +359,10 @@ func (g *BundleGenerator) createArchive() error {
log.Errorf("failed to add corrupted state files to debug bundle: %v", err)
}
if err := g.addMetrics(); err != nil {
log.Errorf("failed to add metrics to debug bundle: %v", err)
}
if err := g.addWgShow(); err != nil {
log.Errorf("failed to add wg show output: %v", err)
}
@@ -418,7 +430,10 @@ func (g *BundleGenerator) addStatus() error {
fullStatus := g.statusRecorder.GetFullStatus()
protoFullStatus := nbstatus.ToProtoFullStatus(fullStatus)
protoFullStatus.Events = g.statusRecorder.GetEventHistory()
overview := nbstatus.ConvertToStatusOutputOverview(protoFullStatus, g.anonymize, version.NetbirdVersion(), "", nil, nil, nil, "", profName)
overview := nbstatus.ConvertToStatusOutputOverview(protoFullStatus, nbstatus.ConvertOptions{
Anonymize: g.anonymize,
ProfileName: profName,
})
statusOutput := overview.FullDetailSummary()
statusReader := strings.NewReader(statusOutput)
@@ -744,6 +759,30 @@ func (g *BundleGenerator) addCorruptedStateFiles() error {
return nil
}
func (g *BundleGenerator) addMetrics() error {
if g.clientMetrics == nil {
log.Debugf("skipping metrics in debug bundle: no metrics collector")
return nil
}
var buf bytes.Buffer
if err := g.clientMetrics.Export(&buf); err != nil {
return fmt.Errorf("export metrics: %w", err)
}
if buf.Len() == 0 {
log.Debugf("skipping metrics.txt in debug bundle: no metrics data")
return nil
}
if err := g.addFileToZip(&buf, "metrics.txt"); err != nil {
return fmt.Errorf("add metrics file to zip: %w", err)
}
log.Debugf("added metrics to debug bundle")
return nil
}
func (g *BundleGenerator) addLogfile() error {
if g.logPath == "" {
log.Debugf("skipping empty log file in debug bundle")

View File

@@ -1263,9 +1263,9 @@ func TestLocalResolver_AuthoritativeFlag(t *testing.T) {
})
}
// TestLocalResolver_Stop tests cleanup on GracefullyStop
// TestLocalResolver_Stop tests cleanup on Stop
func TestLocalResolver_Stop(t *testing.T) {
t.Run("GracefullyStop clears all state", func(t *testing.T) {
t.Run("Stop clears all state", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.CustomZone{{
Domain: "example.com.",
@@ -1285,7 +1285,7 @@ func TestLocalResolver_Stop(t *testing.T) {
assert.False(t, resolver.isInManagedZone("host.example.com."))
})
t.Run("GracefullyStop is safe to call multiple times", func(t *testing.T) {
t.Run("Stop is safe to call multiple times", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.CustomZone{{
Domain: "example.com.",
@@ -1299,7 +1299,7 @@ func TestLocalResolver_Stop(t *testing.T) {
resolver.Stop()
})
t.Run("GracefullyStop cancels in-flight external resolution", func(t *testing.T) {
t.Run("Stop cancels in-flight external resolution", func(t *testing.T) {
resolver := NewResolver()
lookupStarted := make(chan struct{})

View File

@@ -6,6 +6,7 @@ import (
"net"
"net/netip"
"runtime"
"strconv"
"sync"
"time"
@@ -69,7 +70,7 @@ func (s *serviceViaListener) Listen() error {
return fmt.Errorf("eval listen address: %w", err)
}
s.listenIP = s.listenIP.Unmap()
s.server.Addr = fmt.Sprintf("%s:%d", s.listenIP, s.listenPort)
s.server.Addr = net.JoinHostPort(s.listenIP.String(), strconv.Itoa(int(s.listenPort)))
log.Debugf("starting dns on %s", s.server.Addr)
go func() {
s.setListenerStatus(true)
@@ -186,7 +187,7 @@ func (s *serviceViaListener) testFreePort(port int) (netip.Addr, bool) {
}
func (s *serviceViaListener) tryToBind(ip netip.Addr, port int) bool {
addrString := fmt.Sprintf("%s:%d", ip, port)
addrString := net.JoinHostPort(ip.String(), strconv.Itoa(port))
udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString))
probeListener, err := net.ListenUDP("udp", udpAddr)
if err != nil {

View File

@@ -38,6 +38,7 @@ import (
"github.com/netbirdio/netbird/client/internal/dnsfwd"
"github.com/netbirdio/netbird/client/internal/expose"
"github.com/netbirdio/netbird/client/internal/ingressgw"
"github.com/netbirdio/netbird/client/internal/metrics"
"github.com/netbirdio/netbird/client/internal/netflow"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/networkmonitor"
@@ -45,7 +46,6 @@ import (
"github.com/netbirdio/netbird/client/internal/peer/guard"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/portforward"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/client/internal/rosenpass"
@@ -150,6 +150,7 @@ type EngineServices struct {
Checks []*mgmProto.Checks
StateManager *statemanager.Manager
UpdateManager *updater.Manager
ClientMetrics *metrics.ClientMetrics
}
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
@@ -209,10 +210,9 @@ type Engine struct {
// checks are the client-applied posture checks that need to be evaluated on the client
checks []*mgmProto.Checks
relayManager *relayClient.Manager
stateManager *statemanager.Manager
portForwardManager *portforward.Manager
srWatcher *guard.SRWatcher
relayManager *relayClient.Manager
stateManager *statemanager.Manager
srWatcher *guard.SRWatcher
// Sync response persistence (protected by syncRespMux)
syncRespMux sync.RWMutex
@@ -231,6 +231,9 @@ type Engine struct {
probeStunTurn *relay.StunTurnProbe
// clientMetrics collects and pushes metrics
clientMetrics *metrics.ClientMetrics
jobExecutor *jobexec.Executor
jobExecutorWG sync.WaitGroup
@@ -256,25 +259,26 @@ func NewEngine(
mobileDep MobileDependency,
) *Engine {
engine := &Engine{
clientCtx: clientCtx,
clientCancel: clientCancel,
signal: services.SignalClient,
signaler: peer.NewSignaler(services.SignalClient, config.WgPrivateKey),
mgmClient: services.MgmClient,
relayManager: services.RelayManager,
peerStore: peerstore.NewConnStore(),
syncMsgMux: &sync.Mutex{},
config: config,
mobileDep: mobileDep,
STUNs: []*stun.URI{},
TURNs: []*stun.URI{},
networkSerial: 0,
statusRecorder: services.StatusRecorder,
stateManager: services.StateManager,
portForwardManager: portforward.NewManager(),
checks: services.Checks,
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
jobExecutor: jobexec.NewExecutor(), updateManager: services.UpdateManager,
clientCtx: clientCtx,
clientCancel: clientCancel,
signal: services.SignalClient,
signaler: peer.NewSignaler(services.SignalClient, config.WgPrivateKey),
mgmClient: services.MgmClient,
relayManager: services.RelayManager,
peerStore: peerstore.NewConnStore(),
syncMsgMux: &sync.Mutex{},
config: config,
mobileDep: mobileDep,
STUNs: []*stun.URI{},
TURNs: []*stun.URI{},
networkSerial: 0,
statusRecorder: services.StatusRecorder,
stateManager: services.StateManager,
checks: services.Checks,
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
jobExecutor: jobexec.NewExecutor(),
clientMetrics: services.ClientMetrics,
updateManager: services.UpdateManager,
}
log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String())
@@ -517,13 +521,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
// conntrack entries from being created before the rules are in place
e.setupWGProxyNoTrack()
// Start after interface is up since port may have been resolved from 0 or changed if occupied
e.shutdownWg.Add(1)
go func() {
defer e.shutdownWg.Done()
e.portForwardManager.Start(e.ctx, uint16(e.config.WgPort))
}()
// Set the WireGuard interface for rosenpass after interface is up
if e.rpManager != nil {
e.rpManager.SetInterface(e.wgInterface)
@@ -822,7 +819,9 @@ func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdate
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
started := time.Now()
defer func() {
log.Infof("sync finished in %s", time.Since(started))
duration := time.Since(started)
log.Infof("sync finished in %s", duration)
e.clientMetrics.RecordSyncDuration(e.ctx, duration)
}()
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
@@ -998,10 +997,11 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
return errors.New("wireguard interface is not initialized")
}
// Cannot update the IP address without restarting the engine because
// the firewall, route manager, and other components cache the old address
if e.wgInterface.Address().String() != conf.Address {
log.Infof("peer IP address has changed from %s to %s", e.wgInterface.Address().String(), conf.Address)
log.Infof("peer IP address changed from %s to %s, restarting client", e.wgInterface.Address().String(), conf.Address)
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
e.clientCancel()
return ErrResetConnection
}
if conf.GetSshConfig() != nil {
@@ -1069,6 +1069,7 @@ func (e *Engine) handleBundle(params *mgmProto.BundleParameters) (*mgmProto.JobR
StatusRecorder: e.statusRecorder,
SyncResponse: syncResponse,
LogPath: e.config.LogPath,
ClientMetrics: e.clientMetrics,
RefreshStatus: func() {
e.RunHealthProbes(true)
},
@@ -1523,12 +1524,12 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
}
serviceDependencies := peer.ServiceDependencies{
StatusRecorder: e.statusRecorder,
Signaler: e.signaler,
IFaceDiscover: e.mobileDep.IFaceDiscover,
RelayManager: e.relayManager,
SrWatcher: e.srWatcher,
PortForwardManager: e.portForwardManager,
StatusRecorder: e.statusRecorder,
Signaler: e.signaler,
IFaceDiscover: e.mobileDep.IFaceDiscover,
RelayManager: e.relayManager,
SrWatcher: e.srWatcher,
MetricsRecorder: e.clientMetrics,
}
peerConn, err := peer.NewConn(config, serviceDependencies)
if err != nil {
@@ -1685,12 +1686,6 @@ func (e *Engine) close() {
if e.rpManager != nil {
_ = e.rpManager.Close()
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := e.portForwardManager.GracefullyStop(ctx); err != nil {
log.Warnf("failed to gracefully stop port forwarding manager: %s", err)
}
}
func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, error) {
@@ -1831,6 +1826,11 @@ func (e *Engine) GetExposeManager() *expose.Manager {
return e.exposeManager
}
// GetClientMetrics returns the client metrics
func (e *Engine) GetClientMetrics() *metrics.ClientMetrics {
return e.clientMetrics
}
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
iface, err := net.InterfaceByName(ifaceName)
if err != nil {

View File

@@ -828,7 +828,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
WgPrivateKey: key,
WgPort: 33100,
MTU: iface.DefaultMTU,
}, EngineServices{
}, EngineServices{
SignalClient: &signal.MockClient{},
MgmClient: &mgmt.MockClient{},
RelayManager: relayMgr,
@@ -1035,7 +1035,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
WgPrivateKey: key,
WgPort: 33100,
MTU: iface.DefaultMTU,
}, EngineServices{
}, EngineServices{
SignalClient: &signal.MockClient{},
MgmClient: &mgmt.MockClient{},
RelayManager: relayMgr,
@@ -1566,7 +1566,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
}
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
e, err := NewEngine(ctx, cancel, conf, EngineServices{
e, err := NewEngine(ctx, cancel, conf, EngineServices{
SignalClient: signalClient,
MgmClient: mgmtClient,
RelayManager: relayMgr,

View File

@@ -0,0 +1,17 @@
package metrics
// ConnectionType represents the type of peer connection
type ConnectionType string
const (
// ConnectionTypeICE represents a direct peer-to-peer connection using ICE
ConnectionTypeICE ConnectionType = "ice"
// ConnectionTypeRelay represents a relayed connection
ConnectionTypeRelay ConnectionType = "relay"
)
// String returns the string representation of the connection type
func (c ConnectionType) String() string {
return string(c)
}

View File

@@ -0,0 +1,51 @@
package metrics
import (
"net/url"
"strings"
)
// DeploymentType represents the type of NetBird deployment
type DeploymentType int
const (
// DeploymentTypeUnknown represents an unknown or uninitialized deployment type
DeploymentTypeUnknown DeploymentType = iota
// DeploymentTypeCloud represents a cloud-hosted NetBird deployment
DeploymentTypeCloud
// DeploymentTypeSelfHosted represents a self-hosted NetBird deployment
DeploymentTypeSelfHosted
)
// String returns the string representation of the deployment type
func (d DeploymentType) String() string {
switch d {
case DeploymentTypeCloud:
return "cloud"
case DeploymentTypeSelfHosted:
return "selfhosted"
default:
return "unknown"
}
}
// DetermineDeploymentType determines if the deployment is cloud or self-hosted
// based on the management URL string
func DetermineDeploymentType(managementURL string) DeploymentType {
if managementURL == "" {
return DeploymentTypeUnknown
}
u, err := url.Parse(managementURL)
if err != nil {
return DeploymentTypeSelfHosted
}
if strings.ToLower(u.Hostname()) == "api.netbird.io" {
return DeploymentTypeCloud
}
return DeploymentTypeSelfHosted
}

View File

@@ -0,0 +1,93 @@
package metrics
import (
"net/url"
"os"
"strconv"
"time"
log "github.com/sirupsen/logrus"
)
const (
// EnvMetricsPushEnabled controls whether collected metrics are pushed to the backend.
// Metrics collection itself is always active (for debug bundles).
// Disabled by default. Set NB_METRICS_PUSH_ENABLED=true to enable push.
EnvMetricsPushEnabled = "NB_METRICS_PUSH_ENABLED"
// EnvMetricsForceSending if set to true, skips remote configuration fetch and forces metric sending
EnvMetricsForceSending = "NB_METRICS_FORCE_SENDING"
// EnvMetricsConfigURL is the environment variable to override the metrics push config ServerAddress
EnvMetricsConfigURL = "NB_METRICS_CONFIG_URL"
// EnvMetricsServerURL is the environment variable to override the metrics server address.
// When set, this takes precedence over the server_url from remote push config.
EnvMetricsServerURL = "NB_METRICS_SERVER_URL"
// EnvMetricsInterval overrides the push interval from the remote config.
// Only affects how often metrics are pushed; remote config availability
// and version range checks are still respected.
// Format: duration string like "1h", "30m", "4h"
EnvMetricsInterval = "NB_METRICS_INTERVAL"
defaultMetricsConfigURL = "https://ingest.netbird.io/config"
)
// IsMetricsPushEnabled returns true if metrics push is enabled via NB_METRICS_PUSH_ENABLED env var.
// Disabled by default. Metrics collection is always active for debug bundles.
func IsMetricsPushEnabled() bool {
enabled, _ := strconv.ParseBool(os.Getenv(EnvMetricsPushEnabled))
return enabled
}
// getMetricsInterval returns the metrics push interval from NB_METRICS_INTERVAL env var.
// Returns 0 if not set or invalid.
func getMetricsInterval() time.Duration {
intervalStr := os.Getenv(EnvMetricsInterval)
if intervalStr == "" {
return 0
}
interval, err := time.ParseDuration(intervalStr)
if err != nil {
log.Warnf("invalid metrics interval from env %q: %v", intervalStr, err)
return 0
}
if interval <= 0 {
log.Warnf("invalid metrics interval from env %q: must be positive", intervalStr)
return 0
}
return interval
}
func isForceSending() bool {
force, _ := strconv.ParseBool(os.Getenv(EnvMetricsForceSending))
return force
}
// getMetricsConfigURL returns the URL to fetch push configuration from
func getMetricsConfigURL() string {
if envURL := os.Getenv(EnvMetricsConfigURL); envURL != "" {
return envURL
}
return defaultMetricsConfigURL
}
// getMetricsServerURL returns the metrics server URL from NB_METRICS_SERVER_URL env var.
// Returns nil if not set or invalid.
func getMetricsServerURL() *url.URL {
envURL := os.Getenv(EnvMetricsServerURL)
if envURL == "" {
return nil
}
parsed, err := url.ParseRequestURI(envURL)
if err != nil || parsed.Host == "" {
log.Warnf("invalid metrics server URL %q: must be an absolute HTTP(S) URL", envURL)
return nil
}
if parsed.Scheme != "http" && parsed.Scheme != "https" {
log.Warnf("invalid metrics server URL %q: unsupported scheme %q", envURL, parsed.Scheme)
return nil
}
return parsed
}

View File

@@ -0,0 +1,219 @@
package metrics
import (
"context"
"fmt"
"io"
"maps"
"slices"
"sync"
"time"
log "github.com/sirupsen/logrus"
)
const (
maxSampleAge = 5 * 24 * time.Hour // drop samples older than 5 days
maxBufferSize = 5 * 1024 * 1024 // drop oldest samples when estimated size exceeds 5 MB
// estimatedSampleSize is a rough per-sample memory estimate (measurement + tags + fields + timestamp)
estimatedSampleSize = 256
)
// influxSample is a single InfluxDB line protocol entry.
type influxSample struct {
measurement string
tags string
fields map[string]float64
timestamp time.Time
}
// influxDBMetrics collects metric events as timestamped samples.
// Each event is recorded with its exact timestamp, pushed once, then cleared.
type influxDBMetrics struct {
mu sync.Mutex
samples []influxSample
}
func newInfluxDBMetrics() metricsImplementation {
return &influxDBMetrics{}
}
func (m *influxDBMetrics) RecordConnectionStages(
_ context.Context,
agentInfo AgentInfo,
connectionPairID string,
connectionType ConnectionType,
isReconnection bool,
timestamps ConnectionStageTimestamps,
) {
var signalingReceivedToConnection, connectionToWgHandshake, totalDuration float64
if !timestamps.SignalingReceived.IsZero() && !timestamps.ConnectionReady.IsZero() {
signalingReceivedToConnection = timestamps.ConnectionReady.Sub(timestamps.SignalingReceived).Seconds()
}
if !timestamps.ConnectionReady.IsZero() && !timestamps.WgHandshakeSuccess.IsZero() {
connectionToWgHandshake = timestamps.WgHandshakeSuccess.Sub(timestamps.ConnectionReady).Seconds()
}
if !timestamps.SignalingReceived.IsZero() && !timestamps.WgHandshakeSuccess.IsZero() {
totalDuration = timestamps.WgHandshakeSuccess.Sub(timestamps.SignalingReceived).Seconds()
}
attemptType := "initial"
if isReconnection {
attemptType = "reconnection"
}
connTypeStr := connectionType.String()
tags := fmt.Sprintf("deployment_type=%s,connection_type=%s,attempt_type=%s,version=%s,os=%s,arch=%s,peer_id=%s,connection_pair_id=%s",
agentInfo.DeploymentType.String(),
connTypeStr,
attemptType,
agentInfo.Version,
agentInfo.OS,
agentInfo.Arch,
agentInfo.peerID,
connectionPairID,
)
now := time.Now()
m.mu.Lock()
defer m.mu.Unlock()
m.samples = append(m.samples, influxSample{
measurement: "netbird_peer_connection",
tags: tags,
fields: map[string]float64{
"signaling_to_connection_seconds": signalingReceivedToConnection,
"connection_to_wg_handshake_seconds": connectionToWgHandshake,
"total_seconds": totalDuration,
},
timestamp: now,
})
m.trimLocked()
log.Tracef("peer connection metrics [%s, %s, %s]: signalingReceived→connection: %.3fs, connection→wg_handshake: %.3fs, total: %.3fs",
agentInfo.DeploymentType.String(), connTypeStr, attemptType, signalingReceivedToConnection, connectionToWgHandshake, totalDuration)
}
func (m *influxDBMetrics) RecordSyncDuration(_ context.Context, agentInfo AgentInfo, duration time.Duration) {
tags := fmt.Sprintf("deployment_type=%s,version=%s,os=%s,arch=%s,peer_id=%s",
agentInfo.DeploymentType.String(),
agentInfo.Version,
agentInfo.OS,
agentInfo.Arch,
agentInfo.peerID,
)
m.mu.Lock()
defer m.mu.Unlock()
m.samples = append(m.samples, influxSample{
measurement: "netbird_sync",
tags: tags,
fields: map[string]float64{
"duration_seconds": duration.Seconds(),
},
timestamp: time.Now(),
})
m.trimLocked()
}
func (m *influxDBMetrics) RecordLoginDuration(_ context.Context, agentInfo AgentInfo, duration time.Duration, success bool) {
result := "success"
if !success {
result = "failure"
}
tags := fmt.Sprintf("deployment_type=%s,result=%s,version=%s,os=%s,arch=%s,peer_id=%s",
agentInfo.DeploymentType.String(),
result,
agentInfo.Version,
agentInfo.OS,
agentInfo.Arch,
agentInfo.peerID,
)
m.mu.Lock()
defer m.mu.Unlock()
m.samples = append(m.samples, influxSample{
measurement: "netbird_login",
tags: tags,
fields: map[string]float64{
"duration_seconds": duration.Seconds(),
},
timestamp: time.Now(),
})
m.trimLocked()
log.Tracef("login metrics [%s, %s]: duration=%.3fs", agentInfo.DeploymentType.String(), result, duration.Seconds())
}
// Export writes pending samples in InfluxDB line protocol format.
// Format: measurement,tag=val,tag=val field=val,field=val timestamp_ns
func (m *influxDBMetrics) Export(w io.Writer) error {
m.mu.Lock()
samples := make([]influxSample, len(m.samples))
copy(samples, m.samples)
m.mu.Unlock()
for _, s := range samples {
if _, err := fmt.Fprintf(w, "%s,%s ", s.measurement, s.tags); err != nil {
return err
}
sortedKeys := slices.Sorted(maps.Keys(s.fields))
first := true
for _, k := range sortedKeys {
if !first {
if _, err := fmt.Fprint(w, ","); err != nil {
return err
}
}
if _, err := fmt.Fprintf(w, "%s=%g", k, s.fields[k]); err != nil {
return err
}
first = false
}
if _, err := fmt.Fprintf(w, " %d\n", s.timestamp.UnixNano()); err != nil {
return err
}
}
return nil
}
// Reset clears pending samples after a successful push
func (m *influxDBMetrics) Reset() {
m.mu.Lock()
defer m.mu.Unlock()
m.samples = m.samples[:0]
}
// trimLocked removes samples that exceed age or size limits.
// Must be called with m.mu held.
func (m *influxDBMetrics) trimLocked() {
now := time.Now()
// drop samples older than maxSampleAge
cutoff := 0
for cutoff < len(m.samples) && now.Sub(m.samples[cutoff].timestamp) > maxSampleAge {
cutoff++
}
if cutoff > 0 {
copy(m.samples, m.samples[cutoff:])
m.samples = m.samples[:len(m.samples)-cutoff]
log.Debugf("influxdb metrics: dropped %d samples older than %s", cutoff, maxSampleAge)
}
// drop oldest samples if estimated size exceeds maxBufferSize
maxSamples := maxBufferSize / estimatedSampleSize
if len(m.samples) > maxSamples {
drop := len(m.samples) - maxSamples
copy(m.samples, m.samples[drop:])
m.samples = m.samples[:maxSamples]
log.Debugf("influxdb metrics: dropped %d oldest samples to stay under %d MB size limit", drop, maxBufferSize/(1024*1024))
}
}

View File

@@ -0,0 +1,229 @@
package metrics
import (
"bytes"
"context"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestInfluxDBMetrics_RecordAndExport(t *testing.T) {
m := newInfluxDBMetrics().(*influxDBMetrics)
agentInfo := AgentInfo{
DeploymentType: DeploymentTypeCloud,
Version: "1.0.0",
OS: "linux",
Arch: "amd64",
peerID: "abc123",
}
ts := ConnectionStageTimestamps{
SignalingReceived: time.Now().Add(-3 * time.Second),
ConnectionReady: time.Now().Add(-2 * time.Second),
WgHandshakeSuccess: time.Now().Add(-1 * time.Second),
}
m.RecordConnectionStages(context.Background(), agentInfo, "pair123", ConnectionTypeICE, false, ts)
var buf bytes.Buffer
err := m.Export(&buf)
require.NoError(t, err)
output := buf.String()
assert.Contains(t, output, "netbird_peer_connection,")
assert.Contains(t, output, "connection_to_wg_handshake_seconds=")
assert.Contains(t, output, "signaling_to_connection_seconds=")
assert.Contains(t, output, "total_seconds=")
}
func TestInfluxDBMetrics_ExportDeterministicFieldOrder(t *testing.T) {
m := newInfluxDBMetrics().(*influxDBMetrics)
agentInfo := AgentInfo{
DeploymentType: DeploymentTypeCloud,
Version: "1.0.0",
OS: "linux",
Arch: "amd64",
peerID: "abc123",
}
ts := ConnectionStageTimestamps{
SignalingReceived: time.Now().Add(-3 * time.Second),
ConnectionReady: time.Now().Add(-2 * time.Second),
WgHandshakeSuccess: time.Now().Add(-1 * time.Second),
}
// Record multiple times and verify consistent field order
for i := 0; i < 10; i++ {
m.RecordConnectionStages(context.Background(), agentInfo, "pair123", ConnectionTypeICE, false, ts)
}
var buf bytes.Buffer
err := m.Export(&buf)
require.NoError(t, err)
lines := strings.Split(strings.TrimSpace(buf.String()), "\n")
require.Len(t, lines, 10)
// Extract field portion from each line and verify they're all identical
var fieldSections []string
for _, line := range lines {
parts := strings.SplitN(line, " ", 3)
require.Len(t, parts, 3, "each line should have measurement, fields, timestamp")
fieldSections = append(fieldSections, parts[1])
}
for i := 1; i < len(fieldSections); i++ {
assert.Equal(t, fieldSections[0], fieldSections[i], "field order should be deterministic across samples")
}
// Fields should be alphabetically sorted
assert.True(t, strings.HasPrefix(fieldSections[0], "connection_to_wg_handshake_seconds="),
"fields should be sorted: connection_to_wg < signaling_to < total")
}
func TestInfluxDBMetrics_RecordSyncDuration(t *testing.T) {
m := newInfluxDBMetrics().(*influxDBMetrics)
agentInfo := AgentInfo{
DeploymentType: DeploymentTypeSelfHosted,
Version: "2.0.0",
OS: "darwin",
Arch: "arm64",
peerID: "def456",
}
m.RecordSyncDuration(context.Background(), agentInfo, 1500*time.Millisecond)
var buf bytes.Buffer
err := m.Export(&buf)
require.NoError(t, err)
output := buf.String()
assert.Contains(t, output, "netbird_sync,")
assert.Contains(t, output, "duration_seconds=1.5")
assert.Contains(t, output, "deployment_type=selfhosted")
}
func TestInfluxDBMetrics_Reset(t *testing.T) {
m := newInfluxDBMetrics().(*influxDBMetrics)
agentInfo := AgentInfo{
DeploymentType: DeploymentTypeCloud,
Version: "1.0.0",
OS: "linux",
Arch: "amd64",
peerID: "abc123",
}
m.RecordSyncDuration(context.Background(), agentInfo, time.Second)
var buf bytes.Buffer
err := m.Export(&buf)
require.NoError(t, err)
assert.NotEmpty(t, buf.String())
m.Reset()
buf.Reset()
err = m.Export(&buf)
require.NoError(t, err)
assert.Empty(t, buf.String(), "should be empty after reset")
}
func TestInfluxDBMetrics_ExportEmpty(t *testing.T) {
m := newInfluxDBMetrics().(*influxDBMetrics)
var buf bytes.Buffer
err := m.Export(&buf)
require.NoError(t, err)
assert.Empty(t, buf.String())
}
func TestInfluxDBMetrics_TrimByAge(t *testing.T) {
m := newInfluxDBMetrics().(*influxDBMetrics)
m.mu.Lock()
m.samples = append(m.samples, influxSample{
measurement: "old",
tags: "t=1",
fields: map[string]float64{"v": 1},
timestamp: time.Now().Add(-maxSampleAge - time.Hour),
})
m.trimLocked()
remaining := len(m.samples)
m.mu.Unlock()
assert.Equal(t, 0, remaining, "old samples should be trimmed")
}
func TestInfluxDBMetrics_RecordLoginDuration(t *testing.T) {
m := newInfluxDBMetrics().(*influxDBMetrics)
agentInfo := AgentInfo{
DeploymentType: DeploymentTypeCloud,
Version: "1.0.0",
OS: "linux",
Arch: "amd64",
peerID: "abc123",
}
m.RecordLoginDuration(context.Background(), agentInfo, 2500*time.Millisecond, true)
var buf bytes.Buffer
err := m.Export(&buf)
require.NoError(t, err)
output := buf.String()
assert.Contains(t, output, "netbird_login,")
assert.Contains(t, output, "duration_seconds=2.5")
assert.Contains(t, output, "result=success")
}
func TestInfluxDBMetrics_RecordLoginDurationFailure(t *testing.T) {
m := newInfluxDBMetrics().(*influxDBMetrics)
agentInfo := AgentInfo{
DeploymentType: DeploymentTypeSelfHosted,
Version: "1.0.0",
OS: "darwin",
Arch: "arm64",
peerID: "xyz789",
}
m.RecordLoginDuration(context.Background(), agentInfo, 5*time.Second, false)
var buf bytes.Buffer
err := m.Export(&buf)
require.NoError(t, err)
output := buf.String()
assert.Contains(t, output, "netbird_login,")
assert.Contains(t, output, "result=failure")
assert.Contains(t, output, "deployment_type=selfhosted")
}
func TestInfluxDBMetrics_TrimBySize(t *testing.T) {
m := newInfluxDBMetrics().(*influxDBMetrics)
maxSamples := maxBufferSize / estimatedSampleSize
m.mu.Lock()
for i := 0; i < maxSamples+100; i++ {
m.samples = append(m.samples, influxSample{
measurement: "test",
tags: "t=1",
fields: map[string]float64{"v": float64(i)},
timestamp: time.Now(),
})
}
m.trimLocked()
remaining := len(m.samples)
m.mu.Unlock()
assert.Equal(t, maxSamples, remaining, "should trim to max samples")
}

View File

@@ -0,0 +1,16 @@
# Copy to .env and adjust values before running docker compose
# InfluxDB admin (server-side only, never exposed to clients)
INFLUXDB_ADMIN_PASSWORD=changeme
INFLUXDB_ADMIN_TOKEN=changeme
# Grafana admin credentials
GRAFANA_ADMIN_USER=admin
GRAFANA_ADMIN_PASSWORD=changeme
# Remote config served by ingest at /config
# Set CONFIG_METRICS_SERVER_URL to the ingest server's public address to enable
CONFIG_METRICS_SERVER_URL=
CONFIG_VERSION_SINCE=0.0.0
CONFIG_VERSION_UNTIL=99.99.99
CONFIG_PERIOD_MINUTES=5

View File

@@ -0,0 +1 @@
.env

View File

@@ -0,0 +1,194 @@
# Client Metrics
Internal documentation for the NetBird client metrics system.
## Overview
Client metrics track connection performance and sync durations using InfluxDB line protocol (`influxdb.go`). Each event is pushed once then cleared.
Metrics collection is always active (for debug bundles). Push to backend is:
- Disabled by default (opt-in via `NB_METRICS_PUSH_ENABLED=true`)
- Managed at daemon layer (survives engine restarts)
## Architecture
### Layer Separation
```text
Daemon Layer (connect.go)
├─ Creates ClientMetrics instance once
├─ Starts/stops push lifecycle
└─ Updates AgentInfo on profile switch
Engine Layer (engine.go)
└─ Records metrics via ClientMetrics methods
```
### Ingest Server
Clients do not talk to InfluxDB directly. An ingest server sits between clients and InfluxDB:
```text
Client ──POST──▶ Ingest Server (:8087) ──▶ InfluxDB (internal)
├─ Validates line protocol
├─ Allowlists measurements, fields, and tags
├─ Rejects out-of-bound values
└─ Serves remote config at /config
```
- **No secret/token-based client auth** — the ingest server holds the InfluxDB token server-side. Clients must send a hashed peer ID via `X-Peer-ID` header.
- **InfluxDB is not exposed** — only accessible within the docker network
- Source: `ingest/main.go`
## Metrics Collected
### Connection Stage Timing
Measurement: `netbird_peer_connection`
| Field | Timestamps | Description |
|-------|-----------|-------------|
| `signaling_to_connection_seconds` | `SignalingReceived → ConnectionReady` | ICE/relay negotiation time after the first signal is received from the remote peer |
| `connection_to_wg_handshake_seconds` | `ConnectionReady → WgHandshakeSuccess` | WireGuard cryptographic handshake latency once the transport layer is ready |
| `total_seconds` | `SignalingReceived → WgHandshakeSuccess` | End-to-end connection time anchored at the first received signal |
Tags:
- `deployment_type`: "cloud" | "selfhosted" | "unknown"
- `connection_type`: "ice" | "relay"
- `attempt_type`: "initial" | "reconnection"
- `version`: NetBird version string
- `os`: Operating system (linux, darwin, windows, android, ios, etc.)
- `arch`: CPU architecture (amd64, arm64, etc.)
**Note:** `SignalingReceived` is set when the first offer or answer arrives from the remote peer (in both initial and reconnection paths). It excludes the potentially unbounded wait for the remote peer to come online.
### Sync Duration
Measurement: `netbird_sync`
| Field | Description |
|-------|-------------|
| `duration_seconds` | Time to process a sync message from management server |
Tags:
- `deployment_type`: "cloud" | "selfhosted" | "unknown"
- `version`: NetBird version string
- `os`: Operating system (linux, darwin, windows, android, ios, etc.)
- `arch`: CPU architecture (amd64, arm64, etc.)
### Login Duration
Measurement: `netbird_login`
| Field | Description |
|-------|-------------|
| `duration_seconds` | Time to complete the login/auth exchange with management server |
Tags:
- `deployment_type`: "cloud" | "selfhosted" | "unknown"
- `result`: "success" | "failure"
- `version`: NetBird version string
- `os`: Operating system (linux, darwin, windows, android, ios, etc.)
- `arch`: CPU architecture (amd64, arm64, etc.)
## Buffer Limits
The InfluxDB backend limits in-memory sample storage to prevent unbounded growth when pushes fail:
- **Max age:** Samples older than 5 days are dropped
- **Max size:** Estimated buffer size capped at 5 MB (~20k samples)
## Configuration
### Client Environment Variables
| Variable | Default | Description |
|----------|---------|-------------|
| `NB_METRICS_PUSH_ENABLED` | `false` | Enable metrics push to backend |
| `NB_METRICS_SERVER_URL` | *(from remote config)* | Ingest server URL (e.g., `https://ingest.netbird.io`) |
| `NB_METRICS_INTERVAL` | *(from remote config)* | Push interval (e.g., "1m", "30m", "4h") |
| `NB_METRICS_FORCE_SENDING` | `false` | Skip remote config, push unconditionally |
| `NB_METRICS_CONFIG_URL` | `https://ingest.netbird.io/config` | Remote push config URL |
`NB_METRICS_SERVER_URL` and `NB_METRICS_INTERVAL` override their respective values but do not bypass remote config eligibility checks (version range). Use `NB_METRICS_FORCE_SENDING=true` to skip all remote config gating.
### Ingest Server Environment Variables
| Variable | Default | Description |
|----------|---------|-------------|
| `INGEST_LISTEN_ADDR` | `:8087` | Listen address |
| `INFLUXDB_URL` | `http://influxdb:8086/api/v2/write?org=netbird&bucket=metrics&precision=ns` | InfluxDB write endpoint |
| `INFLUXDB_TOKEN` | *(required)* | InfluxDB auth token (server-side only) |
| `CONFIG_METRICS_SERVER_URL` | *(empty — disables /config)* | `server_url` in the remote config JSON (the URL clients push metrics to) |
| `CONFIG_VERSION_SINCE` | `0.0.0` | Minimum client version to push metrics |
| `CONFIG_VERSION_UNTIL` | `99.99.99` | Maximum client version to push metrics |
| `CONFIG_PERIOD_MINUTES` | `5` | Push interval in minutes |
The ingest server serves a remote config JSON at `GET /config` when `CONFIG_METRICS_SERVER_URL` is set. Clients can use `NB_METRICS_CONFIG_URL=http://<ingest>/config` to fetch it.
### Configuration Precedence
For URL and Interval, the precedence is:
1. **Environment variable** - `NB_METRICS_SERVER_URL` / `NB_METRICS_INTERVAL`
2. **Remote config** - fetched from `NB_METRICS_CONFIG_URL`
3. **Default** - 5 minute interval, URL from remote config
## Push Behavior
1. `StartPush()` spawns background goroutine with timer
2. First push happens immediately on startup
3. Periodically: `push()``Export()` → HTTP POST to ingest server
4. On failure: log error, continue (non-blocking)
5. On success: `Reset()` clears pushed samples
6. `StopPush()` cancels context and waits for goroutine
Samples are collected with exact timestamps, pushed once, then cleared. No data is resent.
## Local Development Setup
### 1. Configure and Start Services
```bash
# From this directory (client/internal/metrics/infra)
cp .env.example .env
# Edit .env to set INFLUXDB_ADMIN_PASSWORD, INFLUXDB_ADMIN_TOKEN, and GRAFANA_ADMIN_PASSWORD
docker compose up -d
```
This starts:
- **Ingest server** on http://localhost:8087 — accepts client metrics (requires `X-Peer-ID` header, no secret/token auth)
- **InfluxDB** — internal only, not exposed to host
- **Grafana** on http://localhost:3001
### 2. Configure Client
```bash
export NB_METRICS_PUSH_ENABLED=true
export NB_METRICS_FORCE_SENDING=true
export NB_METRICS_SERVER_URL=http://localhost:8087
export NB_METRICS_INTERVAL=1m
```
### 3. Run Client
```bash
cd ../../../..
go run ./client/ up
```
### 4. View in Grafana
- **InfluxDB dashboard:** http://localhost:3001/d/netbird-influxdb-metrics
### 5. Verify Data
```bash
# Query via InfluxDB (using admin token from .env)
docker compose exec influxdb influx query \
'from(bucket: "metrics") |> range(start: -1h)' \
--org netbird
# Check ingest server health
curl http://localhost:8087/health
```

View File

@@ -0,0 +1,69 @@
version: '3.8'
services:
ingest:
container_name: ingest
build:
context: ./ingest
ports:
- "8087:8087"
environment:
- INGEST_LISTEN_ADDR=:8087
- INFLUXDB_URL=http://influxdb:8086/api/v2/write?org=netbird&bucket=metrics&precision=ns
- INFLUXDB_TOKEN=${INFLUXDB_ADMIN_TOKEN:?required}
- CONFIG_METRICS_SERVER_URL=${CONFIG_METRICS_SERVER_URL:-}
- CONFIG_VERSION_SINCE=${CONFIG_VERSION_SINCE:-0.0.0}
- CONFIG_VERSION_UNTIL=${CONFIG_VERSION_UNTIL:-99.99.99}
- CONFIG_PERIOD_MINUTES=${CONFIG_PERIOD_MINUTES:-5}
depends_on:
- influxdb
restart: unless-stopped
networks:
- metrics
influxdb:
container_name: influxdb
image: influxdb:2
# No ports exposed — only accessible within the metrics network
volumes:
- influxdb-data:/var/lib/influxdb2
- ./influxdb/scripts:/docker-entrypoint-initdb.d
environment:
- DOCKER_INFLUXDB_INIT_MODE=setup
- DOCKER_INFLUXDB_INIT_USERNAME=admin
- DOCKER_INFLUXDB_INIT_PASSWORD=${INFLUXDB_ADMIN_PASSWORD:?required}
- DOCKER_INFLUXDB_INIT_ORG=netbird
- DOCKER_INFLUXDB_INIT_BUCKET=metrics
- DOCKER_INFLUXDB_INIT_RETENTION=365d
- DOCKER_INFLUXDB_INIT_ADMIN_TOKEN=${INFLUXDB_ADMIN_TOKEN:-}
restart: unless-stopped
networks:
- metrics
grafana:
container_name: grafana
image: grafana/grafana:11.6.0
ports:
- "3001:3000"
environment:
- GF_SECURITY_ADMIN_USER=${GRAFANA_ADMIN_USER:-admin}
- GF_SECURITY_ADMIN_PASSWORD=${GRAFANA_ADMIN_PASSWORD:?required}
- GF_USERS_ALLOW_SIGN_UP=false
- GF_INSTALL_PLUGINS=
- INFLUXDB_ADMIN_TOKEN=${INFLUXDB_ADMIN_TOKEN:-}
volumes:
- grafana-data:/var/lib/grafana
- ./grafana/provisioning:/etc/grafana/provisioning
depends_on:
- influxdb
restart: unless-stopped
networks:
- metrics
volumes:
influxdb-data:
grafana-data:
networks:
metrics:
driver: bridge

View File

@@ -0,0 +1,12 @@
apiVersion: 1
providers:
- name: 'NetBird Dashboards'
orgId: 1
folder: ''
type: file
disableDeletion: false
updateIntervalSeconds: 10
allowUiUpdates: true
options:
path: /etc/grafana/provisioning/dashboards/json

View File

@@ -0,0 +1,280 @@
{
"uid": "netbird-influxdb-metrics",
"title": "NetBird Client Metrics (InfluxDB)",
"tags": ["netbird", "connections", "influxdb"],
"timezone": "browser",
"panels": [
{
"id": 5,
"title": "Sync Duration Extremes",
"type": "stat",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"gridPos": {
"h": 8,
"w": 12,
"x": 0,
"y": 0
},
"targets": [
{
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_sync\" and r._field == \"duration_seconds\")\n |> map(fn: (r) => ({r with _value: r._value * 1000.0}))\n |> drop(columns: [\"deployment_type\", \"version\", \"os\", \"arch\", \"peer_id\"])\n |> min()\n |> set(key: \"_field\", value: \"Min\")",
"refId": "A"
},
{
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_sync\" and r._field == \"duration_seconds\")\n |> map(fn: (r) => ({r with _value: r._value * 1000.0}))\n |> drop(columns: [\"deployment_type\", \"version\", \"os\", \"arch\", \"peer_id\"])\n |> max()\n |> set(key: \"_field\", value: \"Max\")",
"refId": "B"
}
],
"fieldConfig": {
"defaults": {
"unit": "ms",
"min": 0
}
},
"options": {
"reduceOptions": {
"calcs": ["lastNotNull"]
},
"colorMode": "value",
"graphMode": "none",
"textMode": "auto"
}
},
{
"id": 6,
"title": "Total Connection Time Extremes",
"type": "stat",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"gridPos": {
"h": 8,
"w": 12,
"x": 12,
"y": 0
},
"targets": [
{
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_peer_connection\" and r._field == \"total_seconds\")\n |> map(fn: (r) => ({r with _value: r._value * 1000.0}))\n |> drop(columns: [\"deployment_type\", \"connection_type\", \"attempt_type\", \"version\", \"os\", \"arch\", \"peer_id\", \"connection_pair_id\"])\n |> min()\n |> set(key: \"_field\", value: \"Min\")",
"refId": "A"
},
{
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_peer_connection\" and r._field == \"total_seconds\")\n |> map(fn: (r) => ({r with _value: r._value * 1000.0}))\n |> drop(columns: [\"deployment_type\", \"connection_type\", \"attempt_type\", \"version\", \"os\", \"arch\", \"peer_id\", \"connection_pair_id\"])\n |> max()\n |> set(key: \"_field\", value: \"Max\")",
"refId": "B"
}
],
"fieldConfig": {
"defaults": {
"unit": "ms",
"min": 0
}
},
"options": {
"reduceOptions": {
"calcs": ["lastNotNull"]
},
"colorMode": "value",
"graphMode": "none",
"textMode": "auto"
}
},
{
"id": 1,
"title": "Sync Duration",
"type": "timeseries",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"gridPos": {
"h": 8,
"w": 12,
"x": 0,
"y": 8
},
"targets": [
{
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_sync\" and r._field == \"duration_seconds\")\n |> map(fn: (r) => ({r with _value: r._value * 1000.0}))\n |> drop(columns: [\"deployment_type\", \"version\", \"os\", \"arch\", \"peer_id\"])\n |> set(key: \"_field\", value: \"Sync Duration\")",
"refId": "A"
}
],
"fieldConfig": {
"defaults": {
"unit": "ms",
"min": 0,
"custom": {
"drawStyle": "points",
"pointSize": 5
}
}
}
},
{
"id": 4,
"title": "ICE vs Relay",
"type": "piechart",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"gridPos": {
"h": 8,
"w": 12,
"x": 12,
"y": 8
},
"targets": [
{
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_peer_connection\" and r._field == \"total_seconds\")\n |> drop(columns: [\"deployment_type\", \"attempt_type\", \"version\", \"os\", \"arch\", \"peer_id\"])\n |> group(columns: [\"connection_pair_id\"])\n |> last()\n |> group(columns: [\"connection_type\"])\n |> count()",
"refId": "A"
}
],
"options": {
"reduceOptions": {
"calcs": ["lastNotNull"]
},
"pieType": "donut",
"tooltip": {
"mode": "multi"
}
}
},
{
"id": 2,
"title": "Connection Stage Durations (avg)",
"type": "bargauge",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"gridPos": {
"h": 8,
"w": 12,
"x": 0,
"y": 16
},
"targets": [
{
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_peer_connection\" and r._field == \"signaling_to_connection_seconds\")\n |> map(fn: (r) => ({r with _value: r._value * 1000.0}))\n |> drop(columns: [\"deployment_type\", \"connection_type\", \"attempt_type\", \"version\", \"os\", \"arch\", \"peer_id\", \"connection_pair_id\"])\n |> mean()\n |> drop(columns: [\"_start\", \"_stop\", \"_measurement\", \"_time\", \"_field\"])\n |> rename(columns: {_value: \"Avg Signaling to Connection\"})",
"refId": "A"
},
{
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_peer_connection\" and r._field == \"connection_to_wg_handshake_seconds\")\n |> map(fn: (r) => ({r with _value: r._value * 1000.0}))\n |> drop(columns: [\"deployment_type\", \"connection_type\", \"attempt_type\", \"version\", \"os\", \"arch\", \"peer_id\", \"connection_pair_id\"])\n |> mean()\n |> drop(columns: [\"_start\", \"_stop\", \"_measurement\", \"_time\", \"_field\"])\n |> rename(columns: {_value: \"Avg Connection to WG Handshake\"})",
"refId": "B"
}
],
"fieldConfig": {
"defaults": {
"unit": "ms",
"min": 0
}
},
"options": {
"reduceOptions": {
"calcs": ["lastNotNull"]
},
"orientation": "horizontal",
"displayMode": "gradient"
}
},
{
"id": 3,
"title": "Total Connection Time",
"type": "timeseries",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"gridPos": {
"h": 8,
"w": 12,
"x": 12,
"y": 16
},
"targets": [
{
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_peer_connection\" and r._field == \"total_seconds\")\n |> map(fn: (r) => ({r with _value: r._value * 1000.0}))\n |> drop(columns: [\"deployment_type\", \"connection_type\", \"attempt_type\", \"version\", \"os\", \"arch\", \"peer_id\", \"connection_pair_id\"])\n |> set(key: \"_field\", value: \"Total Connection Time\")",
"refId": "A"
}
],
"fieldConfig": {
"defaults": {
"unit": "ms",
"min": 0,
"custom": {
"drawStyle": "points",
"pointSize": 5
}
}
}
},
{
"id": 7,
"title": "Login Duration",
"type": "timeseries",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"gridPos": {
"h": 8,
"w": 12,
"x": 0,
"y": 24
},
"targets": [
{
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_login\" and r._field == \"duration_seconds\")\n |> map(fn: (r) => ({r with _value: r._value * 1000.0}))\n |> drop(columns: [\"deployment_type\", \"version\", \"os\", \"arch\", \"peer_id\"])\n |> set(key: \"_field\", value: \"Login Duration\")",
"refId": "A"
}
],
"fieldConfig": {
"defaults": {
"unit": "ms",
"min": 0,
"custom": {
"drawStyle": "points",
"pointSize": 5
}
}
}
},
{
"id": 8,
"title": "Login Success vs Failure",
"type": "piechart",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"gridPos": {
"h": 8,
"w": 12,
"x": 12,
"y": 24
},
"targets": [
{
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_login\" and r._field == \"duration_seconds\")\n |> drop(columns: [\"deployment_type\", \"version\", \"os\", \"arch\", \"peer_id\"])\n |> group(columns: [\"result\"])\n |> count()",
"refId": "A"
}
],
"options": {
"reduceOptions": {
"calcs": ["lastNotNull"]
},
"pieType": "donut",
"tooltip": {
"mode": "multi"
}
}
}
],
"schemaVersion": 27,
"version": 2,
"refresh": "30s"
}

View File

@@ -0,0 +1,15 @@
apiVersion: 1
datasources:
- name: InfluxDB
uid: influxdb
type: influxdb
access: proxy
url: http://influxdb:8086
editable: true
jsonData:
version: Flux
organization: netbird
defaultBucket: metrics
secureJsonData:
token: ${INFLUXDB_ADMIN_TOKEN}

View File

@@ -0,0 +1,25 @@
#!/bin/bash
# Creates a scoped InfluxDB read-only token for Grafana.
# Clients do not need a token — they push via the ingest server.
BUCKET_ID=$(influx bucket list --org netbird --name metrics --json | grep -oP '"id"\s*:\s*"\K[^"]+' | head -1)
ORG_ID=$(influx org list --name netbird --json | grep -oP '"id"\s*:\s*"\K[^"]+' | head -1)
if [[ -z "$BUCKET_ID" ]] || [[ -z "$ORG_ID" ]]; then
echo "ERROR: Could not determine bucket or org ID" >&2
echo "BUCKET_ID=$BUCKET_ID ORG_ID=$ORG_ID" >&2
exit 1
fi
# Create read-only token for Grafana
READ_TOKEN=$(influx auth create \
--org netbird \
--read-bucket "$BUCKET_ID" \
--description "Grafana read-only token" \
--json | grep -oP '"token"\s*:\s*"\K[^"]+' | head -1)
echo ""
echo "============================================"
echo "GRAFANA READ-ONLY TOKEN:"
echo "$READ_TOKEN"
echo "============================================"

View File

@@ -0,0 +1,10 @@
FROM golang:1.25-alpine AS build
WORKDIR /app
COPY go.mod main.go ./
RUN CGO_ENABLED=0 go build -o ingest .
FROM alpine:3.20
RUN adduser -D -H ingest
COPY --from=build /app/ingest /usr/local/bin/ingest
USER ingest
ENTRYPOINT ["ingest"]

View File

@@ -0,0 +1,11 @@
module github.com/netbirdio/netbird/client/internal/metrics/infra/ingest
go 1.25
require github.com/stretchr/testify v1.11.1
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

View File

@@ -0,0 +1,10 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -0,0 +1,355 @@
package main
import (
"bytes"
"compress/gzip"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"os"
"strconv"
"strings"
"time"
)
const (
defaultListenAddr = ":8087"
defaultInfluxDBURL = "http://influxdb:8086/api/v2/write?org=netbird&bucket=metrics&precision=ns"
maxBodySize = 50 * 1024 * 1024 // 50 MB max request body
maxDurationSeconds = 300.0 // reject any duration field > 5 minutes
peerIDLength = 16 // truncated SHA-256: 8 bytes = 16 hex chars
maxTagValueLength = 64 // reject tag values longer than this
)
type measurementSpec struct {
allowedFields map[string]bool
allowedTags map[string]bool
}
var allowedMeasurements = map[string]measurementSpec{
"netbird_peer_connection": {
allowedFields: map[string]bool{
"signaling_to_connection_seconds": true,
"connection_to_wg_handshake_seconds": true,
"total_seconds": true,
},
allowedTags: map[string]bool{
"deployment_type": true,
"connection_type": true,
"attempt_type": true,
"version": true,
"os": true,
"arch": true,
"peer_id": true,
"connection_pair_id": true,
},
},
"netbird_sync": {
allowedFields: map[string]bool{
"duration_seconds": true,
},
allowedTags: map[string]bool{
"deployment_type": true,
"version": true,
"os": true,
"arch": true,
"peer_id": true,
},
},
"netbird_login": {
allowedFields: map[string]bool{
"duration_seconds": true,
},
allowedTags: map[string]bool{
"deployment_type": true,
"result": true,
"version": true,
"os": true,
"arch": true,
"peer_id": true,
},
},
}
func main() {
listenAddr := envOr("INGEST_LISTEN_ADDR", defaultListenAddr)
influxURL := envOr("INFLUXDB_URL", defaultInfluxDBURL)
influxToken := os.Getenv("INFLUXDB_TOKEN")
if influxToken == "" {
log.Fatal("INFLUXDB_TOKEN is required")
}
client := &http.Client{Timeout: 10 * time.Second}
http.HandleFunc("/", handleIngest(client, influxURL, influxToken))
// Build config JSON once at startup from env vars
configJSON := buildConfigJSON()
if configJSON != nil {
log.Printf("serving remote config at /config")
}
http.HandleFunc("/config", func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
if configJSON == nil {
http.Error(w, "config not configured", http.StatusNotFound)
return
}
w.Header().Set("Content-Type", "application/json")
w.Write(configJSON) //nolint:errcheck
})
http.HandleFunc("/health", func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
fmt.Fprint(w, "ok") //nolint:errcheck
})
log.Printf("ingest server listening on %s, forwarding to %s", listenAddr, influxURL)
if err := http.ListenAndServe(listenAddr, nil); err != nil { //nolint:gosec
log.Fatal(err)
}
}
func handleIngest(client *http.Client, influxURL, influxToken string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
if err := validateAuth(r); err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
body, err := readBody(r)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if len(body) > maxBodySize {
http.Error(w, "body too large", http.StatusRequestEntityTooLarge)
return
}
validated, err := validateLineProtocol(body)
if err != nil {
log.Printf("WARN validation failed from %s: %v", r.RemoteAddr, err)
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
forwardToInflux(w, r, client, influxURL, influxToken, validated)
}
}
func forwardToInflux(w http.ResponseWriter, r *http.Request, client *http.Client, influxURL, influxToken string, body []byte) {
req, err := http.NewRequestWithContext(r.Context(), http.MethodPost, influxURL, bytes.NewReader(body))
if err != nil {
log.Printf("ERROR create request: %v", err)
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
req.Header.Set("Content-Type", "text/plain; charset=utf-8")
req.Header.Set("Authorization", "Token "+influxToken)
resp, err := client.Do(req)
if err != nil {
log.Printf("ERROR forward to influxdb: %v", err)
http.Error(w, "upstream error", http.StatusBadGateway)
return
}
defer func(Body io.ReadCloser) {
_ = Body.Close()
}(resp.Body)
w.WriteHeader(resp.StatusCode)
io.Copy(w, resp.Body) //nolint:errcheck
}
// validateAuth checks that the X-Peer-ID header contains a valid hashed peer ID.
func validateAuth(r *http.Request) error {
peerID := r.Header.Get("X-Peer-ID")
if peerID == "" {
return fmt.Errorf("missing X-Peer-ID header")
}
if len(peerID) != peerIDLength {
return fmt.Errorf("invalid X-Peer-ID header length")
}
if _, err := hex.DecodeString(peerID); err != nil {
return fmt.Errorf("invalid X-Peer-ID header format")
}
return nil
}
// readBody reads the request body, decompressing gzip if Content-Encoding indicates it.
func readBody(r *http.Request) ([]byte, error) {
reader := io.LimitReader(r.Body, maxBodySize+1)
if r.Header.Get("Content-Encoding") == "gzip" {
gz, err := gzip.NewReader(reader)
if err != nil {
return nil, fmt.Errorf("invalid gzip: %w", err)
}
defer gz.Close()
reader = io.LimitReader(gz, maxBodySize+1)
}
return io.ReadAll(reader)
}
// validateLineProtocol parses InfluxDB line protocol lines,
// whitelists measurements and fields, and checks value bounds.
func validateLineProtocol(body []byte) ([]byte, error) {
lines := strings.Split(strings.TrimSpace(string(body)), "\n")
var valid []string
for _, line := range lines {
line = strings.TrimSpace(line)
if line == "" {
continue
}
if err := validateLine(line); err != nil {
return nil, err
}
valid = append(valid, line)
}
if len(valid) == 0 {
return nil, fmt.Errorf("no valid lines")
}
return []byte(strings.Join(valid, "\n") + "\n"), nil
}
func validateLine(line string) error {
// line protocol: measurement,tag=val,tag=val field=val,field=val timestamp
parts := strings.SplitN(line, " ", 3)
if len(parts) < 2 {
return fmt.Errorf("invalid line protocol: %q", truncate(line, 100))
}
// parts[0] is "measurement,tag=val,tag=val"
measurementAndTags := strings.Split(parts[0], ",")
measurement := measurementAndTags[0]
spec, ok := allowedMeasurements[measurement]
if !ok {
return fmt.Errorf("unknown measurement: %q", measurement)
}
// Validate tags (everything after measurement name in parts[0])
for _, tagPair := range measurementAndTags[1:] {
if err := validateTag(tagPair, measurement, spec.allowedTags); err != nil {
return err
}
}
// Validate fields
for _, pair := range strings.Split(parts[1], ",") {
if err := validateField(pair, measurement, spec.allowedFields); err != nil {
return err
}
}
return nil
}
func validateTag(pair, measurement string, allowedTags map[string]bool) error {
kv := strings.SplitN(pair, "=", 2)
if len(kv) != 2 {
return fmt.Errorf("invalid tag: %q", pair)
}
tagName := kv[0]
if !allowedTags[tagName] {
return fmt.Errorf("unknown tag %q in measurement %q", tagName, measurement)
}
if len(kv[1]) > maxTagValueLength {
return fmt.Errorf("tag value too long for %q: %d > %d", tagName, len(kv[1]), maxTagValueLength)
}
return nil
}
func validateField(pair, measurement string, allowedFields map[string]bool) error {
kv := strings.SplitN(pair, "=", 2)
if len(kv) != 2 {
return fmt.Errorf("invalid field: %q", pair)
}
fieldName := kv[0]
if !allowedFields[fieldName] {
return fmt.Errorf("unknown field %q in measurement %q", fieldName, measurement)
}
val, err := strconv.ParseFloat(kv[1], 64)
if err != nil {
return fmt.Errorf("invalid field value %q for %q", kv[1], fieldName)
}
if val < 0 {
return fmt.Errorf("negative value for %q: %g", fieldName, val)
}
if strings.HasSuffix(fieldName, "_seconds") && val > maxDurationSeconds {
return fmt.Errorf("%q too large: %g > %g", fieldName, val, maxDurationSeconds)
}
return nil
}
// buildConfigJSON builds the remote config JSON from env vars.
// Returns nil if required vars are not set.
func buildConfigJSON() []byte {
serverURL := os.Getenv("CONFIG_METRICS_SERVER_URL")
versionSince := envOr("CONFIG_VERSION_SINCE", "0.0.0")
versionUntil := envOr("CONFIG_VERSION_UNTIL", "99.99.99")
periodMinutes := envOr("CONFIG_PERIOD_MINUTES", "5")
if serverURL == "" {
return nil
}
period, err := strconv.Atoi(periodMinutes)
if err != nil || period <= 0 {
log.Printf("WARN invalid CONFIG_PERIOD_MINUTES: %q, using 5", periodMinutes)
period = 5
}
cfg := map[string]any{
"server_url": serverURL,
"version-since": versionSince,
"version-until": versionUntil,
"period_minutes": period,
}
data, err := json.Marshal(cfg)
if err != nil {
log.Printf("ERROR failed to marshal config: %v", err)
return nil
}
return data
}
func envOr(key, defaultVal string) string {
if v := os.Getenv(key); v != "" {
return v
}
return defaultVal
}
func truncate(s string, n int) string {
if len(s) <= n {
return s
}
return s[:n] + "..."
}

View File

@@ -0,0 +1,124 @@
package main
import (
"net/http"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestValidateLine_ValidPeerConnection(t *testing.T) {
line := `netbird_peer_connection,deployment_type=cloud,connection_type=ice,attempt_type=initial,version=1.0.0,os=linux,arch=amd64,peer_id=abcdef0123456789,connection_pair_id=pair1234 signaling_to_connection_seconds=1.5,connection_to_wg_handshake_seconds=0.5,total_seconds=2 1234567890`
assert.NoError(t, validateLine(line))
}
func TestValidateLine_ValidSync(t *testing.T) {
line := `netbird_sync,deployment_type=selfhosted,version=2.0.0,os=darwin,arch=arm64,peer_id=abcdef0123456789 duration_seconds=1.5 1234567890`
assert.NoError(t, validateLine(line))
}
func TestValidateLine_ValidLogin(t *testing.T) {
line := `netbird_login,deployment_type=cloud,result=success,version=1.0.0,os=linux,arch=amd64,peer_id=abcdef0123456789 duration_seconds=3.2 1234567890`
assert.NoError(t, validateLine(line))
}
func TestValidateLine_UnknownMeasurement(t *testing.T) {
line := `unknown_metric,foo=bar value=1 1234567890`
err := validateLine(line)
require.Error(t, err)
assert.Contains(t, err.Error(), "unknown measurement")
}
func TestValidateLine_UnknownTag(t *testing.T) {
line := `netbird_sync,deployment_type=cloud,evil_tag=injected,version=1.0.0,os=linux,arch=amd64,peer_id=abc duration_seconds=1.5 1234567890`
err := validateLine(line)
require.Error(t, err)
assert.Contains(t, err.Error(), "unknown tag")
}
func TestValidateLine_UnknownField(t *testing.T) {
line := `netbird_sync,deployment_type=cloud,version=1.0.0,os=linux,arch=amd64,peer_id=abc injected_field=1 1234567890`
err := validateLine(line)
require.Error(t, err)
assert.Contains(t, err.Error(), "unknown field")
}
func TestValidateLine_NegativeValue(t *testing.T) {
line := `netbird_sync,deployment_type=cloud,version=1.0.0,os=linux,arch=amd64,peer_id=abc duration_seconds=-1.5 1234567890`
err := validateLine(line)
require.Error(t, err)
assert.Contains(t, err.Error(), "negative")
}
func TestValidateLine_DurationTooLarge(t *testing.T) {
line := `netbird_sync,deployment_type=cloud,version=1.0.0,os=linux,arch=amd64,peer_id=abc duration_seconds=999 1234567890`
err := validateLine(line)
require.Error(t, err)
assert.Contains(t, err.Error(), "too large")
}
func TestValidateLine_TotalSecondsTooLarge(t *testing.T) {
line := `netbird_peer_connection,deployment_type=cloud,connection_type=ice,attempt_type=initial,version=1.0.0,os=linux,arch=amd64,peer_id=abc,connection_pair_id=pair total_seconds=500 1234567890`
err := validateLine(line)
require.Error(t, err)
assert.Contains(t, err.Error(), "too large")
}
func TestValidateLine_TagValueTooLong(t *testing.T) {
longTag := strings.Repeat("a", maxTagValueLength+1)
line := `netbird_sync,deployment_type=` + longTag + `,version=1.0.0,os=linux,arch=amd64,peer_id=abc duration_seconds=1.5 1234567890`
err := validateLine(line)
require.Error(t, err)
assert.Contains(t, err.Error(), "tag value too long")
}
func TestValidateLineProtocol_MultipleLines(t *testing.T) {
body := []byte(
"netbird_sync,deployment_type=cloud,version=1.0.0,os=linux,arch=amd64,peer_id=abc duration_seconds=1.5 1234567890\n" +
"netbird_login,deployment_type=cloud,result=success,version=1.0.0,os=linux,arch=amd64,peer_id=abc duration_seconds=2.0 1234567890\n",
)
validated, err := validateLineProtocol(body)
require.NoError(t, err)
assert.Contains(t, string(validated), "netbird_sync")
assert.Contains(t, string(validated), "netbird_login")
}
func TestValidateLineProtocol_RejectsOnBadLine(t *testing.T) {
body := []byte(
"netbird_sync,deployment_type=cloud,version=1.0.0,os=linux,arch=amd64,peer_id=abc duration_seconds=1.5 1234567890\n" +
"evil_metric,foo=bar value=1 1234567890\n",
)
_, err := validateLineProtocol(body)
require.Error(t, err)
}
func TestValidateAuth(t *testing.T) {
tests := []struct {
name string
peerID string
wantErr bool
}{
{"valid hex", "abcdef0123456789", false},
{"empty", "", true},
{"too short", "abcdef01234567", true},
{"too long", "abcdef01234567890", true},
{"invalid hex", "ghijklmnopqrstuv", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r, _ := http.NewRequest(http.MethodPost, "/", nil)
if tt.peerID != "" {
r.Header.Set("X-Peer-ID", tt.peerID)
}
err := validateAuth(r)
if tt.wantErr {
require.Error(t, err)
} else {
require.NoError(t, err)
}
})
}
}

View File

@@ -0,0 +1,224 @@
package metrics
import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/metrics/remoteconfig"
)
// AgentInfo holds static information about the agent
type AgentInfo struct {
DeploymentType DeploymentType
Version string
OS string // runtime.GOOS (linux, darwin, windows, etc.)
Arch string // runtime.GOARCH (amd64, arm64, etc.)
peerID string // anonymised peer identifier (SHA-256 of WireGuard public key)
}
// peerIDFromPublicKey returns a truncated SHA-256 hash (8 bytes / 16 hex chars) of the given WireGuard public key.
func peerIDFromPublicKey(pubKey string) string {
hash := sha256.Sum256([]byte(pubKey))
return hex.EncodeToString(hash[:8])
}
// connectionPairID returns a deterministic identifier for a connection between two peers.
// It sorts the two peer IDs before hashing so the same pair always produces the same ID
// regardless of which side computes it.
func connectionPairID(peerID1, peerID2 string) string {
a, b := peerID1, peerID2
if a > b {
a, b = b, a
}
hash := sha256.Sum256([]byte(a + b))
return hex.EncodeToString(hash[:8])
}
// metricsImplementation defines the internal interface for metrics implementations
type metricsImplementation interface {
// RecordConnectionStages records connection stage metrics from timestamps
RecordConnectionStages(
ctx context.Context,
agentInfo AgentInfo,
connectionPairID string,
connectionType ConnectionType,
isReconnection bool,
timestamps ConnectionStageTimestamps,
)
// RecordSyncDuration records how long it took to process a sync message
RecordSyncDuration(ctx context.Context, agentInfo AgentInfo, duration time.Duration)
// RecordLoginDuration records how long the login to management took
RecordLoginDuration(ctx context.Context, agentInfo AgentInfo, duration time.Duration, success bool)
// Export exports metrics in InfluxDB line protocol format
Export(w io.Writer) error
// Reset clears all collected metrics
Reset()
}
type ClientMetrics struct {
impl metricsImplementation
agentInfo AgentInfo
mu sync.RWMutex
push *Push
pushMu sync.Mutex
wg sync.WaitGroup
pushCancel context.CancelFunc
}
// ConnectionStageTimestamps holds timestamps for each connection stage
type ConnectionStageTimestamps struct {
SignalingReceived time.Time // First signal received from remote peer (both initial and reconnection)
ConnectionReady time.Time
WgHandshakeSuccess time.Time
}
// String returns a human-readable representation of the connection stage timestamps
func (c ConnectionStageTimestamps) String() string {
return fmt.Sprintf("ConnectionStageTimestamps{SignalingReceived=%v, ConnectionReady=%v, WgHandshakeSuccess=%v}",
c.SignalingReceived.Format(time.RFC3339Nano),
c.ConnectionReady.Format(time.RFC3339Nano),
c.WgHandshakeSuccess.Format(time.RFC3339Nano),
)
}
// RecordConnectionStages calculates stage durations from timestamps and records them.
// remotePubKey is the remote peer's WireGuard public key; it will be hashed for anonymisation.
func (c *ClientMetrics) RecordConnectionStages(
ctx context.Context,
remotePubKey string,
connectionType ConnectionType,
isReconnection bool,
timestamps ConnectionStageTimestamps,
) {
if c == nil {
return
}
c.mu.RLock()
agentInfo := c.agentInfo
c.mu.RUnlock()
remotePeerID := peerIDFromPublicKey(remotePubKey)
pairID := connectionPairID(agentInfo.peerID, remotePeerID)
c.impl.RecordConnectionStages(ctx, agentInfo, pairID, connectionType, isReconnection, timestamps)
}
// RecordSyncDuration records the duration of sync message processing
func (c *ClientMetrics) RecordSyncDuration(ctx context.Context, duration time.Duration) {
if c == nil {
return
}
c.mu.RLock()
agentInfo := c.agentInfo
c.mu.RUnlock()
c.impl.RecordSyncDuration(ctx, agentInfo, duration)
}
// RecordLoginDuration records how long the login to management server took
func (c *ClientMetrics) RecordLoginDuration(ctx context.Context, duration time.Duration, success bool) {
if c == nil {
return
}
c.mu.RLock()
agentInfo := c.agentInfo
c.mu.RUnlock()
c.impl.RecordLoginDuration(ctx, agentInfo, duration, success)
}
// UpdateAgentInfo updates the agent information (e.g., when switching profiles).
// publicKey is the WireGuard public key; it will be hashed for anonymisation.
func (c *ClientMetrics) UpdateAgentInfo(agentInfo AgentInfo, publicKey string) {
if c == nil {
return
}
agentInfo.peerID = peerIDFromPublicKey(publicKey)
c.mu.Lock()
c.agentInfo = agentInfo
c.mu.Unlock()
c.pushMu.Lock()
push := c.push
c.pushMu.Unlock()
if push != nil {
push.SetPeerID(agentInfo.peerID)
}
}
// Export exports metrics to the writer
func (c *ClientMetrics) Export(w io.Writer) error {
if c == nil {
return nil
}
return c.impl.Export(w)
}
// StartPush starts periodic pushing of metrics with the given configuration
// Precedence: PushConfig.ServerAddress > remote config server_url
func (c *ClientMetrics) StartPush(ctx context.Context, config PushConfig) {
if c == nil {
return
}
c.pushMu.Lock()
defer c.pushMu.Unlock()
if c.push != nil {
log.Warnf("metrics push already running")
return
}
c.mu.RLock()
agentVersion := c.agentInfo.Version
peerID := c.agentInfo.peerID
c.mu.RUnlock()
configManager := remoteconfig.NewManager(getMetricsConfigURL(), remoteconfig.DefaultMinRefreshInterval)
push, err := NewPush(c.impl, configManager, config, agentVersion)
if err != nil {
log.Errorf("failed to create metrics push: %v", err)
return
}
push.SetPeerID(peerID)
ctx, cancel := context.WithCancel(ctx)
c.pushCancel = cancel
c.wg.Add(1)
go func() {
defer c.wg.Done()
push.Start(ctx)
}()
c.push = push
}
func (c *ClientMetrics) StopPush() {
if c == nil {
return
}
c.pushMu.Lock()
defer c.pushMu.Unlock()
if c.push == nil {
return
}
c.pushCancel()
c.wg.Wait()
c.push = nil
}

View File

@@ -0,0 +1,11 @@
//go:build !js
package metrics
// NewClientMetrics creates a new ClientMetrics instance
func NewClientMetrics(agentInfo AgentInfo) *ClientMetrics {
return &ClientMetrics{
impl: newInfluxDBMetrics(),
agentInfo: agentInfo,
}
}

View File

@@ -0,0 +1,8 @@
//go:build js
package metrics
// NewClientMetrics returns nil on WASM builds — all ClientMetrics methods are nil-safe.
func NewClientMetrics(AgentInfo) *ClientMetrics {
return nil
}

View File

@@ -0,0 +1,289 @@
package metrics
import (
"bytes"
"compress/gzip"
"context"
"fmt"
"net/http"
"net/url"
"sync"
"time"
goversion "github.com/hashicorp/go-version"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/metrics/remoteconfig"
)
const (
// defaultPushInterval is the default interval for pushing metrics
defaultPushInterval = 5 * time.Minute
)
// defaultMetricsServerURL is used as fallback when NB_METRICS_FORCE_SENDING is true
var defaultMetricsServerURL *url.URL
func init() {
defaultMetricsServerURL, _ = url.Parse("https://ingest.netbird.io")
}
// PushConfig holds configuration for metrics push
type PushConfig struct {
// ServerAddress is the metrics server URL. If nil, uses remote config server_url.
ServerAddress *url.URL
// Interval is how often to push metrics. If 0, uses remote config interval or defaultPushInterval.
Interval time.Duration
// ForceSending skips remote configuration fetch and version checks, pushing unconditionally.
ForceSending bool
}
// PushConfigFromEnv builds a PushConfig from environment variables.
func PushConfigFromEnv() PushConfig {
config := PushConfig{}
config.ForceSending = isForceSending()
config.ServerAddress = getMetricsServerURL()
config.Interval = getMetricsInterval()
return config
}
// remoteConfigProvider abstracts remote push config fetching for testability
type remoteConfigProvider interface {
RefreshIfNeeded(ctx context.Context) *remoteconfig.Config
}
// Push handles periodic pushing of metrics
type Push struct {
metrics metricsImplementation
configManager remoteConfigProvider
agentVersion *goversion.Version
peerID string
peerMu sync.RWMutex
client *http.Client
cfgForceSending bool
cfgInterval time.Duration
cfgAddress *url.URL
}
// NewPush creates a new Push instance with configuration resolution
func NewPush(metrics metricsImplementation, configManager remoteConfigProvider, config PushConfig, agentVersion string) (*Push, error) {
var cfgInterval time.Duration
var cfgAddress *url.URL
if config.ForceSending {
cfgInterval = config.Interval
if config.Interval <= 0 {
cfgInterval = defaultPushInterval
}
cfgAddress = config.ServerAddress
if cfgAddress == nil {
cfgAddress = defaultMetricsServerURL
}
} else {
cfgAddress = config.ServerAddress
if config.Interval < 0 {
log.Warnf("negative metrics push interval %s", config.Interval)
} else {
cfgInterval = config.Interval
}
}
parsedVersion, err := goversion.NewVersion(agentVersion)
if err != nil {
if !config.ForceSending {
return nil, fmt.Errorf("parse agent version %q: %w", agentVersion, err)
}
}
return &Push{
metrics: metrics,
configManager: configManager,
agentVersion: parsedVersion,
cfgForceSending: config.ForceSending,
cfgInterval: cfgInterval,
cfgAddress: cfgAddress,
client: &http.Client{
Timeout: 10 * time.Second,
},
}, nil
}
// SetPeerID updates the hashed peer ID used for the Authorization header.
func (p *Push) SetPeerID(peerID string) {
p.peerMu.Lock()
p.peerID = peerID
p.peerMu.Unlock()
}
// Start starts the periodic push loop.
// The env interval override controls tick frequency but does not bypass remote config
// version gating. Use ForceSending to skip remote config entirely.
func (p *Push) Start(ctx context.Context) {
// Log initial state
switch {
case p.cfgForceSending:
log.Infof("started metrics push with force sending to %s, interval %s", p.cfgAddress, p.cfgInterval)
case p.cfgAddress != nil:
log.Infof("started metrics push with server URL override: %s", p.cfgAddress.String())
default:
log.Infof("started metrics push, server URL will be resolved from remote config")
}
timer := time.NewTimer(0) // fire immediately on first iteration
defer timer.Stop()
for {
select {
case <-ctx.Done():
log.Debug("stopping metrics push")
return
case <-timer.C:
}
pushURL, interval := p.resolve(ctx)
if pushURL != "" {
if err := p.push(ctx, pushURL); err != nil {
log.Errorf("failed to push metrics: %v", err)
}
}
if interval <= 0 {
interval = defaultPushInterval
}
timer.Reset(interval)
}
}
// resolve returns the push URL and interval for the next cycle.
// Returns empty pushURL to skip this cycle.
func (p *Push) resolve(ctx context.Context) (pushURL string, interval time.Duration) {
if p.cfgForceSending {
return p.resolveServerURL(nil), p.cfgInterval
}
config := p.configManager.RefreshIfNeeded(ctx)
if config == nil {
log.Debug("no metrics push config available, waiting to retry")
return "", defaultPushInterval
}
// prefer env variables instead of remote config
if p.cfgInterval > 0 {
interval = p.cfgInterval
} else {
interval = config.Interval
}
if !isVersionInRange(p.agentVersion, config.VersionSince, config.VersionUntil) {
log.Debugf("agent version %s not in range [%s, %s), skipping metrics push",
p.agentVersion, config.VersionSince, config.VersionUntil)
return "", interval
}
pushURL = p.resolveServerURL(&config.ServerURL)
if pushURL == "" {
log.Warn("no metrics server URL available, skipping push")
}
return pushURL, interval
}
// push exports metrics and sends them to the metrics server
func (p *Push) push(ctx context.Context, pushURL string) error {
// Export metrics without clearing
var buf bytes.Buffer
if err := p.metrics.Export(&buf); err != nil {
return fmt.Errorf("export metrics: %w", err)
}
// Don't push if there are no metrics
if buf.Len() == 0 {
log.Tracef("no metrics to push")
return nil
}
// Gzip compress the body
compressed, err := gzipCompress(buf.Bytes())
if err != nil {
return fmt.Errorf("gzip compress: %w", err)
}
// Create HTTP request
req, err := http.NewRequestWithContext(ctx, "POST", pushURL, compressed)
if err != nil {
return fmt.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", "text/plain; charset=utf-8")
req.Header.Set("Content-Encoding", "gzip")
p.peerMu.RLock()
peerID := p.peerID
p.peerMu.RUnlock()
if peerID != "" {
req.Header.Set("X-Peer-ID", peerID)
}
// Send request
resp, err := p.client.Do(req)
if err != nil {
return fmt.Errorf("send request: %w", err)
}
defer func() {
if resp.Body == nil {
return
}
if err := resp.Body.Close(); err != nil {
log.Warnf("failed to close response body: %v", err)
}
}()
// Check response status
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return fmt.Errorf("push failed with status %d", resp.StatusCode)
}
log.Debugf("successfully pushed metrics to %s", pushURL)
p.metrics.Reset()
return nil
}
// resolveServerURL determines the push URL.
// Precedence: envAddress (env var) > remote config server_url
func (p *Push) resolveServerURL(remoteServerURL *url.URL) string {
var baseURL *url.URL
if p.cfgAddress != nil {
baseURL = p.cfgAddress
} else {
baseURL = remoteServerURL
}
if baseURL == nil {
return ""
}
return baseURL.String()
}
// gzipCompress compresses data using gzip and returns the compressed buffer.
func gzipCompress(data []byte) (*bytes.Buffer, error) {
var buf bytes.Buffer
gz := gzip.NewWriter(&buf)
if _, err := gz.Write(data); err != nil {
_ = gz.Close()
return nil, err
}
if err := gz.Close(); err != nil {
return nil, err
}
return &buf, nil
}
// isVersionInRange checks if current falls within [since, until)
func isVersionInRange(current, since, until *goversion.Version) bool {
return !current.LessThan(since) && current.LessThan(until)
}

View File

@@ -0,0 +1,343 @@
package metrics
import (
"context"
"io"
"net/http"
"net/http/httptest"
"net/url"
"sync/atomic"
"testing"
"time"
goversion "github.com/hashicorp/go-version"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal/metrics/remoteconfig"
)
func mustVersion(s string) *goversion.Version {
v, err := goversion.NewVersion(s)
if err != nil {
panic(err)
}
return v
}
func mustURL(s string) url.URL {
u, err := url.Parse(s)
if err != nil {
panic(err)
}
return *u
}
func parseURL(s string) *url.URL {
u, err := url.Parse(s)
if err != nil {
panic(err)
}
return u
}
func testConfig(serverURL, since, until string, period time.Duration) *remoteconfig.Config {
return &remoteconfig.Config{
ServerURL: mustURL(serverURL),
VersionSince: mustVersion(since),
VersionUntil: mustVersion(until),
Interval: period,
}
}
// mockConfigProvider implements remoteConfigProvider for testing
type mockConfigProvider struct {
config *remoteconfig.Config
}
func (m *mockConfigProvider) RefreshIfNeeded(_ context.Context) *remoteconfig.Config {
return m.config
}
// mockMetrics implements metricsImplementation for testing
type mockMetrics struct {
exportData string
}
func (m *mockMetrics) RecordConnectionStages(_ context.Context, _ AgentInfo, _ string, _ ConnectionType, _ bool, _ ConnectionStageTimestamps) {
}
func (m *mockMetrics) RecordSyncDuration(_ context.Context, _ AgentInfo, _ time.Duration) {
}
func (m *mockMetrics) RecordLoginDuration(_ context.Context, _ AgentInfo, _ time.Duration, _ bool) {
}
func (m *mockMetrics) Export(w io.Writer) error {
if m.exportData != "" {
_, err := w.Write([]byte(m.exportData))
return err
}
return nil
}
func (m *mockMetrics) Reset() {
}
func TestPush_OverrideIntervalPushes(t *testing.T) {
var pushCount atomic.Int32
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
pushCount.Add(1)
w.WriteHeader(http.StatusNoContent)
}))
defer server.Close()
metrics := &mockMetrics{exportData: "test_metric 1\n"}
configProvider := &mockConfigProvider{config: testConfig(server.URL, "1.0.0", "2.0.0", 60*time.Minute)}
push, err := NewPush(metrics, configProvider, PushConfig{
Interval: 50 * time.Millisecond,
ServerAddress: parseURL(server.URL),
}, "1.0.0")
require.NoError(t, err)
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
go func() {
push.Start(ctx)
close(done)
}()
require.Eventually(t, func() bool {
return pushCount.Load() >= 3
}, 2*time.Second, 10*time.Millisecond)
cancel()
<-done
}
func TestPush_RemoteConfigVersionInRange(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
}))
defer server.Close()
metrics := &mockMetrics{exportData: "test_metric 1\n"}
configProvider := &mockConfigProvider{config: testConfig(server.URL, "1.0.0", "2.0.0", 1*time.Minute)}
push, err := NewPush(metrics, configProvider, PushConfig{}, "1.5.0")
require.NoError(t, err)
pushURL, interval := push.resolve(context.Background())
assert.NotEmpty(t, pushURL)
assert.Equal(t, 1*time.Minute, interval)
}
func TestPush_RemoteConfigVersionOutOfRange(t *testing.T) {
metrics := &mockMetrics{exportData: "test_metric 1\n"}
configProvider := &mockConfigProvider{config: testConfig("http://localhost", "1.0.0", "1.5.0", 1*time.Minute)}
push, err := NewPush(metrics, configProvider, PushConfig{}, "2.0.0")
require.NoError(t, err)
pushURL, interval := push.resolve(context.Background())
assert.Empty(t, pushURL)
assert.Equal(t, 1*time.Minute, interval)
}
func TestPush_NoConfigReturnsDefault(t *testing.T) {
metrics := &mockMetrics{}
configProvider := &mockConfigProvider{config: nil}
push, err := NewPush(metrics, configProvider, PushConfig{}, "1.0.0")
require.NoError(t, err)
pushURL, interval := push.resolve(context.Background())
assert.Empty(t, pushURL)
assert.Equal(t, defaultPushInterval, interval)
}
func TestPush_OverrideIntervalRespectsVersionCheck(t *testing.T) {
metrics := &mockMetrics{}
configProvider := &mockConfigProvider{config: testConfig("http://localhost", "3.0.0", "4.0.0", 60*time.Minute)}
push, err := NewPush(metrics, configProvider, PushConfig{
Interval: 30 * time.Second,
ServerAddress: parseURL("http://localhost"),
}, "1.0.0")
require.NoError(t, err)
pushURL, interval := push.resolve(context.Background())
assert.Empty(t, pushURL) // version out of range
assert.Equal(t, 30*time.Second, interval) // but uses override interval
}
func TestPush_OverrideIntervalUsedWhenVersionInRange(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
}))
defer server.Close()
metrics := &mockMetrics{}
configProvider := &mockConfigProvider{config: testConfig(server.URL, "1.0.0", "2.0.0", 60*time.Minute)}
push, err := NewPush(metrics, configProvider, PushConfig{
Interval: 30 * time.Second,
}, "1.5.0")
require.NoError(t, err)
pushURL, interval := push.resolve(context.Background())
assert.NotEmpty(t, pushURL)
assert.Equal(t, 30*time.Second, interval)
}
func TestPush_NoMetricsSkipsPush(t *testing.T) {
var pushCount atomic.Int32
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
pushCount.Add(1)
w.WriteHeader(http.StatusNoContent)
}))
defer server.Close()
metrics := &mockMetrics{exportData: ""} // no metrics to export
configProvider := &mockConfigProvider{config: nil}
push, err := NewPush(metrics, configProvider, PushConfig{}, "1.0.0")
require.NoError(t, err)
err = push.push(context.Background(), server.URL)
assert.NoError(t, err)
assert.Equal(t, int32(0), pushCount.Load())
}
func TestPush_ServerURLFromRemoteConfig(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
}))
defer server.Close()
metrics := &mockMetrics{exportData: "test_metric 1\n"}
configProvider := &mockConfigProvider{config: testConfig(server.URL, "1.0.0", "2.0.0", 1*time.Minute)}
push, err := NewPush(metrics, configProvider, PushConfig{}, "1.5.0")
require.NoError(t, err)
pushURL, interval := push.resolve(context.Background())
assert.Contains(t, pushURL, server.URL)
assert.Equal(t, 1*time.Minute, interval)
}
func TestPush_ServerAddressOverridesTakePrecedenceOverRemoteConfig(t *testing.T) {
overrideServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
}))
defer overrideServer.Close()
metrics := &mockMetrics{exportData: "test_metric 1\n"}
configProvider := &mockConfigProvider{config: testConfig("http://remote-config-server", "1.0.0", "2.0.0", 1*time.Minute)}
push, err := NewPush(metrics, configProvider, PushConfig{
ServerAddress: parseURL(overrideServer.URL),
}, "1.5.0")
require.NoError(t, err)
pushURL, _ := push.resolve(context.Background())
assert.Contains(t, pushURL, overrideServer.URL)
assert.NotContains(t, pushURL, "remote-config-server")
}
func TestPush_OverrideIntervalWithoutOverrideURL_UsesRemoteConfigURL(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
}))
defer server.Close()
metrics := &mockMetrics{exportData: "test_metric 1\n"}
configProvider := &mockConfigProvider{config: testConfig(server.URL, "1.0.0", "2.0.0", 60*time.Minute)}
push, err := NewPush(metrics, configProvider, PushConfig{
Interval: 30 * time.Second,
}, "1.0.0")
require.NoError(t, err)
pushURL, interval := push.resolve(context.Background())
assert.Contains(t, pushURL, server.URL)
assert.Equal(t, 30*time.Second, interval)
}
func TestPush_NoConfigSkipsPush(t *testing.T) {
metrics := &mockMetrics{exportData: "test_metric 1\n"}
configProvider := &mockConfigProvider{config: nil}
push, err := NewPush(metrics, configProvider, PushConfig{
Interval: 30 * time.Second,
}, "1.0.0")
require.NoError(t, err)
pushURL, interval := push.resolve(context.Background())
assert.Empty(t, pushURL)
assert.Equal(t, defaultPushInterval, interval) // no config available, use default retry interval
}
func TestPush_ForceSendingSkipsRemoteConfig(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
}))
defer server.Close()
metrics := &mockMetrics{exportData: "test_metric 1\n"}
configProvider := &mockConfigProvider{config: nil}
push, err := NewPush(metrics, configProvider, PushConfig{
ForceSending: true,
Interval: 1 * time.Minute,
ServerAddress: parseURL(server.URL),
}, "1.0.0")
require.NoError(t, err)
pushURL, interval := push.resolve(context.Background())
assert.NotEmpty(t, pushURL)
assert.Equal(t, 1*time.Minute, interval)
}
func TestPush_ForceSendingUsesDefaultInterval(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
}))
defer server.Close()
metrics := &mockMetrics{exportData: "test_metric 1\n"}
configProvider := &mockConfigProvider{config: nil}
push, err := NewPush(metrics, configProvider, PushConfig{
ForceSending: true,
ServerAddress: parseURL(server.URL),
}, "1.0.0")
require.NoError(t, err)
pushURL, interval := push.resolve(context.Background())
assert.NotEmpty(t, pushURL)
assert.Equal(t, defaultPushInterval, interval)
}
func TestIsVersionInRange(t *testing.T) {
tests := []struct {
name string
current string
since string
until string
expected bool
}{
{"at lower bound inclusive", "1.2.2", "1.2.2", "1.2.3", true},
{"in range", "1.2.2", "1.2.0", "1.3.0", true},
{"at upper bound exclusive", "1.2.3", "1.2.2", "1.2.3", false},
{"below range", "1.2.1", "1.2.2", "1.2.3", false},
{"above range", "1.3.0", "1.2.2", "1.2.3", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.expected, isVersionInRange(mustVersion(tt.current), mustVersion(tt.since), mustVersion(tt.until)))
})
}
}

View File

@@ -0,0 +1,149 @@
package remoteconfig
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"sync"
"time"
goversion "github.com/hashicorp/go-version"
log "github.com/sirupsen/logrus"
)
const (
DefaultMinRefreshInterval = 30 * time.Minute
)
// Config holds the parsed remote push configuration
type Config struct {
ServerURL url.URL
VersionSince *goversion.Version
VersionUntil *goversion.Version
Interval time.Duration
}
// rawConfig is the JSON wire format fetched from the remote server
type rawConfig struct {
ServerURL string `json:"server_url"`
VersionSince string `json:"version-since"`
VersionUntil string `json:"version-until"`
PeriodMinutes int `json:"period_minutes"`
}
// Manager handles fetching and caching remote push configuration
type Manager struct {
configURL string
minRefreshInterval time.Duration
client *http.Client
mu sync.Mutex
lastConfig *Config
lastFetched time.Time
}
func NewManager(configURL string, minRefreshInterval time.Duration) *Manager {
return &Manager{
configURL: configURL,
minRefreshInterval: minRefreshInterval,
client: &http.Client{
Timeout: 10 * time.Second,
},
}
}
// RefreshIfNeeded fetches new config if the cached one is stale.
// Returns the current config (possibly just fetched) or nil if unavailable.
func (m *Manager) RefreshIfNeeded(ctx context.Context) *Config {
m.mu.Lock()
defer m.mu.Unlock()
if m.isConfigFresh() {
return m.lastConfig
}
fetchedConfig, err := m.fetch(ctx)
m.lastFetched = time.Now()
if err != nil {
log.Warnf("failed to fetch metrics remote config: %v", err)
return m.lastConfig // return cached (may be nil)
}
m.lastConfig = fetchedConfig
log.Tracef("fetched metrics remote config: version-since=%s version-until=%s period=%s",
fetchedConfig.VersionSince, fetchedConfig.VersionUntil, fetchedConfig.Interval)
return fetchedConfig
}
func (m *Manager) isConfigFresh() bool {
if m.lastConfig == nil {
return false
}
return time.Since(m.lastFetched) < m.minRefreshInterval
}
func (m *Manager) fetch(ctx context.Context) (*Config, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, m.configURL, nil)
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
resp, err := m.client.Do(req)
if err != nil {
return nil, fmt.Errorf("send request: %w", err)
}
defer func() {
if resp.Body != nil {
_ = resp.Body.Close()
}
}()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}
body, err := io.ReadAll(io.LimitReader(resp.Body, 4096))
if err != nil {
return nil, fmt.Errorf("read body: %w", err)
}
var raw rawConfig
if err := json.Unmarshal(body, &raw); err != nil {
return nil, fmt.Errorf("parse config: %w", err)
}
if raw.PeriodMinutes <= 0 {
return nil, fmt.Errorf("invalid period_minutes: %d", raw.PeriodMinutes)
}
if raw.ServerURL == "" {
return nil, fmt.Errorf("server_url is required")
}
serverURL, err := url.Parse(raw.ServerURL)
if err != nil {
return nil, fmt.Errorf("parse server_url %q: %w", raw.ServerURL, err)
}
since, err := goversion.NewVersion(raw.VersionSince)
if err != nil {
return nil, fmt.Errorf("parse version-since %q: %w", raw.VersionSince, err)
}
until, err := goversion.NewVersion(raw.VersionUntil)
if err != nil {
return nil, fmt.Errorf("parse version-until %q: %w", raw.VersionUntil, err)
}
return &Config{
ServerURL: *serverURL,
VersionSince: since,
VersionUntil: until,
Interval: time.Duration(raw.PeriodMinutes) * time.Minute,
}, nil
}

View File

@@ -0,0 +1,197 @@
package remoteconfig
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const testMinRefresh = 100 * time.Millisecond
func TestManager_FetchSuccess(t *testing.T) {
server := newConfigServer(t, rawConfig{
ServerURL: "https://ingest.example.com",
VersionSince: "1.0.0",
VersionUntil: "2.0.0",
PeriodMinutes: 60,
})
defer server.Close()
mgr := NewManager(server.URL, testMinRefresh)
config := mgr.RefreshIfNeeded(context.Background())
require.NotNil(t, config)
assert.Equal(t, "https://ingest.example.com", config.ServerURL.String())
assert.Equal(t, "1.0.0", config.VersionSince.String())
assert.Equal(t, "2.0.0", config.VersionUntil.String())
assert.Equal(t, 60*time.Minute, config.Interval)
}
func TestManager_CachesConfig(t *testing.T) {
var fetchCount atomic.Int32
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fetchCount.Add(1)
err := json.NewEncoder(w).Encode(rawConfig{
ServerURL: "https://ingest.example.com",
VersionSince: "1.0.0",
VersionUntil: "2.0.0",
PeriodMinutes: 60,
})
require.NoError(t, err)
}))
defer server.Close()
mgr := NewManager(server.URL, testMinRefresh)
// First call fetches
config1 := mgr.RefreshIfNeeded(context.Background())
require.NotNil(t, config1)
assert.Equal(t, int32(1), fetchCount.Load())
// Second call uses cache (within minRefreshInterval)
config2 := mgr.RefreshIfNeeded(context.Background())
require.NotNil(t, config2)
assert.Equal(t, int32(1), fetchCount.Load())
assert.Equal(t, config1, config2)
}
func TestManager_RefetchesWhenStale(t *testing.T) {
var fetchCount atomic.Int32
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fetchCount.Add(1)
err := json.NewEncoder(w).Encode(rawConfig{
ServerURL: "https://ingest.example.com",
VersionSince: "1.0.0",
VersionUntil: "2.0.0",
PeriodMinutes: 60,
})
require.NoError(t, err)
}))
defer server.Close()
mgr := NewManager(server.URL, testMinRefresh)
// First fetch
mgr.RefreshIfNeeded(context.Background())
assert.Equal(t, int32(1), fetchCount.Load())
// Wait for config to become stale
time.Sleep(testMinRefresh + 10*time.Millisecond)
// Should refetch
mgr.RefreshIfNeeded(context.Background())
assert.Equal(t, int32(2), fetchCount.Load())
}
func TestManager_FetchFailureReturnsNil(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
defer server.Close()
mgr := NewManager(server.URL, testMinRefresh)
config := mgr.RefreshIfNeeded(context.Background())
assert.Nil(t, config)
}
func TestManager_FetchFailureReturnsCached(t *testing.T) {
var fetchCount atomic.Int32
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fetchCount.Add(1)
if fetchCount.Load() > 1 {
w.WriteHeader(http.StatusInternalServerError)
return
}
err := json.NewEncoder(w).Encode(rawConfig{
ServerURL: "https://ingest.example.com",
VersionSince: "1.0.0",
VersionUntil: "2.0.0",
PeriodMinutes: 60,
})
require.NoError(t, err)
}))
defer server.Close()
mgr := NewManager(server.URL, testMinRefresh)
// First call succeeds
config1 := mgr.RefreshIfNeeded(context.Background())
require.NotNil(t, config1)
// Wait for config to become stale
time.Sleep(testMinRefresh + 10*time.Millisecond)
// Second call fails but returns cached
config2 := mgr.RefreshIfNeeded(context.Background())
require.NotNil(t, config2)
assert.Equal(t, config1, config2)
}
func TestManager_RejectsInvalidPeriod(t *testing.T) {
tests := []struct {
name string
period int
}{
{"zero", 0},
{"negative", -5},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := newConfigServer(t, rawConfig{
ServerURL: "https://ingest.example.com",
VersionSince: "1.0.0",
VersionUntil: "2.0.0",
PeriodMinutes: tt.period,
})
defer server.Close()
mgr := NewManager(server.URL, testMinRefresh)
config := mgr.RefreshIfNeeded(context.Background())
assert.Nil(t, config)
})
}
}
func TestManager_RejectsEmptyServerURL(t *testing.T) {
server := newConfigServer(t, rawConfig{
ServerURL: "",
VersionSince: "1.0.0",
VersionUntil: "2.0.0",
PeriodMinutes: 60,
})
defer server.Close()
mgr := NewManager(server.URL, testMinRefresh)
config := mgr.RefreshIfNeeded(context.Background())
assert.Nil(t, config)
}
func TestManager_RejectsInvalidJSON(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := w.Write([]byte("not json"))
require.NoError(t, err)
}))
defer server.Close()
mgr := NewManager(server.URL, testMinRefresh)
config := mgr.RefreshIfNeeded(context.Background())
assert.Nil(t, config)
}
func newConfigServer(t *testing.T, config rawConfig) *httptest.Server {
t.Helper()
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w).Encode(config)
require.NoError(t, err)
}))
}

View File

@@ -15,18 +15,29 @@ import (
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgproxy"
"github.com/netbirdio/netbird/client/internal/metrics"
"github.com/netbirdio/netbird/client/internal/peer/conntype"
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
"github.com/netbirdio/netbird/client/internal/peer/guard"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/peer/id"
"github.com/netbirdio/netbird/client/internal/peer/worker"
"github.com/netbirdio/netbird/client/internal/portforward"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/route"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
)
// MetricsRecorder is an interface for recording peer connection metrics
type MetricsRecorder interface {
RecordConnectionStages(
ctx context.Context,
remotePubKey string,
connectionType metrics.ConnectionType,
isReconnection bool,
timestamps metrics.ConnectionStageTimestamps,
)
}
type ServiceDependencies struct {
StatusRecorder *Status
Signaler *Signaler
@@ -34,7 +45,7 @@ type ServiceDependencies struct {
RelayManager *relayClient.Manager
SrWatcher *guard.SRWatcher
PeerConnDispatcher *dispatcher.ConnectionDispatcher
PortForwardManager *portforward.Manager
MetricsRecorder MetricsRecorder
}
type WgConfig struct {
@@ -76,17 +87,16 @@ type ConnConfig struct {
}
type Conn struct {
Log *log.Entry
mu sync.Mutex
ctx context.Context
ctxCancel context.CancelFunc
config ConnConfig
statusRecorder *Status
signaler *Signaler
iFaceDiscover stdnet.ExternalIFaceDiscover
relayManager *relayClient.Manager
srWatcher *guard.SRWatcher
portForwardManager *portforward.Manager
Log *log.Entry
mu sync.Mutex
ctx context.Context
ctxCancel context.CancelFunc
config ConnConfig
statusRecorder *Status
signaler *Signaler
iFaceDiscover stdnet.ExternalIFaceDiscover
relayManager *relayClient.Manager
srWatcher *guard.SRWatcher
onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)
onDisconnected func(remotePeer string)
@@ -118,6 +128,10 @@ type Conn struct {
dumpState *stateDump
endpointUpdater *EndpointUpdater
// Connection stage timestamps for metrics
metricsRecorder MetricsRecorder
metricsStages *MetricsStages
}
// NewConn creates a new not opened Conn to the remote peer.
@@ -131,19 +145,19 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
dumpState := newStateDump(config.Key, connLog, services.StatusRecorder)
var conn = &Conn{
Log: connLog,
config: config,
statusRecorder: services.StatusRecorder,
signaler: services.Signaler,
iFaceDiscover: services.IFaceDiscover,
relayManager: services.RelayManager,
srWatcher: services.SrWatcher,
portForwardManager: services.PortForwardManager,
statusRelay: worker.NewAtomicStatus(),
statusICE: worker.NewAtomicStatus(),
dumpState: dumpState,
endpointUpdater: NewEndpointUpdater(connLog, config.WgConfig, isController(config)),
wgWatcher: NewWGWatcher(connLog, config.WgConfig.WgInterface, config.Key, dumpState),
Log: connLog,
config: config,
statusRecorder: services.StatusRecorder,
signaler: services.Signaler,
iFaceDiscover: services.IFaceDiscover,
relayManager: services.RelayManager,
srWatcher: services.SrWatcher,
statusRelay: worker.NewAtomicStatus(),
statusICE: worker.NewAtomicStatus(),
dumpState: dumpState,
endpointUpdater: NewEndpointUpdater(connLog, config.WgConfig, isController(config)),
wgWatcher: NewWGWatcher(connLog, config.WgConfig.WgInterface, config.Key, dumpState),
metricsRecorder: services.MetricsRecorder,
}
return conn, nil
@@ -160,6 +174,9 @@ func (conn *Conn) Open(engineCtx context.Context) error {
return nil
}
// Allocate new metrics stages so old goroutines don't corrupt new state
conn.metricsStages = &MetricsStages{}
conn.ctx, conn.ctxCancel = context.WithCancel(engineCtx)
conn.workerRelay = NewWorkerRelay(conn.ctx, conn.Log, isController(conn.config), conn.config, conn, conn.relayManager)
@@ -171,7 +188,7 @@ func (conn *Conn) Open(engineCtx context.Context) error {
}
conn.workerICE = workerICE
conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay)
conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay, conn.metricsStages)
conn.handshaker.AddRelayListener(conn.workerRelay.OnNewOffer)
if !isForceRelayed() {
@@ -339,7 +356,7 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
if conn.currentConnPriority > priority {
conn.Log.Infof("current connection priority (%s) is higher than the new one (%s), do not upgrade connection", conn.currentConnPriority, priority)
conn.statusICE.SetConnected()
conn.updateIceState(iceConnInfo)
conn.updateIceState(iceConnInfo, time.Now())
return
}
@@ -379,7 +396,8 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
}
conn.Log.Infof("configure WireGuard endpoint to: %s", ep.String())
conn.enableWgWatcherIfNeeded()
updateTime := time.Now()
conn.enableWgWatcherIfNeeded(updateTime)
presharedKey := conn.presharedKey(iceConnInfo.RosenpassPubKey)
if err = conn.endpointUpdater.ConfigureWGEndpoint(ep, presharedKey); err != nil {
@@ -395,8 +413,8 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
conn.currentConnPriority = priority
conn.statusICE.SetConnected()
conn.updateIceState(iceConnInfo)
conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr)
conn.updateIceState(iceConnInfo, updateTime)
conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr, updateTime)
}
func (conn *Conn) onICEStateDisconnected(sessionChanged bool) {
@@ -448,6 +466,10 @@ func (conn *Conn) onICEStateDisconnected(sessionChanged bool) {
conn.disableWgWatcherIfNeeded()
if conn.currentConnPriority == conntype.None {
conn.metricsStages.Disconnected()
}
peerState := State{
PubKey: conn.config.Key,
ConnStatus: conn.evalStatus(),
@@ -488,7 +510,7 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
conn.Log.Debugf("do not switch to relay because current priority is: %s", conn.currentConnPriority.String())
conn.setRelayedProxy(wgProxy)
conn.statusRelay.SetConnected()
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey, time.Now())
return
}
@@ -497,7 +519,8 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
if controller {
wgProxy.Work()
}
conn.enableWgWatcherIfNeeded()
updateTime := time.Now()
conn.enableWgWatcherIfNeeded(updateTime)
if err := conn.endpointUpdater.ConfigureWGEndpoint(wgProxy.EndpointAddr(), conn.presharedKey(rci.rosenpassPubKey)); err != nil {
if err := wgProxy.CloseConn(); err != nil {
conn.Log.Warnf("Failed to close relay connection: %v", err)
@@ -508,13 +531,16 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
if !controller {
wgProxy.Work()
}
wgConfigWorkaround()
conn.rosenpassRemoteKey = rci.rosenpassPubKey
conn.currentConnPriority = conntype.Relay
conn.statusRelay.SetConnected()
conn.setRelayedProxy(wgProxy)
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey, updateTime)
conn.Log.Infof("start to communicate with peer via relay")
conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr)
conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr, updateTime)
}
func (conn *Conn) onRelayDisconnected() {
@@ -552,6 +578,10 @@ func (conn *Conn) handleRelayDisconnectedLocked() {
conn.disableWgWatcherIfNeeded()
if conn.currentConnPriority == conntype.None {
conn.metricsStages.Disconnected()
}
peerState := State{
PubKey: conn.config.Key,
ConnStatus: conn.evalStatus(),
@@ -592,10 +622,10 @@ func (conn *Conn) onWGDisconnected() {
}
}
func (conn *Conn) updateRelayStatus(relayServerAddr string, rosenpassPubKey []byte) {
func (conn *Conn) updateRelayStatus(relayServerAddr string, rosenpassPubKey []byte, updateTime time.Time) {
peerState := State{
PubKey: conn.config.Key,
ConnStatusUpdate: time.Now(),
ConnStatusUpdate: updateTime,
ConnStatus: conn.evalStatus(),
Relayed: conn.isRelayed(),
RelayServerAddress: relayServerAddr,
@@ -608,10 +638,10 @@ func (conn *Conn) updateRelayStatus(relayServerAddr string, rosenpassPubKey []by
}
}
func (conn *Conn) updateIceState(iceConnInfo ICEConnInfo) {
func (conn *Conn) updateIceState(iceConnInfo ICEConnInfo, updateTime time.Time) {
peerState := State{
PubKey: conn.config.Key,
ConnStatusUpdate: time.Now(),
ConnStatusUpdate: updateTime,
ConnStatus: conn.evalStatus(),
Relayed: iceConnInfo.Relayed,
LocalIceCandidateType: iceConnInfo.LocalIceCandidateType,
@@ -649,11 +679,13 @@ func (conn *Conn) setStatusToDisconnected() {
}
}
func (conn *Conn) doOnConnected(remoteRosenpassPubKey []byte, remoteRosenpassAddr string) {
func (conn *Conn) doOnConnected(remoteRosenpassPubKey []byte, remoteRosenpassAddr string, updateTime time.Time) {
if runtime.GOOS == "ios" {
runtime.GC()
}
conn.metricsStages.RecordConnectionReady(updateTime)
if conn.onConnected != nil {
conn.onConnected(conn.config.Key, remoteRosenpassPubKey, conn.config.WgConfig.AllowedIps[0].Addr().String(), remoteRosenpassAddr)
}
@@ -705,14 +737,14 @@ func (conn *Conn) isConnectedOnAllWay() (connected bool) {
return true
}
func (conn *Conn) enableWgWatcherIfNeeded() {
func (conn *Conn) enableWgWatcherIfNeeded(enabledTime time.Time) {
if !conn.wgWatcher.IsEnabled() {
wgWatcherCtx, wgWatcherCancel := context.WithCancel(conn.ctx)
conn.wgWatcherCancel = wgWatcherCancel
conn.wgWatcherWg.Add(1)
go func() {
defer conn.wgWatcherWg.Done()
conn.wgWatcher.EnableWgWatcher(wgWatcherCtx, conn.onWGDisconnected)
conn.wgWatcher.EnableWgWatcher(wgWatcherCtx, enabledTime, conn.onWGDisconnected, conn.onWGHandshakeSuccess)
}()
}
}
@@ -787,6 +819,41 @@ func (conn *Conn) setRelayedProxy(proxy wgproxy.Proxy) {
conn.wgProxyRelay = proxy
}
// onWGHandshakeSuccess is called when the first WireGuard handshake is detected
func (conn *Conn) onWGHandshakeSuccess(when time.Time) {
conn.metricsStages.RecordWGHandshakeSuccess(when)
conn.recordConnectionMetrics()
}
// recordConnectionMetrics records connection stage timestamps as metrics
func (conn *Conn) recordConnectionMetrics() {
if conn.metricsRecorder == nil {
return
}
// Determine connection type based on current priority
conn.mu.Lock()
priority := conn.currentConnPriority
conn.mu.Unlock()
var connType metrics.ConnectionType
switch priority {
case conntype.Relay:
connType = metrics.ConnectionTypeRelay
default:
connType = metrics.ConnectionTypeICE
}
// Record metrics with timestamps - duration calculation happens in metrics package
conn.metricsRecorder.RecordConnectionStages(
context.Background(),
conn.config.Key,
connType,
conn.metricsStages.IsReconnection(),
conn.metricsStages.GetTimestamps(),
)
}
// AllowedIP returns the allowed IP of the remote peer
func (conn *Conn) AllowedIP() netip.Addr {
return conn.config.WgConfig.AllowedIps[0].Addr()

View File

@@ -44,12 +44,13 @@ type OfferAnswer struct {
}
type Handshaker struct {
mu sync.Mutex
log *log.Entry
config ConnConfig
signaler *Signaler
ice *WorkerICE
relay *WorkerRelay
mu sync.Mutex
log *log.Entry
config ConnConfig
signaler *Signaler
ice *WorkerICE
relay *WorkerRelay
metricsStages *MetricsStages
// relayListener is not blocking because the listener is using a goroutine to process the messages
// and it will only keep the latest message if multiple offers are received in a short time
// this is to avoid blocking the handshaker if the listener is doing some heavy processing
@@ -64,13 +65,14 @@ type Handshaker struct {
remoteAnswerCh chan OfferAnswer
}
func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *WorkerICE, relay *WorkerRelay) *Handshaker {
func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *WorkerICE, relay *WorkerRelay, metricsStages *MetricsStages) *Handshaker {
return &Handshaker{
log: log,
config: config,
signaler: signaler,
ice: ice,
relay: relay,
metricsStages: metricsStages,
remoteOffersCh: make(chan OfferAnswer),
remoteAnswerCh: make(chan OfferAnswer),
}
@@ -89,6 +91,12 @@ func (h *Handshaker) Listen(ctx context.Context) {
select {
case remoteOfferAnswer := <-h.remoteOffersCh:
h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
// Record signaling received for reconnection attempts
if h.metricsStages != nil {
h.metricsStages.RecordSignalingReceived()
}
if h.relayListener != nil {
h.relayListener.Notify(&remoteOfferAnswer)
}
@@ -103,6 +111,12 @@ func (h *Handshaker) Listen(ctx context.Context) {
}
case remoteOfferAnswer := <-h.remoteAnswerCh:
h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
// Record signaling received for reconnection attempts
if h.metricsStages != nil {
h.metricsStages.RecordSignalingReceived()
}
if h.relayListener != nil {
h.relayListener.Notify(&remoteOfferAnswer)
}

View File

@@ -0,0 +1,73 @@
package peer
import (
"sync"
"time"
"github.com/netbirdio/netbird/client/internal/metrics"
)
type MetricsStages struct {
isReconnectionAttempt bool // Track if current attempt is a reconnection
stageTimestamps metrics.ConnectionStageTimestamps
mu sync.Mutex
}
// RecordSignalingReceived records when the first signal is received from the remote peer.
// Used as the base for all subsequent stage durations to avoid inflating metrics when
// the remote peer was offline.
func (s *MetricsStages) RecordSignalingReceived() {
s.mu.Lock()
defer s.mu.Unlock()
if s.stageTimestamps.SignalingReceived.IsZero() {
s.stageTimestamps.SignalingReceived = time.Now()
}
}
func (s *MetricsStages) RecordConnectionReady(when time.Time) {
s.mu.Lock()
defer s.mu.Unlock()
if s.stageTimestamps.ConnectionReady.IsZero() {
s.stageTimestamps.ConnectionReady = when
}
}
func (s *MetricsStages) RecordWGHandshakeSuccess(handshakeTime time.Time) {
s.mu.Lock()
defer s.mu.Unlock()
if !s.stageTimestamps.ConnectionReady.IsZero() && s.stageTimestamps.WgHandshakeSuccess.IsZero() {
// WireGuard only reports handshake times with second precision, but ConnectionReady
// is captured with microsecond precision. If handshake appears before ConnectionReady
// due to truncation (e.g., handshake at 6.042s truncated to 6.000s), normalize to
// ConnectionReady to avoid negative duration metrics.
if handshakeTime.Before(s.stageTimestamps.ConnectionReady) {
s.stageTimestamps.WgHandshakeSuccess = s.stageTimestamps.ConnectionReady
} else {
s.stageTimestamps.WgHandshakeSuccess = handshakeTime
}
}
}
// Disconnected sets the mode to reconnection. It is called only when both ICE and Relay have been disconnected at the same time.
func (s *MetricsStages) Disconnected() {
s.mu.Lock()
defer s.mu.Unlock()
// Reset all timestamps for reconnection
s.stageTimestamps = metrics.ConnectionStageTimestamps{}
s.isReconnectionAttempt = true
}
func (s *MetricsStages) IsReconnection() bool {
s.mu.Lock()
defer s.mu.Unlock()
return s.isReconnectionAttempt
}
func (s *MetricsStages) GetTimestamps() metrics.ConnectionStageTimestamps {
s.mu.Lock()
defer s.mu.Unlock()
return s.stageTimestamps
}

View File

@@ -0,0 +1,125 @@
package peer
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal/metrics"
)
func TestMetricsStages_RecordSignalingReceived(t *testing.T) {
s := &MetricsStages{}
s.RecordSignalingReceived()
ts := s.GetTimestamps()
require.False(t, ts.SignalingReceived.IsZero())
// Second call should not overwrite
first := ts.SignalingReceived
time.Sleep(time.Millisecond)
s.RecordSignalingReceived()
ts = s.GetTimestamps()
assert.Equal(t, first, ts.SignalingReceived, "should keep the first signaling timestamp")
}
func TestMetricsStages_RecordConnectionReady(t *testing.T) {
s := &MetricsStages{}
now := time.Now()
s.RecordConnectionReady(now)
ts := s.GetTimestamps()
assert.Equal(t, now, ts.ConnectionReady)
// Second call should not overwrite
later := now.Add(time.Second)
s.RecordConnectionReady(later)
ts = s.GetTimestamps()
assert.Equal(t, now, ts.ConnectionReady, "should keep the first connection ready timestamp")
}
func TestMetricsStages_RecordWGHandshakeSuccess(t *testing.T) {
s := &MetricsStages{}
connReady := time.Now()
s.RecordConnectionReady(connReady)
handshake := connReady.Add(500 * time.Millisecond)
s.RecordWGHandshakeSuccess(handshake)
ts := s.GetTimestamps()
assert.Equal(t, handshake, ts.WgHandshakeSuccess)
}
func TestMetricsStages_HandshakeBeforeConnectionReady_Normalizes(t *testing.T) {
s := &MetricsStages{}
connReady := time.Now()
s.RecordConnectionReady(connReady)
// WG handshake appears before ConnectionReady due to second-precision truncation
handshake := connReady.Add(-100 * time.Millisecond)
s.RecordWGHandshakeSuccess(handshake)
ts := s.GetTimestamps()
assert.Equal(t, connReady, ts.WgHandshakeSuccess, "should normalize to ConnectionReady when handshake appears earlier")
}
func TestMetricsStages_HandshakeIgnoredWithoutConnectionReady(t *testing.T) {
s := &MetricsStages{}
s.RecordWGHandshakeSuccess(time.Now())
ts := s.GetTimestamps()
assert.True(t, ts.WgHandshakeSuccess.IsZero(), "should not record handshake without connection ready")
}
func TestMetricsStages_HandshakeRecordedOnce(t *testing.T) {
s := &MetricsStages{}
connReady := time.Now()
s.RecordConnectionReady(connReady)
first := connReady.Add(time.Second)
s.RecordWGHandshakeSuccess(first)
// Second call (rekey) should be ignored
second := connReady.Add(2 * time.Second)
s.RecordWGHandshakeSuccess(second)
ts := s.GetTimestamps()
assert.Equal(t, first, ts.WgHandshakeSuccess, "should preserve first handshake, ignore rekeys")
}
func TestMetricsStages_Disconnected(t *testing.T) {
s := &MetricsStages{}
s.RecordSignalingReceived()
s.RecordConnectionReady(time.Now())
assert.False(t, s.IsReconnection())
s.Disconnected()
assert.True(t, s.IsReconnection())
ts := s.GetTimestamps()
assert.True(t, ts.SignalingReceived.IsZero(), "timestamps should be reset after disconnect")
assert.True(t, ts.ConnectionReady.IsZero(), "timestamps should be reset after disconnect")
assert.True(t, ts.WgHandshakeSuccess.IsZero(), "timestamps should be reset after disconnect")
}
func TestMetricsStages_GetTimestamps(t *testing.T) {
s := &MetricsStages{}
ts := s.GetTimestamps()
assert.Equal(t, metrics.ConnectionStageTimestamps{}, ts)
now := time.Now()
s.RecordSignalingReceived()
s.RecordConnectionReady(now)
ts = s.GetTimestamps()
assert.False(t, ts.SignalingReceived.IsZero())
assert.Equal(t, now, ts.ConnectionReady)
assert.True(t, ts.WgHandshakeSuccess.IsZero())
}

View File

@@ -48,7 +48,7 @@ func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey strin
// EnableWgWatcher starts the WireGuard watcher. If it is already enabled, it will return immediately and do nothing.
// The watcher runs until ctx is cancelled. Caller is responsible for context lifecycle management.
func (w *WGWatcher) EnableWgWatcher(ctx context.Context, onDisconnectedFn func()) {
func (w *WGWatcher) EnableWgWatcher(ctx context.Context, enabledTime time.Time, onDisconnectedFn func(), onHandshakeSuccessFn func(when time.Time)) {
w.muEnabled.Lock()
if w.enabled {
w.muEnabled.Unlock()
@@ -56,7 +56,6 @@ func (w *WGWatcher) EnableWgWatcher(ctx context.Context, onDisconnectedFn func()
}
w.log.Debugf("enable WireGuard watcher")
enabledTime := time.Now()
w.enabled = true
w.muEnabled.Unlock()
@@ -65,7 +64,7 @@ func (w *WGWatcher) EnableWgWatcher(ctx context.Context, onDisconnectedFn func()
w.log.Warnf("failed to read initial wg stats: %v", err)
}
w.periodicHandshakeCheck(ctx, onDisconnectedFn, enabledTime, initialHandshake)
w.periodicHandshakeCheck(ctx, onDisconnectedFn, onHandshakeSuccessFn, enabledTime, initialHandshake)
w.muEnabled.Lock()
w.enabled = false
@@ -89,7 +88,7 @@ func (w *WGWatcher) Reset() {
}
// wgStateCheck help to check the state of the WireGuard handshake and relay connection
func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, onDisconnectedFn func(), enabledTime time.Time, initialHandshake time.Time) {
func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, onDisconnectedFn func(), onHandshakeSuccessFn func(when time.Time), enabledTime time.Time, initialHandshake time.Time) {
w.log.Infof("WireGuard watcher started")
timer := time.NewTimer(wgHandshakeOvertime)
@@ -108,6 +107,9 @@ func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, onDisconnectedFn
if lastHandshake.IsZero() {
elapsed := calcElapsed(enabledTime, *handshake)
w.log.Infof("first wg handshake detected within: %.2fsec, (%s)", elapsed, handshake)
if onHandshakeSuccessFn != nil {
onHandshakeSuccessFn(*handshake)
}
}
lastHandshake = *handshake

View File

@@ -35,9 +35,11 @@ func TestWGWatcher_EnableWgWatcher(t *testing.T) {
defer cancel()
onDisconnected := make(chan struct{}, 1)
go watcher.EnableWgWatcher(ctx, func() {
go watcher.EnableWgWatcher(ctx, time.Now(), func() {
mlog.Infof("onDisconnectedFn")
onDisconnected <- struct{}{}
}, func(when time.Time) {
mlog.Infof("onHandshakeSuccess: %v", when)
})
// wait for initial reading
@@ -64,7 +66,7 @@ func TestWGWatcher_ReEnable(t *testing.T) {
wg.Add(1)
go func() {
defer wg.Done()
watcher.EnableWgWatcher(ctx, func() {})
watcher.EnableWgWatcher(ctx, time.Now(), func() {}, func(when time.Time) {})
}()
cancel()
@@ -75,9 +77,9 @@ func TestWGWatcher_ReEnable(t *testing.T) {
defer cancel()
onDisconnected := make(chan struct{}, 1)
go watcher.EnableWgWatcher(ctx, func() {
go watcher.EnableWgWatcher(ctx, time.Now(), func() {
onDisconnected <- struct{}{}
})
}, func(when time.Time) {})
time.Sleep(2 * time.Second)
mocWgIface.disconnect()

View File

@@ -16,7 +16,6 @@ import (
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/internal/peer/conntype"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/portforward"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/route"
)
@@ -62,9 +61,6 @@ type WorkerICE struct {
// we record the last known state of the ICE agent to avoid duplicate on disconnected events
lastKnownState ice.ConnectionState
// portForwardAttempted tracks if we've already tried port forwarding this session
portForwardAttempted bool
}
func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, conn *Conn, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool) (*WorkerICE, error) {
@@ -218,8 +214,6 @@ func (w *WorkerICE) Close() {
}
func (w *WorkerICE) reCreateAgent(dialerCancel context.CancelFunc, candidates []ice.CandidateType) (*icemaker.ThreadSafeAgent, error) {
w.portForwardAttempted = false
agent, err := icemaker.NewAgent(w.ctx, w.iFaceDiscover, w.config.ICEConfig, candidates, w.localUfrag, w.localPwd)
if err != nil {
return nil, fmt.Errorf("create agent: %w", err)
@@ -376,93 +370,6 @@ func (w *WorkerICE) onICECandidate(candidate ice.Candidate) {
w.log.Errorf("failed signaling candidate to the remote peer %s %s", w.config.Key, err)
}
}()
if candidate.Type() == ice.CandidateTypeServerReflexive {
w.injectPortForwardedCandidate(candidate)
}
}
// injectPortForwardedCandidate signals an additional candidate using the pre-created port mapping.
func (w *WorkerICE) injectPortForwardedCandidate(srflxCandidate ice.Candidate) {
pfManager := w.conn.portForwardManager
if pfManager == nil {
return
}
mapping := pfManager.GetMapping()
if mapping == nil {
return
}
w.muxAgent.Lock()
if w.portForwardAttempted {
w.muxAgent.Unlock()
return
}
w.portForwardAttempted = true
w.muxAgent.Unlock()
forwardedCandidate, err := w.createForwardedCandidate(srflxCandidate, mapping)
if err != nil {
w.log.Warnf("create forwarded candidate: %v", err)
return
}
w.log.Debugf("injecting port-forwarded candidate: %s (mapping: %d -> %d via %s, priority: %d)",
forwardedCandidate.String(), mapping.InternalPort, mapping.ExternalPort, mapping.NATType, forwardedCandidate.Priority())
go func() {
if err := w.signaler.SignalICECandidate(forwardedCandidate, w.config.Key); err != nil {
w.log.Errorf("signal port-forwarded candidate: %v", err)
}
}()
}
// createForwardedCandidate creates a new server reflexive candidate with the forwarded port.
// It uses the NAT gateway's external IP with the forwarded port.
func (w *WorkerICE) createForwardedCandidate(srflxCandidate ice.Candidate, mapping *portforward.Mapping) (ice.Candidate, error) {
var externalIP string
if mapping.ExternalIP != nil && !mapping.ExternalIP.IsUnspecified() {
externalIP = mapping.ExternalIP.String()
} else {
// Fallback to STUN-discovered address if NAT didn't provide external IP
externalIP = srflxCandidate.Address()
}
// Per RFC 8445, the related address for srflx is the base (host candidate address).
// If the original srflx has unspecified related address, use its own address as base.
relAddr := srflxCandidate.RelatedAddress().Address
if relAddr == "" || relAddr == "0.0.0.0" || relAddr == "::" {
relAddr = srflxCandidate.Address()
}
// Arbitrary +1000 boost on top of RFC 8445 priority to favor port-forwarded candidates
// over regular srflx during ICE connectivity checks.
priority := srflxCandidate.Priority() + 1000
candidate, err := ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{
Network: srflxCandidate.NetworkType().String(),
Address: externalIP,
Port: int(mapping.ExternalPort),
Component: srflxCandidate.Component(),
Priority: priority,
RelAddr: relAddr,
RelPort: int(mapping.InternalPort),
})
if err != nil {
return nil, fmt.Errorf("create candidate: %w", err)
}
for _, e := range srflxCandidate.Extensions() {
if e.Key == ice.ExtensionKeyCandidateID {
e.Value = srflxCandidate.ID()
}
if err := candidate.AddExtension(e); err != nil {
return nil, fmt.Errorf("add extension: %w", err)
}
}
return candidate, nil
}
func (w *WorkerICE) onICESelectedCandidatePair(agent *icemaker.ThreadSafeAgent, c1, c2 ice.Candidate) {
@@ -504,10 +411,10 @@ func (w *WorkerICE) logSuccessfulPaths(agent *icemaker.ThreadSafeAgent) {
if !lok || !rok {
continue
}
w.log.Debugf("successful ICE path %s: [%s %s %s:%d] <-> [%s %s %s:%d] rtt=%.3fms",
w.log.Debugf("successful ICE path %s: [%s %s %s] <-> [%s %s %s] rtt=%.3fms",
sessionID,
local.NetworkType(), local.Type(), local.Address(), local.Port(),
remote.NetworkType(), remote.Type(), remote.Address(), remote.Port(),
local.NetworkType(), local.Type(), local.Address(),
remote.NetworkType(), remote.Type(), remote.Address(),
stat.CurrentRoundTripTime*1000)
}
}

View File

@@ -1,35 +0,0 @@
package portforward
import (
"os"
"strconv"
log "github.com/sirupsen/logrus"
)
const (
envDisableNATMapper = "NB_DISABLE_NAT_MAPPER"
envDisablePCPHealthCheck = "NB_DISABLE_PCP_HEALTH_CHECK"
)
func isDisabledByEnv() bool {
return parseBoolEnv(envDisableNATMapper)
}
func isHealthCheckDisabled() bool {
return parseBoolEnv(envDisablePCPHealthCheck)
}
func parseBoolEnv(key string) bool {
val := os.Getenv(key)
if val == "" {
return false
}
disabled, err := strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", key, err)
return false
}
return disabled
}

View File

@@ -1,298 +0,0 @@
//go:build !js
package portforward
import (
"context"
"fmt"
"net"
"sync"
"time"
"github.com/libp2p/go-nat"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/portforward/pcp"
)
const (
defaultMappingTTL = 2 * time.Hour
renewalInterval = defaultMappingTTL / 2
healthCheckInterval = 1 * time.Minute
discoveryTimeout = 10 * time.Second
mappingDescription = "NetBird"
)
type Mapping struct {
Protocol string
InternalPort uint16
ExternalPort uint16
ExternalIP net.IP
NATType string
}
type Manager struct {
cancel context.CancelFunc
mapping *Mapping
mappingLock sync.Mutex
wgPort uint16
done chan struct{}
stopCtx chan context.Context
// protect exported functions
mu sync.Mutex
}
func NewManager() *Manager {
return &Manager{
stopCtx: make(chan context.Context, 1),
}
}
func (m *Manager) Start(ctx context.Context, wgPort uint16) {
m.mu.Lock()
if m.cancel != nil {
m.mu.Unlock()
return
}
if isDisabledByEnv() {
log.Infof("NAT port mapper disabled via %s", envDisableNATMapper)
m.mu.Unlock()
return
}
if wgPort == 0 {
log.Warnf("invalid WireGuard port 0; NAT mapping disabled")
m.mu.Unlock()
return
}
m.wgPort = wgPort
m.done = make(chan struct{})
defer close(m.done)
ctx, m.cancel = context.WithCancel(ctx)
m.mu.Unlock()
gateway, mapping, err := m.setup(ctx)
if err != nil {
log.Errorf("failed to setup NAT port mapping: %v", err)
return
}
m.mappingLock.Lock()
m.mapping = mapping
m.mappingLock.Unlock()
m.renewLoop(ctx, gateway)
select {
case cleanupCtx := <-m.stopCtx:
// block the Start while cleaned up gracefully
m.cleanup(cleanupCtx, gateway)
default:
// return Start immediately and cleanup in background
cleanupCtx, cleanupCancel := context.WithTimeout(context.Background(), 10*time.Second)
go func() {
defer cleanupCancel()
m.cleanup(cleanupCtx, gateway)
}()
}
}
// GetMapping returns the current mapping if ready, nil otherwise
func (m *Manager) GetMapping() *Mapping {
m.mappingLock.Lock()
defer m.mappingLock.Unlock()
if m.mapping == nil {
return nil
}
mapping := *m.mapping
return &mapping
}
// GracefullyStop cancels the manager and attempts to delete the port mapping.
// After GracefullyStop returns, the manager cannot be restarted.
func (m *Manager) GracefullyStop(ctx context.Context) error {
m.mu.Lock()
defer m.mu.Unlock()
if m.cancel == nil {
return nil
}
// Send cleanup context before cancelling, so Start picks it up after renewLoop exits.
m.startTearDown(ctx)
m.cancel()
m.cancel = nil
select {
case <-ctx.Done():
return ctx.Err()
case <-m.done:
return nil
}
}
func (m *Manager) setup(ctx context.Context) (nat.NAT, *Mapping, error) {
discoverCtx, discoverCancel := context.WithTimeout(ctx, discoveryTimeout)
defer discoverCancel()
gateway, err := discoverGateway(discoverCtx)
if err != nil {
log.Infof("NAT gateway discovery failed: %v (port forwarding disabled)", err)
return nil, nil, err
}
log.Infof("discovered NAT gateway: %s", gateway.Type())
mapping, err := m.createMapping(ctx, gateway)
if err != nil {
log.Warnf("failed to create port mapping: %v", err)
return nil, nil, err
}
return gateway, mapping, nil
}
func (m *Manager) createMapping(ctx context.Context, gateway nat.NAT) (*Mapping, error) {
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
externalPort, err := gateway.AddPortMapping(ctx, "udp", int(m.wgPort), mappingDescription, defaultMappingTTL)
if err != nil {
return nil, err
}
externalIP, err := gateway.GetExternalAddress()
if err != nil {
log.Debugf("failed to get external address: %v", err)
}
mapping := &Mapping{
Protocol: "udp",
InternalPort: m.wgPort,
ExternalPort: uint16(externalPort),
ExternalIP: externalIP,
NATType: gateway.Type(),
}
log.Infof("created port mapping: %d -> %d via %s (external IP: %s)",
m.wgPort, externalPort, gateway.Type(), externalIP)
return mapping, nil
}
func (m *Manager) renewLoop(ctx context.Context, gateway nat.NAT) {
renewTicker := time.NewTicker(renewalInterval)
healthTicker := time.NewTicker(healthCheckInterval)
defer renewTicker.Stop()
defer healthTicker.Stop()
for {
select {
case <-ctx.Done():
return
case <-renewTicker.C:
if err := m.renewMapping(ctx, gateway); err != nil {
log.Warnf("failed to renew port mapping: %v", err)
continue
}
case <-healthTicker.C:
if m.checkHealthAndRecreate(ctx, gateway) {
renewTicker.Reset(renewalInterval)
}
}
}
}
func (m *Manager) checkHealthAndRecreate(ctx context.Context, gateway nat.NAT) bool {
if isHealthCheckDisabled() {
return false
}
m.mappingLock.Lock()
hasMapping := m.mapping != nil
m.mappingLock.Unlock()
if !hasMapping {
return false
}
pcpNAT, ok := gateway.(*pcp.NAT)
if !ok {
return false
}
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
epoch, serverRestarted, err := pcpNAT.CheckServerHealth(ctx)
if err != nil {
log.Debugf("PCP health check failed: %v", err)
return false
}
if serverRestarted {
log.Warnf("PCP server restart detected (epoch=%d), recreating port mapping", epoch)
if err := m.renewMapping(ctx, gateway); err != nil {
log.Errorf("failed to recreate port mapping after server restart: %v", err)
return false
}
return true
}
return false
}
func (m *Manager) renewMapping(ctx context.Context, gateway nat.NAT) error {
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
externalPort, err := gateway.AddPortMapping(ctx, m.mapping.Protocol, int(m.mapping.InternalPort), mappingDescription, defaultMappingTTL)
if err != nil {
return fmt.Errorf("add port mapping: %w", err)
}
if uint16(externalPort) != m.mapping.ExternalPort {
log.Warnf("external port changed on renewal: %d -> %d (candidate may be stale)", m.mapping.ExternalPort, externalPort)
m.mappingLock.Lock()
m.mapping.ExternalPort = uint16(externalPort)
m.mappingLock.Unlock()
}
log.Debugf("renewed port mapping: %d -> %d", m.mapping.InternalPort, m.mapping.ExternalPort)
return nil
}
func (m *Manager) cleanup(ctx context.Context, gateway nat.NAT) {
m.mappingLock.Lock()
mapping := m.mapping
m.mapping = nil
m.mappingLock.Unlock()
if mapping == nil {
return
}
if err := gateway.DeletePortMapping(ctx, mapping.Protocol, int(mapping.InternalPort)); err != nil {
log.Warnf("delete port mapping on stop: %v", err)
return
}
log.Infof("deleted port mapping for port %d", mapping.InternalPort)
}
func (m *Manager) startTearDown(ctx context.Context) {
select {
case m.stopCtx <- ctx:
default:
}
}

View File

@@ -1,36 +0,0 @@
package portforward
import (
"context"
"net"
)
// Mapping represents port mapping information.
type Mapping struct {
Protocol string
InternalPort uint16
ExternalPort uint16
ExternalIP net.IP
NATType string
}
// Manager is a stub for js/wasm builds where NAT-PMP/UPnP is not supported.
type Manager struct{}
// NewManager returns a stub manager for js/wasm builds.
func NewManager() *Manager {
return &Manager{}
}
// Start is a no-op on js/wasm: NAT-PMP/UPnP is not available in browser environments.
func (m *Manager) Start(context.Context, uint16) {
// no NAT traversal in wasm
}
// GracefullyStop is a no-op on js/wasm.
func (m *Manager) GracefullyStop(context.Context) error { return nil }
// GetMapping always returns nil on js/wasm.
func (m *Manager) GetMapping() *Mapping {
return nil
}

View File

@@ -1,159 +0,0 @@
//go:build !js
package portforward
import (
"context"
"net"
"testing"
"time"
"github.com/libp2p/go-nat"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type mockNAT struct {
natType string
deviceAddr net.IP
externalAddr net.IP
internalAddr net.IP
mappings map[int]int
addMappingErr error
deleteMappingErr error
}
func newMockNAT() *mockNAT {
return &mockNAT{
natType: "Mock-NAT",
deviceAddr: net.ParseIP("192.168.1.1"),
externalAddr: net.ParseIP("203.0.113.50"),
internalAddr: net.ParseIP("192.168.1.100"),
mappings: make(map[int]int),
}
}
func (m *mockNAT) Type() string {
return m.natType
}
func (m *mockNAT) GetDeviceAddress() (net.IP, error) {
return m.deviceAddr, nil
}
func (m *mockNAT) GetExternalAddress() (net.IP, error) {
return m.externalAddr, nil
}
func (m *mockNAT) GetInternalAddress() (net.IP, error) {
return m.internalAddr, nil
}
func (m *mockNAT) AddPortMapping(ctx context.Context, protocol string, internalPort int, description string, timeout time.Duration) (int, error) {
if m.addMappingErr != nil {
return 0, m.addMappingErr
}
externalPort := internalPort
m.mappings[internalPort] = externalPort
return externalPort, nil
}
func (m *mockNAT) DeletePortMapping(ctx context.Context, protocol string, internalPort int) error {
if m.deleteMappingErr != nil {
return m.deleteMappingErr
}
delete(m.mappings, internalPort)
return nil
}
func TestManager_CreateMapping(t *testing.T) {
m := NewManager()
m.wgPort = 51820
gateway := newMockNAT()
mapping, err := m.createMapping(context.Background(), gateway)
require.NoError(t, err)
require.NotNil(t, mapping)
assert.Equal(t, "udp", mapping.Protocol)
assert.Equal(t, uint16(51820), mapping.InternalPort)
assert.Equal(t, uint16(51820), mapping.ExternalPort)
assert.Equal(t, "Mock-NAT", mapping.NATType)
assert.Equal(t, net.ParseIP("203.0.113.50").To4(), mapping.ExternalIP.To4())
}
func TestManager_GetMapping_ReturnsNilWhenNotReady(t *testing.T) {
m := NewManager()
assert.Nil(t, m.GetMapping())
}
func TestManager_GetMapping_ReturnsCopy(t *testing.T) {
m := NewManager()
m.mapping = &Mapping{
Protocol: "udp",
InternalPort: 51820,
ExternalPort: 51820,
}
mapping := m.GetMapping()
require.NotNil(t, mapping)
assert.Equal(t, uint16(51820), mapping.InternalPort)
// Mutating the returned copy should not affect the manager's mapping.
mapping.ExternalPort = 9999
assert.Equal(t, uint16(51820), m.GetMapping().ExternalPort)
}
func TestManager_Cleanup_DeletesMapping(t *testing.T) {
m := NewManager()
m.mapping = &Mapping{
Protocol: "udp",
InternalPort: 51820,
ExternalPort: 51820,
}
gateway := newMockNAT()
// Seed the mock so we can verify deletion.
gateway.mappings[51820] = 51820
m.cleanup(context.Background(), gateway)
_, exists := gateway.mappings[51820]
assert.False(t, exists, "mapping should be deleted from gateway")
assert.Nil(t, m.GetMapping(), "in-memory mapping should be cleared")
}
func TestManager_Cleanup_NilMapping(t *testing.T) {
m := NewManager()
gateway := newMockNAT()
// Should not panic or call gateway.
m.cleanup(context.Background(), gateway)
}
func TestState_Cleanup(t *testing.T) {
origDiscover := discoverGateway
defer func() { discoverGateway = origDiscover }()
mockGateway := newMockNAT()
mockGateway.mappings[51820] = 51820
discoverGateway = func(ctx context.Context) (nat.NAT, error) {
return mockGateway, nil
}
state := &State{
Protocol: "udp",
InternalPort: 51820,
}
err := state.Cleanup()
assert.NoError(t, err)
_, exists := mockGateway.mappings[51820]
assert.False(t, exists, "mapping should be deleted after cleanup")
}
func TestState_Name(t *testing.T) {
state := &State{}
assert.Equal(t, "port_forward_state", state.Name())
}

View File

@@ -1,408 +0,0 @@
package pcp
import (
"context"
"crypto/rand"
"errors"
"fmt"
"net"
"net/netip"
"sync"
"time"
log "github.com/sirupsen/logrus"
)
const (
defaultTimeout = 3 * time.Second
responseBufferSize = 128
// RFC 6887 Section 8.1.1 retry timing
initialRetryDelay = 3 * time.Second
maxRetryDelay = 1024 * time.Second
maxRetries = 4 // 3s + 6s + 12s + 24s = 45s total worst case
)
// Client is a PCP protocol client.
// All methods are safe for concurrent use.
type Client struct {
gateway netip.Addr
timeout time.Duration
mu sync.Mutex
// localIP caches the resolved local IP address.
localIP netip.Addr
// lastEpoch is the last observed server epoch value.
lastEpoch uint32
// epochTime tracks when lastEpoch was received for state loss detection.
epochTime time.Time
// externalIP caches the external IP from the last successful MAP response.
externalIP netip.Addr
// epochStateLost is set when epoch indicates server restart.
epochStateLost bool
}
// NewClient creates a new PCP client for the gateway at the given IP.
func NewClient(gateway net.IP) *Client {
addr, ok := netip.AddrFromSlice(gateway)
if !ok {
log.Debugf("invalid gateway IP: %v", gateway)
}
return &Client{
gateway: addr.Unmap(),
timeout: defaultTimeout,
}
}
// NewClientWithTimeout creates a new PCP client with a custom timeout.
func NewClientWithTimeout(gateway net.IP, timeout time.Duration) *Client {
addr, ok := netip.AddrFromSlice(gateway)
if !ok {
log.Debugf("invalid gateway IP: %v", gateway)
}
return &Client{
gateway: addr.Unmap(),
timeout: timeout,
}
}
// SetLocalIP sets the local IP address to use in PCP requests.
func (c *Client) SetLocalIP(ip net.IP) {
addr, ok := netip.AddrFromSlice(ip)
if !ok {
log.Debugf("invalid local IP: %v", ip)
}
c.mu.Lock()
c.localIP = addr.Unmap()
c.mu.Unlock()
}
// Gateway returns the gateway IP address.
func (c *Client) Gateway() net.IP {
return c.gateway.AsSlice()
}
// Announce sends a PCP ANNOUNCE request to discover PCP support.
// Returns the server's epoch time on success.
func (c *Client) Announce(ctx context.Context) (epoch uint32, err error) {
localIP, err := c.getLocalIP()
if err != nil {
return 0, fmt.Errorf("get local IP: %w", err)
}
req := buildAnnounceRequest(localIP)
resp, err := c.sendRequest(ctx, req)
if err != nil {
return 0, fmt.Errorf("send announce: %w", err)
}
parsed, err := parseResponse(resp)
if err != nil {
return 0, fmt.Errorf("parse announce response: %w", err)
}
if parsed.ResultCode != ResultSuccess {
return 0, fmt.Errorf("PCP ANNOUNCE failed: %s", ResultCodeString(parsed.ResultCode))
}
c.mu.Lock()
if c.updateEpochLocked(parsed.Epoch) {
log.Warnf("PCP server epoch indicates state loss - mappings may need refresh")
}
c.mu.Unlock()
return parsed.Epoch, nil
}
// AddPortMapping requests a port mapping from the PCP server.
func (c *Client) AddPortMapping(ctx context.Context, protocol string, internalPort int, lifetime time.Duration) (*MapResponse, error) {
return c.addPortMappingWithHint(ctx, protocol, internalPort, internalPort, netip.Addr{}, lifetime)
}
// AddPortMappingWithHint requests a port mapping with suggested external port and IP.
// Use lifetime <= 0 to delete a mapping.
func (c *Client) AddPortMappingWithHint(ctx context.Context, protocol string, internalPort, suggestedExtPort int, suggestedExtIP net.IP, lifetime time.Duration) (*MapResponse, error) {
var extIP netip.Addr
if suggestedExtIP != nil {
var ok bool
extIP, ok = netip.AddrFromSlice(suggestedExtIP)
if !ok {
log.Debugf("invalid suggested external IP: %v", suggestedExtIP)
}
extIP = extIP.Unmap()
}
return c.addPortMappingWithHint(ctx, protocol, internalPort, suggestedExtPort, extIP, lifetime)
}
func (c *Client) addPortMappingWithHint(ctx context.Context, protocol string, internalPort, suggestedExtPort int, suggestedExtIP netip.Addr, lifetime time.Duration) (*MapResponse, error) {
localIP, err := c.getLocalIP()
if err != nil {
return nil, fmt.Errorf("get local IP: %w", err)
}
proto, err := protocolNumber(protocol)
if err != nil {
return nil, fmt.Errorf("parse protocol: %w", err)
}
var nonce [12]byte
if _, err := rand.Read(nonce[:]); err != nil {
return nil, fmt.Errorf("generate nonce: %w", err)
}
// Convert lifetime to seconds. Lifetime 0 means delete, so only apply
// default for positive durations that round to 0 seconds.
var lifetimeSec uint32
if lifetime > 0 {
lifetimeSec = uint32(lifetime.Seconds())
if lifetimeSec == 0 {
lifetimeSec = DefaultLifetime
}
}
req := buildMapRequest(localIP, nonce, proto, uint16(internalPort), uint16(suggestedExtPort), suggestedExtIP, lifetimeSec)
resp, err := c.sendRequest(ctx, req)
if err != nil {
return nil, fmt.Errorf("send map request: %w", err)
}
mapResp, err := parseMapResponse(resp)
if err != nil {
return nil, fmt.Errorf("parse map response: %w", err)
}
if mapResp.Nonce != nonce {
return nil, fmt.Errorf("nonce mismatch in response")
}
if mapResp.Protocol != proto {
return nil, fmt.Errorf("protocol mismatch: requested %d, got %d", proto, mapResp.Protocol)
}
if mapResp.InternalPort != uint16(internalPort) {
return nil, fmt.Errorf("internal port mismatch: requested %d, got %d", internalPort, mapResp.InternalPort)
}
if mapResp.ResultCode != ResultSuccess {
return nil, &Error{
Code: mapResp.ResultCode,
Message: ResultCodeString(mapResp.ResultCode),
}
}
c.mu.Lock()
if c.updateEpochLocked(mapResp.Epoch) {
log.Warnf("PCP server epoch indicates state loss - mappings may need refresh")
}
c.cacheExternalIPLocked(mapResp.ExternalIP)
c.mu.Unlock()
return mapResp, nil
}
// DeletePortMapping removes a port mapping by requesting zero lifetime.
func (c *Client) DeletePortMapping(ctx context.Context, protocol string, internalPort int) error {
if _, err := c.addPortMappingWithHint(ctx, protocol, internalPort, 0, netip.Addr{}, 0); err != nil {
var pcpErr *Error
if errors.As(err, &pcpErr) && pcpErr.Code == ResultNotAuthorized {
return nil
}
return fmt.Errorf("delete mapping: %w", err)
}
return nil
}
// GetExternalAddress returns the external IP address.
// First checks for a cached value from previous MAP responses.
// If not cached, creates a short-lived mapping to discover the external IP.
func (c *Client) GetExternalAddress(ctx context.Context) (net.IP, error) {
c.mu.Lock()
if c.externalIP.IsValid() {
ip := c.externalIP.AsSlice()
c.mu.Unlock()
return ip, nil
}
c.mu.Unlock()
// Use an ephemeral port in the dynamic range (49152-65535).
// Port 0 is not valid with UDP/TCP protocols per RFC 6887.
ephemeralPort := 49152 + int(uint16(time.Now().UnixNano()))%(65535-49152)
// Use minimal lifetime (1 second) for discovery.
resp, err := c.AddPortMapping(ctx, "udp", ephemeralPort, time.Second)
if err != nil {
return nil, fmt.Errorf("create temporary mapping: %w", err)
}
if err := c.DeletePortMapping(ctx, "udp", ephemeralPort); err != nil {
log.Debugf("cleanup temporary PCP mapping: %v", err)
}
return resp.ExternalIP.AsSlice(), nil
}
// LastEpoch returns the last observed server epoch value.
// A decrease in epoch indicates the server may have restarted and mappings may be lost.
func (c *Client) LastEpoch() uint32 {
c.mu.Lock()
defer c.mu.Unlock()
return c.lastEpoch
}
// EpochStateLost returns true if epoch state loss was detected and clears the flag.
func (c *Client) EpochStateLost() bool {
c.mu.Lock()
defer c.mu.Unlock()
lost := c.epochStateLost
c.epochStateLost = false
return lost
}
// updateEpoch updates the epoch tracking and detects potential state loss.
// Returns true if state loss was detected (server likely restarted).
// Caller must hold c.mu.
func (c *Client) updateEpochLocked(newEpoch uint32) bool {
now := time.Now()
stateLost := false
// RFC 6887 Section 8.5: Detect invalid epoch indicating server state loss.
// client_delta = time since last response
// server_delta = epoch change since last response
// Invalid if: client_delta+2 < server_delta - server_delta/16
// OR: server_delta+2 < client_delta - client_delta/16
// The +2 handles quantization, /16 (6.25%) handles clock drift.
if !c.epochTime.IsZero() && c.lastEpoch > 0 {
clientDelta := uint32(now.Sub(c.epochTime).Seconds())
serverDelta := newEpoch - c.lastEpoch
// Check for epoch going backwards or jumping unexpectedly.
// Subtraction is safe: serverDelta/16 is always <= serverDelta.
if clientDelta+2 < serverDelta-(serverDelta/16) ||
serverDelta+2 < clientDelta-(clientDelta/16) {
stateLost = true
c.epochStateLost = true
}
}
c.lastEpoch = newEpoch
c.epochTime = now
return stateLost
}
// cacheExternalIP stores the external IP from a successful MAP response.
// Caller must hold c.mu.
func (c *Client) cacheExternalIPLocked(ip netip.Addr) {
if ip.IsValid() && !ip.IsUnspecified() {
c.externalIP = ip
}
}
// sendRequest sends a PCP request with retries per RFC 6887 Section 8.1.1.
func (c *Client) sendRequest(ctx context.Context, req []byte) ([]byte, error) {
addr := &net.UDPAddr{IP: c.gateway.AsSlice(), Port: Port}
var lastErr error
delay := initialRetryDelay
for range maxRetries {
resp, err := c.sendOnce(ctx, addr, req)
if err == nil {
return resp, nil
}
lastErr = err
if ctx.Err() != nil {
return nil, ctx.Err()
}
// RFC 6887 Section 8.1.1: RT = (1 + RAND) * MIN(2 * RTprev, MRT)
// RAND is random between -0.1 and +0.1
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(retryDelayWithJitter(delay)):
}
delay = min(delay*2, maxRetryDelay)
}
return nil, fmt.Errorf("PCP request failed after %d retries: %w", maxRetries, lastErr)
}
// retryDelayWithJitter applies RFC 6887 jitter: multiply by (1 + RAND) where RAND is [-0.1, +0.1].
func retryDelayWithJitter(d time.Duration) time.Duration {
var b [1]byte
_, _ = rand.Read(b[:])
// Convert byte to range [-0.1, +0.1]: (b/255 * 0.2) - 0.1
jitter := (float64(b[0])/255.0)*0.2 - 0.1
return time.Duration(float64(d) * (1 + jitter))
}
func (c *Client) sendOnce(ctx context.Context, addr *net.UDPAddr, req []byte) ([]byte, error) {
// Use ListenUDP instead of DialUDP to validate response source address per RFC 6887 §8.3.
conn, err := net.ListenUDP("udp", nil)
if err != nil {
return nil, fmt.Errorf("listen: %w", err)
}
defer func() {
if err := conn.Close(); err != nil {
log.Debugf("close UDP connection: %v", err)
}
}()
timeout := c.timeout
if deadline, ok := ctx.Deadline(); ok {
if remaining := time.Until(deadline); remaining < timeout {
timeout = remaining
}
}
if err := conn.SetDeadline(time.Now().Add(timeout)); err != nil {
return nil, fmt.Errorf("set deadline: %w", err)
}
if _, err := conn.WriteToUDP(req, addr); err != nil {
return nil, fmt.Errorf("write: %w", err)
}
resp := make([]byte, responseBufferSize)
n, from, err := conn.ReadFromUDP(resp)
if err != nil {
return nil, fmt.Errorf("read: %w", err)
}
// RFC 6887 §8.3: Validate response came from expected PCP server.
if !from.IP.Equal(addr.IP) {
return nil, fmt.Errorf("response from unexpected source %s (expected %s)", from.IP, addr.IP)
}
return resp[:n], nil
}
func (c *Client) getLocalIP() (netip.Addr, error) {
c.mu.Lock()
defer c.mu.Unlock()
if !c.localIP.IsValid() {
return netip.Addr{}, fmt.Errorf("local IP not set for gateway %s", c.gateway)
}
return c.localIP, nil
}
func protocolNumber(protocol string) (uint8, error) {
switch protocol {
case "udp", "UDP":
return ProtoUDP, nil
case "tcp", "TCP":
return ProtoTCP, nil
default:
return 0, fmt.Errorf("unsupported protocol: %s", protocol)
}
}
// Error represents a PCP error response.
type Error struct {
Code uint8
Message string
}
func (e *Error) Error() string {
return fmt.Sprintf("PCP error: %s (%d)", e.Message, e.Code)
}

View File

@@ -1,187 +0,0 @@
package pcp
import (
"context"
"net"
"net/netip"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestAddrConversion(t *testing.T) {
tests := []struct {
name string
addr netip.Addr
}{
{"IPv4", netip.MustParseAddr("192.168.1.100")},
{"IPv4 loopback", netip.MustParseAddr("127.0.0.1")},
{"IPv6", netip.MustParseAddr("2001:db8::1")},
{"IPv6 loopback", netip.MustParseAddr("::1")},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
b16 := addrTo16(tt.addr)
recovered := addrFrom16(b16)
assert.Equal(t, tt.addr, recovered, "address should round-trip")
})
}
}
func TestBuildAnnounceRequest(t *testing.T) {
clientIP := netip.MustParseAddr("192.168.1.100")
req := buildAnnounceRequest(clientIP)
require.Len(t, req, headerSize)
assert.Equal(t, byte(Version), req[0], "version")
assert.Equal(t, byte(OpAnnounce), req[1], "opcode")
// Check client IP is properly encoded as IPv4-mapped IPv6
assert.Equal(t, byte(0xff), req[18], "IPv4-mapped prefix byte 10")
assert.Equal(t, byte(0xff), req[19], "IPv4-mapped prefix byte 11")
assert.Equal(t, byte(192), req[20], "IP octet 1")
assert.Equal(t, byte(168), req[21], "IP octet 2")
assert.Equal(t, byte(1), req[22], "IP octet 3")
assert.Equal(t, byte(100), req[23], "IP octet 4")
}
func TestBuildMapRequest(t *testing.T) {
clientIP := netip.MustParseAddr("192.168.1.100")
nonce := [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}
req := buildMapRequest(clientIP, nonce, ProtoUDP, 51820, 51820, netip.Addr{}, 3600)
require.Len(t, req, mapRequestSize)
assert.Equal(t, byte(Version), req[0], "version")
assert.Equal(t, byte(OpMap), req[1], "opcode")
// Lifetime at bytes 4-7
assert.Equal(t, uint32(3600), (uint32(req[4])<<24)|(uint32(req[5])<<16)|(uint32(req[6])<<8)|uint32(req[7]), "lifetime")
// Nonce at bytes 24-35
assert.Equal(t, nonce[:], req[24:36], "nonce")
// Protocol at byte 36
assert.Equal(t, byte(ProtoUDP), req[36], "protocol")
// Internal port at bytes 40-41
assert.Equal(t, uint16(51820), (uint16(req[40])<<8)|uint16(req[41]), "internal port")
// External port at bytes 42-43
assert.Equal(t, uint16(51820), (uint16(req[42])<<8)|uint16(req[43]), "external port")
}
func TestParseResponse(t *testing.T) {
// Construct a valid ANNOUNCE response
resp := make([]byte, headerSize)
resp[0] = Version
resp[1] = OpAnnounce | OpReply
// Result code = 0 (success)
// Lifetime = 0
// Epoch = 12345
resp[8] = 0
resp[9] = 0
resp[10] = 0x30
resp[11] = 0x39
parsed, err := parseResponse(resp)
require.NoError(t, err)
assert.Equal(t, uint8(Version), parsed.Version)
assert.Equal(t, uint8(OpAnnounce|OpReply), parsed.Opcode)
assert.Equal(t, uint8(ResultSuccess), parsed.ResultCode)
assert.Equal(t, uint32(12345), parsed.Epoch)
}
func TestParseResponseErrors(t *testing.T) {
t.Run("too short", func(t *testing.T) {
_, err := parseResponse([]byte{1, 2, 3})
assert.Error(t, err)
})
t.Run("wrong version", func(t *testing.T) {
resp := make([]byte, headerSize)
resp[0] = 1 // Wrong version
resp[1] = OpReply
_, err := parseResponse(resp)
assert.Error(t, err)
})
t.Run("missing reply bit", func(t *testing.T) {
resp := make([]byte, headerSize)
resp[0] = Version
resp[1] = OpAnnounce // Missing OpReply bit
_, err := parseResponse(resp)
assert.Error(t, err)
})
}
func TestResultCodeString(t *testing.T) {
assert.Equal(t, "SUCCESS", ResultCodeString(ResultSuccess))
assert.Equal(t, "NOT_AUTHORIZED", ResultCodeString(ResultNotAuthorized))
assert.Equal(t, "ADDRESS_MISMATCH", ResultCodeString(ResultAddressMismatch))
assert.Contains(t, ResultCodeString(255), "UNKNOWN")
}
func TestProtocolNumber(t *testing.T) {
proto, err := protocolNumber("udp")
require.NoError(t, err)
assert.Equal(t, uint8(ProtoUDP), proto)
proto, err = protocolNumber("tcp")
require.NoError(t, err)
assert.Equal(t, uint8(ProtoTCP), proto)
proto, err = protocolNumber("UDP")
require.NoError(t, err)
assert.Equal(t, uint8(ProtoUDP), proto)
_, err = protocolNumber("icmp")
assert.Error(t, err)
}
func TestClientCreation(t *testing.T) {
gateway := netip.MustParseAddr("192.168.1.1").AsSlice()
client := NewClient(gateway)
assert.Equal(t, net.IP(gateway), client.Gateway())
assert.Equal(t, defaultTimeout, client.timeout)
clientWithTimeout := NewClientWithTimeout(gateway, 5*time.Second)
assert.Equal(t, 5*time.Second, clientWithTimeout.timeout)
}
func TestNATType(t *testing.T) {
n := NewNAT(netip.MustParseAddr("192.168.1.1").AsSlice(), netip.MustParseAddr("192.168.1.100").AsSlice())
assert.Equal(t, "PCP", n.Type())
}
// Integration test - skipped unless PCP_TEST_GATEWAY env is set
func TestClientIntegration(t *testing.T) {
t.Skip("Integration test - run manually with PCP_TEST_GATEWAY=<gateway-ip>")
gateway := netip.MustParseAddr("10.0.1.1").AsSlice() // Change to your test gateway
localIP := netip.MustParseAddr("10.0.1.100").AsSlice() // Change to your local IP
client := NewClient(gateway)
client.SetLocalIP(localIP)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
// Test ANNOUNCE
epoch, err := client.Announce(ctx)
require.NoError(t, err)
t.Logf("Server epoch: %d", epoch)
// Test MAP
resp, err := client.AddPortMapping(ctx, "udp", 51820, 1*time.Hour)
require.NoError(t, err)
t.Logf("Mapping: internal=%d external=%d externalIP=%s",
resp.InternalPort, resp.ExternalPort, resp.ExternalIP)
// Cleanup
err = client.DeletePortMapping(ctx, "udp", 51820)
require.NoError(t, err)
}

View File

@@ -1,209 +0,0 @@
package pcp
import (
"context"
"fmt"
"net"
"net/netip"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/libp2p/go-nat"
"github.com/libp2p/go-netroute"
)
var _ nat.NAT = (*NAT)(nil)
// NAT implements the go-nat NAT interface using PCP.
// Supports dual-stack (IPv4 and IPv6) when available.
// All methods are safe for concurrent use.
//
// TODO: IPv6 pinholes use the local IPv6 address. If the address changes
// (e.g., due to SLAAC rotation or network change), the pinhole becomes stale
// and needs to be recreated with the new address.
type NAT struct {
client *Client
mu sync.RWMutex
// client6 is the IPv6 PCP client, nil if IPv6 is unavailable.
client6 *Client
// localIP6 caches the local IPv6 address used for PCP requests.
localIP6 netip.Addr
}
// NewNAT creates a new NAT instance backed by PCP.
func NewNAT(gateway, localIP net.IP) *NAT {
client := NewClient(gateway)
client.SetLocalIP(localIP)
return &NAT{
client: client,
}
}
// Type returns "PCP" as the NAT type.
func (n *NAT) Type() string {
return "PCP"
}
// GetDeviceAddress returns the gateway IP address.
func (n *NAT) GetDeviceAddress() (net.IP, error) {
return n.client.Gateway(), nil
}
// GetExternalAddress returns the external IP address.
func (n *NAT) GetExternalAddress() (net.IP, error) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
return n.client.GetExternalAddress(ctx)
}
// GetInternalAddress returns the local IP address used to communicate with the gateway.
func (n *NAT) GetInternalAddress() (net.IP, error) {
addr, err := n.client.getLocalIP()
if err != nil {
return nil, err
}
return addr.AsSlice(), nil
}
// AddPortMapping creates a port mapping on both IPv4 and IPv6 (if available).
func (n *NAT) AddPortMapping(ctx context.Context, protocol string, internalPort int, _ string, timeout time.Duration) (int, error) {
resp, err := n.client.AddPortMapping(ctx, protocol, internalPort, timeout)
if err != nil {
return 0, fmt.Errorf("add mapping: %w", err)
}
n.mu.RLock()
client6 := n.client6
localIP6 := n.localIP6
n.mu.RUnlock()
if client6 == nil {
return int(resp.ExternalPort), nil
}
if _, err := client6.AddPortMapping(ctx, protocol, internalPort, timeout); err != nil {
log.Warnf("IPv6 PCP mapping failed (continuing with IPv4): %v", err)
return int(resp.ExternalPort), nil
}
log.Infof("created IPv6 PCP pinhole: %s:%d", localIP6, internalPort)
return int(resp.ExternalPort), nil
}
// DeletePortMapping removes a port mapping from both IPv4 and IPv6.
func (n *NAT) DeletePortMapping(ctx context.Context, protocol string, internalPort int) error {
err := n.client.DeletePortMapping(ctx, protocol, internalPort)
n.mu.RLock()
client6 := n.client6
n.mu.RUnlock()
if client6 != nil {
if err6 := client6.DeletePortMapping(ctx, protocol, internalPort); err6 != nil {
log.Warnf("IPv6 PCP delete mapping failed: %v", err6)
}
}
if err != nil {
return fmt.Errorf("delete mapping: %w", err)
}
return nil
}
// CheckServerHealth sends an ANNOUNCE to verify the server is still responsive.
// Returns the current epoch and whether the server may have restarted (epoch state loss detected).
func (n *NAT) CheckServerHealth(ctx context.Context) (epoch uint32, serverRestarted bool, err error) {
epoch, err = n.client.Announce(ctx)
if err != nil {
return 0, false, fmt.Errorf("announce: %w", err)
}
return epoch, n.client.EpochStateLost(), nil
}
// DiscoverPCP attempts to discover a PCP-capable gateway.
// Returns a NAT interface if PCP is supported, or an error otherwise.
// Discovers both IPv4 and IPv6 gateways when available.
func DiscoverPCP(ctx context.Context) (nat.NAT, error) {
gateway, localIP, err := getDefaultGateway()
if err != nil {
return nil, fmt.Errorf("get default gateway: %w", err)
}
client := NewClient(gateway)
client.SetLocalIP(localIP)
if _, err := client.Announce(ctx); err != nil {
return nil, fmt.Errorf("PCP announce: %w", err)
}
result := &NAT{client: client}
discoverIPv6(ctx, result)
return result, nil
}
func discoverIPv6(ctx context.Context, result *NAT) {
gateway6, localIP6, err := getDefaultGateway6()
if err != nil {
log.Debugf("IPv6 gateway discovery failed: %v", err)
return
}
client6 := NewClient(gateway6)
client6.SetLocalIP(localIP6)
if _, err := client6.Announce(ctx); err != nil {
log.Debugf("PCP IPv6 announce failed: %v", err)
return
}
addr, ok := netip.AddrFromSlice(localIP6)
if !ok {
log.Debugf("invalid IPv6 local IP: %v", localIP6)
return
}
result.mu.Lock()
result.client6 = client6
result.localIP6 = addr
result.mu.Unlock()
log.Debugf("PCP IPv6 gateway discovered: %s (local: %s)", gateway6, localIP6)
}
// getDefaultGateway returns the default IPv4 gateway and local IP using the system routing table.
func getDefaultGateway() (gateway net.IP, localIP net.IP, err error) {
router, err := netroute.New()
if err != nil {
return nil, nil, err
}
_, gateway, localIP, err = router.Route(net.IPv4zero)
if err != nil {
return nil, nil, err
}
if gateway == nil {
return nil, nil, nat.ErrNoNATFound
}
return gateway, localIP, nil
}
// getDefaultGateway6 returns the default IPv6 gateway IP address using the system routing table.
func getDefaultGateway6() (gateway net.IP, localIP net.IP, err error) {
router, err := netroute.New()
if err != nil {
return nil, nil, err
}
_, gateway, localIP, err = router.Route(net.IPv6zero)
if err != nil {
return nil, nil, err
}
if gateway == nil {
return nil, nil, nat.ErrNoNATFound
}
return gateway, localIP, nil
}

View File

@@ -1,225 +0,0 @@
// Package pcp implements the Port Control Protocol (RFC 6887).
//
// # Implemented Features
//
// - ANNOUNCE opcode: Discovers PCP server support
// - MAP opcode: Creates/deletes port mappings (IPv4 NAT) and firewall pinholes (IPv6)
// - Dual-stack: Simultaneous IPv4 and IPv6 support via separate clients
// - Nonce validation: Prevents response spoofing
// - Epoch tracking: Detects server restarts per Section 8.5
// - RFC-compliant retry timing: 3s initial, exponential backoff to 1024s max (Section 8.1.1)
//
// # Not Implemented
//
// - PEER opcode: For outbound peer connections (not needed for inbound NAT traversal)
// - THIRD_PARTY option: For managing mappings on behalf of other devices
// - PREFER_FAILURE option: Requires exact external port or fail (IPv4 NAT only, not needed for IPv6 pinholing)
// - FILTER option: To restrict remote peer addresses
//
// These optional features are omitted because the primary use case is simple
// port forwarding for WireGuard, which only requires MAP with default behavior.
package pcp
import (
"encoding/binary"
"fmt"
"net/netip"
)
const (
// Version is the PCP protocol version (RFC 6887).
Version = 2
// Port is the standard PCP server port.
Port = 5351
// DefaultLifetime is the default requested mapping lifetime in seconds.
DefaultLifetime = 7200 // 2 hours
// Header sizes
headerSize = 24
mapPayloadSize = 36
mapRequestSize = headerSize + mapPayloadSize // 60 bytes
)
// Opcodes
const (
OpAnnounce = 0
OpMap = 1
OpPeer = 2
OpReply = 0x80 // OR'd with opcode in responses
)
// Protocol numbers for MAP requests
const (
ProtoUDP = 17
ProtoTCP = 6
)
// Result codes (RFC 6887 Section 7.4)
const (
ResultSuccess = 0
ResultUnsuppVersion = 1
ResultNotAuthorized = 2
ResultMalformedRequest = 3
ResultUnsuppOpcode = 4
ResultUnsuppOption = 5
ResultMalformedOption = 6
ResultNetworkFailure = 7
ResultNoResources = 8
ResultUnsuppProtocol = 9
ResultUserExQuota = 10
ResultCannotProvideExt = 11
ResultAddressMismatch = 12
ResultExcessiveRemotePeers = 13
)
// ResultCodeString returns a human-readable string for a result code.
func ResultCodeString(code uint8) string {
switch code {
case ResultSuccess:
return "SUCCESS"
case ResultUnsuppVersion:
return "UNSUPP_VERSION"
case ResultNotAuthorized:
return "NOT_AUTHORIZED"
case ResultMalformedRequest:
return "MALFORMED_REQUEST"
case ResultUnsuppOpcode:
return "UNSUPP_OPCODE"
case ResultUnsuppOption:
return "UNSUPP_OPTION"
case ResultMalformedOption:
return "MALFORMED_OPTION"
case ResultNetworkFailure:
return "NETWORK_FAILURE"
case ResultNoResources:
return "NO_RESOURCES"
case ResultUnsuppProtocol:
return "UNSUPP_PROTOCOL"
case ResultUserExQuota:
return "USER_EX_QUOTA"
case ResultCannotProvideExt:
return "CANNOT_PROVIDE_EXTERNAL"
case ResultAddressMismatch:
return "ADDRESS_MISMATCH"
case ResultExcessiveRemotePeers:
return "EXCESSIVE_REMOTE_PEERS"
default:
return fmt.Sprintf("UNKNOWN(%d)", code)
}
}
// Response represents a parsed PCP response header.
type Response struct {
Version uint8
Opcode uint8
ResultCode uint8
Lifetime uint32
Epoch uint32
}
// MapResponse contains the full response to a MAP request.
type MapResponse struct {
Response
Nonce [12]byte
Protocol uint8
InternalPort uint16
ExternalPort uint16
ExternalIP netip.Addr
}
// addrTo16 converts an address to its 16-byte IPv4-mapped IPv6 representation.
func addrTo16(addr netip.Addr) [16]byte {
if addr.Is4() {
return netip.AddrFrom4(addr.As4()).As16()
}
return addr.As16()
}
// addrFrom16 extracts an address from a 16-byte representation, unmapping IPv4.
func addrFrom16(b [16]byte) netip.Addr {
return netip.AddrFrom16(b).Unmap()
}
// buildAnnounceRequest creates a PCP ANNOUNCE request packet.
func buildAnnounceRequest(clientIP netip.Addr) []byte {
req := make([]byte, headerSize)
req[0] = Version
req[1] = OpAnnounce
mapped := addrTo16(clientIP)
copy(req[8:24], mapped[:])
return req
}
// buildMapRequest creates a PCP MAP request packet.
func buildMapRequest(clientIP netip.Addr, nonce [12]byte, protocol uint8, internalPort, suggestedExtPort uint16, suggestedExtIP netip.Addr, lifetime uint32) []byte {
req := make([]byte, mapRequestSize)
// Header
req[0] = Version
req[1] = OpMap
binary.BigEndian.PutUint32(req[4:8], lifetime)
mapped := addrTo16(clientIP)
copy(req[8:24], mapped[:])
// MAP payload
copy(req[24:36], nonce[:])
req[36] = protocol
binary.BigEndian.PutUint16(req[40:42], internalPort)
binary.BigEndian.PutUint16(req[42:44], suggestedExtPort)
if suggestedExtIP.IsValid() {
extMapped := addrTo16(suggestedExtIP)
copy(req[44:60], extMapped[:])
}
return req
}
// parseResponse parses the common PCP response header.
func parseResponse(data []byte) (*Response, error) {
if len(data) < headerSize {
return nil, fmt.Errorf("response too short: %d bytes", len(data))
}
resp := &Response{
Version: data[0],
Opcode: data[1],
ResultCode: data[3], // Byte 2 is reserved, byte 3 is result code (RFC 6887 §7.2)
Lifetime: binary.BigEndian.Uint32(data[4:8]),
Epoch: binary.BigEndian.Uint32(data[8:12]),
}
if resp.Version != Version {
return nil, fmt.Errorf("unsupported PCP version: %d", resp.Version)
}
if resp.Opcode&OpReply == 0 {
return nil, fmt.Errorf("response missing reply bit: opcode=0x%02x", resp.Opcode)
}
return resp, nil
}
// parseMapResponse parses a complete MAP response.
func parseMapResponse(data []byte) (*MapResponse, error) {
if len(data) < mapRequestSize {
return nil, fmt.Errorf("MAP response too short: %d bytes", len(data))
}
resp, err := parseResponse(data)
if err != nil {
return nil, fmt.Errorf("parse header: %w", err)
}
mapResp := &MapResponse{
Response: *resp,
Protocol: data[36],
InternalPort: binary.BigEndian.Uint16(data[40:42]),
ExternalPort: binary.BigEndian.Uint16(data[42:44]),
ExternalIP: addrFrom16([16]byte(data[44:60])),
}
copy(mapResp.Nonce[:], data[24:36])
return mapResp, nil
}

View File

@@ -1,63 +0,0 @@
//go:build !js
package portforward
import (
"context"
"fmt"
"github.com/libp2p/go-nat"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/portforward/pcp"
)
// discoverGateway is the function used for NAT gateway discovery.
// It can be replaced in tests to avoid real network operations.
// Tries PCP first, then falls back to NAT-PMP/UPnP.
var discoverGateway = defaultDiscoverGateway
func defaultDiscoverGateway(ctx context.Context) (nat.NAT, error) {
pcpGateway, err := pcp.DiscoverPCP(ctx)
if err == nil {
return pcpGateway, nil
}
log.Debugf("PCP discovery failed: %v, trying NAT-PMP/UPnP", err)
return nat.DiscoverGateway(ctx)
}
// State is persisted only for crash recovery cleanup
type State struct {
InternalPort uint16 `json:"internal_port,omitempty"`
Protocol string `json:"protocol,omitempty"`
}
func (s *State) Name() string {
return "port_forward_state"
}
// Cleanup implements statemanager.CleanableState for crash recovery
func (s *State) Cleanup() error {
if s.InternalPort == 0 {
return nil
}
log.Infof("cleaning up stale port mapping for port %d", s.InternalPort)
ctx, cancel := context.WithTimeout(context.Background(), discoveryTimeout)
defer cancel()
gateway, err := discoverGateway(ctx)
if err != nil {
// Discovery failure is not an error - gateway may not exist
log.Debugf("cleanup: no gateway found: %v", err)
return nil
}
if err := gateway.DeletePortMapping(ctx, s.Protocol, int(s.InternalPort)); err != nil {
return fmt.Errorf("delete port mapping: %w", err)
}
return nil
}

View File

@@ -3,7 +3,9 @@ package client
import (
"context"
"fmt"
"net"
"reflect"
"strconv"
"time"
log "github.com/sirupsen/logrus"
@@ -564,7 +566,7 @@ func HandlerFromRoute(params common.HandlerParams) RouteHandler {
return dnsinterceptor.New(params)
case handlerTypeDynamic:
dns := nbdns.NewServiceViaMemory(params.WgInterface)
dnsAddr := fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort())
dnsAddr := net.JoinHostPort(dns.RuntimeIP().String(), strconv.Itoa(dns.RuntimePort()))
return dynamic.NewRoute(params, dnsAddr)
default:
return static.NewRoute(params)

View File

@@ -4,8 +4,10 @@ import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"runtime"
"strconv"
"strings"
"sync"
"sync/atomic"
@@ -249,7 +251,7 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
r.MsgHdr.AuthenticatedData = true
}
upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), uint16(d.forwarderPort.Load()))
upstream := net.JoinHostPort(upstreamIP.String(), strconv.FormatUint(uint64(d.forwarderPort.Load()), 10))
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
defer cancel()

View File

@@ -1,12 +1,10 @@
#!/usr/bin/env bash
set -eEuo pipefail
: ${NB_ENTRYPOINT_SERVICE_TIMEOUT:="5"}
: ${NB_ENTRYPOINT_LOGIN_TIMEOUT:="5"}
: ${NB_ENTRYPOINT_SERVICE_TIMEOUT:="30"}
NETBIRD_BIN="${NETBIRD_BIN:-"netbird"}"
export NB_LOG_FILE="${NB_LOG_FILE:-"console,/var/log/netbird/client.log"}"
service_pids=()
log_file_path=""
_log() {
# mimic Go logger's output for easier parsing
@@ -33,60 +31,29 @@ on_exit() {
fi
}
wait_for_message() {
local timeout="${1}" message="${2}"
if test "${timeout}" -eq 0; then
info "not waiting for log line ${message@Q} due to zero timeout."
elif test -n "${log_file_path}"; then
info "waiting for log line ${message@Q} for ${timeout} seconds..."
grep -E -q "${message}" <(timeout "${timeout}" tail -F "${log_file_path}" 2>/dev/null)
else
info "log file unsupported, sleeping for ${timeout} seconds..."
sleep "${timeout}"
fi
}
locate_log_file() {
local log_files_string="${1}"
while read -r log_file; do
case "${log_file}" in
console | syslog) ;;
*)
log_file_path="${log_file}"
return
;;
esac
done < <(sed 's#,#\n#g' <<<"${log_files_string}")
warn "log files parsing for ${log_files_string@Q} is not supported by debug bundles"
warn "please consider removing the \$NB_LOG_FILE or setting it to real file, before gathering debug bundles."
}
wait_for_daemon_startup() {
local timeout="${1}"
if test -n "${log_file_path}"; then
if ! wait_for_message "${timeout}" "started daemon server"; then
warn "log line containing 'started daemon server' not found after ${timeout} seconds"
warn "daemon failed to start, exiting..."
exit 1
fi
else
warn "daemon service startup not discovered, sleeping ${timeout} instead"
sleep "${timeout}"
if [[ "${timeout}" -eq 0 ]]; then
info "not waiting for daemon startup due to zero timeout."
return
fi
local deadline=$((SECONDS + timeout))
while [[ "${SECONDS}" -lt "${deadline}" ]]; do
if "${NETBIRD_BIN}" status --check live 2>/dev/null; then
return
fi
sleep 1
done
warn "daemon did not become responsive after ${timeout} seconds, exiting..."
exit 1
}
login_if_needed() {
local timeout="${1}"
if test -n "${log_file_path}" && wait_for_message "${timeout}" 'peer has been successfully registered|management connection state READY'; then
info "already logged in, skipping 'netbird up'..."
else
info "logging in..."
"${NETBIRD_BIN}" up
fi
connect() {
info "running 'netbird up'..."
"${NETBIRD_BIN}" up
return $?
}
main() {
@@ -95,9 +62,8 @@ main() {
service_pids+=("$!")
info "registered new service process 'netbird service run', currently running: ${service_pids[@]@Q}"
locate_log_file "${NB_LOG_FILE}"
wait_for_daemon_startup "${NB_ENTRYPOINT_SERVICE_TIMEOUT}"
login_if_needed "${NB_ENTRYPOINT_LOGIN_TIMEOUT}"
connect
wait "${service_pids[@]}"
}

View File

@@ -26,6 +26,15 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
log.Warnf("failed to get latest sync response: %v", err)
}
var clientMetrics debug.MetricsExporter
if s.connectClient != nil {
if engine := s.connectClient.Engine(); engine != nil {
if cm := engine.GetClientMetrics(); cm != nil {
clientMetrics = cm
}
}
}
var cpuProfileData []byte
if s.cpuProfileBuf != nil && !s.cpuProfiling {
cpuProfileData = s.cpuProfileBuf.Bytes()
@@ -54,6 +63,7 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
LogPath: s.logFile,
CPUProfile: cpuProfileData,
RefreshStatus: refreshStatus,
ClientMetrics: clientMetrics,
},
debug.BundleConfig{
Anonymize: req.GetAnonymize(),

View File

@@ -9,11 +9,6 @@ import (
"github.com/netbirdio/netbird/client/ssh/config"
)
// registerStates registers all states that need crash recovery cleanup.
// Note: portforward.State is intentionally NOT registered here to avoid blocking startup
// for up to 10 seconds during NAT gateway discovery when no gateway is present.
// The gateway reference cannot be persisted across restarts, so cleanup requires re-discovery.
// Port forward cleanup is handled by the Manager during normal operation instead.
func registerStates(mgr *statemanager.Manager) {
mgr.RegisterState(&dns.ShutdownState{})
mgr.RegisterState(&systemops.ShutdownState{})

View File

@@ -11,11 +11,6 @@ import (
"github.com/netbirdio/netbird/client/ssh/config"
)
// registerStates registers all states that need crash recovery cleanup.
// Note: portforward.State is intentionally NOT registered here to avoid blocking startup
// for up to 10 seconds during NAT gateway discovery when no gateway is present.
// The gateway reference cannot be persisted across restarts, so cleanup requires re-discovery.
// Port forward cleanup is handled by the Manager during normal operation instead.
func registerStates(mgr *statemanager.Manager) {
mgr.RegisterState(&dns.ShutdownState{})
mgr.RegisterState(&systemops.ShutdownState{})

View File

@@ -25,6 +25,38 @@ import (
"github.com/netbirdio/netbird/version"
)
// DaemonStatus represents the current state of the NetBird daemon.
// These values mirror internal.StatusType but are defined here to avoid an import cycle.
type DaemonStatus string
const (
DaemonStatusIdle DaemonStatus = "Idle"
DaemonStatusConnecting DaemonStatus = "Connecting"
DaemonStatusConnected DaemonStatus = "Connected"
DaemonStatusNeedsLogin DaemonStatus = "NeedsLogin"
DaemonStatusLoginFailed DaemonStatus = "LoginFailed"
DaemonStatusSessionExpired DaemonStatus = "SessionExpired"
)
// ParseDaemonStatus converts a raw status string to DaemonStatus.
// Unrecognized values are preserved as-is to remain visible during version skew.
func ParseDaemonStatus(s string) DaemonStatus {
return DaemonStatus(s)
}
// ConvertOptions holds parameters for ConvertToStatusOutputOverview.
type ConvertOptions struct {
Anonymize bool
DaemonVersion string
DaemonStatus DaemonStatus
StatusFilter string
PrefixNamesFilter []string
PrefixNamesFilterMap map[string]struct{}
IPsFilter map[string]struct{}
ConnectionTypeFilter string
ProfileName string
}
type PeerStateDetailOutput struct {
FQDN string `json:"fqdn" yaml:"fqdn"`
IP string `json:"netbirdIp" yaml:"netbirdIp"`
@@ -102,6 +134,7 @@ type OutputOverview struct {
Peers PeersStateOutput `json:"peers" yaml:"peers"`
CliVersion string `json:"cliVersion" yaml:"cliVersion"`
DaemonVersion string `json:"daemonVersion" yaml:"daemonVersion"`
DaemonStatus DaemonStatus `json:"daemonStatus" yaml:"daemonStatus"`
ManagementState ManagementStateOutput `json:"management" yaml:"management"`
SignalState SignalStateOutput `json:"signal" yaml:"signal"`
Relays RelayStateOutput `json:"relays" yaml:"relays"`
@@ -120,7 +153,8 @@ type OutputOverview struct {
SSHServerState SSHServerStateOutput `json:"sshServer" yaml:"sshServer"`
}
func ConvertToStatusOutputOverview(pbFullStatus *proto.FullStatus, anon bool, daemonVersion string, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}, connectionTypeFilter string, profName string) OutputOverview {
// ConvertToStatusOutputOverview converts protobuf status to the output overview.
func ConvertToStatusOutputOverview(pbFullStatus *proto.FullStatus, opts ConvertOptions) OutputOverview {
managementState := pbFullStatus.GetManagementState()
managementOverview := ManagementStateOutput{
URL: managementState.GetURL(),
@@ -137,12 +171,13 @@ func ConvertToStatusOutputOverview(pbFullStatus *proto.FullStatus, anon bool, da
relayOverview := mapRelays(pbFullStatus.GetRelays())
sshServerOverview := mapSSHServer(pbFullStatus.GetSshServerState())
peersOverview := mapPeers(pbFullStatus.GetPeers(), statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter, connectionTypeFilter)
peersOverview := mapPeers(pbFullStatus.GetPeers(), opts.StatusFilter, opts.PrefixNamesFilter, opts.PrefixNamesFilterMap, opts.IPsFilter, opts.ConnectionTypeFilter)
overview := OutputOverview{
Peers: peersOverview,
CliVersion: version.NetbirdVersion(),
DaemonVersion: daemonVersion,
DaemonVersion: opts.DaemonVersion,
DaemonStatus: opts.DaemonStatus,
ManagementState: managementOverview,
SignalState: signalOverview,
Relays: relayOverview,
@@ -157,11 +192,11 @@ func ConvertToStatusOutputOverview(pbFullStatus *proto.FullStatus, anon bool, da
NSServerGroups: mapNSGroups(pbFullStatus.GetDnsServers()),
Events: mapEvents(pbFullStatus.GetEvents()),
LazyConnectionEnabled: pbFullStatus.GetLazyConnectionEnabled(),
ProfileName: profName,
ProfileName: opts.ProfileName,
SSHServerState: sshServerOverview,
}
if anon {
if opts.Anonymize {
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
anonymizeOverview(anonymizer, &overview)
}

View File

@@ -176,6 +176,7 @@ var overview = OutputOverview{
Events: []SystemEventOutput{},
CliVersion: version.NetbirdVersion(),
DaemonVersion: "0.14.1",
DaemonStatus: DaemonStatusConnected,
ManagementState: ManagementStateOutput{
URL: "my-awesome-management.com:443",
Connected: true,
@@ -238,7 +239,10 @@ var overview = OutputOverview{
}
func TestConversionFromFullStatusToOutputOverview(t *testing.T) {
convertedResult := ConvertToStatusOutputOverview(resp.GetFullStatus(), false, resp.GetDaemonVersion(), "", nil, nil, nil, "", "")
convertedResult := ConvertToStatusOutputOverview(resp.GetFullStatus(), ConvertOptions{
DaemonVersion: resp.GetDaemonVersion(),
DaemonStatus: ParseDaemonStatus(resp.GetStatus()),
})
assert.Equal(t, overview, convertedResult)
}
@@ -329,6 +333,7 @@ func TestParsingToJSON(t *testing.T) {
},
"cliVersion": "development",
"daemonVersion": "0.14.1",
"daemonStatus": "Connected",
"management": {
"url": "my-awesome-management.com:443",
"connected": true,
@@ -452,6 +457,7 @@ func TestParsingToYAML(t *testing.T) {
networks: []
cliVersion: development
daemonVersion: 0.14.1
daemonStatus: Connected
management:
url: my-awesome-management.com:443
connected: true

View File

@@ -18,7 +18,6 @@ import (
"github.com/netbirdio/netbird/client/wasm/internal/rdp"
"github.com/netbirdio/netbird/client/wasm/internal/ssh"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/version"
)
const (
@@ -350,7 +349,7 @@ func getStatusOverview(client *netbird.Client) (nbstatus.OutputOverview, error)
pbFullStatus := fullStatus.ToProto()
return nbstatus.ConvertToStatusOutputOverview(pbFullStatus, false, version.NetbirdVersion(), "", nil, nil, nil, "", ""), nil
return nbstatus.ConvertToStatusOutputOverview(pbFullStatus, nbstatus.ConvertOptions{}), nil
}
// createStatusMethod creates the status method that returns JSON

9
go.mod
View File

@@ -30,10 +30,10 @@ require (
require (
fyne.io/fyne/v2 v2.7.0
fyne.io/systray v1.12.1-0.20260116214250-81f8e1a496f9
github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible
github.com/awnumar/memguard v0.23.0
github.com/aws/aws-sdk-go-v2 v1.36.3
github.com/aws/aws-sdk-go-v2/config v1.29.14
github.com/aws/aws-sdk-go-v2/credentials v1.17.67
github.com/aws/aws-sdk-go-v2/service/s3 v1.79.2
github.com/c-robinson/iplib v1.0.3
github.com/caddyserver/certmagic v0.21.3
@@ -62,7 +62,6 @@ require (
github.com/hashicorp/go-version v1.6.0
github.com/jackc/pgx/v5 v5.5.5
github.com/libdns/route53 v1.5.0
github.com/libp2p/go-nat v0.2.0
github.com/libp2p/go-netroute v0.2.1
github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81
github.com/mdlayher/socket v0.5.1
@@ -145,7 +144,6 @@ require (
github.com/apapsch/go-jsonmerge/v2 v2.0.0 // indirect
github.com/awnumar/memcall v0.4.0 // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.10 // indirect
github.com/aws/aws-sdk-go-v2/credentials v1.17.67 // indirect
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34 // indirect
@@ -202,12 +200,10 @@ require (
github.com/hashicorp/errwrap v1.1.0 // indirect
github.com/hashicorp/go-uuid v1.0.3 // indirect
github.com/huandu/xstrings v1.5.0 // indirect
github.com/huin/goupnp v1.2.0 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
github.com/jackc/puddle/v2 v2.2.1 // indirect
github.com/jackpal/go-nat-pmp v1.0.2 // indirect
github.com/jeandeaual/go-locale v0.0.0-20250612000132-0ef82f21eade // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
@@ -217,7 +213,6 @@ require (
github.com/kelseyhightower/envconfig v1.4.0 // indirect
github.com/klauspost/compress v1.18.0 // indirect
github.com/klauspost/cpuid/v2 v2.2.7 // indirect
github.com/koron/go-ssdp v0.0.4 // indirect
github.com/kr/fs v0.1.0 // indirect
github.com/lib/pq v1.10.9 // indirect
github.com/libdns/libdns v0.2.2 // indirect
@@ -259,7 +254,7 @@ require (
github.com/russellhaering/goxmldsig v1.5.0 // indirect
github.com/rymdport/portal v0.4.2 // indirect
github.com/shirou/gopsutil/v4 v4.25.1 // indirect
github.com/shoenig/go-m1cpu v0.1.6 // indirect
github.com/shoenig/go-m1cpu v0.2.0 // indirect
github.com/shopspring/decimal v1.4.0 // indirect
github.com/spf13/cast v1.7.0 // indirect
github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c // indirect

16
go.sum
View File

@@ -34,8 +34,6 @@ github.com/Masterminds/sprig/v3 v3.3.0/go.mod h1:Zy1iXRYNqNLUolqCpL4uhk6SHUMAOSC
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/RaveNoX/go-jsoncommentstrip v1.0.0/go.mod h1:78ihd09MekBnJnxpICcwzCMzGrKSKYe4AqU6PDYYpjk=
github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible h1:hqcTK6ZISdip65SR792lwYJTa/axESA0889D3UlZbLo=
github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible/go.mod h1:6B1nuc1MUs6c62ODZDl7hVE5Pv7O2XGSkgg2olnq34I=
github.com/alexbrainman/sspi v0.0.0-20250919150558-7d374ff0d59e h1:4dAU9FXIyQktpoUAgOJK3OTFc/xug0PCXYCqU0FgDKI=
github.com/alexbrainman/sspi v0.0.0-20250919150558-7d374ff0d59e/go.mod h1:cEWa1LVoE5KvSD9ONXsZrj0z6KqySlCCNKHlLzbqAt4=
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8=
@@ -283,8 +281,6 @@ github.com/hashicorp/go-version v1.6.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI=
github.com/huandu/xstrings v1.5.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE=
github.com/huin/goupnp v1.2.0 h1:uOKW26NG1hsSSbXIZ1IR7XP9Gjd1U8pnLaCMgntmkmY=
github.com/huin/goupnp v1.2.0/go.mod h1:gnGPsThkYa7bFi/KWmEysQRf48l2dvR5bxr2OFckNX8=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
@@ -295,8 +291,6 @@ github.com/jackc/pgx/v5 v5.5.5 h1:amBjrZVmksIdNjxGW/IiIMzxMKZFelXbUoPNb+8sjQw=
github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A=
github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/jackpal/go-nat-pmp v1.0.2 h1:KzKSgb7qkJvOUTqYl9/Hg/me3pWgBmERKrTGD7BdWus=
github.com/jackpal/go-nat-pmp v1.0.2/go.mod h1:QPH045xvCAeXUZOxsnwmrtiCoxIr9eob+4orBN1SBKc=
github.com/jcmturner/aescts/v2 v2.0.0 h1:9YKLH6ey7H4eDBXW8khjYslgyqG2xZikXP0EQFKrle8=
github.com/jcmturner/aescts/v2 v2.0.0/go.mod h1:AiaICIRyfYg35RUkr8yESTqvSy7csK90qZ5xfvvsoNs=
github.com/jcmturner/dnsutils/v2 v2.0.0 h1:lltnkeZGL0wILNvrNiVCR6Ro5PGU/SeBvVO/8c/iPbo=
@@ -334,8 +328,6 @@ github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYW
github.com/klauspost/cpuid/v2 v2.0.12/go.mod h1:g2LTdtYhdyuGPqyWyv7qRAmj1WBqxuObKfj5c0PQa7c=
github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM=
github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
github.com/koron/go-ssdp v0.0.4 h1:1IDwrghSKYM7yLf7XCzbByg2sJ/JcNOZRXS2jczTwz0=
github.com/koron/go-ssdp v0.0.4/go.mod h1:oDXq+E5IL5q0U8uSBcoAXzTzInwy5lEgC91HoKtbmZk=
github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8=
github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
@@ -354,8 +346,6 @@ github.com/libdns/libdns v0.2.2 h1:O6ws7bAfRPaBsgAYt8MDe2HcNBGC29hkZ9MX2eUSX3s=
github.com/libdns/libdns v0.2.2/go.mod h1:4Bj9+5CQiNMVGf87wjX4CY3HQJypUHRuLvlsfsZqLWQ=
github.com/libdns/route53 v1.5.0 h1:2SKdpPFl/qgWsXQvsLNJJAoX7rSxlk7zgoL4jnWdXVA=
github.com/libdns/route53 v1.5.0/go.mod h1:joT4hKmaTNKHEwb7GmZ65eoDz1whTu7KKYPS8ZqIh6Q=
github.com/libp2p/go-nat v0.2.0 h1:Tyz+bUFAYqGyJ/ppPPymMGbIgNRH+WqC5QrT5fKrrGk=
github.com/libp2p/go-nat v0.2.0/go.mod h1:3MJr+GRpRkyT65EpVPBstXLvOlAPzUVlG6Pwg9ohLJk=
github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81 h1:J56rFEfUTFT9j9CiRXhi1r8lUJ4W5idG3CiaBZGojNU=
github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81/go.mod h1:RD8ML/YdXctQ7qbcizZkw5mZ6l8Ogrl1dodBzVJduwI=
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
@@ -520,10 +510,12 @@ github.com/shirou/gopsutil/v3 v3.24.4 h1:dEHgzZXt4LMNm+oYELpzl9YCqV65Yr/6SfrvgRB
github.com/shirou/gopsutil/v3 v3.24.4/go.mod h1:lTd2mdiOspcqLgAnr9/nGi71NkeMpWKdmhuxm9GusH8=
github.com/shirou/gopsutil/v4 v4.25.1 h1:QSWkTc+fu9LTAWfkZwZ6j8MSUk4A2LV7rbH0ZqmLjXs=
github.com/shirou/gopsutil/v4 v4.25.1/go.mod h1:RoUCUpndaJFtT+2zsZzzmhvbfGoDCJ7nFXKJf8GqJbI=
github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFtM=
github.com/shoenig/go-m1cpu v0.1.6/go.mod h1:1JJMcUBvfNwpq05QDQVAnx3gUHr9IYF7GNg9SUEw2VQ=
github.com/shoenig/test v0.6.4 h1:kVTaSd7WLz5WZ2IaoM0RSzRsUD+m8wRR+5qvntpn4LU=
github.com/shoenig/go-m1cpu v0.2.0 h1:t4GNqvPZ84Vjtpboo/kT3pIkbaK3vc+JIlD/Wz1zSFY=
github.com/shoenig/go-m1cpu v0.2.0/go.mod h1:KkDOw6m3ZJQAPHbrzkZki4hnx+pDRR1Lo+ldA56wD5w=
github.com/shoenig/test v0.6.4/go.mod h1:byHiCGXqrVaflBLAMq/srcZIHynQPQgeyvkvXnjqq0k=
github.com/shoenig/test v1.7.0 h1:eWcHtTXa6QLnBvm0jgEabMRN/uJ4DMV3M8xUGgRkZmk=
github.com/shoenig/test v1.7.0/go.mod h1:UxJ6u/x2v/TNs/LoLxBNJRV9DiwBBKYxXSyczsBHFoI=
github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k=
github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=

View File

@@ -17,6 +17,9 @@ type Domain struct {
// SupportsCustomPorts is populated at query time for free domains from the
// proxy cluster capabilities. Not persisted.
SupportsCustomPorts *bool `gorm:"-"`
// RequireSubdomain is populated at query time. When true, the domain
// cannot be used bare and a subdomain label must be prepended. Not persisted.
RequireSubdomain *bool `gorm:"-"`
}
// EventMeta returns activity event metadata for a domain

View File

@@ -47,6 +47,7 @@ func domainToApi(d *domain.Domain) api.ReverseProxyDomain {
Type: domainTypeToApi(d.Type),
Validated: d.Validated,
SupportsCustomPorts: d.SupportsCustomPorts,
RequireSubdomain: d.RequireSubdomain,
}
if d.TargetCluster != "" {
resp.TargetCluster = &d.TargetCluster

View File

@@ -0,0 +1,172 @@
package manager
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
)
func TestExtractClusterFromFreeDomain(t *testing.T) {
clusters := []string{"eu1.proxy.netbird.io", "us1.proxy.netbird.io"}
tests := []struct {
name string
domain string
wantOK bool
wantVal string
}{
{
name: "subdomain of cluster matches",
domain: "myapp.eu1.proxy.netbird.io",
wantOK: true,
wantVal: "eu1.proxy.netbird.io",
},
{
name: "deep subdomain of cluster matches",
domain: "foo.bar.eu1.proxy.netbird.io",
wantOK: true,
wantVal: "eu1.proxy.netbird.io",
},
{
name: "bare cluster domain matches",
domain: "eu1.proxy.netbird.io",
wantOK: true,
wantVal: "eu1.proxy.netbird.io",
},
{
name: "unrelated domain does not match",
domain: "example.com",
wantOK: false,
},
{
name: "partial suffix does not match",
domain: "fakeu1.proxy.netbird.io",
wantOK: false,
},
{
name: "second cluster matches",
domain: "app.us1.proxy.netbird.io",
wantOK: true,
wantVal: "us1.proxy.netbird.io",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
cluster, ok := ExtractClusterFromFreeDomain(tc.domain, clusters)
assert.Equal(t, tc.wantOK, ok)
if ok {
assert.Equal(t, tc.wantVal, cluster)
}
})
}
}
func TestExtractClusterFromCustomDomains(t *testing.T) {
customDomains := []*domain.Domain{
{Domain: "example.com", TargetCluster: "eu1.proxy.netbird.io"},
{Domain: "proxy.corp.io", TargetCluster: "us1.proxy.netbird.io"},
}
tests := []struct {
name string
domain string
wantOK bool
wantVal string
}{
{
name: "subdomain of custom domain matches",
domain: "app.example.com",
wantOK: true,
wantVal: "eu1.proxy.netbird.io",
},
{
name: "bare custom domain matches",
domain: "example.com",
wantOK: true,
wantVal: "eu1.proxy.netbird.io",
},
{
name: "deep subdomain of custom domain matches",
domain: "a.b.example.com",
wantOK: true,
wantVal: "eu1.proxy.netbird.io",
},
{
name: "subdomain of multi-level custom domain matches",
domain: "app.proxy.corp.io",
wantOK: true,
wantVal: "us1.proxy.netbird.io",
},
{
name: "bare multi-level custom domain matches",
domain: "proxy.corp.io",
wantOK: true,
wantVal: "us1.proxy.netbird.io",
},
{
name: "unrelated domain does not match",
domain: "other.com",
wantOK: false,
},
{
name: "partial suffix does not match custom domain",
domain: "fakeexample.com",
wantOK: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
cluster, ok := extractClusterFromCustomDomains(tc.domain, customDomains)
assert.Equal(t, tc.wantOK, ok)
if ok {
assert.Equal(t, tc.wantVal, cluster)
}
})
}
}
func TestExtractClusterFromCustomDomains_OverlappingDomains(t *testing.T) {
customDomains := []*domain.Domain{
{Domain: "example.com", TargetCluster: "cluster-generic"},
{Domain: "app.example.com", TargetCluster: "cluster-app"},
}
tests := []struct {
name string
domain string
wantVal string
}{
{
name: "exact match on more specific domain",
domain: "app.example.com",
wantVal: "cluster-app",
},
{
name: "subdomain of more specific domain",
domain: "api.app.example.com",
wantVal: "cluster-app",
},
{
name: "subdomain of generic domain",
domain: "other.example.com",
wantVal: "cluster-generic",
},
{
name: "bare generic domain",
domain: "example.com",
wantVal: "cluster-generic",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
cluster, ok := extractClusterFromCustomDomains(tc.domain, customDomains)
assert.True(t, ok)
assert.Equal(t, tc.wantVal, cluster)
})
}
}

View File

@@ -35,6 +35,7 @@ type proxyManager interface {
type clusterCapabilities interface {
ClusterSupportsCustomPorts(clusterAddr string) *bool
ClusterRequireSubdomain(clusterAddr string) *bool
}
type Manager struct {
@@ -98,6 +99,7 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d
}
if m.clusterCapabilities != nil {
d.SupportsCustomPorts = m.clusterCapabilities.ClusterSupportsCustomPorts(cluster)
d.RequireSubdomain = m.clusterCapabilities.ClusterRequireSubdomain(cluster)
}
ret = append(ret, d)
}
@@ -115,6 +117,8 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d
if m.clusterCapabilities != nil && d.TargetCluster != "" {
cd.SupportsCustomPorts = m.clusterCapabilities.ClusterSupportsCustomPorts(d.TargetCluster)
}
// Custom domains never require a subdomain by default since
// the account owns them and should be able to use the bare domain.
ret = append(ret, cd)
}
@@ -302,13 +306,19 @@ func (m Manager) DeriveClusterFromDomain(ctx context.Context, accountID, domain
return "", fmt.Errorf("domain %s does not match any available proxy cluster", domain)
}
func extractClusterFromCustomDomains(domain string, customDomains []*domain.Domain) (string, bool) {
for _, customDomain := range customDomains {
if strings.HasSuffix(domain, "."+customDomain.Domain) {
return customDomain.TargetCluster, true
func extractClusterFromCustomDomains(serviceDomain string, customDomains []*domain.Domain) (string, bool) {
bestCluster := ""
bestLen := -1
for _, cd := range customDomains {
if serviceDomain != cd.Domain && !strings.HasSuffix(serviceDomain, "."+cd.Domain) {
continue
}
if l := len(cd.Domain); l > bestLen {
bestLen = l
bestCluster = cd.TargetCluster
}
}
return "", false
return bestCluster, bestLen >= 0
}
// ExtractClusterFromFreeDomain extracts the cluster address from a free domain.

View File

@@ -13,8 +13,9 @@ import (
type Manager interface {
Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
Disconnect(ctx context.Context, proxyID string) error
Heartbeat(ctx context.Context, proxyID string) error
Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
GetActiveClusters(ctx context.Context) ([]Cluster, error)
CleanupStale(ctx context.Context, inactivityDuration time.Duration) error
}
@@ -34,4 +35,5 @@ type Controller interface {
UnregisterProxyFromCluster(ctx context.Context, clusterAddr, proxyID string) error
GetProxiesForCluster(clusterAddr string) []string
ClusterSupportsCustomPorts(clusterAddr string) *bool
ClusterRequireSubdomain(clusterAddr string) *bool
}

View File

@@ -77,6 +77,12 @@ func (c *GRPCController) ClusterSupportsCustomPorts(clusterAddr string) *bool {
return c.proxyGRPCServer.ClusterSupportsCustomPorts(clusterAddr)
}
// ClusterRequireSubdomain returns whether the cluster requires a subdomain label.
// Returns nil when no proxy has reported the capability (defaults to false).
func (c *GRPCController) ClusterRequireSubdomain(clusterAddr string) *bool {
return c.proxyGRPCServer.ClusterRequireSubdomain(clusterAddr)
}
// GetProxiesForCluster returns all proxy IDs registered for a specific cluster.
func (c *GRPCController) GetProxiesForCluster(clusterAddr string) []string {
proxySet, ok := c.clusterProxies.Load(clusterAddr)

View File

@@ -13,8 +13,9 @@ import (
// store defines the interface for proxy persistence operations
type store interface {
SaveProxy(ctx context.Context, p *proxy.Proxy) error
UpdateProxyHeartbeat(ctx context.Context, proxyID string) error
UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error)
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
}
@@ -86,11 +87,13 @@ func (m Manager) Disconnect(ctx context.Context, proxyID string) error {
}
// Heartbeat updates the proxy's last seen timestamp
func (m Manager) Heartbeat(ctx context.Context, proxyID string) error {
if err := m.store.UpdateProxyHeartbeat(ctx, proxyID); err != nil {
func (m Manager) Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
if err := m.store.UpdateProxyHeartbeat(ctx, proxyID, clusterAddress, ipAddress); err != nil {
log.WithContext(ctx).Debugf("failed to update proxy %s heartbeat: %v", proxyID, err)
return err
}
log.WithContext(ctx).Tracef("updated heartbeat for proxy %s", proxyID)
m.metrics.IncrementProxyHeartbeatCount()
return nil
}
@@ -105,6 +108,16 @@ func (m Manager) GetActiveClusterAddresses(ctx context.Context) ([]string, error
return addresses, nil
}
// GetActiveClusters returns all active proxy clusters with their connected proxy count.
func (m Manager) GetActiveClusters(ctx context.Context) ([]proxy.Cluster, error) {
clusters, err := m.store.GetActiveProxyClusters(ctx)
if err != nil {
log.WithContext(ctx).Errorf("failed to get active proxy clusters: %v", err)
return nil, err
}
return clusters, nil
}
// CleanupStale removes proxies that haven't sent heartbeat in the specified duration
func (m Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error {
if err := m.store.CleanupStaleProxies(ctx, inactivityDuration); err != nil {

View File

@@ -93,18 +93,33 @@ func (mr *MockManagerMockRecorder) GetActiveClusterAddresses(ctx interface{}) *g
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusterAddresses", reflect.TypeOf((*MockManager)(nil).GetActiveClusterAddresses), ctx)
}
// Heartbeat mocks base method.
func (m *MockManager) Heartbeat(ctx context.Context, proxyID string) error {
// GetActiveClusters mocks base method.
func (m *MockManager) GetActiveClusters(ctx context.Context) ([]Cluster, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Heartbeat", ctx, proxyID)
ret := m.ctrl.Call(m, "GetActiveClusters", ctx)
ret0, _ := ret[0].([]Cluster)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetActiveClusters indicates an expected call of GetActiveClusters.
func (mr *MockManagerMockRecorder) GetActiveClusters(ctx interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusters", reflect.TypeOf((*MockManager)(nil).GetActiveClusters), ctx)
}
// Heartbeat mocks base method.
func (m *MockManager) Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Heartbeat", ctx, proxyID, clusterAddress, ipAddress)
ret0, _ := ret[0].(error)
return ret0
}
// Heartbeat indicates an expected call of Heartbeat.
func (mr *MockManagerMockRecorder) Heartbeat(ctx, proxyID interface{}) *gomock.Call {
func (mr *MockManagerMockRecorder) Heartbeat(ctx, proxyID, clusterAddress, ipAddress interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Heartbeat", reflect.TypeOf((*MockManager)(nil).Heartbeat), ctx, proxyID)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Heartbeat", reflect.TypeOf((*MockManager)(nil).Heartbeat), ctx, proxyID, clusterAddress, ipAddress)
}
// MockController is a mock of Controller interface.
@@ -130,20 +145,6 @@ func (m *MockController) EXPECT() *MockControllerMockRecorder {
return m.recorder
}
// GetOIDCValidationConfig mocks base method.
func (m *MockController) GetOIDCValidationConfig() OIDCValidationConfig {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetOIDCValidationConfig")
ret0, _ := ret[0].(OIDCValidationConfig)
return ret0
}
// GetOIDCValidationConfig indicates an expected call of GetOIDCValidationConfig.
func (mr *MockControllerMockRecorder) GetOIDCValidationConfig() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOIDCValidationConfig", reflect.TypeOf((*MockController)(nil).GetOIDCValidationConfig))
}
// ClusterSupportsCustomPorts mocks base method.
func (m *MockController) ClusterSupportsCustomPorts(clusterAddr string) *bool {
m.ctrl.T.Helper()
@@ -158,6 +159,34 @@ func (mr *MockControllerMockRecorder) ClusterSupportsCustomPorts(clusterAddr int
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterSupportsCustomPorts", reflect.TypeOf((*MockController)(nil).ClusterSupportsCustomPorts), clusterAddr)
}
// ClusterRequireSubdomain mocks base method.
func (m *MockController) ClusterRequireSubdomain(clusterAddr string) *bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ClusterRequireSubdomain", clusterAddr)
ret0, _ := ret[0].(*bool)
return ret0
}
// ClusterRequireSubdomain indicates an expected call of ClusterRequireSubdomain.
func (mr *MockControllerMockRecorder) ClusterRequireSubdomain(clusterAddr interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterRequireSubdomain", reflect.TypeOf((*MockController)(nil).ClusterRequireSubdomain), clusterAddr)
}
// GetOIDCValidationConfig mocks base method.
func (m *MockController) GetOIDCValidationConfig() OIDCValidationConfig {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetOIDCValidationConfig")
ret0, _ := ret[0].(OIDCValidationConfig)
return ret0
}
// GetOIDCValidationConfig indicates an expected call of GetOIDCValidationConfig.
func (mr *MockControllerMockRecorder) GetOIDCValidationConfig() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOIDCValidationConfig", reflect.TypeOf((*MockController)(nil).GetOIDCValidationConfig))
}
// GetProxiesForCluster mocks base method.
func (m *MockController) GetProxiesForCluster(clusterAddr string) []string {
m.ctrl.T.Helper()

View File

@@ -18,3 +18,9 @@ type Proxy struct {
func (Proxy) TableName() string {
return "proxies"
}
// Cluster represents a group of proxy nodes serving the same address.
type Cluster struct {
Address string
ConnectedProxies int
}

View File

@@ -4,9 +4,12 @@ package service
import (
"context"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
)
type Manager interface {
GetActiveClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error)
GetAllServices(ctx context.Context, accountID, userID string) ([]*Service, error)
GetService(ctx context.Context, accountID, userID, serviceID string) (*Service, error)
CreateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error)

View File

@@ -9,6 +9,7 @@ import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
proxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
)
// MockManager is a mock of Manager interface.
@@ -107,6 +108,21 @@ func (mr *MockManagerMockRecorder) GetAccountServices(ctx, accountID interface{}
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountServices", reflect.TypeOf((*MockManager)(nil).GetAccountServices), ctx, accountID)
}
// GetActiveClusters mocks base method.
func (m *MockManager) GetActiveClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetActiveClusters", ctx, accountID, userID)
ret0, _ := ret[0].([]proxy.Cluster)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetActiveClusters indicates an expected call of GetActiveClusters.
func (mr *MockManagerMockRecorder) GetActiveClusters(ctx, accountID, userID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusters", reflect.TypeOf((*MockManager)(nil).GetActiveClusters), ctx, accountID, userID)
}
// GetAllServices mocks base method.
func (m *MockManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*Service, error) {
m.ctrl.T.Helper()

View File

@@ -34,6 +34,7 @@ func RegisterEndpoints(manager rpservice.Manager, domainManager domainmanager.Ma
accesslogsmanager.RegisterEndpoints(router, accessLogsManager)
router.HandleFunc("/reverse-proxies/clusters", h.getClusters).Methods("GET", "OPTIONS")
router.HandleFunc("/reverse-proxies/services", h.getAllServices).Methods("GET", "OPTIONS")
router.HandleFunc("/reverse-proxies/services", h.createService).Methods("POST", "OPTIONS")
router.HandleFunc("/reverse-proxies/services/{serviceId}", h.getService).Methods("GET", "OPTIONS")
@@ -177,3 +178,27 @@ func (h *handler) deleteService(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
}
func (h *handler) getClusters(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
clusters, err := h.manager.GetActiveClusters(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
apiClusters := make([]api.ProxyCluster, 0, len(clusters))
for _, c := range clusters {
apiClusters = append(apiClusters, api.ProxyCluster{
Address: c.Address,
ConnectedProxies: c.ConnectedProxies,
})
}
util.WriteJSONObject(r.Context(), w, apiClusters)
}

View File

@@ -76,6 +76,7 @@ func setupL4Test(t *testing.T, customPortsSupported *bool) (*Manager, store.Stor
mockCtrl := proxy.NewMockController(ctrl)
mockCtrl.EXPECT().ClusterSupportsCustomPorts(gomock.Any()).Return(customPortsSupported).AnyTimes()
mockCtrl.EXPECT().ClusterRequireSubdomain(gomock.Any()).Return((*bool)(nil)).AnyTimes()
mockCtrl.EXPECT().SendServiceUpdateToCluster(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
mockCtrl.EXPECT().GetOIDCValidationConfig().Return(proxy.OIDCValidationConfig{}).AnyTimes()

View File

@@ -14,6 +14,8 @@ import (
nbpeer "github.com/netbirdio/netbird/management/server/peer"
resourcetypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
@@ -100,6 +102,19 @@ func (m *Manager) StartExposeReaper(ctx context.Context) {
m.exposeReaper.StartExposeReaper(ctx)
}
// GetActiveClusters returns all active proxy clusters with their connected proxy count.
func (m *Manager) GetActiveClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
return m.store.GetActiveProxyClusters(ctx)
}
func (m *Manager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
@@ -221,6 +236,10 @@ func (m *Manager) initializeServiceForCreate(ctx context.Context, accountID stri
return status.Errorf(status.PreconditionFailed, "could not derive cluster from domain %s: %v", service.Domain, err)
}
service.ProxyCluster = proxyCluster
if err := m.validateSubdomainRequirement(service.Domain, proxyCluster); err != nil {
return err
}
}
service.AccountID = accountID
@@ -246,6 +265,20 @@ func (m *Manager) initializeServiceForCreate(ctx context.Context, accountID stri
return nil
}
// validateSubdomainRequirement checks whether the domain can be used bare
// (without a subdomain label) on the given cluster. If the cluster reports
// require_subdomain=true and the domain equals the cluster domain, it rejects.
func (m *Manager) validateSubdomainRequirement(domain, cluster string) error {
if domain != cluster {
return nil
}
requireSub := m.proxyController.ClusterRequireSubdomain(cluster)
if requireSub != nil && *requireSub {
return status.Errorf(status.InvalidArgument, "domain %s requires a subdomain label", domain)
}
return nil
}
func (m *Manager) persistNewService(ctx context.Context, accountID string, svc *service.Service) error {
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if svc.Domain != "" {
@@ -474,53 +507,61 @@ func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, se
var updateInfo serviceUpdateInfo
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
existingService, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, service.ID)
if err != nil {
return err
}
if err := validateProtocolChange(existingService.Mode, service.Mode); err != nil {
return err
}
updateInfo.oldCluster = existingService.ProxyCluster
updateInfo.domainChanged = existingService.Domain != service.Domain
if updateInfo.domainChanged {
if err := m.handleDomainChange(ctx, transaction, accountID, service); err != nil {
return err
}
} else {
service.ProxyCluster = existingService.ProxyCluster
}
m.preserveExistingAuthSecrets(service, existingService)
if err := validateHeaderAuthValues(service.Auth.HeaderAuths); err != nil {
return err
}
m.preserveServiceMetadata(service, existingService)
m.preserveListenPort(service, existingService)
updateInfo.serviceEnabledChanged = existingService.Enabled != service.Enabled
if err := m.ensureL4Port(ctx, transaction, service); err != nil {
return err
}
if err := m.checkPortConflict(ctx, transaction, service); err != nil {
return err
}
if err := validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
return err
}
if err := transaction.UpdateService(ctx, service); err != nil {
return fmt.Errorf("update service: %w", err)
}
return nil
return m.executeServiceUpdate(ctx, transaction, accountID, service, &updateInfo)
})
return &updateInfo, err
}
func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.Store, accountID string, service *service.Service, updateInfo *serviceUpdateInfo) error {
existingService, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, service.ID)
if err != nil {
return err
}
if err := validateProtocolChange(existingService.Mode, service.Mode); err != nil {
return err
}
updateInfo.oldCluster = existingService.ProxyCluster
updateInfo.domainChanged = existingService.Domain != service.Domain
if updateInfo.domainChanged {
if err := m.handleDomainChange(ctx, transaction, accountID, service); err != nil {
return err
}
} else {
service.ProxyCluster = existingService.ProxyCluster
}
if err := m.validateSubdomainRequirement(service.Domain, service.ProxyCluster); err != nil {
return err
}
m.preserveExistingAuthSecrets(service, existingService)
if err := validateHeaderAuthValues(service.Auth.HeaderAuths); err != nil {
return err
}
m.preserveServiceMetadata(service, existingService)
m.preserveListenPort(service, existingService)
updateInfo.serviceEnabledChanged = existingService.Enabled != service.Enabled
if err := m.ensureL4Port(ctx, transaction, service); err != nil {
return err
}
if err := m.checkPortConflict(ctx, transaction, service); err != nil {
return err
}
if err := validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
return err
}
if err := transaction.UpdateService(ctx, service); err != nil {
return fmt.Errorf("update service: %w", err)
}
return nil
}
func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, svc *service.Service) error {
if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, svc.ID); err != nil {
return err
@@ -636,18 +677,12 @@ func validateTargetReferences(ctx context.Context, transaction store.Store, acco
for _, target := range targets {
switch target.TargetType {
case service.TargetTypePeer:
if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
return status.Errorf(status.InvalidArgument, "peer target %q not found in account", target.TargetId)
}
return fmt.Errorf("look up peer target %q: %w", target.TargetId, err)
if err := validatePeerTarget(ctx, transaction, accountID, target); err != nil {
return err
}
case service.TargetTypeHost, service.TargetTypeSubnet, service.TargetTypeDomain:
if _, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
return status.Errorf(status.InvalidArgument, "resource target %q not found in account", target.TargetId)
}
return fmt.Errorf("look up resource target %q: %w", target.TargetId, err)
if err := validateResourceTarget(ctx, transaction, accountID, target); err != nil {
return err
}
default:
return status.Errorf(status.InvalidArgument, "unknown target type %q for target %q", target.TargetType, target.TargetId)
@@ -656,6 +691,39 @@ func validateTargetReferences(ctx context.Context, transaction store.Store, acco
return nil
}
func validatePeerTarget(ctx context.Context, transaction store.Store, accountID string, target *service.Target) error {
if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
return status.Errorf(status.InvalidArgument, "peer target %q not found in account", target.TargetId)
}
return fmt.Errorf("look up peer target %q: %w", target.TargetId, err)
}
return nil
}
func validateResourceTarget(ctx context.Context, transaction store.Store, accountID string, target *service.Target) error {
resource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, target.TargetId)
if err != nil {
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
return status.Errorf(status.InvalidArgument, "resource target %q not found in account", target.TargetId)
}
return fmt.Errorf("look up resource target %q: %w", target.TargetId, err)
}
return validateResourceTargetType(target, resource)
}
// validateResourceTargetType checks that target_type matches the actual network resource type.
func validateResourceTargetType(target *service.Target, resource *resourcetypes.NetworkResource) error {
expected := resourcetypes.NetworkResourceType(target.TargetType)
if resource.Type != expected {
return status.Errorf(status.InvalidArgument,
"target %q has target_type %q but resource is of type %q",
target.TargetId, target.TargetType, resource.Type,
)
}
return nil
}
func (m *Manager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
if err != nil {

View File

@@ -19,6 +19,7 @@ import (
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/mock_server"
resourcetypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
@@ -1214,3 +1215,126 @@ func TestValidateProtocolChange(t *testing.T) {
})
}
}
func TestValidateTargetReferences_ResourceTypeMismatch(t *testing.T) {
ctx := context.Background()
ctrl := gomock.NewController(t)
mockStore := store.NewMockStore(ctrl)
accountID := "test-account"
tests := []struct {
name string
targetType rpservice.TargetType
resourceType resourcetypes.NetworkResourceType
wantErr bool
}{
{"host matches host", rpservice.TargetTypeHost, resourcetypes.Host, false},
{"domain matches domain", rpservice.TargetTypeDomain, resourcetypes.Domain, false},
{"subnet matches subnet", rpservice.TargetTypeSubnet, resourcetypes.Subnet, false},
{"host but resource is domain", rpservice.TargetTypeHost, resourcetypes.Domain, true},
{"domain but resource is host", rpservice.TargetTypeDomain, resourcetypes.Host, true},
{"host but resource is subnet", rpservice.TargetTypeHost, resourcetypes.Subnet, true},
{"subnet but resource is domain", rpservice.TargetTypeSubnet, resourcetypes.Domain, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockStore.EXPECT().
GetNetworkResourceByID(gomock.Any(), store.LockingStrengthShare, accountID, "resource-1").
Return(&resourcetypes.NetworkResource{Type: tt.resourceType}, nil)
targets := []*rpservice.Target{
{TargetId: "resource-1", TargetType: tt.targetType, Host: "10.0.0.1"},
}
err := validateTargetReferences(ctx, mockStore, accountID, targets)
if tt.wantErr {
require.Error(t, err)
assert.Contains(t, err.Error(), "target_type")
} else {
require.NoError(t, err)
}
})
}
}
func TestValidateTargetReferences_PeerValid(t *testing.T) {
ctx := context.Background()
ctrl := gomock.NewController(t)
mockStore := store.NewMockStore(ctrl)
accountID := "test-account"
mockStore.EXPECT().
GetPeerByID(gomock.Any(), store.LockingStrengthShare, accountID, "peer-1").
Return(&nbpeer.Peer{}, nil)
targets := []*rpservice.Target{
{TargetId: "peer-1", TargetType: rpservice.TargetTypePeer},
}
require.NoError(t, validateTargetReferences(ctx, mockStore, accountID, targets))
}
func TestValidateSubdomainRequirement(t *testing.T) {
ptrBool := func(b bool) *bool { return &b }
tests := []struct {
name string
domain string
cluster string
requireSubdomain *bool
wantErr bool
}{
{
name: "subdomain present, require_subdomain true",
domain: "app.eu1.proxy.netbird.io",
cluster: "eu1.proxy.netbird.io",
requireSubdomain: ptrBool(true),
wantErr: false,
},
{
name: "bare cluster domain, require_subdomain true",
domain: "eu1.proxy.netbird.io",
cluster: "eu1.proxy.netbird.io",
requireSubdomain: ptrBool(true),
wantErr: true,
},
{
name: "bare cluster domain, require_subdomain false",
domain: "eu1.proxy.netbird.io",
cluster: "eu1.proxy.netbird.io",
requireSubdomain: ptrBool(false),
wantErr: false,
},
{
name: "bare cluster domain, require_subdomain nil (default)",
domain: "eu1.proxy.netbird.io",
cluster: "eu1.proxy.netbird.io",
requireSubdomain: nil,
wantErr: false,
},
{
name: "custom domain apex is not the cluster",
domain: "example.com",
cluster: "eu1.proxy.netbird.io",
requireSubdomain: ptrBool(true),
wantErr: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
ctrl := gomock.NewController(t)
mockCtrl := proxy.NewMockController(ctrl)
mockCtrl.EXPECT().ClusterRequireSubdomain(tc.cluster).Return(tc.requireSubdomain).AnyTimes()
mgr := &Manager{proxyController: mockCtrl}
err := mgr.validateSubdomainRequirement(tc.domain, tc.cluster)
if tc.wantErr {
require.Error(t, err)
assert.Contains(t, err.Error(), "requires a subdomain label")
} else {
require.NoError(t, err)
}
})
}
}

View File

@@ -262,7 +262,9 @@ func (s *Service) ToAPIResponse() *api.Service {
if opts == nil {
opts = &api.ServiceTargetOptions{}
}
opts.ProxyProtocol = &target.ProxyProtocol
if target.ProxyProtocol {
opts.ProxyProtocol = &target.ProxyProtocol
}
st.Options = opts
apiTargets = append(apiTargets, st)
}
@@ -790,7 +792,7 @@ func (s *Service) validateL4Target(target *Target) error {
return errors.New("target_id is required for L4 services")
}
switch target.TargetType {
case TargetTypePeer, TargetTypeHost:
case TargetTypePeer, TargetTypeHost, TargetTypeDomain:
// OK
case TargetTypeSubnet:
if target.Host == "" {
@@ -848,7 +850,7 @@ func IsPortBasedProtocol(mode string) bool {
}
const (
maxCustomHeaders = 16
maxCustomHeaders = 16
maxHeaderKeyLen = 128
maxHeaderValueLen = 4096
)
@@ -945,7 +947,6 @@ func containsCRLF(s string) bool {
}
func validateHeaderAuths(headers []*HeaderAuthConfig) error {
seen := make(map[string]struct{})
for i, h := range headers {
if h == nil || !h.Enabled {
continue
@@ -966,10 +967,6 @@ func validateHeaderAuths(headers []*HeaderAuthConfig) error {
if canonical == "Host" {
return fmt.Errorf("header_auths[%d]: Host header cannot be used for auth", i)
}
if _, dup := seen[canonical]; dup {
return fmt.Errorf("header_auths[%d]: duplicate header %q (same canonical form already configured)", i, h.Header)
}
seen[canonical] = struct{}{}
if len(h.Value) > maxHeaderValueLen {
return fmt.Errorf("header_auths[%d]: value exceeds maximum length of %d", i, maxHeaderValueLen)
}

View File

@@ -847,6 +847,32 @@ func TestValidate_TLSSubnetValid(t *testing.T) {
require.NoError(t, rp.Validate())
}
func TestValidate_L4DomainTargetValid(t *testing.T) {
modes := []struct {
mode string
port uint16
proto string
}{
{"tcp", 5432, "tcp"},
{"tls", 443, "tcp"},
{"udp", 5432, "udp"},
}
for _, m := range modes {
t.Run(m.mode, func(t *testing.T) {
rp := &Service{
Name: m.mode + "-domain",
Mode: m.mode,
Domain: "cluster.test",
ListenPort: m.port,
Targets: []*Target{
{TargetId: "resource-1", TargetType: TargetTypeDomain, Protocol: m.proto, Port: m.port, Enabled: true},
},
}
require.NoError(t, rp.Validate())
})
}
}
func TestValidate_HTTPProxyProtocolRejected(t *testing.T) {
rp := validProxy()
rp.Targets[0].ProxyProtocol = true
@@ -909,3 +935,107 @@ func TestExposeServiceRequest_Validate_HTTPAllowsAuth(t *testing.T) {
req := ExposeServiceRequest{Port: 8080, Mode: "http", Pin: "123456"}
require.NoError(t, req.Validate())
}
func TestValidate_HeaderAuths(t *testing.T) {
t.Run("single valid header", func(t *testing.T) {
rp := validProxy()
rp.Auth = AuthConfig{
HeaderAuths: []*HeaderAuthConfig{
{Enabled: true, Header: "X-API-Key", Value: "secret"},
},
}
require.NoError(t, rp.Validate())
})
t.Run("multiple headers same canonical name allowed", func(t *testing.T) {
rp := validProxy()
rp.Auth = AuthConfig{
HeaderAuths: []*HeaderAuthConfig{
{Enabled: true, Header: "Authorization", Value: "Bearer token-1"},
{Enabled: true, Header: "Authorization", Value: "Bearer token-2"},
},
}
require.NoError(t, rp.Validate())
})
t.Run("multiple headers different case same canonical allowed", func(t *testing.T) {
rp := validProxy()
rp.Auth = AuthConfig{
HeaderAuths: []*HeaderAuthConfig{
{Enabled: true, Header: "x-api-key", Value: "key-1"},
{Enabled: true, Header: "X-Api-Key", Value: "key-2"},
},
}
require.NoError(t, rp.Validate())
})
t.Run("multiple different headers allowed", func(t *testing.T) {
rp := validProxy()
rp.Auth = AuthConfig{
HeaderAuths: []*HeaderAuthConfig{
{Enabled: true, Header: "Authorization", Value: "Bearer tok"},
{Enabled: true, Header: "X-API-Key", Value: "key"},
},
}
require.NoError(t, rp.Validate())
})
t.Run("empty header name rejected", func(t *testing.T) {
rp := validProxy()
rp.Auth = AuthConfig{
HeaderAuths: []*HeaderAuthConfig{
{Enabled: true, Header: "", Value: "val"},
},
}
err := rp.Validate()
require.Error(t, err)
assert.Contains(t, err.Error(), "header name is required")
})
t.Run("hop-by-hop header rejected", func(t *testing.T) {
rp := validProxy()
rp.Auth = AuthConfig{
HeaderAuths: []*HeaderAuthConfig{
{Enabled: true, Header: "Connection", Value: "val"},
},
}
err := rp.Validate()
require.Error(t, err)
assert.Contains(t, err.Error(), "hop-by-hop")
})
t.Run("host header rejected", func(t *testing.T) {
rp := validProxy()
rp.Auth = AuthConfig{
HeaderAuths: []*HeaderAuthConfig{
{Enabled: true, Header: "Host", Value: "val"},
},
}
err := rp.Validate()
require.Error(t, err)
assert.Contains(t, err.Error(), "Host header cannot be used")
})
t.Run("disabled entries skipped", func(t *testing.T) {
rp := validProxy()
rp.Auth = AuthConfig{
HeaderAuths: []*HeaderAuthConfig{
{Enabled: false, Header: "", Value: ""},
{Enabled: true, Header: "X-Key", Value: "val"},
},
}
require.NoError(t, rp.Validate())
})
t.Run("value too long rejected", func(t *testing.T) {
rp := validProxy()
rp.Auth = AuthConfig{
HeaderAuths: []*HeaderAuthConfig{
{Enabled: true, Header: "X-Key", Value: strings.Repeat("a", maxHeaderValueLen+1)},
},
}
err := rp.Validate()
require.Error(t, err)
assert.Contains(t, err.Error(), "exceeds maximum length")
})
}

View File

@@ -123,7 +123,7 @@ func (s *ProxyServiceServer) cleanupStaleProxies(ctx context.Context) {
case <-ctx.Done():
return
case <-ticker.C:
if err := s.proxyManager.CleanupStale(ctx, 10*time.Minute); err != nil {
if err := s.proxyManager.CleanupStale(ctx, 1*time.Hour); err != nil {
log.WithContext(ctx).Debugf("Failed to cleanup stale proxies: %v", err)
}
}
@@ -215,7 +215,7 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
go s.sender(conn, errChan)
// Start heartbeat goroutine
go s.heartbeat(connCtx, proxyID)
go s.heartbeat(connCtx, proxyID, proxyAddress, peerInfo)
select {
case err := <-errChan:
@@ -226,14 +226,14 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
}
// heartbeat updates the proxy's last_seen timestamp every minute
func (s *ProxyServiceServer) heartbeat(ctx context.Context, proxyID string) {
func (s *ProxyServiceServer) heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if err := s.proxyManager.Heartbeat(ctx, proxyID); err != nil {
if err := s.proxyManager.Heartbeat(ctx, proxyID, clusterAddress, ipAddress); err != nil {
log.WithContext(ctx).Debugf("Failed to update proxy %s heartbeat: %v", proxyID, err)
}
case <-ctx.Done():
@@ -537,6 +537,35 @@ func (s *ProxyServiceServer) ClusterSupportsCustomPorts(clusterAddr string) *boo
return nil
}
// ClusterRequireSubdomain returns whether any connected proxy in the given
// cluster reports that a subdomain is required. Returns nil if no proxy has
// reported the capability (defaults to not required).
func (s *ProxyServiceServer) ClusterRequireSubdomain(clusterAddr string) *bool {
if s.proxyController == nil {
return nil
}
var hasCapabilities bool
for _, pid := range s.proxyController.GetProxiesForCluster(clusterAddr) {
connVal, ok := s.connectedProxies.Load(pid)
if !ok {
continue
}
conn := connVal.(*proxyConnection)
if conn.capabilities == nil || conn.capabilities.RequireSubdomain == nil {
continue
}
if *conn.capabilities.RequireSubdomain {
return ptr(true)
}
hasCapabilities = true
}
if hasCapabilities {
return ptr(false)
}
return nil
}
func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
service, err := s.serviceManager.GetServiceByID(ctx, req.GetAccountId(), req.GetId())
if err != nil {

View File

@@ -8,6 +8,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/types"
)
@@ -90,6 +91,10 @@ func (m *mockReverseProxyManager) StopServiceFromPeer(_ context.Context, _, _, _
func (m *mockReverseProxyManager) StartExposeReaper(_ context.Context) {}
func (m *mockReverseProxyManager) GetActiveClusters(_ context.Context, _, _ string) ([]proxy.Cluster, error) {
return nil, nil
}
type mockUsersManager struct {
users map[string]*types.User
err error

View File

@@ -57,6 +57,10 @@ func (c *testProxyController) ClusterSupportsCustomPorts(_ string) *bool {
return ptr(true)
}
func (c *testProxyController) ClusterRequireSubdomain(_ string) *bool {
return nil
}
func (c *testProxyController) GetProxiesForCluster(clusterAddr string) []string {
c.mu.Lock()
defer c.mu.Unlock()

View File

@@ -13,6 +13,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
"github.com/netbirdio/netbird/management/server/store"
@@ -320,6 +321,10 @@ func (m *testValidateSessionServiceManager) StopServiceFromPeer(_ context.Contex
func (m *testValidateSessionServiceManager) StartExposeReaper(_ context.Context) {}
func (m *testValidateSessionServiceManager) GetActiveClusters(_ context.Context, _, _ string) ([]proxy.Cluster, error) {
return nil, nil
}
type testValidateSessionProxyManager struct{}
func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _ string) error {
@@ -338,6 +343,10 @@ func (m *testValidateSessionProxyManager) GetActiveClusterAddresses(_ context.Co
return nil, nil
}
func (m *testValidateSessionProxyManager) GetActiveClusters(_ context.Context) ([]proxy.Cluster, error) {
return nil, nil
}
func (m *testValidateSessionProxyManager) CleanupStale(_ context.Context, _ time.Duration) error {
return nil
}

View File

@@ -19,6 +19,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
nbproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/store"
@@ -433,6 +434,10 @@ func (m *testServiceManager) StopServiceFromPeer(_ context.Context, _, _, _ stri
func (m *testServiceManager) StartExposeReaper(_ context.Context) {}
func (m *testServiceManager) GetActiveClusters(_ context.Context, _, _ string) ([]nbproxy.Cluster, error) {
return nil, nil
}
func createTestState(t *testing.T, ps *nbgrpc.ProxyServiceServer, redirectURL string) string {
t.Helper()

View File

@@ -197,6 +197,7 @@ func NewManager(ctx context.Context, config Config, appMetrics telemetry.AppMetr
case "jumpcloud":
return NewJumpCloudManager(JumpCloudClientConfig{
APIToken: config.ExtraConfig["ApiToken"],
ApiUrl: config.ExtraConfig["ApiUrl"],
}, appMetrics)
case "pocketid":
return NewPocketIdManager(PocketIdClientConfig{

View File

@@ -1,24 +1,40 @@
package idp
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
v1 "github.com/TheJumpCloud/jcapi-go/v1"
"github.com/netbirdio/netbird/management/server/telemetry"
)
const (
contentType = "application/json"
accept = "application/json"
jumpCloudDefaultApiUrl = "https://console.jumpcloud.com"
jumpCloudSearchPageSize = 100
)
// jumpCloudUser represents a JumpCloud V1 API system user.
type jumpCloudUser struct {
ID string `json:"_id"`
Email string `json:"email"`
Firstname string `json:"firstname"`
Middlename string `json:"middlename"`
Lastname string `json:"lastname"`
}
// jumpCloudUserList represents the response from the JumpCloud search endpoint.
type jumpCloudUserList struct {
Results []jumpCloudUser `json:"results"`
TotalCount int `json:"totalCount"`
}
// JumpCloudManager JumpCloud manager client instance.
type JumpCloudManager struct {
client *v1.APIClient
apiBase string
apiToken string
httpClient ManagerHTTPClient
credentials ManagerCredentials
@@ -29,6 +45,7 @@ type JumpCloudManager struct {
// JumpCloudClientConfig JumpCloud manager client configurations.
type JumpCloudClientConfig struct {
APIToken string
ApiUrl string
}
// JumpCloudCredentials JumpCloud authentication information.
@@ -55,7 +72,15 @@ func NewJumpCloudManager(config JumpCloudClientConfig, appMetrics telemetry.AppM
return nil, fmt.Errorf("jumpCloud IdP configuration is incomplete, ApiToken is missing")
}
client := v1.NewAPIClient(v1.NewConfiguration())
apiBase := config.ApiUrl
if apiBase == "" {
apiBase = jumpCloudDefaultApiUrl
}
apiBase = strings.TrimSuffix(apiBase, "/")
if !strings.HasSuffix(apiBase, "/api") {
apiBase += "/api"
}
credentials := &JumpCloudCredentials{
clientConfig: config,
httpClient: httpClient,
@@ -64,7 +89,7 @@ func NewJumpCloudManager(config JumpCloudClientConfig, appMetrics telemetry.AppM
}
return &JumpCloudManager{
client: client,
apiBase: apiBase,
apiToken: config.APIToken,
httpClient: httpClient,
credentials: credentials,
@@ -78,37 +103,58 @@ func (jc *JumpCloudCredentials) Authenticate(_ context.Context) (JWTToken, error
return JWTToken{}, nil
}
func (jm *JumpCloudManager) authenticationContext() context.Context {
return context.WithValue(context.Background(), v1.ContextAPIKey, v1.APIKey{
Key: jm.apiToken,
})
}
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
func (jm *JumpCloudManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
return nil
}
// GetUserDataByID requests user data from JumpCloud via ID.
func (jm *JumpCloudManager) GetUserDataByID(_ context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
authCtx := jm.authenticationContext()
user, resp, err := jm.client.SystemusersApi.SystemusersGet(authCtx, userID, contentType, accept, nil)
// doRequest executes an HTTP request against the JumpCloud V1 API.
func (jm *JumpCloudManager) doRequest(ctx context.Context, method, path string, body io.Reader) ([]byte, error) {
reqURL := jm.apiBase + path
req, err := http.NewRequestWithContext(ctx, method, reqURL, body)
if err != nil {
return nil, err
}
req.Header.Set("x-api-key", jm.apiToken)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
resp, err := jm.httpClient.Do(req)
if err != nil {
if jm.appMetrics != nil {
jm.appMetrics.IDPMetrics().CountRequestError()
}
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
if jm.appMetrics != nil {
jm.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("unable to get user %s, statusCode %d", userID, resp.StatusCode)
return nil, fmt.Errorf("JumpCloud API request %s %s failed with status %d", method, path, resp.StatusCode)
}
return io.ReadAll(resp.Body)
}
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
func (jm *JumpCloudManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
return nil
}
// GetUserDataByID requests user data from JumpCloud via ID.
func (jm *JumpCloudManager) GetUserDataByID(ctx context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
body, err := jm.doRequest(ctx, http.MethodGet, "/systemusers/"+userID, nil)
if err != nil {
return nil, err
}
if jm.appMetrics != nil {
jm.appMetrics.IDPMetrics().CountGetUserDataByID()
}
var user jumpCloudUser
if err = jm.helper.Unmarshal(body, &user); err != nil {
return nil, err
}
userData := parseJumpCloudUser(user)
userData.AppMetadata = appMetadata
@@ -116,30 +162,20 @@ func (jm *JumpCloudManager) GetUserDataByID(_ context.Context, userID string, ap
}
// GetAccount returns all the users for a given profile.
func (jm *JumpCloudManager) GetAccount(_ context.Context, accountID string) ([]*UserData, error) {
authCtx := jm.authenticationContext()
userList, resp, err := jm.client.SearchApi.SearchSystemusersPost(authCtx, contentType, accept, nil)
func (jm *JumpCloudManager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) {
allUsers, err := jm.searchAllUsers(ctx)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
if jm.appMetrics != nil {
jm.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("unable to get account %s users, statusCode %d", accountID, resp.StatusCode)
}
if jm.appMetrics != nil {
jm.appMetrics.IDPMetrics().CountGetAccount()
}
users := make([]*UserData, 0)
for _, user := range userList.Results {
users := make([]*UserData, 0, len(allUsers))
for _, user := range allUsers {
userData := parseJumpCloudUser(user)
userData.AppMetadata.WTAccountID = accountID
users = append(users, userData)
}
@@ -148,27 +184,18 @@ func (jm *JumpCloudManager) GetAccount(_ context.Context, accountID string) ([]*
// GetAllAccounts gets all registered accounts with corresponding user data.
// It returns a list of users indexed by accountID.
func (jm *JumpCloudManager) GetAllAccounts(_ context.Context) (map[string][]*UserData, error) {
authCtx := jm.authenticationContext()
userList, resp, err := jm.client.SearchApi.SearchSystemusersPost(authCtx, contentType, accept, nil)
func (jm *JumpCloudManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
allUsers, err := jm.searchAllUsers(ctx)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
if jm.appMetrics != nil {
jm.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("unable to get all accounts, statusCode %d", resp.StatusCode)
}
if jm.appMetrics != nil {
jm.appMetrics.IDPMetrics().CountGetAllAccounts()
}
indexedUsers := make(map[string][]*UserData)
for _, user := range userList.Results {
for _, user := range allUsers {
userData := parseJumpCloudUser(user)
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData)
}
@@ -176,6 +203,41 @@ func (jm *JumpCloudManager) GetAllAccounts(_ context.Context) (map[string][]*Use
return indexedUsers, nil
}
// searchAllUsers paginates through all system users using limit/skip.
func (jm *JumpCloudManager) searchAllUsers(ctx context.Context) ([]jumpCloudUser, error) {
var allUsers []jumpCloudUser
for skip := 0; ; skip += jumpCloudSearchPageSize {
searchReq := map[string]int{
"limit": jumpCloudSearchPageSize,
"skip": skip,
}
payload, err := json.Marshal(searchReq)
if err != nil {
return nil, err
}
body, err := jm.doRequest(ctx, http.MethodPost, "/search/systemusers", bytes.NewReader(payload))
if err != nil {
return nil, err
}
var userList jumpCloudUserList
if err = jm.helper.Unmarshal(body, &userList); err != nil {
return nil, err
}
allUsers = append(allUsers, userList.Results...)
if skip+len(userList.Results) >= userList.TotalCount {
break
}
}
return allUsers, nil
}
// CreateUser creates a new user in JumpCloud Idp and sends an invitation.
func (jm *JumpCloudManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) {
return nil, fmt.Errorf("method CreateUser not implemented")
@@ -183,7 +245,7 @@ func (jm *JumpCloudManager) CreateUser(_ context.Context, _, _, _, _ string) (*U
// GetUserByEmail searches users with a given email.
// If no users have been found, this function returns an empty list.
func (jm *JumpCloudManager) GetUserByEmail(_ context.Context, email string) ([]*UserData, error) {
func (jm *JumpCloudManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
searchFilter := map[string]interface{}{
"searchFilter": map[string]interface{}{
"filter": []string{email},
@@ -191,25 +253,26 @@ func (jm *JumpCloudManager) GetUserByEmail(_ context.Context, email string) ([]*
},
}
authCtx := jm.authenticationContext()
userList, resp, err := jm.client.SearchApi.SearchSystemusersPost(authCtx, contentType, accept, searchFilter)
payload, err := json.Marshal(searchFilter)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
if jm.appMetrics != nil {
jm.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("unable to get user %s, statusCode %d", email, resp.StatusCode)
body, err := jm.doRequest(ctx, http.MethodPost, "/search/systemusers", bytes.NewReader(payload))
if err != nil {
return nil, err
}
if jm.appMetrics != nil {
jm.appMetrics.IDPMetrics().CountGetUserByEmail()
}
usersData := make([]*UserData, 0)
var userList jumpCloudUserList
if err = jm.helper.Unmarshal(body, &userList); err != nil {
return nil, err
}
usersData := make([]*UserData, 0, len(userList.Results))
for _, user := range userList.Results {
usersData = append(usersData, parseJumpCloudUser(user))
}
@@ -224,20 +287,11 @@ func (jm *JumpCloudManager) InviteUserByID(_ context.Context, _ string) error {
}
// DeleteUser from jumpCloud directory
func (jm *JumpCloudManager) DeleteUser(_ context.Context, userID string) error {
authCtx := jm.authenticationContext()
_, resp, err := jm.client.SystemusersApi.SystemusersDelete(authCtx, userID, contentType, accept, nil)
func (jm *JumpCloudManager) DeleteUser(ctx context.Context, userID string) error {
_, err := jm.doRequest(ctx, http.MethodDelete, "/systemusers/"+userID, nil)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
if jm.appMetrics != nil {
jm.appMetrics.IDPMetrics().CountRequestStatusError()
}
return fmt.Errorf("unable to delete user, statusCode %d", resp.StatusCode)
}
if jm.appMetrics != nil {
jm.appMetrics.IDPMetrics().CountDeleteUser()
@@ -247,11 +301,11 @@ func (jm *JumpCloudManager) DeleteUser(_ context.Context, userID string) error {
}
// parseJumpCloudUser parse JumpCloud system user returned from API V1 to UserData.
func parseJumpCloudUser(user v1.Systemuserreturn) *UserData {
func parseJumpCloudUser(user jumpCloudUser) *UserData {
names := []string{user.Firstname, user.Middlename, user.Lastname}
return &UserData{
Email: user.Email,
Name: strings.Join(names, " "),
ID: user.Id,
ID: user.ID,
}
}

View File

@@ -1,8 +1,15 @@
package idp
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/telemetry"
@@ -44,3 +51,212 @@ func TestNewJumpCloudManager(t *testing.T) {
})
}
}
func TestJumpCloudGetUserDataByID(t *testing.T) {
userResponse := jumpCloudUser{
ID: "user123",
Email: "test@example.com",
Firstname: "John",
Middlename: "",
Lastname: "Doe",
}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/systemusers/user123", r.URL.Path)
assert.Equal(t, http.MethodGet, r.Method)
assert.Equal(t, "test-api-key", r.Header.Get("x-api-key"))
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(userResponse)
}))
defer server.Close()
manager := newTestJumpCloudManager(t, server.URL)
userData, err := manager.GetUserDataByID(context.Background(), "user123", AppMetadata{WTAccountID: "acc1"})
require.NoError(t, err)
assert.Equal(t, "user123", userData.ID)
assert.Equal(t, "test@example.com", userData.Email)
assert.Equal(t, "John Doe", userData.Name)
assert.Equal(t, "acc1", userData.AppMetadata.WTAccountID)
}
func TestJumpCloudGetAccount(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/search/systemusers", r.URL.Path)
assert.Equal(t, http.MethodPost, r.Method)
var reqBody map[string]any
assert.NoError(t, json.NewDecoder(r.Body).Decode(&reqBody))
assert.Contains(t, reqBody, "limit")
assert.Contains(t, reqBody, "skip")
resp := jumpCloudUserList{
Results: []jumpCloudUser{
{ID: "u1", Email: "a@test.com", Firstname: "Alice", Lastname: "Smith"},
{ID: "u2", Email: "b@test.com", Firstname: "Bob", Lastname: "Jones"},
},
TotalCount: 2,
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
manager := newTestJumpCloudManager(t, server.URL)
users, err := manager.GetAccount(context.Background(), "testAccount")
require.NoError(t, err)
assert.Len(t, users, 2)
assert.Equal(t, "testAccount", users[0].AppMetadata.WTAccountID)
assert.Equal(t, "testAccount", users[1].AppMetadata.WTAccountID)
}
func TestJumpCloudGetAllAccounts(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := jumpCloudUserList{
Results: []jumpCloudUser{
{ID: "u1", Email: "a@test.com", Firstname: "Alice"},
{ID: "u2", Email: "b@test.com", Firstname: "Bob"},
},
TotalCount: 2,
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
manager := newTestJumpCloudManager(t, server.URL)
indexedUsers, err := manager.GetAllAccounts(context.Background())
require.NoError(t, err)
assert.Len(t, indexedUsers[UnsetAccountID], 2)
}
func TestJumpCloudGetAllAccountsPagination(t *testing.T) {
totalUsers := 250
allUsers := make([]jumpCloudUser, totalUsers)
for i := range allUsers {
allUsers[i] = jumpCloudUser{
ID: fmt.Sprintf("u%d", i),
Email: fmt.Sprintf("user%d@test.com", i),
Firstname: fmt.Sprintf("User%d", i),
}
}
requestCount := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var reqBody map[string]int
assert.NoError(t, json.NewDecoder(r.Body).Decode(&reqBody))
limit := reqBody["limit"]
skip := reqBody["skip"]
requestCount++
end := skip + limit
if end > totalUsers {
end = totalUsers
}
resp := jumpCloudUserList{
Results: allUsers[skip:end],
TotalCount: totalUsers,
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
manager := newTestJumpCloudManager(t, server.URL)
indexedUsers, err := manager.GetAllAccounts(context.Background())
require.NoError(t, err)
assert.Len(t, indexedUsers[UnsetAccountID], totalUsers)
assert.Equal(t, 3, requestCount, "should require 3 pages for 250 users at page size 100")
}
func TestJumpCloudGetUserByEmail(t *testing.T) {
searchResponse := jumpCloudUserList{
Results: []jumpCloudUser{
{ID: "u1", Email: "alice@test.com", Firstname: "Alice", Lastname: "Smith"},
},
TotalCount: 1,
}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/search/systemusers", r.URL.Path)
assert.Equal(t, http.MethodPost, r.Method)
body, err := io.ReadAll(r.Body)
assert.NoError(t, err)
assert.Contains(t, string(body), "alice@test.com")
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(searchResponse)
}))
defer server.Close()
manager := newTestJumpCloudManager(t, server.URL)
users, err := manager.GetUserByEmail(context.Background(), "alice@test.com")
require.NoError(t, err)
assert.Len(t, users, 1)
assert.Equal(t, "alice@test.com", users[0].Email)
}
func TestJumpCloudDeleteUser(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/systemusers/user123", r.URL.Path)
assert.Equal(t, http.MethodDelete, r.Method)
assert.Equal(t, "test-api-key", r.Header.Get("x-api-key"))
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]string{"_id": "user123"})
}))
defer server.Close()
manager := newTestJumpCloudManager(t, server.URL)
err := manager.DeleteUser(context.Background(), "user123")
require.NoError(t, err)
}
func TestJumpCloudAPIError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
}))
defer server.Close()
manager := newTestJumpCloudManager(t, server.URL)
_, err := manager.GetUserDataByID(context.Background(), "user123", AppMetadata{})
require.Error(t, err)
assert.Contains(t, err.Error(), "401")
}
func TestParseJumpCloudUser(t *testing.T) {
user := jumpCloudUser{
ID: "abc123",
Email: "test@example.com",
Firstname: "John",
Middlename: "M",
Lastname: "Doe",
}
userData := parseJumpCloudUser(user)
assert.Equal(t, "abc123", userData.ID)
assert.Equal(t, "test@example.com", userData.Email)
assert.Equal(t, "John M Doe", userData.Name)
}
func newTestJumpCloudManager(t *testing.T, apiBase string) *JumpCloudManager {
t.Helper()
return &JumpCloudManager{
apiBase: apiBase,
apiToken: "test-api-key",
httpClient: http.DefaultClient,
helper: JsonParser{},
appMetrics: nil,
}
}

View File

@@ -249,7 +249,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
if err != nil {
newLabel = ""
} else {
_, err := transaction.GetPeerIdByLabel(ctx, store.LockingStrengthNone, accountID, update.Name)
_, err := transaction.GetPeerIdByLabel(ctx, store.LockingStrengthNone, accountID, newLabel)
if err == nil {
newLabel = ""
}

View File

@@ -37,6 +37,7 @@ import (
"github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/util"
@@ -2738,3 +2739,70 @@ func TestProcessPeerAddAuth(t *testing.T) {
assert.Empty(t, config.GroupsToAdd)
})
}
func TestUpdatePeer_DnsLabelCollisionWithFQDN(t *testing.T) {
manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
require.NoError(t, err, "unable to create an account")
// Add first peer with hostname that produces DNS label "netbird1"
key1, err := wgtypes.GenerateKey()
require.NoError(t, err)
peer1, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
Key: key1.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "netbird1.netbird.cloud"},
}, false)
require.NoError(t, err, "unable to add first peer")
assert.Equal(t, "netbird1", peer1.DNSLabel)
// Add second peer with a different hostname
key2, err := wgtypes.GenerateKey()
require.NoError(t, err)
peer2, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
Key: key2.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "ip-10-29-5-130"},
}, false)
require.NoError(t, err)
update := peer2.Copy()
update.Name = "netbird1.demo.netbird.cloud"
updated, err := manager.UpdatePeer(context.Background(), accountID, userID, update)
require.NoError(t, err, "renaming peer should not fail with duplicate DNS label error")
assert.Equal(t, "netbird1.demo.netbird.cloud", updated.Name)
assert.NotEqual(t, "netbird1", updated.DNSLabel, "DNS label should not collide with existing peer")
assert.Contains(t, updated.DNSLabel, "netbird1-", "DNS label should be IP-based fallback")
}
func TestUpdatePeer_DnsLabelUniqueName(t *testing.T) {
manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
require.NoError(t, err, "unable to create an account")
key1, err := wgtypes.GenerateKey()
require.NoError(t, err)
peer1, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
Key: key1.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "web-server"},
}, false)
require.NoError(t, err)
assert.Equal(t, "web-server", peer1.DNSLabel)
// Add second peer and rename it to a unique FQDN whose first label doesn't collide
key2, err := wgtypes.GenerateKey()
require.NoError(t, err)
peer2, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
Key: key2.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "old-name"},
}, false)
require.NoError(t, err)
update := peer2.Copy()
update.Name = "api-server.example.com"
updated, err := manager.UpdatePeer(context.Background(), accountID, userID, update)
require.NoError(t, err, "renaming to unique FQDN should succeed")
assert.Equal(t, "api-server", updated.DNSLabel, "DNS label should be first label of FQDN")
}

View File

@@ -4997,7 +4997,6 @@ func (s *SqlStore) GetServiceByDomain(ctx context.Context, domain string) (*rpse
return service, nil
}
func (s *SqlStore) GetServices(ctx context.Context, lockStrength LockingStrength) ([]*rpservice.Service, error) {
tx := s.db.Preload("Targets")
if lockStrength != LockingStrengthNone {
@@ -5408,17 +5407,35 @@ func (s *SqlStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error {
return nil
}
// UpdateProxyHeartbeat updates the last_seen timestamp for a proxy
func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID string) error {
// UpdateProxyHeartbeat updates the last_seen timestamp for a proxy or creates a new entry if it doesn't exist
func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
now := time.Now()
result := s.db.WithContext(ctx).
Model(&proxy.Proxy{}).
Where("id = ? AND status = ?", proxyID, "connected").
Update("last_seen", time.Now())
Update("last_seen", now)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to update proxy heartbeat: %v", result.Error)
return status.Errorf(status.Internal, "failed to update proxy heartbeat")
}
if result.RowsAffected == 0 {
p := &proxy.Proxy{
ID: proxyID,
ClusterAddress: clusterAddress,
IPAddress: ipAddress,
LastSeen: now,
ConnectedAt: &now,
Status: "connected",
}
if err := s.db.WithContext(ctx).Save(p).Error; err != nil {
log.WithContext(ctx).Errorf("failed to create proxy on heartbeat: %v", err)
return status.Errorf(status.Internal, "failed to create proxy on heartbeat")
}
}
return nil
}
@@ -5440,6 +5457,24 @@ func (s *SqlStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string
return addresses, nil
}
// GetActiveProxyClusters returns all active proxy clusters with their connected proxy count.
func (s *SqlStore) GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error) {
var clusters []proxy.Cluster
result := s.db.Model(&proxy.Proxy{}).
Select("cluster_address as address, COUNT(*) as connected_proxies").
Where("status = ? AND last_seen > ?", "connected", time.Now().Add(-2*time.Minute)).
Group("cluster_address").
Scan(&clusters)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get active proxy clusters: %v", result.Error)
return nil, status.Errorf(status.Internal, "get active proxy clusters")
}
return clusters, nil
}
// CleanupStaleProxies deletes proxies that haven't sent heartbeat in the specified duration
func (s *SqlStore) CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error {
cutoffTime := time.Now().Add(-inactivityDuration)

View File

@@ -284,8 +284,9 @@ type Store interface {
DeleteServiceTargets(ctx context.Context, accountID string, serviceID string) error
SaveProxy(ctx context.Context, proxy *proxy.Proxy) error
UpdateProxyHeartbeat(ctx context.Context, proxyID string) error
UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error)
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
GetCustomDomainsCounts(ctx context.Context) (total int64, validated int64, err error)

View File

@@ -1287,6 +1287,21 @@ func (mr *MockStoreMockRecorder) GetActiveProxyClusterAddresses(ctx interface{})
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveProxyClusterAddresses", reflect.TypeOf((*MockStore)(nil).GetActiveProxyClusterAddresses), ctx)
}
// GetActiveProxyClusters mocks base method.
func (m *MockStore) GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetActiveProxyClusters", ctx)
ret0, _ := ret[0].([]proxy.Cluster)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetActiveProxyClusters indicates an expected call of GetActiveProxyClusters.
func (mr *MockStoreMockRecorder) GetActiveProxyClusters(ctx interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveProxyClusters", reflect.TypeOf((*MockStore)(nil).GetActiveProxyClusters), ctx)
}
// GetAllAccounts mocks base method.
func (m *MockStore) GetAllAccounts(ctx context.Context) []*types2.Account {
m.ctrl.T.Helper()
@@ -2924,17 +2939,17 @@ func (mr *MockStoreMockRecorder) UpdateGroups(ctx, accountID, groups interface{}
}
// UpdateProxyHeartbeat mocks base method.
func (m *MockStore) UpdateProxyHeartbeat(ctx context.Context, proxyID string) error {
func (m *MockStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateProxyHeartbeat", ctx, proxyID)
ret := m.ctrl.Call(m, "UpdateProxyHeartbeat", ctx, proxyID, clusterAddress, ipAddress)
ret0, _ := ret[0].(error)
return ret0
}
// UpdateProxyHeartbeat indicates an expected call of UpdateProxyHeartbeat.
func (mr *MockStoreMockRecorder) UpdateProxyHeartbeat(ctx, proxyID interface{}) *gomock.Call {
func (mr *MockStoreMockRecorder) UpdateProxyHeartbeat(ctx, proxyID, clusterAddress, ipAddress interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateProxyHeartbeat", reflect.TypeOf((*MockStore)(nil).UpdateProxyHeartbeat), ctx, proxyID)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateProxyHeartbeat", reflect.TypeOf((*MockStore)(nil).UpdateProxyHeartbeat), ctx, proxyID, clusterAddress, ipAddress)
}
// UpdateService mocks base method.

Some files were not shown because too many files have changed in this diff Show More