Compare commits
5 Commits
feature/ad
...
improve-ip
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7ebee7c126 | ||
|
|
24cf83b862 | ||
|
|
54958100f5 | ||
|
|
449036b66d | ||
|
|
17d6fca49c |
1
.github/workflows/golang-test-linux.yml
vendored
1
.github/workflows/golang-test-linux.yml
vendored
@@ -13,7 +13,6 @@ concurrency:
|
||||
jobs:
|
||||
test:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ '386','amd64' ]
|
||||
store: [ 'sqlite', 'postgres']
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/pion/transport/v3"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/net/ipv6"
|
||||
wgConn "golang.zx2c4.com/wireguard/conn"
|
||||
)
|
||||
|
||||
@@ -24,8 +25,8 @@ type receiverCreator struct {
|
||||
iceBind *ICEBind
|
||||
}
|
||||
|
||||
func (rc receiverCreator) CreateIPv4ReceiverFn(msgPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc {
|
||||
return rc.iceBind.createIPv4ReceiverFn(msgPool, pc, conn)
|
||||
func (rc receiverCreator) CreateIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgPool *sync.Pool) wgConn.ReceiveFunc {
|
||||
return rc.iceBind.createIPv4ReceiverFn(pc, conn, rxOffload, msgPool)
|
||||
}
|
||||
|
||||
// ICEBind is a bind implementation with two main features:
|
||||
@@ -82,6 +83,10 @@ func (s *ICEBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) {
|
||||
return nil, 0, err
|
||||
}
|
||||
fns = append(fns, s.receiveRelayed)
|
||||
|
||||
s.muUDPMux.Lock()
|
||||
s.udpMux = s.createUDPMux()
|
||||
s.muUDPMux.Unlock()
|
||||
return fns, port, nil
|
||||
}
|
||||
|
||||
@@ -154,28 +159,32 @@ func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc {
|
||||
s.muUDPMux.Lock()
|
||||
defer s.muUDPMux.Unlock()
|
||||
|
||||
s.udpMux = NewUniversalUDPMuxDefault(
|
||||
UniversalUDPMuxParams{
|
||||
UDPConn: conn,
|
||||
Net: s.transportNet,
|
||||
FilterFn: s.filterFn,
|
||||
},
|
||||
)
|
||||
func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgsPool *sync.Pool) wgConn.ReceiveFunc {
|
||||
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
|
||||
msgs := ipv4MsgsPool.Get().(*[]ipv4.Message)
|
||||
defer ipv4MsgsPool.Put(msgs)
|
||||
msgs := getMessages(msgsPool)
|
||||
for i := range bufs {
|
||||
(*msgs)[i].Buffers[0] = bufs[i]
|
||||
(*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
|
||||
}
|
||||
defer putMessages(msgs, msgsPool)
|
||||
var numMsgs int
|
||||
if runtime.GOOS == "linux" {
|
||||
numMsgs, err = pc.ReadBatch(*msgs, 0)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
||||
if rxOffload {
|
||||
readAt := len(*msgs) - (wgConn.IdealBatchSize / wgConn.UdpSegmentMaxDatagrams)
|
||||
//nolint
|
||||
numMsgs, err = pc.ReadBatch((*msgs)[readAt:], 0)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
numMsgs, err = wgConn.SplitCoalescedMessages(*msgs, readAt, wgConn.GetGSOSize)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
} else {
|
||||
numMsgs, err = pc.ReadBatch(*msgs, 0)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
msg := &(*msgs)[0]
|
||||
@@ -191,11 +200,12 @@ func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketC
|
||||
// todo: handle err
|
||||
ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr)
|
||||
if ok {
|
||||
sizes[i] = 0
|
||||
} else {
|
||||
sizes[i] = msg.N
|
||||
continue
|
||||
}
|
||||
sizes[i] = msg.N
|
||||
if sizes[i] == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
|
||||
ep := &wgConn.StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
|
||||
wgConn.GetSrcFromControl(msg.OOB[:msg.NN], ep)
|
||||
@@ -241,14 +251,14 @@ func (s *ICEBind) parseSTUNMessage(raw []byte) (*stun.Message, error) {
|
||||
|
||||
// receiveRelayed is a receive function that is used to receive packets from the relayed connection and forward to the
|
||||
// WireGuard. Critical part is do not block if the Closed() has been called.
|
||||
func (c *ICEBind) receiveRelayed(buffs [][]byte, sizes []int, eps []wgConn.Endpoint) (int, error) {
|
||||
c.closedChanMu.RLock()
|
||||
defer c.closedChanMu.RUnlock()
|
||||
func (s *ICEBind) receiveRelayed(buffs [][]byte, sizes []int, eps []wgConn.Endpoint) (int, error) {
|
||||
s.closedChanMu.RLock()
|
||||
defer s.closedChanMu.RUnlock()
|
||||
|
||||
select {
|
||||
case <-c.closedChan:
|
||||
case <-s.closedChan:
|
||||
return 0, net.ErrClosed
|
||||
case msg, ok := <-c.RecvChan:
|
||||
case msg, ok := <-s.RecvChan:
|
||||
if !ok {
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
@@ -259,6 +269,16 @@ func (c *ICEBind) receiveRelayed(buffs [][]byte, sizes []int, eps []wgConn.Endpo
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ICEBind) createUDPMux() *UniversalUDPMuxDefault {
|
||||
return NewUniversalUDPMuxDefault(
|
||||
UniversalUDPMuxParams{
|
||||
UDPConn: s.StdNetBind.IPv4Conn(),
|
||||
Net: s.transportNet,
|
||||
FilterFn: s.filterFn,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// fakeAddress returns a fake address that is used to as an identifier for the peer.
|
||||
// The fake address is in the format of 127.1.x.x where x.x is the last two octets of the peer address.
|
||||
func fakeAddress(peerAddress *net.UDPAddr) (*net.UDPAddr, error) {
|
||||
@@ -273,3 +293,15 @@ func fakeAddress(peerAddress *net.UDPAddr) (*net.UDPAddr, error) {
|
||||
}
|
||||
return newAddr, nil
|
||||
}
|
||||
|
||||
func getMessages(msgsPool *sync.Pool) *[]ipv6.Message {
|
||||
return msgsPool.Get().(*[]ipv6.Message)
|
||||
}
|
||||
|
||||
func putMessages(msgs *[]ipv6.Message, msgsPool *sync.Pool) {
|
||||
for i := range *msgs {
|
||||
(*msgs)[i].OOB = (*msgs)[i].OOB[:0]
|
||||
(*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB}
|
||||
}
|
||||
msgsPool.Put(msgs)
|
||||
}
|
||||
|
||||
@@ -104,8 +104,8 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
|
||||
}
|
||||
}()
|
||||
|
||||
buf := make([]byte, 1500)
|
||||
for {
|
||||
buf := make([]byte, 1500)
|
||||
n, err := p.remoteConn.Read(buf)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
|
||||
@@ -309,11 +309,6 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon
|
||||
return
|
||||
}
|
||||
|
||||
if remoteConnNil(conn.log, iceConnInfo.RemoteConn) {
|
||||
conn.log.Errorf("remote ICE connection is nil")
|
||||
return
|
||||
}
|
||||
|
||||
conn.log.Debugf("ICE connection is ready")
|
||||
|
||||
if conn.currentConnPriority > priority {
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func remoteConnNil(log *log.Entry, conn net.Conn) bool {
|
||||
if conn == nil {
|
||||
log.Errorf("ice conn is nil")
|
||||
return true
|
||||
}
|
||||
|
||||
if conn.RemoteAddr() == nil {
|
||||
log.Errorf("ICE remote address is nil")
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
@@ -1,7 +0,0 @@
|
||||
//go:build !windows
|
||||
|
||||
package server
|
||||
|
||||
func handlePanicLog() error {
|
||||
return nil
|
||||
}
|
||||
@@ -1,83 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
const (
|
||||
windowsPanicLogEnvVar = "NB_WINDOWS_PANIC_LOG"
|
||||
// STD_ERROR_HANDLE ((DWORD)-12) = 4294967284
|
||||
stdErrorHandle = ^uintptr(11)
|
||||
)
|
||||
|
||||
var (
|
||||
kernel32 = syscall.NewLazyDLL("kernel32.dll")
|
||||
|
||||
// https://learn.microsoft.com/en-us/windows/console/setstdhandle
|
||||
setStdHandleFn = kernel32.NewProc("SetStdHandle")
|
||||
)
|
||||
|
||||
func handlePanicLog() error {
|
||||
logPath := os.Getenv(windowsPanicLogEnvVar)
|
||||
if logPath == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ensure the directory exists
|
||||
logDir := filepath.Dir(logPath)
|
||||
if err := os.MkdirAll(logDir, 0750); err != nil {
|
||||
return fmt.Errorf("create panic log directory: %w", err)
|
||||
}
|
||||
if err := util.EnforcePermission(logPath); err != nil {
|
||||
return fmt.Errorf("enforce permission on panic log file: %w", err)
|
||||
}
|
||||
|
||||
// Open log file with append mode
|
||||
f, err := os.OpenFile(logPath, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open panic log file: %w", err)
|
||||
}
|
||||
|
||||
// Redirect stderr to the file
|
||||
if err = redirectStderr(f); err != nil {
|
||||
if closeErr := f.Close(); closeErr != nil {
|
||||
log.Warnf("failed to close file after redirect error: %v", closeErr)
|
||||
}
|
||||
return fmt.Errorf("redirect stderr: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("successfully configured panic logging to: %s", logPath)
|
||||
return nil
|
||||
}
|
||||
|
||||
// redirectStderr redirects stderr to the provided file
|
||||
func redirectStderr(f *os.File) error {
|
||||
// Get the current process's stderr handle
|
||||
if err := setStdHandle(f); err != nil {
|
||||
return fmt.Errorf("failed to set stderr handle: %w", err)
|
||||
}
|
||||
|
||||
// Also set os.Stderr for Go's standard library
|
||||
os.Stderr = f
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func setStdHandle(f *os.File) error {
|
||||
handle := f.Fd()
|
||||
r0, _, e1 := setStdHandleFn.Call(stdErrorHandle, handle)
|
||||
if r0 == 0 {
|
||||
if e1 != nil {
|
||||
return e1
|
||||
}
|
||||
return syscall.EINVAL
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -97,10 +97,6 @@ func (s *Server) Start() error {
|
||||
defer s.mutex.Unlock()
|
||||
state := internal.CtxGetState(s.rootCtx)
|
||||
|
||||
if err := handlePanicLog(); err != nil {
|
||||
log.Warnf("failed to redirect stderr: %v", err)
|
||||
}
|
||||
|
||||
if err := restoreResidualState(s.rootCtx); err != nil {
|
||||
log.Warnf(errRestoreResidualState, err)
|
||||
}
|
||||
@@ -626,8 +622,6 @@ func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownRes
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
s.oauthAuthFlow = oauthAuthFlow{}
|
||||
|
||||
if s.actCancel == nil {
|
||||
return nil, fmt.Errorf("service is not up")
|
||||
}
|
||||
|
||||
5
go.mod
5
go.mod
@@ -71,6 +71,7 @@ require (
|
||||
github.com/pion/transport/v3 v3.0.1
|
||||
github.com/pion/turn/v3 v3.0.1
|
||||
github.com/prometheus/client_golang v1.19.1
|
||||
github.com/r3labs/diff/v3 v3.0.1
|
||||
github.com/rs/xid v1.3.0
|
||||
github.com/shirou/gopsutil/v3 v3.24.4
|
||||
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
|
||||
@@ -210,6 +211,8 @@ require (
|
||||
github.com/tklauser/go-sysconf v0.3.14 // indirect
|
||||
github.com/tklauser/numcpus v0.8.0 // indirect
|
||||
github.com/vishvananda/netns v0.0.4 // indirect
|
||||
github.com/vmihailenco/msgpack/v5 v5.3.5 // indirect
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
|
||||
github.com/yuin/goldmark v1.7.1 // indirect
|
||||
github.com/zeebo/blake3 v0.2.3 // indirect
|
||||
go.opencensus.io v0.24.0 // indirect
|
||||
@@ -236,7 +239,7 @@ replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-2024
|
||||
|
||||
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949
|
||||
|
||||
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed
|
||||
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20241101085246-a698cd316cd6
|
||||
|
||||
replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6
|
||||
|
||||
|
||||
10
go.sum
10
go.sum
@@ -527,8 +527,8 @@ github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9ax
|
||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28=
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed h1:t0UADZUJDaaZgfKrt8JUPrOLL9Mg/ryjP85RAH53qgs=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20241101085246-a698cd316cd6 h1:+FOZ3vbuyH5kaD37IJya06xNMuEtWEMeBNdPtIUVwbc=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20241101085246-a698cd316cd6/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
|
||||
github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM=
|
||||
github.com/nicksnyder/go-i18n/v2 v2.4.0/go.mod h1:nxYSZE9M0bf3Y70gPQjN9ha7XNHX7gMc814+6wVyEI4=
|
||||
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=
|
||||
@@ -605,6 +605,8 @@ github.com/prometheus/common v0.53.0 h1:U2pL9w9nmJwJDa4qqLQ3ZaePJ6ZTwt7cMD3AG3+a
|
||||
github.com/prometheus/common v0.53.0/go.mod h1:BrxBKv3FWBIGXw89Mg1AeBq7FSyRzXWI3l3e7W3RN5U=
|
||||
github.com/prometheus/procfs v0.15.0 h1:A82kmvXJq2jTu5YUhSGNlYoxh85zLnKgPz4bMZgI5Ek=
|
||||
github.com/prometheus/procfs v0.15.0/go.mod h1:Y0RJ/Y5g5wJpkTisOtqwDSo4HwhGmLB4VQSw2sQJLHk=
|
||||
github.com/r3labs/diff/v3 v3.0.1 h1:CBKqf3XmNRHXKmdU7mZP1w7TV0pDyVCis1AUHtA4Xtg=
|
||||
github.com/r3labs/diff/v3 v3.0.1/go.mod h1:f1S9bourRbiM66NskseyUdo0fTmEE0qKrikYJX63dgo=
|
||||
github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ=
|
||||
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
|
||||
github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
|
||||
@@ -697,6 +699,10 @@ github.com/vishvananda/netlink v1.2.1-beta.2/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhg
|
||||
github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0=
|
||||
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
|
||||
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
|
||||
github.com/vmihailenco/msgpack/v5 v5.3.5 h1:5gO0H1iULLWGhs2H5tbAHIZTV8/cYafcFOr9znI5mJU=
|
||||
github.com/vmihailenco/msgpack/v5 v5.3.5/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc=
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
|
||||
github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
|
||||
@@ -1147,14 +1147,14 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID, updMsg.sessionID)
|
||||
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
message := <-updMsg.channel
|
||||
message := <-updMsg
|
||||
networkMap := message.Update.GetNetworkMap()
|
||||
if len(networkMap.RemotePeers) != 2 {
|
||||
t.Errorf("mismatch peers count: 2 expected, got %v", len(networkMap.RemotePeers))
|
||||
@@ -1174,14 +1174,14 @@ func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
|
||||
manager, account, peer1, _, _ := setupNetworkMapTest(t)
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID, updMsg.sessionID)
|
||||
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
message := <-updMsg.channel
|
||||
message := <-updMsg
|
||||
networkMap := message.Update.GetNetworkMap()
|
||||
if len(networkMap.RemotePeers) != 0 {
|
||||
t.Errorf("mismatch peers count: 0 expected, got %v", len(networkMap.RemotePeers))
|
||||
@@ -1210,7 +1210,7 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
|
||||
}
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID, updMsg.sessionID)
|
||||
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
|
||||
policy := Policy{
|
||||
Enabled: true,
|
||||
@@ -1230,7 +1230,7 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
message := <-updMsg.channel
|
||||
message := <-updMsg
|
||||
networkMap := message.Update.GetNetworkMap()
|
||||
if len(networkMap.RemotePeers) != 2 {
|
||||
t.Errorf("mismatch peers count: 2 expected, got %v", len(networkMap.RemotePeers))
|
||||
@@ -1277,14 +1277,14 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
|
||||
}
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID, updMsg.sessionID)
|
||||
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
message := <-updMsg.channel
|
||||
message := <-updMsg
|
||||
networkMap := message.Update.GetNetworkMap()
|
||||
if len(networkMap.RemotePeers) != 1 {
|
||||
t.Errorf("mismatch peers count: 1 expected, got %v", len(networkMap.RemotePeers))
|
||||
@@ -1303,7 +1303,7 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
|
||||
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID, updMsg.sessionID)
|
||||
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
|
||||
group := group.Group{
|
||||
ID: "groupA",
|
||||
@@ -1339,7 +1339,7 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
message := <-updMsg.channel
|
||||
message := <-updMsg
|
||||
networkMap := message.Update.GetNetworkMap()
|
||||
if len(networkMap.RemotePeers) != 0 {
|
||||
t.Errorf("mismatch peers count: 0 expected, got %v", len(networkMap.RemotePeers))
|
||||
|
||||
82
management/server/differs/netip.go
Normal file
82
management/server/differs/netip.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package differs
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"reflect"
|
||||
|
||||
"github.com/r3labs/diff/v3"
|
||||
)
|
||||
|
||||
// NetIPAddr is a custom differ for netip.Addr
|
||||
type NetIPAddr struct {
|
||||
DiffFunc func(path []string, a, b reflect.Value, p interface{}) error
|
||||
}
|
||||
|
||||
func (differ NetIPAddr) Match(a, b reflect.Value) bool {
|
||||
return diff.AreType(a, b, reflect.TypeOf(netip.Addr{}))
|
||||
}
|
||||
|
||||
func (differ NetIPAddr) Diff(_ diff.DiffType, _ diff.DiffFunc, cl *diff.Changelog, path []string, a, b reflect.Value, _ interface{}) error {
|
||||
if a.Kind() == reflect.Invalid {
|
||||
cl.Add(diff.CREATE, path, nil, b.Interface())
|
||||
return nil
|
||||
}
|
||||
|
||||
if b.Kind() == reflect.Invalid {
|
||||
cl.Add(diff.DELETE, path, a.Interface(), nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
fromAddr, ok1 := a.Interface().(netip.Addr)
|
||||
toAddr, ok2 := b.Interface().(netip.Addr)
|
||||
if !ok1 || !ok2 {
|
||||
return fmt.Errorf("invalid type for netip.Addr")
|
||||
}
|
||||
|
||||
if fromAddr.String() != toAddr.String() {
|
||||
cl.Add(diff.UPDATE, path, fromAddr.String(), toAddr.String())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (differ NetIPAddr) InsertParentDiffer(dfunc func(path []string, a, b reflect.Value, p interface{}) error) {
|
||||
differ.DiffFunc = dfunc //nolint
|
||||
}
|
||||
|
||||
// NetIPPrefix is a custom differ for netip.Prefix
|
||||
type NetIPPrefix struct {
|
||||
DiffFunc func(path []string, a, b reflect.Value, p interface{}) error
|
||||
}
|
||||
|
||||
func (differ NetIPPrefix) Match(a, b reflect.Value) bool {
|
||||
return diff.AreType(a, b, reflect.TypeOf(netip.Prefix{}))
|
||||
}
|
||||
|
||||
func (differ NetIPPrefix) Diff(_ diff.DiffType, _ diff.DiffFunc, cl *diff.Changelog, path []string, a, b reflect.Value, _ interface{}) error {
|
||||
if a.Kind() == reflect.Invalid {
|
||||
cl.Add(diff.CREATE, path, nil, b.Interface())
|
||||
return nil
|
||||
}
|
||||
if b.Kind() == reflect.Invalid {
|
||||
cl.Add(diff.DELETE, path, a.Interface(), nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
fromPrefix, ok1 := a.Interface().(netip.Prefix)
|
||||
toPrefix, ok2 := b.Interface().(netip.Prefix)
|
||||
if !ok1 || !ok2 {
|
||||
return fmt.Errorf("invalid type for netip.Addr")
|
||||
}
|
||||
|
||||
if fromPrefix.String() != toPrefix.String() {
|
||||
cl.Add(diff.UPDATE, path, fromPrefix.String(), toPrefix.String())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (differ NetIPPrefix) InsertParentDiffer(dfunc func(path []string, a, b reflect.Value, p interface{}) error) {
|
||||
differ.DiffFunc = dfunc //nolint
|
||||
}
|
||||
@@ -8,10 +8,9 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
@@ -499,14 +498,14 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
t.Cleanup(func() {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID, updMsg.sessionID)
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
})
|
||||
|
||||
// Saving DNS settings with groups that have no peers should not trigger updates to account peers or send peer updates
|
||||
t.Run("saving dns setting with unused groups", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg.channel)
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -522,70 +521,29 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
// Creating DNS settings with groups that have no peers should not update account peers or send peer update
|
||||
t.Run("creating dns setting with unused groups", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg.channel)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err = manager.CreateNameServerGroup(
|
||||
context.Background(), account.Id, "ns-group", "ns-group", []dns.NameServer{{
|
||||
IP: netip.MustParseAddr(peer1.IP.String()),
|
||||
NSType: dns.UDPNameServerType,
|
||||
Port: dns.DefaultDNSPort,
|
||||
}},
|
||||
[]string{"groupB"},
|
||||
true, []string{}, true, userID, false,
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
err = manager.SaveGroup(context.Background(), account.Id, userID, &group.Group{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Creating DNS settings with groups that have peers should update account peers and send peer update
|
||||
t.Run("creating dns setting with used groups", func(t *testing.T) {
|
||||
err = manager.SaveGroup(context.Background(), account.Id, userID, &group.Group{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err = manager.CreateNameServerGroup(
|
||||
context.Background(), account.Id, "ns-group-1", "ns-group-1", []dns.NameServer{{
|
||||
IP: netip.MustParseAddr(peer1.IP.String()),
|
||||
NSType: dns.UDPNameServerType,
|
||||
Port: dns.DefaultDNSPort,
|
||||
}},
|
||||
[]string{"groupA"},
|
||||
true, []string{}, true, userID, false,
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
_, err = manager.CreateNameServerGroup(
|
||||
context.Background(), account.Id, "ns-group-1", "ns-group-1", []dns.NameServer{{
|
||||
IP: netip.MustParseAddr(peer1.IP.String()),
|
||||
NSType: dns.UDPNameServerType,
|
||||
Port: dns.DefaultDNSPort,
|
||||
}},
|
||||
[]string{"groupA"},
|
||||
true, []string{}, true, userID, false,
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Saving DNS settings with groups that have peers should update account peers and send peer update
|
||||
t.Run("saving dns setting with used groups", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -601,11 +559,32 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
// Saving unchanged DNS settings with used groups should update account peers and not send peer update
|
||||
// since there is no change in the network map
|
||||
t.Run("saving unchanged dns setting with used groups", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{
|
||||
DisabledManagementGroups: []string{"groupA", "groupB"},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Removing group with no peers from DNS settings should not trigger updates to account peers or send peer updates
|
||||
t.Run("removing group with no peers from dns settings", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg.channel)
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -625,7 +604,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("removing group with peers from dns settings", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
|
||||
@@ -8,13 +8,12 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -418,14 +417,14 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
t.Cleanup(func() {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID, updMsg.sessionID)
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
})
|
||||
|
||||
// Saving a group that is not linked to any resource should not update account peers
|
||||
t.Run("saving unlinked group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg.channel)
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -448,7 +447,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("adding peer to unlinked group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg.channel)
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -467,7 +466,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("removing peer from unliked group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg.channel)
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -485,7 +484,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("deleting group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg.channel)
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -519,7 +518,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("saving linked group to policy", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -537,11 +536,34 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
// Saving an unchanged group should trigger account peers update and not send peer update
|
||||
// since there is no change in the network map
|
||||
t.Run("saving unchanged group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// adding peer to a used group should update account peers and send peer update
|
||||
t.Run("adding peer to linked group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -559,7 +581,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("removing peer from linked group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -588,7 +610,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -629,7 +651,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -656,7 +678,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
|
||||
@@ -194,31 +194,31 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
}
|
||||
|
||||
// handleUpdates sends updates to the connected peer until the updates channel is closed.
|
||||
func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, peerUpdates *PeerUpdateChannel, srv proto.ManagementService_SyncServer) error {
|
||||
func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error {
|
||||
for {
|
||||
select {
|
||||
// condition when there are some updates
|
||||
case update, open := <-peerUpdates.channel:
|
||||
case update, open := <-updates:
|
||||
if s.appMetrics != nil {
|
||||
s.appMetrics.GRPCMetrics().UpdateChannelQueueLength(len(peerUpdates.channel) + 1)
|
||||
s.appMetrics.GRPCMetrics().UpdateChannelQueueLength(len(updates) + 1)
|
||||
}
|
||||
|
||||
if !open {
|
||||
log.WithContext(ctx).Debugf("updates channel for peer %s was closed", peerKey.String())
|
||||
s.cancelPeerRoutines(ctx, accountID, peer, peerUpdates.sessionID)
|
||||
s.cancelPeerRoutines(ctx, accountID, peer)
|
||||
return nil
|
||||
}
|
||||
log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String())
|
||||
|
||||
if err := s.sendUpdate(ctx, accountID, peerKey, peer, peerUpdates.sessionID, update, srv); err != nil {
|
||||
if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// condition when client <-> server connection has been terminated
|
||||
case <-srv.Context().Done():
|
||||
// happens when connection drops, e.g. client disconnects
|
||||
log.WithContext(ctx).Debugf("stream of peer %s with session %s has been closed", peerKey.String(), peerUpdates.sessionID)
|
||||
s.cancelPeerRoutines(ctx, accountID, peer, peerUpdates.sessionID)
|
||||
log.WithContext(ctx).Debugf("stream of peer %s has been closed", peerKey.String())
|
||||
s.cancelPeerRoutines(ctx, accountID, peer)
|
||||
return srv.Context().Err()
|
||||
}
|
||||
}
|
||||
@@ -226,10 +226,10 @@ func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKe
|
||||
|
||||
// sendUpdate encrypts the update message using the peer key and the server's wireguard key,
|
||||
// then sends the encrypted message to the connected peer via the sync server.
|
||||
func (s *GRPCServer) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, sessionID string, update *UpdateMessage, srv proto.ManagementService_SyncServer) error {
|
||||
func (s *GRPCServer) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *UpdateMessage, srv proto.ManagementService_SyncServer) error {
|
||||
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, update.Update)
|
||||
if err != nil {
|
||||
s.cancelPeerRoutines(ctx, accountID, peer, sessionID)
|
||||
s.cancelPeerRoutines(ctx, accountID, peer)
|
||||
return status.Errorf(codes.Internal, "failed processing update message")
|
||||
}
|
||||
err = srv.SendMsg(&proto.EncryptedMessage{
|
||||
@@ -237,20 +237,18 @@ func (s *GRPCServer) sendUpdate(ctx context.Context, accountID string, peerKey w
|
||||
Body: encryptedResp,
|
||||
})
|
||||
if err != nil {
|
||||
s.cancelPeerRoutines(ctx, accountID, peer, sessionID)
|
||||
s.cancelPeerRoutines(ctx, accountID, peer)
|
||||
return status.Errorf(codes.Internal, "failed sending update message")
|
||||
}
|
||||
log.WithContext(ctx).Debugf("sent an update to peer %s", peerKey.String())
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer, sessionID string) {
|
||||
ok := s.peersUpdateManager.CloseChannel(ctx, peer.ID, sessionID)
|
||||
if ok {
|
||||
_ = s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key)
|
||||
s.secretsManager.CancelRefresh(peer.ID)
|
||||
s.ephemeralManager.OnPeerDisconnected(ctx, peer)
|
||||
}
|
||||
func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) {
|
||||
s.peersUpdateManager.CloseChannel(ctx, peer.ID)
|
||||
s.secretsManager.CancelRefresh(peer.ID)
|
||||
_ = s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key)
|
||||
s.ephemeralManager.OnPeerDisconnected(ctx, peer)
|
||||
}
|
||||
|
||||
func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string, error) {
|
||||
|
||||
@@ -960,7 +960,7 @@ func TestNameServerAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
t.Cleanup(func() {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID, updMsg.sessionID)
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
})
|
||||
|
||||
// Creating a nameserver group with a distribution group no peers should not update account peers
|
||||
@@ -968,7 +968,7 @@ func TestNameServerAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("creating nameserver group with distribution group no peers", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg.channel)
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -995,7 +995,7 @@ func TestNameServerAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("saving nameserver group with distribution group no peers", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg.channel)
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -1013,7 +1013,7 @@ func TestNameServerAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("creating nameserver group with distribution group has peers", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -1039,7 +1039,7 @@ func TestNameServerAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("saving nameserver group with distribution group has peers", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -1065,11 +1065,41 @@ func TestNameServerAccountPeersUpdate(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
// saving unchanged nameserver group should update account peers and not send peer update
|
||||
t.Run("saving unchanged nameserver group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
newNameServerGroupB.NameServers = []nbdns.NameServer{
|
||||
{
|
||||
IP: netip.MustParseAddr("1.1.1.2"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: nbdns.DefaultDNSPort,
|
||||
},
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.8.8"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: nbdns.DefaultDNSPort,
|
||||
},
|
||||
}
|
||||
err = manager.SaveNameServerGroup(context.Background(), account.Id, userID, newNameServerGroupB)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Deleting a nameserver group should update account peers and send peer update
|
||||
t.Run("deleting nameserver group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
|
||||
@@ -41,9 +41,9 @@ type Network struct {
|
||||
Dns string
|
||||
// Serial is an ID that increments by 1 when any change to the network happened (e.g. new peer has been added).
|
||||
// Used to synchronize state to the client apps.
|
||||
Serial uint64
|
||||
Serial uint64 `diff:"-"`
|
||||
|
||||
mu sync.Mutex `json:"-" gorm:"-"`
|
||||
mu sync.Mutex `json:"-" gorm:"-" diff:"-"`
|
||||
}
|
||||
|
||||
// NewNetwork creates a new Network initializing it with a Serial=0
|
||||
|
||||
@@ -313,7 +313,7 @@ func (am *DefaultAccountManager) deletePeers(ctx context.Context, account *Accou
|
||||
},
|
||||
NetworkMap: &NetworkMap{},
|
||||
})
|
||||
am.peersUpdateManager.CloseChannel(ctx, peer.ID, SessionIdForceOverwrite)
|
||||
am.peersUpdateManager.CloseChannel(ctx, peer.ID)
|
||||
am.StoreEvent(ctx, userID, peer.ID, account.Id, activity.PeerRemovedByUser, peer.EventMeta(am.GetDNSDomain()))
|
||||
}
|
||||
|
||||
@@ -589,12 +589,6 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
return nil, nil, nil, fmt.Errorf("error getting account: %w", err)
|
||||
}
|
||||
|
||||
allGroup, err := account.GetGroupAll()
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("error getting all group ID: %w", err)
|
||||
}
|
||||
|
||||
groupsToAdd = append(groupsToAdd, allGroup.ID)
|
||||
if areGroupChangesAffectPeers(account, groupsToAdd) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
}
|
||||
|
||||
@@ -20,33 +20,33 @@ type Peer struct {
|
||||
// IP address of the Peer
|
||||
IP net.IP `gorm:"serializer:json"`
|
||||
// Meta is a Peer system meta data
|
||||
Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_"`
|
||||
Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_" diff:"-"`
|
||||
// Name is peer's name (machine name)
|
||||
Name string
|
||||
// DNSLabel is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's
|
||||
// domain to the peer label. e.g. peer-dns-label.netbird.cloud
|
||||
DNSLabel string
|
||||
// Status peer's management connection status
|
||||
Status *PeerStatus `gorm:"embedded;embeddedPrefix:peer_status_"`
|
||||
Status *PeerStatus `gorm:"embedded;embeddedPrefix:peer_status_" diff:"-"`
|
||||
// The user ID that registered the peer
|
||||
UserID string
|
||||
UserID string `diff:"-"`
|
||||
// SSHKey is a public SSH key of the peer
|
||||
SSHKey string
|
||||
// SSHEnabled indicates whether SSH server is enabled on the peer
|
||||
SSHEnabled bool
|
||||
// LoginExpirationEnabled indicates whether peer's login expiration is enabled and once expired the peer has to re-login.
|
||||
// Works with LastLogin
|
||||
LoginExpirationEnabled bool
|
||||
LoginExpirationEnabled bool `diff:"-"`
|
||||
|
||||
InactivityExpirationEnabled bool
|
||||
InactivityExpirationEnabled bool `diff:"-"`
|
||||
// LastLogin the time when peer performed last login operation
|
||||
LastLogin time.Time
|
||||
LastLogin time.Time `diff:"-"`
|
||||
// CreatedAt records the time the peer was created
|
||||
CreatedAt time.Time
|
||||
CreatedAt time.Time `diff:"-"`
|
||||
// Indicate ephemeral peer attribute
|
||||
Ephemeral bool
|
||||
Ephemeral bool `diff:"-"`
|
||||
// Geo location based on connection IP
|
||||
Location Location `gorm:"embedded;embeddedPrefix:location_"`
|
||||
Location Location `gorm:"embedded;embeddedPrefix:location_" diff:"-"`
|
||||
}
|
||||
|
||||
type PeerStatus struct { //nolint:revive
|
||||
|
||||
@@ -864,14 +864,10 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
|
||||
b.Fatalf("Failed to get account: %v", err)
|
||||
}
|
||||
|
||||
peerChannels := make(map[string]*PeerUpdateChannel)
|
||||
peerChannels := make(map[string]chan *UpdateMessage)
|
||||
|
||||
for peerID := range account.Peers {
|
||||
peerChannels[peerID] = &PeerUpdateChannel{
|
||||
peerID: peerID,
|
||||
channel: make(chan *UpdateMessage, channelBufferSize),
|
||||
sessionID: xid.New().String(),
|
||||
}
|
||||
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
|
||||
}
|
||||
|
||||
manager.peersUpdateManager.peerChannels = peerChannels
|
||||
@@ -1319,14 +1315,14 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
t.Cleanup(func() {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID, updMsg.sessionID)
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
})
|
||||
|
||||
// Updating not expired peer and peer expiration is enabled should not update account peers and not send peer update
|
||||
t.Run("updating not expired peer and peer expiration is enabled", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg.channel)
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -1344,7 +1340,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("adding peer to unlinked group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg.channel)
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -1369,7 +1365,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("deleting peer with unlinked group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg.channel)
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -1387,7 +1383,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("updating peer label", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -1421,7 +1417,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -1447,7 +1443,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("deleting peer with linked group to policy", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -1485,7 +1481,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -1511,7 +1507,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("deleting peer with linked group to route", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -1540,7 +1536,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -1566,7 +1562,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("deleting peer with linked group to route", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
|
||||
@@ -405,9 +405,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
|
||||
|
||||
am.StoreEvent(ctx, userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta())
|
||||
|
||||
if anyGroupHasPeers(account, policy.ruleGroups()) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
}
|
||||
am.updateAccountPeers(ctx, account)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -854,9 +854,14 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
updMsg1 := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
t.Cleanup(func() {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID, updMsg.sessionID)
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
})
|
||||
|
||||
updMsg2 := manager.peersUpdateManager.CreateChannel(context.Background(), peer2.ID)
|
||||
t.Cleanup(func() {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer2.ID)
|
||||
})
|
||||
|
||||
// Saving policy with rule groups with no peers should not update account's peers and not send peer update
|
||||
@@ -878,7 +883,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg.channel)
|
||||
peerShouldNotReceiveUpdate(t, updMsg1)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -913,7 +918,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
peerShouldReceiveUpdate(t, updMsg1)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -948,7 +953,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
peerShouldReceiveUpdate(t, updMsg2)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -982,7 +987,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
peerShouldReceiveUpdate(t, updMsg1)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -1016,7 +1021,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
peerShouldReceiveUpdate(t, updMsg1)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -1051,7 +1056,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg.channel)
|
||||
peerShouldNotReceiveUpdate(t, updMsg1)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -1085,7 +1090,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
peerShouldReceiveUpdate(t, updMsg1)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -1099,13 +1104,46 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
// Saving unchanged policy should trigger account peers update but not send peer update
|
||||
t.Run("saving unchanged policy", func(t *testing.T) {
|
||||
policy := Policy{
|
||||
ID: "policy-source-destination-peers",
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: xid.New().String(),
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupD"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg1)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Deleting policy should trigger account peers update and send peer update
|
||||
t.Run("deleting policy with source and destination groups with peers", func(t *testing.T) {
|
||||
policyID := "policy-source-destination-peers"
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
peerShouldReceiveUpdate(t, updMsg1)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -1126,7 +1164,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
policyID := "policy-destination-has-peers-source-none"
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
peerShouldReceiveUpdate(t, updMsg2)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -1142,10 +1180,10 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
// Deleting policy with no peers in groups should not update account's peers and not send peer update
|
||||
t.Run("deleting policy with no peers in groups", func(t *testing.T) {
|
||||
policyID := "policy-rule-groups-no-peers"
|
||||
policyID := "policy-rule-groups-no-peers" // Deleting the policy created in Case 2
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg.channel)
|
||||
peerShouldNotReceiveUpdate(t, updMsg1)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
|
||||
@@ -5,11 +5,10 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/group"
|
||||
"github.com/rs/xid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/group"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
)
|
||||
|
||||
@@ -147,7 +146,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
t.Cleanup(func() {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID, updMsg.sessionID)
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
})
|
||||
|
||||
postureCheck := posture.Checks{
|
||||
@@ -165,7 +164,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("saving unused posture check", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg.channel)
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -183,7 +182,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("updating unused posture check", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg.channel)
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -222,7 +221,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("linking posture check to policy with peers", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -251,7 +250,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -265,11 +264,30 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
// Saving unchanged posture check should not trigger account peers update and not send peer update
|
||||
// since there is no change in the network map
|
||||
t.Run("saving unchanged posture check", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Removing posture check from policy should trigger account peers update and send peer update
|
||||
t.Run("removing posture check from policy", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -289,7 +307,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("deleting unused posture check", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg.channel)
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -328,7 +346,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg.channel)
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -352,7 +370,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("updating linked posture check to policy where destination has peers but source does not", func(t *testing.T) {
|
||||
updMsg1 := manager.peersUpdateManager.CreateChannel(context.Background(), peer2.ID)
|
||||
t.Cleanup(func() {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer2.ID, updMsg1.sessionID)
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer2.ID)
|
||||
})
|
||||
policy = Policy{
|
||||
ID: "policyB",
|
||||
@@ -375,7 +393,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg1.channel)
|
||||
peerShouldReceiveUpdate(t, updMsg1)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -394,8 +412,8 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
// Updating linked client posture check to policy where source has peers but destination does not,
|
||||
// should trigger account peers update and send peer update
|
||||
// Updating linked posture check to policy where source has peers but destination does not,
|
||||
// should not trigger account peers update or send peer update
|
||||
t.Run("updating linked posture check to policy where source has peers but destination does not", func(t *testing.T) {
|
||||
policy = Policy{
|
||||
ID: "policyB",
|
||||
@@ -416,7 +434,48 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
postureCheck.Checks = posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{
|
||||
MinVersion: "0.29.0",
|
||||
},
|
||||
}
|
||||
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Updating linked client posture check to policy where source has peers but destination does not,
|
||||
// should trigger account peers update and send peer update
|
||||
t.Run("updating linked client posture check to policy where source has peers but destination does not", func(t *testing.T) {
|
||||
policy = Policy{
|
||||
ID: "policyB",
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupB"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
SourcePostureChecks: []string{postureCheck.ID},
|
||||
}
|
||||
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
|
||||
assert.NoError(t, err)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
|
||||
@@ -1807,7 +1807,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1ID)
|
||||
t.Cleanup(func() {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1ID, updMsg.sessionID)
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1ID)
|
||||
})
|
||||
|
||||
// Creating a route with no routing peer and no peers in PeerGroups or Groups should not update account peers and not send peer update
|
||||
@@ -1827,7 +1827,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg.channel)
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -1863,7 +1863,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -1899,7 +1899,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("creating route with a routing peer", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -1924,7 +1924,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -1938,11 +1938,31 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
// Updating unchanged route should update account peers and not send peer update
|
||||
t.Run("updating unchanged route", func(t *testing.T) {
|
||||
baseRoute.Groups = []string{routeGroup1, routeGroup2}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SaveRoute(context.Background(), account.Id, userID, &baseRoute)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Deleting the route should update account peers and send peer update
|
||||
t.Run("deleting route", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -1978,7 +1998,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -2018,7 +2038,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
|
||||
@@ -408,7 +408,7 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
t.Cleanup(func() {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID, updMsg.sessionID)
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
})
|
||||
|
||||
var setupKey *SetupKey
|
||||
@@ -417,7 +417,7 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("creating setup key", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg.channel)
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -435,7 +435,7 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("saving setup key", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg.channel)
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
|
||||
@@ -292,8 +292,6 @@ func (s *SqlStore) GetInstallationID() string {
|
||||
}
|
||||
|
||||
func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error {
|
||||
startTime := time.Now()
|
||||
|
||||
// To maintain data integrity, we create a copy of the peer's to prevent unintended updates to other fields.
|
||||
peerCopy := peer.Copy()
|
||||
peerCopy.AccountID = accountID
|
||||
@@ -319,9 +317,6 @@ func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer.
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -329,8 +324,6 @@ func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer.
|
||||
}
|
||||
|
||||
func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error {
|
||||
startTime := time.Now()
|
||||
|
||||
accountCopy := Account{
|
||||
Domain: domain,
|
||||
DomainCategory: category,
|
||||
@@ -343,9 +336,6 @@ func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID
|
||||
Where(idQueryCondition, accountID).
|
||||
Updates(&accountCopy)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return result.Error
|
||||
}
|
||||
|
||||
@@ -357,8 +347,6 @@ func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID
|
||||
}
|
||||
|
||||
func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error {
|
||||
startTime := time.Now()
|
||||
|
||||
var peerCopy nbpeer.Peer
|
||||
peerCopy.Status = &peerStatus
|
||||
|
||||
@@ -371,9 +359,6 @@ func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.Pe
|
||||
Where(accountAndIDQueryCondition, accountID, peerID).
|
||||
Updates(&peerCopy)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return result.Error
|
||||
}
|
||||
|
||||
@@ -385,8 +370,6 @@ func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.Pe
|
||||
}
|
||||
|
||||
func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.Peer) error {
|
||||
startTime := time.Now()
|
||||
|
||||
// To maintain data integrity, we create a copy of the peer's location to prevent unintended updates to other fields.
|
||||
var peerCopy nbpeer.Peer
|
||||
// Since the location field has been migrated to JSON serialization,
|
||||
@@ -398,9 +381,6 @@ func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.P
|
||||
Updates(peerCopy)
|
||||
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return result.Error
|
||||
}
|
||||
|
||||
@@ -414,8 +394,6 @@ func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.P
|
||||
// SaveUsers saves the given list of users to the database.
|
||||
// It updates existing users if a conflict occurs.
|
||||
func (s *SqlStore) SaveUsers(accountID string, users map[string]*User) error {
|
||||
startTime := time.Now()
|
||||
|
||||
usersToSave := make([]User, 0, len(users))
|
||||
for _, user := range users {
|
||||
user.AccountID = accountID
|
||||
@@ -425,28 +403,15 @@ func (s *SqlStore) SaveUsers(accountID string, users map[string]*User) error {
|
||||
}
|
||||
usersToSave = append(usersToSave, *user)
|
||||
}
|
||||
err := s.db.Session(&gorm.Session{FullSaveAssociations: true}).
|
||||
return s.db.Session(&gorm.Session{FullSaveAssociations: true}).
|
||||
Clauses(clause.OnConflict{UpdateAll: true}).
|
||||
Create(&usersToSave).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return status.Errorf(status.Internal, "failed to save users to store: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SaveUser saves the given user to the database.
|
||||
func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error {
|
||||
startTime := time.Now()
|
||||
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(user)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return status.Errorf(status.Internal, "failed to save user to store: %v", result.Error)
|
||||
}
|
||||
return nil
|
||||
@@ -454,17 +419,12 @@ func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, u
|
||||
|
||||
// SaveGroups saves the given list of groups to the database.
|
||||
func (s *SqlStore) SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error {
|
||||
startTime := time.Now()
|
||||
|
||||
if len(groups) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&groups)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return status.Errorf(status.Internal, "failed to save groups to store: %v", result.Error)
|
||||
}
|
||||
return nil
|
||||
@@ -491,8 +451,6 @@ func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string)
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
var accountID string
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("id").
|
||||
Where("domain = ? and is_domain_primary_account = ? and domain_category = ?",
|
||||
@@ -502,9 +460,6 @@ func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return "", status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private")
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return "", status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
log.WithContext(ctx).Errorf("error when getting account from the store: %s", result.Error)
|
||||
return "", status.NewGetAccountFromStoreError(result.Error)
|
||||
}
|
||||
@@ -513,17 +468,12 @@ func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
var key SetupKey
|
||||
result := s.db.WithContext(ctx).Select("account_id").First(&key, keyQueryCondition, setupKey)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return nil, status.NewSetupKeyNotFoundError(result.Error)
|
||||
}
|
||||
|
||||
@@ -535,17 +485,12 @@ func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetTokenIDByHashedToken(ctx context.Context, hashedToken string) (string, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
var token PersonalAccessToken
|
||||
result := s.db.First(&token, "hashed_token = ?", hashedToken)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return "", status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error)
|
||||
return "", status.NewGetAccountFromStoreError(result.Error)
|
||||
}
|
||||
@@ -554,17 +499,12 @@ func (s *SqlStore) GetTokenIDByHashedToken(ctx context.Context, hashedToken stri
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
var token PersonalAccessToken
|
||||
result := s.db.First(&token, idQueryCondition, tokenID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error)
|
||||
return nil, status.NewGetAccountFromStoreError(result.Error)
|
||||
}
|
||||
@@ -588,8 +528,6 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User,
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
var user User
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Preload(clause.Associations).First(&user, idQueryCondition, userID)
|
||||
@@ -597,9 +535,6 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.NewUserNotFoundError(userID)
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return nil, status.NewGetUserFromStoreError()
|
||||
}
|
||||
|
||||
@@ -607,17 +542,12 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountUsers(ctx context.Context, accountID string) ([]*User, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
var users []*User
|
||||
result := s.db.Find(&users, accountIDCondition, accountID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
log.WithContext(ctx).Errorf("error when getting users from the store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "issue getting users from store")
|
||||
}
|
||||
@@ -626,17 +556,12 @@ func (s *SqlStore) GetAccountUsers(ctx context.Context, accountID string) ([]*Us
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
var groups []*nbgroup.Group
|
||||
result := s.db.Find(&groups, accountIDCondition, accountID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
log.WithContext(ctx).Errorf("error when getting groups from the store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "issue getting groups from store")
|
||||
}
|
||||
@@ -736,17 +661,12 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account,
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Account, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
var user User
|
||||
result := s.db.WithContext(ctx).Select("account_id").First(&user, idQueryCondition, userID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return nil, status.NewGetAccountFromStoreError(result.Error)
|
||||
}
|
||||
|
||||
@@ -758,17 +678,12 @@ func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Accoun
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
var peer nbpeer.Peer
|
||||
result := s.db.WithContext(ctx).Select("account_id").First(&peer, idQueryCondition, peerID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return nil, status.NewGetAccountFromStoreError(result.Error)
|
||||
}
|
||||
|
||||
@@ -780,17 +695,13 @@ func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Acco
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
var peer nbpeer.Peer
|
||||
|
||||
result := s.db.WithContext(ctx).Select("account_id").First(&peer, keyQueryCondition, peerKey)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return nil, status.NewGetAccountFromStoreError(result.Error)
|
||||
}
|
||||
|
||||
@@ -802,8 +713,6 @@ func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
var peer nbpeer.Peer
|
||||
var accountID string
|
||||
result := s.db.WithContext(ctx).Model(&peer).Select("account_id").Where(keyQueryCondition, peerKey).First(&accountID)
|
||||
@@ -811,9 +720,6 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string)
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return "", status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return "", status.NewGetAccountFromStoreError(result.Error)
|
||||
}
|
||||
|
||||
@@ -821,17 +727,12 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string)
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
var accountID string
|
||||
result := s.db.Model(&User{}).Select("account_id").Where(idQueryCondition, userID).First(&accountID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return "", status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return "", status.NewGetAccountFromStoreError(result.Error)
|
||||
}
|
||||
|
||||
@@ -839,17 +740,12 @@ func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) {
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) (string, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
var accountID string
|
||||
result := s.db.WithContext(ctx).Model(&SetupKey{}).Select("account_id").Where(keyQueryCondition, setupKey).First(&accountID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return "", status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return "", status.NewSetupKeyNotFoundError(result.Error)
|
||||
}
|
||||
|
||||
@@ -861,8 +757,6 @@ func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string)
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]net.IP, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
var ipJSONStrings []string
|
||||
|
||||
// Fetch the IP addresses as JSON strings
|
||||
@@ -873,9 +767,6 @@ func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "no peers found for the account")
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return nil, status.Errorf(status.Internal, "issue getting IPs from store: %s", result.Error)
|
||||
}
|
||||
|
||||
@@ -893,9 +784,8 @@ func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
var labels []string
|
||||
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}).
|
||||
Where("account_id = ?", accountID).
|
||||
Pluck("dns_label", &labels)
|
||||
@@ -904,9 +794,6 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "no peers found for the account")
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
log.WithContext(ctx).Errorf("error when getting dns labels from the store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "issue getting dns labels from store: %s", result.Error)
|
||||
}
|
||||
@@ -915,33 +802,24 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*Network, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
var accountNetwork AccountNetwork
|
||||
|
||||
if err := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountNetwork).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, status.NewAccountNotFoundError(accountID)
|
||||
}
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return nil, status.Errorf(status.Internal, "issue getting network from store: %s", err)
|
||||
}
|
||||
return accountNetwork.Network, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
var peer nbpeer.Peer
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).First(&peer, keyQueryCondition, peerKey)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "peer not found")
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return nil, status.Errorf(status.Internal, "issue getting peer from store: %s", result.Error)
|
||||
}
|
||||
|
||||
@@ -949,16 +827,11 @@ func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength Locking
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
var accountSettings AccountSettings
|
||||
if err := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "settings not found")
|
||||
}
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return nil, status.Errorf(status.Internal, "issue getting settings from store: %s", err)
|
||||
}
|
||||
return accountSettings.Settings, nil
|
||||
@@ -966,17 +839,13 @@ func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingS
|
||||
|
||||
// SaveUserLastLogin stores the last login time for a user in DB.
|
||||
func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error {
|
||||
startTime := time.Now()
|
||||
|
||||
var user User
|
||||
|
||||
result := s.db.WithContext(ctx).First(&user, accountAndIDQueryCondition, accountID, userID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return status.NewUserNotFoundError(userID)
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return status.NewGetUserFromStoreError()
|
||||
}
|
||||
user.LastLogin = lastLogin
|
||||
@@ -985,8 +854,6 @@ func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID stri
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
definitionJSON, err := json.Marshal(checks)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -995,9 +862,6 @@ func (s *SqlStore) GetPostureCheckByChecksDefinition(accountID string, checks *p
|
||||
var postureCheck posture.Checks
|
||||
err = s.db.Where("account_id = ? AND checks = ?", accountID, string(definitionJSON)).First(&postureCheck).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1107,8 +971,6 @@ func NewPostgresqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore,
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
var setupKey SetupKey
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
First(&setupKey, keyQueryCondition, key)
|
||||
@@ -1116,17 +978,12 @@ func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength Locking
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "setup key not found")
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return nil, status.NewSetupKeyNotFoundError(result.Error)
|
||||
}
|
||||
return &setupKey, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error {
|
||||
startTime := time.Now()
|
||||
|
||||
result := s.db.WithContext(ctx).Model(&SetupKey{}).
|
||||
Where(idQueryCondition, setupKeyID).
|
||||
Updates(map[string]interface{}{
|
||||
@@ -1135,9 +992,6 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string
|
||||
})
|
||||
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return status.Errorf(status.Internal, "issue incrementing setup key usage count: %s", result.Error)
|
||||
}
|
||||
|
||||
@@ -1149,17 +1003,13 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string
|
||||
}
|
||||
|
||||
func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error {
|
||||
startTime := time.Now()
|
||||
|
||||
var group nbgroup.Group
|
||||
|
||||
result := s.db.WithContext(ctx).Where("account_id = ? AND name = ?", accountID, "All").First(&group)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return status.Errorf(status.NotFound, "group 'All' not found for account")
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return status.Errorf(status.Internal, "issue finding group 'All': %s", result.Error)
|
||||
}
|
||||
|
||||
@@ -1172,9 +1022,6 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer
|
||||
group.Peers = append(group.Peers, peerID)
|
||||
|
||||
if err := s.db.Save(&group).Error; err != nil {
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return status.Errorf(status.Internal, "issue updating group 'All': %s", err)
|
||||
}
|
||||
|
||||
@@ -1182,17 +1029,13 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer
|
||||
}
|
||||
|
||||
func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error {
|
||||
startTime := time.Now()
|
||||
|
||||
var group nbgroup.Group
|
||||
|
||||
result := s.db.WithContext(ctx).Where(accountAndIDQueryCondition, accountId, groupID).First(&group)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return status.Errorf(status.NotFound, "group not found for account")
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return status.Errorf(status.Internal, "issue finding group: %s", result.Error)
|
||||
}
|
||||
|
||||
@@ -1205,9 +1048,6 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId
|
||||
group.Peers = append(group.Peers, peerId)
|
||||
|
||||
if err := s.db.Save(&group).Error; err != nil {
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return status.Errorf(status.Internal, "issue updating group: %s", err)
|
||||
}
|
||||
|
||||
@@ -1220,12 +1060,7 @@ func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrengt
|
||||
}
|
||||
|
||||
func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
|
||||
startTime := time.Now()
|
||||
|
||||
if err := s.db.WithContext(ctx).Create(peer).Error; err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return status.Errorf(status.Internal, "issue adding peer to account: %s", err)
|
||||
}
|
||||
|
||||
@@ -1233,13 +1068,8 @@ func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) erro
|
||||
}
|
||||
|
||||
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error {
|
||||
startTime := time.Now()
|
||||
|
||||
result := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return status.Errorf(status.Internal, "issue incrementing network serial count: %s", result.Error)
|
||||
}
|
||||
return nil
|
||||
@@ -1270,18 +1100,14 @@ func (s *SqlStore) GetDB() *gorm.DB {
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
var accountDNSSettings AccountDNSSettings
|
||||
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
|
||||
First(&accountDNSSettings, idQueryCondition, accountID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "dns settings not found")
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return nil, status.Errorf(status.Internal, "failed to get dns settings from store: %v", result.Error)
|
||||
}
|
||||
return &accountDNSSettings.DNSSettings, nil
|
||||
@@ -1289,18 +1115,14 @@ func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength Locki
|
||||
|
||||
// AccountExists checks whether an account exists by the given ID.
|
||||
func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
var accountID string
|
||||
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
|
||||
Select("id").First(&accountID, idQueryCondition, id)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return false, nil
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return false, status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return false, result.Error
|
||||
}
|
||||
|
||||
@@ -1309,18 +1131,14 @@ func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStreng
|
||||
|
||||
// GetAccountDomainAndCategory retrieves the Domain and DomainCategory fields for an account based on the given accountID.
|
||||
func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
var account Account
|
||||
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("domain", "domain_category").
|
||||
Where(idQueryCondition, accountID).First(&account)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return "", "", status.Errorf(status.NotFound, "account not found")
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return "", "", status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return "", "", status.Errorf(status.Internal, "failed to get domain category from store: %v", result.Error)
|
||||
}
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ package status
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -116,11 +115,6 @@ func NewGetUserFromStoreError() error {
|
||||
return Errorf(Internal, "issue getting user from store")
|
||||
}
|
||||
|
||||
// NewStoreContextCanceledError creates a new Error with Internal type for a canceled store context
|
||||
func NewStoreContextCanceledError(duration time.Duration) error {
|
||||
return Errorf(Internal, "store access: context canceled after %v", duration)
|
||||
}
|
||||
|
||||
// NewInvalidKeyIDError creates a new Error with InvalidArgument type for an issue getting a setup key
|
||||
func NewInvalidKeyIDError() error {
|
||||
return Errorf(InvalidArgument, "invalid key ID")
|
||||
|
||||
@@ -18,6 +18,7 @@ type UpdateChannelMetrics struct {
|
||||
getAllConnectedPeersDurationMicro metric.Int64Histogram
|
||||
getAllConnectedPeers metric.Int64Histogram
|
||||
hasChannelDurationMicro metric.Int64Histogram
|
||||
networkMapDiffDurationMicro metric.Int64Histogram
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
@@ -63,6 +64,11 @@ func NewUpdateChannelMetrics(ctx context.Context, meter metric.Meter) (*UpdateCh
|
||||
return nil, err
|
||||
}
|
||||
|
||||
networkMapDiffDurationMicro, err := meter.Int64Histogram("management.updatechannel.networkmap.diff.duration.micro")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &UpdateChannelMetrics{
|
||||
createChannelDurationMicro: createChannelDurationMicro,
|
||||
closeChannelDurationMicro: closeChannelDurationMicro,
|
||||
@@ -72,6 +78,7 @@ func NewUpdateChannelMetrics(ctx context.Context, meter metric.Meter) (*UpdateCh
|
||||
getAllConnectedPeersDurationMicro: getAllConnectedPeersDurationMicro,
|
||||
getAllConnectedPeers: getAllConnectedPeers,
|
||||
hasChannelDurationMicro: hasChannelDurationMicro,
|
||||
networkMapDiffDurationMicro: networkMapDiffDurationMicro,
|
||||
ctx: ctx,
|
||||
}, nil
|
||||
}
|
||||
@@ -111,3 +118,8 @@ func (metrics *UpdateChannelMetrics) CountGetAllConnectedPeersDuration(duration
|
||||
func (metrics *UpdateChannelMetrics) CountHasChannelDuration(duration time.Duration) {
|
||||
metrics.hasChannelDurationMicro.Record(metrics.ctx, duration.Microseconds())
|
||||
}
|
||||
|
||||
// CountNetworkMapDiffDurationMicro counts the duration of the NetworkMapDiff method
|
||||
func (metrics *UpdateChannelMetrics) CountNetworkMapDiffDurationMicro(duration time.Duration) {
|
||||
metrics.networkMapDiffDurationMicro.Record(metrics.ctx, duration.Microseconds())
|
||||
}
|
||||
|
||||
@@ -104,7 +104,7 @@ func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) {
|
||||
loop:
|
||||
for timeout := time.After(5 * time.Second); ; {
|
||||
select {
|
||||
case update := <-updateChannel.channel:
|
||||
case update := <-updateChannel:
|
||||
updates = append(updates, update)
|
||||
case <-timeout:
|
||||
break loop
|
||||
|
||||
@@ -2,33 +2,31 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"runtime/debug"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/r3labs/diff/v3"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/netbirdio/netbird/management/server/differs"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
const channelBufferSize = 100
|
||||
const SessionIdForceOverwrite = "FORCE"
|
||||
|
||||
type UpdateMessage struct {
|
||||
Update *proto.SyncResponse
|
||||
NetworkMap *NetworkMap
|
||||
}
|
||||
|
||||
type PeerUpdateChannel struct {
|
||||
peerID string
|
||||
sessionID string
|
||||
channel chan *UpdateMessage
|
||||
}
|
||||
|
||||
type PeersUpdateManager struct {
|
||||
// peerChannels is a map of peerID to the channel used to deliver updates relevant to the peer
|
||||
peerChannels map[string]*PeerUpdateChannel
|
||||
// peerChannels is an update channel indexed by Peer.ID
|
||||
peerChannels map[string]chan *UpdateMessage
|
||||
// peerNetworkMaps is the UpdateMessage indexed by Peer.ID.
|
||||
peerUpdateMessage map[string]*UpdateMessage
|
||||
// channelsMux keeps the mutex to access peerChannels
|
||||
channelsMux *sync.RWMutex
|
||||
// metrics provides method to collect application metrics
|
||||
@@ -38,9 +36,10 @@ type PeersUpdateManager struct {
|
||||
// NewPeersUpdateManager returns a new instance of PeersUpdateManager
|
||||
func NewPeersUpdateManager(metrics telemetry.AppMetrics) *PeersUpdateManager {
|
||||
return &PeersUpdateManager{
|
||||
peerChannels: make(map[string]*PeerUpdateChannel),
|
||||
channelsMux: &sync.RWMutex{},
|
||||
metrics: metrics,
|
||||
peerChannels: make(map[string]chan *UpdateMessage),
|
||||
peerUpdateMessage: make(map[string]*UpdateMessage),
|
||||
channelsMux: &sync.RWMutex{},
|
||||
metrics: metrics,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -49,6 +48,15 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda
|
||||
start := time.Now()
|
||||
var found, dropped bool
|
||||
|
||||
// skip sending sync update to the peer if there is no change in update message,
|
||||
// it will not check on turn credential refresh as we do not send network map or client posture checks
|
||||
if update.NetworkMap != nil {
|
||||
updated := p.handlePeerMessageUpdate(ctx, peerID, update)
|
||||
if !updated {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
p.channelsMux.Lock()
|
||||
|
||||
defer func() {
|
||||
@@ -58,14 +66,24 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda
|
||||
}
|
||||
}()
|
||||
|
||||
if peerUpdates, ok := p.peerChannels[peerID]; ok {
|
||||
if update.NetworkMap != nil {
|
||||
lastSentUpdate := p.peerUpdateMessage[peerID]
|
||||
if lastSentUpdate != nil && lastSentUpdate.Update.NetworkMap.GetSerial() > update.Update.NetworkMap.GetSerial() {
|
||||
log.WithContext(ctx).Debugf("peer %s new network map serial: %d not greater than last sent: %d, skip sending update",
|
||||
peerID, update.Update.NetworkMap.GetSerial(), lastSentUpdate.Update.NetworkMap.GetSerial())
|
||||
return
|
||||
}
|
||||
p.peerUpdateMessage[peerID] = update
|
||||
}
|
||||
|
||||
if channel, ok := p.peerChannels[peerID]; ok {
|
||||
found = true
|
||||
select {
|
||||
case peerUpdates.channel <- update:
|
||||
case channel <- update:
|
||||
log.WithContext(ctx).Debugf("update was sent to channel for peer %s", peerID)
|
||||
default:
|
||||
dropped = true
|
||||
log.WithContext(ctx).Warnf("channel for peer %s is %d full or closed", peerID, len(peerUpdates.channel))
|
||||
log.WithContext(ctx).Warnf("channel for peer %s is %d full or closed", peerID, len(channel))
|
||||
}
|
||||
} else {
|
||||
log.WithContext(ctx).Debugf("peer %s has no channel", peerID)
|
||||
@@ -73,7 +91,7 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda
|
||||
}
|
||||
|
||||
// CreateChannel creates a go channel for a given peer used to deliver updates relevant to the peer.
|
||||
func (p *PeersUpdateManager) CreateChannel(ctx context.Context, peerID string) *PeerUpdateChannel {
|
||||
func (p *PeersUpdateManager) CreateChannel(ctx context.Context, peerID string) chan *UpdateMessage {
|
||||
start := time.Now()
|
||||
|
||||
closed := false
|
||||
@@ -89,39 +107,26 @@ func (p *PeersUpdateManager) CreateChannel(ctx context.Context, peerID string) *
|
||||
if channel, ok := p.peerChannels[peerID]; ok {
|
||||
closed = true
|
||||
delete(p.peerChannels, peerID)
|
||||
close(channel.channel)
|
||||
log.WithContext(ctx).Debugf("overwriting existing channel for peer %s", peerID)
|
||||
close(channel)
|
||||
delete(p.peerUpdateMessage, peerID)
|
||||
}
|
||||
// mbragin: todo shouldn't it be more? or configurable?
|
||||
channel := make(chan *UpdateMessage, channelBufferSize)
|
||||
p.peerChannels[peerID] = channel
|
||||
|
||||
peerUpdateChannel := &PeerUpdateChannel{
|
||||
peerID: peerID,
|
||||
sessionID: uuid.New().String(),
|
||||
// mbragin: todo shouldn't it be more? or configurable?
|
||||
channel: make(chan *UpdateMessage, channelBufferSize),
|
||||
}
|
||||
log.WithContext(ctx).Debugf("opened updates channel for a peer %s", peerID)
|
||||
|
||||
p.peerChannels[peerID] = peerUpdateChannel
|
||||
|
||||
log.WithContext(ctx).Debugf("opened updates channel for a peer %s and session %s", peerID, peerUpdateChannel.sessionID)
|
||||
|
||||
return peerUpdateChannel
|
||||
return channel
|
||||
}
|
||||
|
||||
func (p *PeersUpdateManager) closeChannel(ctx context.Context, peerID string, sessionID string) bool {
|
||||
if peerUpdates, ok := p.peerChannels[peerID]; ok {
|
||||
if peerUpdates.sessionID == sessionID || sessionID == SessionIdForceOverwrite {
|
||||
delete(p.peerChannels, peerID)
|
||||
close(peerUpdates.channel)
|
||||
log.WithContext(ctx).Debugf("closed updates channel of a peer %s and session %s", peerID, sessionID)
|
||||
return true
|
||||
}
|
||||
log.WithContext(ctx).Warnf("tried to close updates channel of a peer %s with session %s, but current session is %s", peerID, sessionID, peerUpdates.sessionID)
|
||||
return false
|
||||
func (p *PeersUpdateManager) closeChannel(ctx context.Context, peerID string) {
|
||||
if channel, ok := p.peerChannels[peerID]; ok {
|
||||
delete(p.peerChannels, peerID)
|
||||
close(channel)
|
||||
delete(p.peerUpdateMessage, peerID)
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Warnf("tried to close updates channel of a peer %s with session %s, but no channel found", peerID, sessionID)
|
||||
|
||||
return true
|
||||
log.WithContext(ctx).Debugf("closed updates channel of a peer %s", peerID)
|
||||
}
|
||||
|
||||
// CloseChannels closes updates channel for each given peer
|
||||
@@ -137,12 +142,12 @@ func (p *PeersUpdateManager) CloseChannels(ctx context.Context, peerIDs []string
|
||||
}()
|
||||
|
||||
for _, id := range peerIDs {
|
||||
p.closeChannel(ctx, id, SessionIdForceOverwrite)
|
||||
p.closeChannel(ctx, id)
|
||||
}
|
||||
}
|
||||
|
||||
// CloseChannel closes updates channel of a given peer
|
||||
func (p *PeersUpdateManager) CloseChannel(ctx context.Context, peerID string, sessionID string) bool {
|
||||
func (p *PeersUpdateManager) CloseChannel(ctx context.Context, peerID string) {
|
||||
start := time.Now()
|
||||
|
||||
p.channelsMux.Lock()
|
||||
@@ -153,7 +158,7 @@ func (p *PeersUpdateManager) CloseChannel(ctx context.Context, peerID string, se
|
||||
}
|
||||
}()
|
||||
|
||||
return p.closeChannel(ctx, peerID, sessionID)
|
||||
p.closeChannel(ctx, peerID)
|
||||
}
|
||||
|
||||
// GetAllConnectedPeers returns a copy of the connected peers map
|
||||
@@ -195,3 +200,79 @@ func (p *PeersUpdateManager) HasChannel(peerID string) bool {
|
||||
|
||||
return ok
|
||||
}
|
||||
|
||||
// handlePeerMessageUpdate checks if the update message for a peer is new and should be sent.
|
||||
func (p *PeersUpdateManager) handlePeerMessageUpdate(ctx context.Context, peerID string, update *UpdateMessage) bool {
|
||||
p.channelsMux.RLock()
|
||||
lastSentUpdate := p.peerUpdateMessage[peerID]
|
||||
p.channelsMux.RUnlock()
|
||||
|
||||
if lastSentUpdate != nil {
|
||||
updated, err := isNewPeerUpdateMessage(ctx, lastSentUpdate, update, p.metrics)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error checking for SyncResponse updates: %v", err)
|
||||
return true
|
||||
}
|
||||
if !updated {
|
||||
log.WithContext(ctx).Debugf("peer %s network map is not updated, skip sending update", peerID)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// isNewPeerUpdateMessage checks if the given current update message is a new update that should be sent.
|
||||
func isNewPeerUpdateMessage(ctx context.Context, lastSentUpdate, currUpdateToSend *UpdateMessage, metric telemetry.AppMetrics) (isNew bool, err error) {
|
||||
startTime := time.Now()
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.WithContext(ctx).Panicf("comparing peer update messages. Trace: %s", debug.Stack())
|
||||
isNew, err = true, nil
|
||||
}
|
||||
}()
|
||||
|
||||
if lastSentUpdate.Update.NetworkMap.GetSerial() > currUpdateToSend.Update.NetworkMap.GetSerial() {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
differ, err := diff.NewDiffer(
|
||||
diff.CustomValueDiffers(&differs.NetIPAddr{}),
|
||||
diff.CustomValueDiffers(&differs.NetIPPrefix{}),
|
||||
)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to create differ: %v", err)
|
||||
}
|
||||
|
||||
lastSentFiles := getChecksFiles(lastSentUpdate.Update.Checks)
|
||||
currFiles := getChecksFiles(currUpdateToSend.Update.Checks)
|
||||
|
||||
changelog, err := differ.Diff(lastSentFiles, currFiles)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to diff checks: %v", err)
|
||||
}
|
||||
if len(changelog) > 0 {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
changelog, err = differ.Diff(lastSentUpdate.NetworkMap, currUpdateToSend.NetworkMap)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to diff network map: %v", err)
|
||||
}
|
||||
|
||||
if metric != nil {
|
||||
metric.UpdateChannelMetrics().CountNetworkMapDiffDurationMicro(time.Since(startTime))
|
||||
}
|
||||
|
||||
return len(changelog) > 0, nil
|
||||
}
|
||||
|
||||
// getChecksFiles returns a list of files from the given checks.
|
||||
func getChecksFiles(checks []*proto.Checks) []string {
|
||||
files := make([]string, 0, len(checks))
|
||||
for _, check := range checks {
|
||||
files = append(files, check.GetFiles()...)
|
||||
}
|
||||
return files
|
||||
}
|
||||
|
||||
@@ -2,12 +2,21 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
nbroute "github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
// var peersUpdater *PeersUpdateManager
|
||||
@@ -15,7 +24,7 @@ import (
|
||||
func TestCreateChannel(t *testing.T) {
|
||||
peer := "test-create"
|
||||
peersUpdater := NewPeersUpdateManager(nil)
|
||||
defer peersUpdater.CloseChannel(context.Background(), peer, "sessionID")
|
||||
defer peersUpdater.CloseChannel(context.Background(), peer)
|
||||
|
||||
_ = peersUpdater.CreateChannel(context.Background(), peer)
|
||||
if _, ok := peersUpdater.peerChannels[peer]; !ok {
|
||||
@@ -37,7 +46,7 @@ func TestSendUpdate(t *testing.T) {
|
||||
}
|
||||
peersUpdater.SendUpdate(context.Background(), peer, update1)
|
||||
select {
|
||||
case <-peersUpdater.peerChannels[peer].channel:
|
||||
case <-peersUpdater.peerChannels[peer]:
|
||||
default:
|
||||
t.Error("Update wasn't send")
|
||||
}
|
||||
@@ -58,7 +67,7 @@ func TestSendUpdate(t *testing.T) {
|
||||
select {
|
||||
case <-timeout:
|
||||
t.Error("timed out reading previously sent updates")
|
||||
case updateReader := <-peersUpdater.peerChannels[peer].channel:
|
||||
case updateReader := <-peersUpdater.peerChannels[peer]:
|
||||
if updateReader.Update.NetworkMap.Serial == update2.Update.NetworkMap.Serial {
|
||||
t.Error("got the update that shouldn't have been sent")
|
||||
}
|
||||
@@ -67,50 +76,486 @@ func TestSendUpdate(t *testing.T) {
|
||||
|
||||
}
|
||||
|
||||
func TestCloseChannel_WithCorrectSessionID(t *testing.T) {
|
||||
func TestCloseChannel(t *testing.T) {
|
||||
peer := "test-close"
|
||||
peersUpdater := NewPeersUpdateManager(nil)
|
||||
peerUpdates := peersUpdater.CreateChannel(context.Background(), peer)
|
||||
_ = peersUpdater.CreateChannel(context.Background(), peer)
|
||||
if _, ok := peersUpdater.peerChannels[peer]; !ok {
|
||||
t.Error("Error creating the channel")
|
||||
}
|
||||
|
||||
updateDB := peersUpdater.CloseChannel(context.Background(), peer, peerUpdates.sessionID)
|
||||
peersUpdater.CloseChannel(context.Background(), peer)
|
||||
if _, ok := peersUpdater.peerChannels[peer]; ok {
|
||||
t.Error("Error closing the channel")
|
||||
}
|
||||
|
||||
assert.Equal(t, true, updateDB)
|
||||
}
|
||||
|
||||
func TestCloseChannel_WithWrongSessionID(t *testing.T) {
|
||||
peer := "test-close"
|
||||
peersUpdater := NewPeersUpdateManager(nil)
|
||||
peersUpdater.CreateChannel(context.Background(), peer)
|
||||
if _, ok := peersUpdater.peerChannels[peer]; !ok {
|
||||
t.Error("Error creating the channel")
|
||||
func TestHandlePeerMessageUpdate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
peerID string
|
||||
existingUpdate *UpdateMessage
|
||||
newUpdate *UpdateMessage
|
||||
expectedResult bool
|
||||
}{
|
||||
{
|
||||
name: "update message with turn credentials update",
|
||||
peerID: "peer",
|
||||
newUpdate: &UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
WiretrusteeConfig: &proto.WiretrusteeConfig{},
|
||||
},
|
||||
},
|
||||
expectedResult: true,
|
||||
},
|
||||
{
|
||||
name: "update message for peer without existing update",
|
||||
peerID: "peer1",
|
||||
newUpdate: &UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
NetworkMap: &proto.NetworkMap{Serial: 1},
|
||||
},
|
||||
NetworkMap: &NetworkMap{Network: &Network{Serial: 2}},
|
||||
},
|
||||
expectedResult: true,
|
||||
},
|
||||
{
|
||||
name: "update message with no changes in update",
|
||||
peerID: "peer2",
|
||||
existingUpdate: &UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
NetworkMap: &proto.NetworkMap{Serial: 1},
|
||||
},
|
||||
NetworkMap: &NetworkMap{Network: &Network{Serial: 1}},
|
||||
},
|
||||
newUpdate: &UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
NetworkMap: &proto.NetworkMap{Serial: 1},
|
||||
},
|
||||
NetworkMap: &NetworkMap{Network: &Network{Serial: 1}},
|
||||
},
|
||||
expectedResult: false,
|
||||
},
|
||||
{
|
||||
name: "update message with changes in checks",
|
||||
peerID: "peer3",
|
||||
existingUpdate: &UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
NetworkMap: &proto.NetworkMap{Serial: 1},
|
||||
},
|
||||
NetworkMap: &NetworkMap{Network: &Network{Serial: 1}},
|
||||
},
|
||||
newUpdate: &UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
NetworkMap: &proto.NetworkMap{Serial: 2},
|
||||
Checks: []*proto.Checks{
|
||||
{
|
||||
Files: []string{"/usr/bin/netbird"},
|
||||
},
|
||||
},
|
||||
},
|
||||
NetworkMap: &NetworkMap{Network: &Network{Serial: 2}},
|
||||
},
|
||||
expectedResult: true,
|
||||
},
|
||||
{
|
||||
name: "update message with lower serial number",
|
||||
peerID: "peer4",
|
||||
existingUpdate: &UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
NetworkMap: &proto.NetworkMap{Serial: 2},
|
||||
},
|
||||
NetworkMap: &NetworkMap{Network: &Network{Serial: 2}},
|
||||
},
|
||||
newUpdate: &UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
NetworkMap: &proto.NetworkMap{Serial: 1},
|
||||
},
|
||||
NetworkMap: &NetworkMap{Network: &Network{Serial: 1}},
|
||||
},
|
||||
expectedResult: false,
|
||||
},
|
||||
}
|
||||
|
||||
updateDB := peersUpdater.CloseChannel(context.Background(), peer, "wrongSessionID")
|
||||
if _, ok := peersUpdater.peerChannels[peer]; !ok {
|
||||
t.Error("Should not close channel with wrong session id")
|
||||
}
|
||||
for _, tt := range tests {
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p := NewPeersUpdateManager(metrics)
|
||||
ctx := context.Background()
|
||||
|
||||
assert.Equal(t, false, updateDB)
|
||||
if tt.existingUpdate != nil {
|
||||
p.peerUpdateMessage[tt.peerID] = tt.existingUpdate
|
||||
}
|
||||
|
||||
result := p.handlePeerMessageUpdate(ctx, tt.peerID, tt.newUpdate)
|
||||
assert.Equal(t, tt.expectedResult, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloseChannel_WithForceOverwrite(t *testing.T) {
|
||||
peer := "test-close"
|
||||
peersUpdater := NewPeersUpdateManager(nil)
|
||||
peersUpdater.CreateChannel(context.Background(), peer)
|
||||
if _, ok := peersUpdater.peerChannels[peer]; !ok {
|
||||
t.Error("Error creating the channel")
|
||||
}
|
||||
func TestIsNewPeerUpdateMessage(t *testing.T) {
|
||||
t.Run("Unchanged value", func(t *testing.T) {
|
||||
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||
|
||||
updateDB := peersUpdater.CloseChannel(context.Background(), peer, SessionIdForceOverwrite)
|
||||
if _, ok := peersUpdater.peerChannels[peer]; ok {
|
||||
t.Error("Should close channel if forced")
|
||||
}
|
||||
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, message)
|
||||
})
|
||||
|
||||
t.Run("Unchanged value with serial incremented", func(t *testing.T) {
|
||||
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||
|
||||
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||
|
||||
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, message)
|
||||
})
|
||||
|
||||
t.Run("Updating routes network", func(t *testing.T) {
|
||||
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||
|
||||
newUpdateMessage2.NetworkMap.Routes[0].Network = netip.MustParsePrefix("1.1.1.1/32")
|
||||
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||
|
||||
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, message)
|
||||
|
||||
})
|
||||
|
||||
t.Run("Updating routes groups", func(t *testing.T) {
|
||||
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||
|
||||
newUpdateMessage2.NetworkMap.Routes[0].Groups = []string{"randomGroup1"}
|
||||
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||
|
||||
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, message)
|
||||
})
|
||||
|
||||
t.Run("Updating network map peers", func(t *testing.T) {
|
||||
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||
|
||||
newPeer := &nbpeer.Peer{
|
||||
IP: net.ParseIP("192.168.1.4"),
|
||||
SSHEnabled: true,
|
||||
Key: "peer4-key",
|
||||
DNSLabel: "peer4",
|
||||
SSHKey: "peer4-ssh-key",
|
||||
}
|
||||
newUpdateMessage2.NetworkMap.Peers = append(newUpdateMessage2.NetworkMap.Peers, newPeer)
|
||||
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||
|
||||
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, message)
|
||||
})
|
||||
|
||||
t.Run("Updating process check", func(t *testing.T) {
|
||||
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||
|
||||
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, message)
|
||||
|
||||
newUpdateMessage3 := createMockUpdateMessage(t)
|
||||
newUpdateMessage3.Update.Checks = []*proto.Checks{}
|
||||
newUpdateMessage3.Update.NetworkMap.Serial++
|
||||
message, err = isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage3, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, message)
|
||||
|
||||
newUpdateMessage4 := createMockUpdateMessage(t)
|
||||
check := &posture.Checks{
|
||||
Checks: posture.ChecksDefinition{
|
||||
ProcessCheck: &posture.ProcessCheck{
|
||||
Processes: []posture.Process{
|
||||
{
|
||||
LinuxPath: "/usr/local/netbird",
|
||||
MacPath: "/usr/bin/netbird",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
newUpdateMessage4.Update.Checks = []*proto.Checks{toProtocolCheck(check)}
|
||||
newUpdateMessage4.Update.NetworkMap.Serial++
|
||||
message, err = isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage4, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, message)
|
||||
|
||||
newUpdateMessage5 := createMockUpdateMessage(t)
|
||||
check = &posture.Checks{
|
||||
Checks: posture.ChecksDefinition{
|
||||
ProcessCheck: &posture.ProcessCheck{
|
||||
Processes: []posture.Process{
|
||||
{
|
||||
LinuxPath: "/usr/bin/netbird",
|
||||
WindowsPath: "C:\\Program Files\\netbird\\netbird.exe",
|
||||
MacPath: "/usr/local/netbird",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
newUpdateMessage5.Update.Checks = []*proto.Checks{toProtocolCheck(check)}
|
||||
newUpdateMessage5.Update.NetworkMap.Serial++
|
||||
message, err = isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage5, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, message)
|
||||
})
|
||||
|
||||
t.Run("Updating DNS configuration", func(t *testing.T) {
|
||||
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||
|
||||
newDomain := "newexample.com"
|
||||
newUpdateMessage2.NetworkMap.DNSConfig.NameServerGroups[0].Domains = append(
|
||||
newUpdateMessage2.NetworkMap.DNSConfig.NameServerGroups[0].Domains,
|
||||
newDomain,
|
||||
)
|
||||
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||
|
||||
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, message)
|
||||
})
|
||||
|
||||
t.Run("Updating peer IP", func(t *testing.T) {
|
||||
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||
|
||||
newUpdateMessage2.NetworkMap.Peers[0].IP = net.ParseIP("192.168.1.10")
|
||||
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||
|
||||
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, message)
|
||||
})
|
||||
|
||||
t.Run("Updating firewall rule", func(t *testing.T) {
|
||||
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||
|
||||
newUpdateMessage2.NetworkMap.FirewallRules[0].Port = "443"
|
||||
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||
|
||||
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, message)
|
||||
})
|
||||
|
||||
t.Run("Add new firewall rule", func(t *testing.T) {
|
||||
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||
|
||||
newRule := &FirewallRule{
|
||||
PeerIP: "192.168.1.3",
|
||||
Direction: firewallRuleDirectionOUT,
|
||||
Action: string(PolicyTrafficActionDrop),
|
||||
Protocol: string(PolicyRuleProtocolUDP),
|
||||
Port: "53",
|
||||
}
|
||||
newUpdateMessage2.NetworkMap.FirewallRules = append(newUpdateMessage2.NetworkMap.FirewallRules, newRule)
|
||||
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||
|
||||
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, message)
|
||||
})
|
||||
|
||||
t.Run("Removing nameserver", func(t *testing.T) {
|
||||
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||
|
||||
newUpdateMessage2.NetworkMap.DNSConfig.NameServerGroups[0].NameServers = make([]nbdns.NameServer, 0)
|
||||
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||
|
||||
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, message)
|
||||
})
|
||||
|
||||
t.Run("Updating name server IP", func(t *testing.T) {
|
||||
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||
|
||||
newUpdateMessage2.NetworkMap.DNSConfig.NameServerGroups[0].NameServers[0].IP = netip.MustParseAddr("8.8.4.4")
|
||||
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||
|
||||
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, message)
|
||||
})
|
||||
|
||||
t.Run("Updating custom DNS zone", func(t *testing.T) {
|
||||
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||
|
||||
newUpdateMessage2.NetworkMap.DNSConfig.CustomZones[0].Records[0].RData = "100.64.0.2"
|
||||
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||
|
||||
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, message)
|
||||
})
|
||||
|
||||
assert.Equal(t, true, updateDB)
|
||||
}
|
||||
|
||||
func createMockUpdateMessage(t *testing.T) *UpdateMessage {
|
||||
t.Helper()
|
||||
|
||||
_, ipNet, err := net.ParseCIDR("192.168.1.0/24")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
domainList, err := domain.FromStringList([]string{"example.com"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config := &Config{
|
||||
Signal: &Host{
|
||||
Proto: "https",
|
||||
URI: "signal.uri",
|
||||
Username: "",
|
||||
Password: "",
|
||||
},
|
||||
Stuns: []*Host{{URI: "stun.uri", Proto: UDP}},
|
||||
TURNConfig: &TURNConfig{
|
||||
Turns: []*Host{{URI: "turn.uri", Proto: UDP, Username: "turn-user", Password: "turn-pass"}},
|
||||
},
|
||||
}
|
||||
peer := &nbpeer.Peer{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
SSHEnabled: true,
|
||||
Key: "peer-key",
|
||||
DNSLabel: "peer1",
|
||||
SSHKey: "peer1-ssh-key",
|
||||
}
|
||||
|
||||
secretManager := NewTimeBasedAuthSecretsManager(
|
||||
NewPeersUpdateManager(nil),
|
||||
&TURNConfig{
|
||||
TimeBasedCredentials: false,
|
||||
CredentialsTTL: util.Duration{
|
||||
Duration: defaultDuration,
|
||||
},
|
||||
Secret: "secret",
|
||||
Turns: []*Host{TurnTestHost},
|
||||
},
|
||||
&Relay{
|
||||
Addresses: []string{"localhost:0"},
|
||||
CredentialsTTL: util.Duration{Duration: time.Hour},
|
||||
Secret: "secret",
|
||||
},
|
||||
)
|
||||
|
||||
networkMap := &NetworkMap{
|
||||
Network: &Network{Net: *ipNet, Serial: 1000},
|
||||
Peers: []*nbpeer.Peer{{IP: net.ParseIP("192.168.1.2"), Key: "peer2-key", DNSLabel: "peer2", SSHEnabled: true, SSHKey: "peer2-ssh-key"}},
|
||||
OfflinePeers: []*nbpeer.Peer{{IP: net.ParseIP("192.168.1.3"), Key: "peer3-key", DNSLabel: "peer3", SSHEnabled: true, SSHKey: "peer3-ssh-key"}},
|
||||
Routes: []*nbroute.Route{
|
||||
{
|
||||
ID: "route1",
|
||||
Network: netip.MustParsePrefix("10.0.0.0/24"),
|
||||
KeepRoute: true,
|
||||
NetID: "route1",
|
||||
Peer: "peer1",
|
||||
NetworkType: 1,
|
||||
Masquerade: true,
|
||||
Metric: 9999,
|
||||
Enabled: true,
|
||||
Groups: []string{"test1", "test2"},
|
||||
},
|
||||
{
|
||||
ID: "route2",
|
||||
Domains: domainList,
|
||||
KeepRoute: true,
|
||||
NetID: "route2",
|
||||
Peer: "peer1",
|
||||
NetworkType: 1,
|
||||
Masquerade: true,
|
||||
Metric: 9999,
|
||||
Enabled: true,
|
||||
Groups: []string{"test1", "test2"},
|
||||
},
|
||||
},
|
||||
DNSConfig: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
NameServers: []nbdns.NameServer{{
|
||||
IP: netip.MustParseAddr("8.8.8.8"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: nbdns.DefaultDNSPort,
|
||||
}},
|
||||
Primary: true,
|
||||
Domains: []string{"example.com"},
|
||||
Enabled: true,
|
||||
SearchDomainsEnabled: true,
|
||||
},
|
||||
{
|
||||
ID: "ns1",
|
||||
NameServers: []nbdns.NameServer{{
|
||||
IP: netip.MustParseAddr("1.1.1.1"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: nbdns.DefaultDNSPort,
|
||||
}},
|
||||
Groups: []string{"group1"},
|
||||
Primary: true,
|
||||
Domains: []string{"example.com"},
|
||||
Enabled: true,
|
||||
SearchDomainsEnabled: true,
|
||||
},
|
||||
},
|
||||
CustomZones: []nbdns.CustomZone{{Domain: "example.com", Records: []nbdns.SimpleRecord{{Name: "example.com", Type: 1, Class: "IN", TTL: 60, RData: "100.64.0.1"}}}},
|
||||
},
|
||||
FirewallRules: []*FirewallRule{
|
||||
{PeerIP: "192.168.1.2", Direction: firewallRuleDirectionIN, Action: string(PolicyTrafficActionAccept), Protocol: string(PolicyRuleProtocolTCP), Port: "80"},
|
||||
},
|
||||
}
|
||||
dnsName := "example.com"
|
||||
checks := []*posture.Checks{
|
||||
{
|
||||
Checks: posture.ChecksDefinition{
|
||||
ProcessCheck: &posture.ProcessCheck{
|
||||
Processes: []posture.Process{
|
||||
{
|
||||
LinuxPath: "/usr/bin/netbird",
|
||||
WindowsPath: "C:\\Program Files\\netbird\\netbird.exe",
|
||||
MacPath: "/usr/bin/netbird",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
dnsCache := &DNSConfigCache{}
|
||||
|
||||
turnToken, err := secretManager.GenerateTurnToken()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
relayToken, err := secretManager.GenerateRelayToken()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return &UpdateMessage{
|
||||
Update: toSyncResponse(context.Background(), config, peer, turnToken, relayToken, networkMap, dnsName, checks, dnsCache),
|
||||
NetworkMap: networkMap,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,14 +10,13 @@ import (
|
||||
"github.com/eko/gocache/v3/cache"
|
||||
cacheStore "github.com/eko/gocache/v3/store"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
gocache "github.com/patrickmn/go-cache"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/integration_reference"
|
||||
@@ -1298,14 +1297,14 @@ func TestUserAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
t.Cleanup(func() {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID, updMsg.sessionID)
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
})
|
||||
|
||||
// Creating a new regular user should not update account peers and not send peer update
|
||||
t.Run("creating new regular user with no groups", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg.channel)
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -1328,7 +1327,7 @@ func TestUserAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("updating user with no linked peers", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg.channel)
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -1351,7 +1350,7 @@ func TestUserAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("deleting user with no linked peers", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg.channel)
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -1388,7 +1387,7 @@ func TestUserAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("updating user with linked peers", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -1409,14 +1408,14 @@ func TestUserAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
peer4UpdMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer4.ID)
|
||||
t.Cleanup(func() {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer4.ID, peer4UpdMsg.sessionID)
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer4.ID)
|
||||
})
|
||||
|
||||
// deleting user with linked peers should update account peers and send peer update
|
||||
t.Run("deleting user with linked peers", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, peer4UpdMsg.channel)
|
||||
peerShouldReceiveUpdate(t, peer4UpdMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
|
||||
@@ -16,7 +16,6 @@ import (
|
||||
|
||||
var (
|
||||
relayCleanupInterval = 60 * time.Second
|
||||
keepUnusedServerTime = 5 * time.Second
|
||||
|
||||
ErrRelayClientNotConnected = fmt.Errorf("relay client not connected")
|
||||
)
|
||||
@@ -28,13 +27,10 @@ type RelayTrack struct {
|
||||
sync.RWMutex
|
||||
relayClient *Client
|
||||
err error
|
||||
created time.Time
|
||||
}
|
||||
|
||||
func NewRelayTrack() *RelayTrack {
|
||||
return &RelayTrack{
|
||||
created: time.Now(),
|
||||
}
|
||||
return &RelayTrack{}
|
||||
}
|
||||
|
||||
type OnServerCloseListener func()
|
||||
@@ -306,18 +302,6 @@ func (m *Manager) cleanUpUnusedRelays() {
|
||||
|
||||
for addr, rt := range m.relayClients {
|
||||
rt.Lock()
|
||||
// if the connection failed to the server the relay client will be nil
|
||||
// but the instance will be kept in the relayClients until the next locking
|
||||
if rt.err != nil {
|
||||
rt.Unlock()
|
||||
continue
|
||||
}
|
||||
|
||||
if time.Since(rt.created) <= keepUnusedServerTime {
|
||||
rt.Unlock()
|
||||
continue
|
||||
}
|
||||
|
||||
if rt.relayClient.HasConns() {
|
||||
rt.Unlock()
|
||||
continue
|
||||
|
||||
@@ -288,9 +288,8 @@ func TestForeginAutoClose(t *testing.T) {
|
||||
t.Fatalf("failed to close connection: %s", err)
|
||||
}
|
||||
|
||||
timeout := relayCleanupInterval + keepUnusedServerTime + 1*time.Second
|
||||
t.Logf("waiting for relay cleanup: %s", timeout)
|
||||
time.Sleep(timeout)
|
||||
t.Logf("waiting for relay cleanup: %s", relayCleanupInterval+1*time.Second)
|
||||
time.Sleep(relayCleanupInterval + 1*time.Second)
|
||||
if len(mgr.relayClients) != 0 {
|
||||
t.Errorf("expected 0, got %d", len(mgr.relayClients))
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestServerPicker_UnavailableServers(t *testing.T) {
|
||||
@@ -12,7 +13,7 @@ func TestServerPicker_UnavailableServers(t *testing.T) {
|
||||
PeerID: "test",
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), connectionTimeout+1)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
|
||||
Reference in New Issue
Block a user