Compare commits
9 Commits
fix/ssh-pr
...
feature/em
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
55824890da | ||
|
|
1c80258f49 | ||
|
|
9aaa05e8ea | ||
|
|
0af5a0441f | ||
|
|
0fc63ea0ba | ||
|
|
0b329f7881 | ||
|
|
5b85edb753 | ||
|
|
17cfa5fe1e | ||
|
|
2313494e0e |
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
@@ -170,6 +170,7 @@ jobs:
|
||||
run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu
|
||||
|
||||
- name: Decode GPG signing key
|
||||
if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository
|
||||
env:
|
||||
GPG_RPM_PRIVATE_KEY: ${{ secrets.GPG_RPM_PRIVATE_KEY }}
|
||||
run: |
|
||||
@@ -309,6 +310,7 @@ jobs:
|
||||
run: sudo apt update && sudo apt install -y -q libappindicator3-dev gir1.2-appindicator3-0.1 libxxf86vm-dev gcc-mingw-w64-x86-64
|
||||
|
||||
- name: Decode GPG signing key
|
||||
if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository
|
||||
env:
|
||||
GPG_RPM_PRIVATE_KEY: ${{ secrets.GPG_RPM_PRIVATE_KEY }}
|
||||
run: |
|
||||
|
||||
@@ -171,6 +171,7 @@ nfpms:
|
||||
- maintainer: Netbird <dev@netbird.io>
|
||||
description: Netbird client.
|
||||
homepage: https://netbird.io/
|
||||
license: BSD-3-Clause
|
||||
id: netbird_deb
|
||||
bindir: /usr/bin
|
||||
builds:
|
||||
@@ -184,6 +185,7 @@ nfpms:
|
||||
- maintainer: Netbird <dev@netbird.io>
|
||||
description: Netbird client.
|
||||
homepage: https://netbird.io/
|
||||
license: BSD-3-Clause
|
||||
id: netbird_rpm
|
||||
bindir: /usr/bin
|
||||
builds:
|
||||
|
||||
@@ -181,10 +181,11 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
|
||||
if stateWasDown {
|
||||
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
|
||||
return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
|
||||
cmd.PrintErrf("Failed to bring service up: %v\n", status.Convert(err).Message())
|
||||
} else {
|
||||
cmd.Println("netbird up")
|
||||
time.Sleep(time.Second * 10)
|
||||
}
|
||||
cmd.Println("netbird up")
|
||||
time.Sleep(time.Second * 10)
|
||||
}
|
||||
|
||||
initialLevelTrace := initialLogLevel.GetLevel() >= proto.LogLevel_TRACE
|
||||
@@ -199,9 +200,10 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
|
||||
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
|
||||
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
|
||||
cmd.PrintErrf("Failed to bring service down: %v\n", status.Convert(err).Message())
|
||||
} else {
|
||||
cmd.Println("netbird down")
|
||||
}
|
||||
cmd.Println("netbird down")
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
@@ -209,13 +211,14 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
if _, err := client.SetSyncResponsePersistence(cmd.Context(), &proto.SetSyncResponsePersistenceRequest{
|
||||
Enabled: true,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to enable sync response persistence: %v", status.Convert(err).Message())
|
||||
cmd.PrintErrf("Failed to enable sync response persistence: %v\n", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
|
||||
return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
|
||||
cmd.PrintErrf("Failed to bring service up: %v\n", status.Convert(err).Message())
|
||||
} else {
|
||||
cmd.Println("netbird up")
|
||||
}
|
||||
cmd.Println("netbird up")
|
||||
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
@@ -263,16 +266,18 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
|
||||
if stateWasDown {
|
||||
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
|
||||
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
|
||||
cmd.PrintErrf("Failed to restore service down state: %v\n", status.Convert(err).Message())
|
||||
} else {
|
||||
cmd.Println("netbird down")
|
||||
}
|
||||
cmd.Println("netbird down")
|
||||
}
|
||||
|
||||
if !initialLevelTrace {
|
||||
if _, err := client.SetLogLevel(cmd.Context(), &proto.SetLogLevelRequest{Level: initialLogLevel.GetLevel()}); err != nil {
|
||||
return fmt.Errorf("failed to restore log level: %v", status.Convert(err).Message())
|
||||
cmd.PrintErrf("Failed to restore log level: %v\n", status.Convert(err).Message())
|
||||
} else {
|
||||
cmd.Println("Log level restored to", initialLogLevel.GetLevel())
|
||||
}
|
||||
cmd.Println("Log level restored to", initialLogLevel.GetLevel())
|
||||
}
|
||||
|
||||
cmd.Printf("Local file:\n%s\n", resp.GetPath())
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/expose"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
@@ -211,16 +212,21 @@ func exposeFn(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
|
||||
func toExposeProtocol(exposeProtocol string) (proto.ExposeProtocol, error) {
|
||||
switch strings.ToLower(exposeProtocol) {
|
||||
case "http":
|
||||
p, err := expose.ParseProtocolType(exposeProtocol)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid protocol: %w", err)
|
||||
}
|
||||
|
||||
switch p {
|
||||
case expose.ProtocolHTTP:
|
||||
return proto.ExposeProtocol_EXPOSE_HTTP, nil
|
||||
case "https":
|
||||
case expose.ProtocolHTTPS:
|
||||
return proto.ExposeProtocol_EXPOSE_HTTPS, nil
|
||||
case "tcp":
|
||||
case expose.ProtocolTCP:
|
||||
return proto.ExposeProtocol_EXPOSE_TCP, nil
|
||||
case "udp":
|
||||
case expose.ProtocolUDP:
|
||||
return proto.ExposeProtocol_EXPOSE_UDP, nil
|
||||
case "tls":
|
||||
case expose.ProtocolTLS:
|
||||
return proto.ExposeProtocol_EXPOSE_TLS, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("unsupported protocol %q: must be http, https, tcp, udp, or tls", exposeProtocol)
|
||||
|
||||
@@ -33,14 +33,14 @@ var (
|
||||
ErrConfigNotInitialized = errors.New("config not initialized")
|
||||
)
|
||||
|
||||
// PeerConnStatus is a peer's connection status.
|
||||
type PeerConnStatus = peer.ConnStatus
|
||||
|
||||
const (
|
||||
// PeerStatusConnected indicates the peer is in connected state.
|
||||
PeerStatusConnected = peer.StatusConnected
|
||||
)
|
||||
|
||||
// PeerConnStatus is a peer's connection status.
|
||||
type PeerConnStatus = peer.ConnStatus
|
||||
|
||||
// Client manages a netbird embedded client instance.
|
||||
type Client struct {
|
||||
deviceName string
|
||||
@@ -375,6 +375,33 @@ func (c *Client) NewHTTPClient() *http.Client {
|
||||
}
|
||||
}
|
||||
|
||||
// Expose exposes a local service via the NetBird reverse proxy, making it accessible through a public URL.
|
||||
// It returns an ExposeSession. Call Wait on the session to keep it alive.
|
||||
func (c *Client) Expose(ctx context.Context, req ExposeRequest) (*ExposeSession, error) {
|
||||
engine, err := c.getEngine()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mgr := engine.GetExposeManager()
|
||||
if mgr == nil {
|
||||
return nil, fmt.Errorf("expose manager not available")
|
||||
}
|
||||
|
||||
resp, err := mgr.Expose(ctx, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("expose: %w", err)
|
||||
}
|
||||
|
||||
return &ExposeSession{
|
||||
Domain: resp.Domain,
|
||||
ServiceName: resp.ServiceName,
|
||||
ServiceURL: resp.ServiceURL,
|
||||
mgr: mgr,
|
||||
ctx: ctx,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Status returns the current status of the client.
|
||||
func (c *Client) Status() (peer.FullStatus, error) {
|
||||
c.mu.Lock()
|
||||
|
||||
42
client/embed/expose.go
Normal file
42
client/embed/expose.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package embed
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/expose"
|
||||
)
|
||||
|
||||
const (
|
||||
// ExposeProtocolHTTP exposes the service as HTTP.
|
||||
ExposeProtocolHTTP = expose.ProtocolHTTP
|
||||
// ExposeProtocolHTTPS exposes the service as HTTPS.
|
||||
ExposeProtocolHTTPS = expose.ProtocolHTTPS
|
||||
// ExposeProtocolTCP exposes the service as TCP.
|
||||
ExposeProtocolTCP = expose.ProtocolTCP
|
||||
// ExposeProtocolUDP exposes the service as UDP.
|
||||
ExposeProtocolUDP = expose.ProtocolUDP
|
||||
// ExposeProtocolTLS exposes the service as TLS.
|
||||
ExposeProtocolTLS = expose.ProtocolTLS
|
||||
)
|
||||
|
||||
// ExposeRequest is a request to expose a local service via the NetBird reverse proxy.
|
||||
type ExposeRequest = expose.Request
|
||||
|
||||
// ExposeProtocolType represents the protocol used for exposing a service.
|
||||
type ExposeProtocolType = expose.ProtocolType
|
||||
|
||||
// ExposeSession represents an active expose session. Use Wait to block until the session ends.
|
||||
type ExposeSession struct {
|
||||
Domain string
|
||||
ServiceName string
|
||||
ServiceURL string
|
||||
|
||||
mgr *expose.Manager
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
// Wait blocks while keeping the expose session alive.
|
||||
// It returns when ctx is cancelled or a keep-alive error occurs, then terminates the session.
|
||||
func (s *ExposeSession) Wait() error {
|
||||
return s.mgr.KeepAlive(s.ctx, s.Domain)
|
||||
}
|
||||
@@ -4,11 +4,14 @@ import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||
)
|
||||
|
||||
const renewTimeout = 10 * time.Second
|
||||
const (
|
||||
renewTimeout = 10 * time.Second
|
||||
)
|
||||
|
||||
// Response holds the response from exposing a service.
|
||||
type Response struct {
|
||||
@@ -22,7 +25,7 @@ type Request struct {
|
||||
NamePrefix string
|
||||
Domain string
|
||||
Port uint16
|
||||
Protocol int
|
||||
Protocol ProtocolType
|
||||
Pin string
|
||||
Password string
|
||||
UserGroups []string
|
||||
|
||||
@@ -86,7 +86,7 @@ func TestNewRequest(t *testing.T) {
|
||||
exposeReq := NewRequest(req)
|
||||
|
||||
assert.Equal(t, uint16(8080), exposeReq.Port, "port should match")
|
||||
assert.Equal(t, int(daemonProto.ExposeProtocol_EXPOSE_HTTPS), exposeReq.Protocol, "protocol should match")
|
||||
assert.Equal(t, ProtocolType(daemonProto.ExposeProtocol_EXPOSE_HTTPS), exposeReq.Protocol, "protocol should match")
|
||||
assert.Equal(t, "123456", exposeReq.Pin, "pin should match")
|
||||
assert.Equal(t, "secret", exposeReq.Password, "password should match")
|
||||
assert.Equal(t, []string{"group1", "group2"}, exposeReq.UserGroups, "user groups should match")
|
||||
|
||||
40
client/internal/expose/protocol.go
Normal file
40
client/internal/expose/protocol.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package expose
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ProtocolType represents the protocol used for exposing a service.
|
||||
type ProtocolType int
|
||||
|
||||
const (
|
||||
// ProtocolHTTP exposes the service as HTTP.
|
||||
ProtocolHTTP ProtocolType = 0
|
||||
// ProtocolHTTPS exposes the service as HTTPS.
|
||||
ProtocolHTTPS ProtocolType = 1
|
||||
// ProtocolTCP exposes the service as TCP.
|
||||
ProtocolTCP ProtocolType = 2
|
||||
// ProtocolUDP exposes the service as UDP.
|
||||
ProtocolUDP ProtocolType = 3
|
||||
// ProtocolTLS exposes the service as TLS.
|
||||
ProtocolTLS ProtocolType = 4
|
||||
)
|
||||
|
||||
// ParseProtocolType parses a protocol string into a ProtocolType.
|
||||
func ParseProtocolType(s string) (ProtocolType, error) {
|
||||
switch strings.ToLower(s) {
|
||||
case "http":
|
||||
return ProtocolHTTP, nil
|
||||
case "https":
|
||||
return ProtocolHTTPS, nil
|
||||
case "tcp":
|
||||
return ProtocolTCP, nil
|
||||
case "udp":
|
||||
return ProtocolUDP, nil
|
||||
case "tls":
|
||||
return ProtocolTLS, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("unsupported protocol %q: must be http, https, tcp, udp, or tls", s)
|
||||
}
|
||||
}
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
func NewRequest(req *daemonProto.ExposeServiceRequest) *Request {
|
||||
return &Request{
|
||||
Port: uint16(req.Port),
|
||||
Protocol: int(req.Protocol),
|
||||
Protocol: ProtocolType(req.Protocol),
|
||||
Pin: req.Pin,
|
||||
Password: req.Password,
|
||||
UserGroups: req.UserGroups,
|
||||
@@ -24,7 +24,7 @@ func toClientExposeRequest(req Request) mgm.ExposeRequest {
|
||||
NamePrefix: req.NamePrefix,
|
||||
Domain: req.Domain,
|
||||
Port: req.Port,
|
||||
Protocol: req.Protocol,
|
||||
Protocol: int(req.Protocol),
|
||||
Pin: req.Pin,
|
||||
Password: req.Password,
|
||||
UserGroups: req.UserGroups,
|
||||
|
||||
@@ -141,7 +141,7 @@ func (p *SSHProxy) runProxySSHServer(jwtToken string) error {
|
||||
|
||||
func (p *SSHProxy) handleSSHSession(session ssh.Session) {
|
||||
ptyReq, winCh, isPty := session.Pty()
|
||||
hasCommand := session.RawCommand() != ""
|
||||
hasCommand := len(session.Command()) > 0
|
||||
|
||||
sshClient, err := p.getOrCreateBackendClient(session.Context(), session.User())
|
||||
if err != nil {
|
||||
@@ -180,7 +180,7 @@ func (p *SSHProxy) handleSSHSession(session ssh.Session) {
|
||||
}
|
||||
|
||||
if hasCommand {
|
||||
if err := serverSession.Run(session.RawCommand()); err != nil {
|
||||
if err := serverSession.Run(strings.Join(session.Command(), " ")); err != nil {
|
||||
log.Debugf("run command: %v", err)
|
||||
p.handleProxyExitCode(session, err)
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
@@ -246,191 +245,6 @@ func TestSSHProxy_Connect(t *testing.T) {
|
||||
cancel()
|
||||
}
|
||||
|
||||
// TestSSHProxy_CommandQuoting verifies that the proxy preserves shell quoting
|
||||
// when forwarding commands to the backend. This is critical for tools like
|
||||
// Ansible that send commands such as:
|
||||
//
|
||||
// /bin/sh -c '( umask 77 && mkdir -p ... ) && sleep 0'
|
||||
//
|
||||
// The single quotes must be preserved so the backend shell receives the
|
||||
// subshell expression as a single argument to -c.
|
||||
func TestSSHProxy_CommandQuoting(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping integration test in short mode")
|
||||
}
|
||||
|
||||
sshClient, cleanup := setupProxySSHClient(t)
|
||||
defer cleanup()
|
||||
|
||||
// These commands simulate what the SSH protocol delivers as exec payloads.
|
||||
// When a user types: ssh host '/bin/sh -c "( echo hello )"'
|
||||
// the local shell strips the outer single quotes, and the SSH exec request
|
||||
// contains the raw string: /bin/sh -c "( echo hello )"
|
||||
//
|
||||
// The proxy must forward this string verbatim. Using session.Command()
|
||||
// (shlex.Split + strings.Join) strips the inner double quotes, breaking
|
||||
// the command on the backend.
|
||||
tests := []struct {
|
||||
name string
|
||||
command string
|
||||
expect string
|
||||
}{
|
||||
{
|
||||
name: "subshell_in_double_quotes",
|
||||
command: `/bin/sh -c "( echo from-subshell ) && echo outer"`,
|
||||
expect: "from-subshell\nouter\n",
|
||||
},
|
||||
{
|
||||
name: "printf_with_special_chars",
|
||||
command: `/bin/sh -c "printf '%s\n' 'hello world'"`,
|
||||
expect: "hello world\n",
|
||||
},
|
||||
{
|
||||
name: "nested_command_substitution",
|
||||
command: `/bin/sh -c "echo $(echo nested)"`,
|
||||
expect: "nested\n",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
session, err := sshClient.NewSession()
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = session.Close() }()
|
||||
|
||||
var stderrBuf bytes.Buffer
|
||||
session.Stderr = &stderrBuf
|
||||
|
||||
outputCh := make(chan []byte, 1)
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
output, err := session.Output(tc.command)
|
||||
outputCh <- output
|
||||
errCh <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case output := <-outputCh:
|
||||
err := <-errCh
|
||||
if stderrBuf.Len() > 0 {
|
||||
t.Logf("stderr: %s", stderrBuf.String())
|
||||
}
|
||||
require.NoError(t, err, "command should succeed: %s", tc.command)
|
||||
assert.Equal(t, tc.expect, string(output), "output mismatch for: %s", tc.command)
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatalf("command timed out: %s", tc.command)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// setupProxySSHClient creates a full proxy test environment and returns
|
||||
// an SSH client connected through the proxy to a backend NetBird SSH server.
|
||||
func setupProxySSHClient(t *testing.T) (*cryptossh.Client, func()) {
|
||||
t.Helper()
|
||||
|
||||
const (
|
||||
issuer = "https://test-issuer.example.com"
|
||||
audience = "test-audience"
|
||||
)
|
||||
|
||||
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
|
||||
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
hostPubKey, err := nbssh.GeneratePublicKey(hostKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverConfig := &server.Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: &server.JWTConfig{
|
||||
Issuer: issuer,
|
||||
Audiences: []string{audience},
|
||||
KeysLocation: jwksURL,
|
||||
},
|
||||
}
|
||||
sshServer := server.New(serverConfig)
|
||||
sshServer.SetAllowRootLogin(true)
|
||||
|
||||
testUsername := testutil.GetTestUsername(t)
|
||||
testJWTUser := "test-username"
|
||||
testUserHash, err := sshuserhash.HashUserID(testJWTUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
authConfig := &sshauth.Config{
|
||||
UserIDClaim: sshauth.DefaultUserIDClaim,
|
||||
AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
|
||||
MachineUsers: map[string][]uint32{
|
||||
testUsername: {0},
|
||||
},
|
||||
}
|
||||
sshServer.UpdateSSHAuth(authConfig)
|
||||
|
||||
sshServerAddr := server.StartTestServer(t, sshServer)
|
||||
|
||||
mockDaemon := startMockDaemon(t)
|
||||
|
||||
host, portStr, err := net.SplitHostPort(sshServerAddr)
|
||||
require.NoError(t, err)
|
||||
port, err := strconv.Atoi(portStr)
|
||||
require.NoError(t, err)
|
||||
|
||||
mockDaemon.setHostKey(host, hostPubKey)
|
||||
|
||||
validToken := generateValidJWT(t, privateKey, issuer, audience, testJWTUser)
|
||||
mockDaemon.setJWTToken(validToken)
|
||||
|
||||
proxyInstance, err := New(mockDaemon.addr, host, port, io.Discard, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
origStdin := os.Stdin
|
||||
origStdout := os.Stdout
|
||||
|
||||
stdinReader, stdinWriter, err := os.Pipe()
|
||||
require.NoError(t, err)
|
||||
stdoutReader, stdoutWriter, err := os.Pipe()
|
||||
require.NoError(t, err)
|
||||
|
||||
os.Stdin = stdinReader
|
||||
os.Stdout = stdoutWriter
|
||||
|
||||
clientConn, proxyConn := net.Pipe()
|
||||
|
||||
go func() { _, _ = io.Copy(stdinWriter, proxyConn) }()
|
||||
go func() { _, _ = io.Copy(proxyConn, stdoutReader) }()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
|
||||
go func() {
|
||||
_ = proxyInstance.Connect(ctx)
|
||||
}()
|
||||
|
||||
sshConfig := &cryptossh.ClientConfig{
|
||||
User: testutil.GetTestUsername(t),
|
||||
Auth: []cryptossh.AuthMethod{},
|
||||
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
sshClientConn, chans, reqs, err := cryptossh.NewClientConn(clientConn, "test", sshConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
client := cryptossh.NewClient(sshClientConn, chans, reqs)
|
||||
|
||||
cleanupFn := func() {
|
||||
_ = client.Close()
|
||||
_ = clientConn.Close()
|
||||
cancel()
|
||||
os.Stdin = origStdin
|
||||
os.Stdout = origStdout
|
||||
_ = sshServer.Stop()
|
||||
mockDaemon.stop()
|
||||
jwksServer.Close()
|
||||
}
|
||||
|
||||
return client, cleanupFn
|
||||
}
|
||||
|
||||
type mockDaemonServer struct {
|
||||
proto.UnimplementedDaemonServiceServer
|
||||
hostKeys map[string][]byte
|
||||
|
||||
@@ -60,7 +60,7 @@ func (s *Server) sessionHandler(session ssh.Session) {
|
||||
}
|
||||
|
||||
ptyReq, winCh, isPty := session.Pty()
|
||||
hasCommand := session.RawCommand() != ""
|
||||
hasCommand := len(session.Command()) > 0
|
||||
|
||||
if isPty && !hasCommand {
|
||||
// ssh <host> - PTY interactive session (login)
|
||||
|
||||
3
go.mod
3
go.mod
@@ -30,10 +30,10 @@ require (
|
||||
require (
|
||||
fyne.io/fyne/v2 v2.7.0
|
||||
fyne.io/systray v1.12.1-0.20260116214250-81f8e1a496f9
|
||||
github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible
|
||||
github.com/awnumar/memguard v0.23.0
|
||||
github.com/aws/aws-sdk-go-v2 v1.36.3
|
||||
github.com/aws/aws-sdk-go-v2/config v1.29.14
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.17.67
|
||||
github.com/aws/aws-sdk-go-v2/service/s3 v1.79.2
|
||||
github.com/c-robinson/iplib v1.0.3
|
||||
github.com/caddyserver/certmagic v0.21.3
|
||||
@@ -144,7 +144,6 @@ require (
|
||||
github.com/apapsch/go-jsonmerge/v2 v2.0.0 // indirect
|
||||
github.com/awnumar/memcall v0.4.0 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.10 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.17.67 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34 // indirect
|
||||
|
||||
2
go.sum
2
go.sum
@@ -34,8 +34,6 @@ github.com/Masterminds/sprig/v3 v3.3.0/go.mod h1:Zy1iXRYNqNLUolqCpL4uhk6SHUMAOSC
|
||||
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
|
||||
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
|
||||
github.com/RaveNoX/go-jsoncommentstrip v1.0.0/go.mod h1:78ihd09MekBnJnxpICcwzCMzGrKSKYe4AqU6PDYYpjk=
|
||||
github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible h1:hqcTK6ZISdip65SR792lwYJTa/axESA0889D3UlZbLo=
|
||||
github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible/go.mod h1:6B1nuc1MUs6c62ODZDl7hVE5Pv7O2XGSkgg2olnq34I=
|
||||
github.com/alexbrainman/sspi v0.0.0-20250919150558-7d374ff0d59e h1:4dAU9FXIyQktpoUAgOJK3OTFc/xug0PCXYCqU0FgDKI=
|
||||
github.com/alexbrainman/sspi v0.0.0-20250919150558-7d374ff0d59e/go.mod h1:cEWa1LVoE5KvSD9ONXsZrj0z6KqySlCCNKHlLzbqAt4=
|
||||
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8=
|
||||
|
||||
@@ -262,7 +262,9 @@ func (s *Service) ToAPIResponse() *api.Service {
|
||||
if opts == nil {
|
||||
opts = &api.ServiceTargetOptions{}
|
||||
}
|
||||
opts.ProxyProtocol = &target.ProxyProtocol
|
||||
if target.ProxyProtocol {
|
||||
opts.ProxyProtocol = &target.ProxyProtocol
|
||||
}
|
||||
st.Options = opts
|
||||
apiTargets = append(apiTargets, st)
|
||||
}
|
||||
@@ -848,7 +850,7 @@ func IsPortBasedProtocol(mode string) bool {
|
||||
}
|
||||
|
||||
const (
|
||||
maxCustomHeaders = 16
|
||||
maxCustomHeaders = 16
|
||||
maxHeaderKeyLen = 128
|
||||
maxHeaderValueLen = 4096
|
||||
)
|
||||
@@ -945,7 +947,6 @@ func containsCRLF(s string) bool {
|
||||
}
|
||||
|
||||
func validateHeaderAuths(headers []*HeaderAuthConfig) error {
|
||||
seen := make(map[string]struct{})
|
||||
for i, h := range headers {
|
||||
if h == nil || !h.Enabled {
|
||||
continue
|
||||
@@ -966,10 +967,6 @@ func validateHeaderAuths(headers []*HeaderAuthConfig) error {
|
||||
if canonical == "Host" {
|
||||
return fmt.Errorf("header_auths[%d]: Host header cannot be used for auth", i)
|
||||
}
|
||||
if _, dup := seen[canonical]; dup {
|
||||
return fmt.Errorf("header_auths[%d]: duplicate header %q (same canonical form already configured)", i, h.Header)
|
||||
}
|
||||
seen[canonical] = struct{}{}
|
||||
if len(h.Value) > maxHeaderValueLen {
|
||||
return fmt.Errorf("header_auths[%d]: value exceeds maximum length of %d", i, maxHeaderValueLen)
|
||||
}
|
||||
|
||||
@@ -935,3 +935,107 @@ func TestExposeServiceRequest_Validate_HTTPAllowsAuth(t *testing.T) {
|
||||
req := ExposeServiceRequest{Port: 8080, Mode: "http", Pin: "123456"}
|
||||
require.NoError(t, req.Validate())
|
||||
}
|
||||
|
||||
func TestValidate_HeaderAuths(t *testing.T) {
|
||||
t.Run("single valid header", func(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Auth = AuthConfig{
|
||||
HeaderAuths: []*HeaderAuthConfig{
|
||||
{Enabled: true, Header: "X-API-Key", Value: "secret"},
|
||||
},
|
||||
}
|
||||
require.NoError(t, rp.Validate())
|
||||
})
|
||||
|
||||
t.Run("multiple headers same canonical name allowed", func(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Auth = AuthConfig{
|
||||
HeaderAuths: []*HeaderAuthConfig{
|
||||
{Enabled: true, Header: "Authorization", Value: "Bearer token-1"},
|
||||
{Enabled: true, Header: "Authorization", Value: "Bearer token-2"},
|
||||
},
|
||||
}
|
||||
require.NoError(t, rp.Validate())
|
||||
})
|
||||
|
||||
t.Run("multiple headers different case same canonical allowed", func(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Auth = AuthConfig{
|
||||
HeaderAuths: []*HeaderAuthConfig{
|
||||
{Enabled: true, Header: "x-api-key", Value: "key-1"},
|
||||
{Enabled: true, Header: "X-Api-Key", Value: "key-2"},
|
||||
},
|
||||
}
|
||||
require.NoError(t, rp.Validate())
|
||||
})
|
||||
|
||||
t.Run("multiple different headers allowed", func(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Auth = AuthConfig{
|
||||
HeaderAuths: []*HeaderAuthConfig{
|
||||
{Enabled: true, Header: "Authorization", Value: "Bearer tok"},
|
||||
{Enabled: true, Header: "X-API-Key", Value: "key"},
|
||||
},
|
||||
}
|
||||
require.NoError(t, rp.Validate())
|
||||
})
|
||||
|
||||
t.Run("empty header name rejected", func(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Auth = AuthConfig{
|
||||
HeaderAuths: []*HeaderAuthConfig{
|
||||
{Enabled: true, Header: "", Value: "val"},
|
||||
},
|
||||
}
|
||||
err := rp.Validate()
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "header name is required")
|
||||
})
|
||||
|
||||
t.Run("hop-by-hop header rejected", func(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Auth = AuthConfig{
|
||||
HeaderAuths: []*HeaderAuthConfig{
|
||||
{Enabled: true, Header: "Connection", Value: "val"},
|
||||
},
|
||||
}
|
||||
err := rp.Validate()
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "hop-by-hop")
|
||||
})
|
||||
|
||||
t.Run("host header rejected", func(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Auth = AuthConfig{
|
||||
HeaderAuths: []*HeaderAuthConfig{
|
||||
{Enabled: true, Header: "Host", Value: "val"},
|
||||
},
|
||||
}
|
||||
err := rp.Validate()
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "Host header cannot be used")
|
||||
})
|
||||
|
||||
t.Run("disabled entries skipped", func(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Auth = AuthConfig{
|
||||
HeaderAuths: []*HeaderAuthConfig{
|
||||
{Enabled: false, Header: "", Value: ""},
|
||||
{Enabled: true, Header: "X-Key", Value: "val"},
|
||||
},
|
||||
}
|
||||
require.NoError(t, rp.Validate())
|
||||
})
|
||||
|
||||
t.Run("value too long rejected", func(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Auth = AuthConfig{
|
||||
HeaderAuths: []*HeaderAuthConfig{
|
||||
{Enabled: true, Header: "X-Key", Value: strings.Repeat("a", maxHeaderValueLen+1)},
|
||||
},
|
||||
}
|
||||
err := rp.Validate()
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "exceeds maximum length")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -197,6 +197,7 @@ func NewManager(ctx context.Context, config Config, appMetrics telemetry.AppMetr
|
||||
case "jumpcloud":
|
||||
return NewJumpCloudManager(JumpCloudClientConfig{
|
||||
APIToken: config.ExtraConfig["ApiToken"],
|
||||
ApiUrl: config.ExtraConfig["ApiUrl"],
|
||||
}, appMetrics)
|
||||
case "pocketid":
|
||||
return NewPocketIdManager(PocketIdClientConfig{
|
||||
|
||||
@@ -1,24 +1,40 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
v1 "github.com/TheJumpCloud/jcapi-go/v1"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
const (
|
||||
contentType = "application/json"
|
||||
accept = "application/json"
|
||||
jumpCloudDefaultApiUrl = "https://console.jumpcloud.com"
|
||||
jumpCloudSearchPageSize = 100
|
||||
)
|
||||
|
||||
// jumpCloudUser represents a JumpCloud V1 API system user.
|
||||
type jumpCloudUser struct {
|
||||
ID string `json:"_id"`
|
||||
Email string `json:"email"`
|
||||
Firstname string `json:"firstname"`
|
||||
Middlename string `json:"middlename"`
|
||||
Lastname string `json:"lastname"`
|
||||
}
|
||||
|
||||
// jumpCloudUserList represents the response from the JumpCloud search endpoint.
|
||||
type jumpCloudUserList struct {
|
||||
Results []jumpCloudUser `json:"results"`
|
||||
TotalCount int `json:"totalCount"`
|
||||
}
|
||||
|
||||
// JumpCloudManager JumpCloud manager client instance.
|
||||
type JumpCloudManager struct {
|
||||
client *v1.APIClient
|
||||
apiBase string
|
||||
apiToken string
|
||||
httpClient ManagerHTTPClient
|
||||
credentials ManagerCredentials
|
||||
@@ -29,6 +45,7 @@ type JumpCloudManager struct {
|
||||
// JumpCloudClientConfig JumpCloud manager client configurations.
|
||||
type JumpCloudClientConfig struct {
|
||||
APIToken string
|
||||
ApiUrl string
|
||||
}
|
||||
|
||||
// JumpCloudCredentials JumpCloud authentication information.
|
||||
@@ -55,7 +72,15 @@ func NewJumpCloudManager(config JumpCloudClientConfig, appMetrics telemetry.AppM
|
||||
return nil, fmt.Errorf("jumpCloud IdP configuration is incomplete, ApiToken is missing")
|
||||
}
|
||||
|
||||
client := v1.NewAPIClient(v1.NewConfiguration())
|
||||
apiBase := config.ApiUrl
|
||||
if apiBase == "" {
|
||||
apiBase = jumpCloudDefaultApiUrl
|
||||
}
|
||||
apiBase = strings.TrimSuffix(apiBase, "/")
|
||||
if !strings.HasSuffix(apiBase, "/api") {
|
||||
apiBase += "/api"
|
||||
}
|
||||
|
||||
credentials := &JumpCloudCredentials{
|
||||
clientConfig: config,
|
||||
httpClient: httpClient,
|
||||
@@ -64,7 +89,7 @@ func NewJumpCloudManager(config JumpCloudClientConfig, appMetrics telemetry.AppM
|
||||
}
|
||||
|
||||
return &JumpCloudManager{
|
||||
client: client,
|
||||
apiBase: apiBase,
|
||||
apiToken: config.APIToken,
|
||||
httpClient: httpClient,
|
||||
credentials: credentials,
|
||||
@@ -78,37 +103,58 @@ func (jc *JumpCloudCredentials) Authenticate(_ context.Context) (JWTToken, error
|
||||
return JWTToken{}, nil
|
||||
}
|
||||
|
||||
func (jm *JumpCloudManager) authenticationContext() context.Context {
|
||||
return context.WithValue(context.Background(), v1.ContextAPIKey, v1.APIKey{
|
||||
Key: jm.apiToken,
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
|
||||
func (jm *JumpCloudManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUserDataByID requests user data from JumpCloud via ID.
|
||||
func (jm *JumpCloudManager) GetUserDataByID(_ context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
|
||||
authCtx := jm.authenticationContext()
|
||||
user, resp, err := jm.client.SystemusersApi.SystemusersGet(authCtx, userID, contentType, accept, nil)
|
||||
// doRequest executes an HTTP request against the JumpCloud V1 API.
|
||||
func (jm *JumpCloudManager) doRequest(ctx context.Context, method, path string, body io.Reader) ([]byte, error) {
|
||||
reqURL := jm.apiBase + path
|
||||
req, err := http.NewRequestWithContext(ctx, method, reqURL, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Set("x-api-key", jm.apiToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := jm.httpClient.Do(req)
|
||||
if err != nil {
|
||||
if jm.appMetrics != nil {
|
||||
jm.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if jm.appMetrics != nil {
|
||||
jm.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return nil, fmt.Errorf("unable to get user %s, statusCode %d", userID, resp.StatusCode)
|
||||
return nil, fmt.Errorf("JumpCloud API request %s %s failed with status %d", method, path, resp.StatusCode)
|
||||
}
|
||||
|
||||
return io.ReadAll(resp.Body)
|
||||
}
|
||||
|
||||
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
|
||||
func (jm *JumpCloudManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUserDataByID requests user data from JumpCloud via ID.
|
||||
func (jm *JumpCloudManager) GetUserDataByID(ctx context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
|
||||
body, err := jm.doRequest(ctx, http.MethodGet, "/systemusers/"+userID, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if jm.appMetrics != nil {
|
||||
jm.appMetrics.IDPMetrics().CountGetUserDataByID()
|
||||
}
|
||||
|
||||
var user jumpCloudUser
|
||||
if err = jm.helper.Unmarshal(body, &user); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userData := parseJumpCloudUser(user)
|
||||
userData.AppMetadata = appMetadata
|
||||
|
||||
@@ -116,30 +162,20 @@ func (jm *JumpCloudManager) GetUserDataByID(_ context.Context, userID string, ap
|
||||
}
|
||||
|
||||
// GetAccount returns all the users for a given profile.
|
||||
func (jm *JumpCloudManager) GetAccount(_ context.Context, accountID string) ([]*UserData, error) {
|
||||
authCtx := jm.authenticationContext()
|
||||
userList, resp, err := jm.client.SearchApi.SearchSystemusersPost(authCtx, contentType, accept, nil)
|
||||
func (jm *JumpCloudManager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) {
|
||||
allUsers, err := jm.searchAllUsers(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if jm.appMetrics != nil {
|
||||
jm.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return nil, fmt.Errorf("unable to get account %s users, statusCode %d", accountID, resp.StatusCode)
|
||||
}
|
||||
|
||||
if jm.appMetrics != nil {
|
||||
jm.appMetrics.IDPMetrics().CountGetAccount()
|
||||
}
|
||||
|
||||
users := make([]*UserData, 0)
|
||||
for _, user := range userList.Results {
|
||||
users := make([]*UserData, 0, len(allUsers))
|
||||
for _, user := range allUsers {
|
||||
userData := parseJumpCloudUser(user)
|
||||
userData.AppMetadata.WTAccountID = accountID
|
||||
|
||||
users = append(users, userData)
|
||||
}
|
||||
|
||||
@@ -148,27 +184,18 @@ func (jm *JumpCloudManager) GetAccount(_ context.Context, accountID string) ([]*
|
||||
|
||||
// GetAllAccounts gets all registered accounts with corresponding user data.
|
||||
// It returns a list of users indexed by accountID.
|
||||
func (jm *JumpCloudManager) GetAllAccounts(_ context.Context) (map[string][]*UserData, error) {
|
||||
authCtx := jm.authenticationContext()
|
||||
userList, resp, err := jm.client.SearchApi.SearchSystemusersPost(authCtx, contentType, accept, nil)
|
||||
func (jm *JumpCloudManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
|
||||
allUsers, err := jm.searchAllUsers(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if jm.appMetrics != nil {
|
||||
jm.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return nil, fmt.Errorf("unable to get all accounts, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
if jm.appMetrics != nil {
|
||||
jm.appMetrics.IDPMetrics().CountGetAllAccounts()
|
||||
}
|
||||
|
||||
indexedUsers := make(map[string][]*UserData)
|
||||
for _, user := range userList.Results {
|
||||
for _, user := range allUsers {
|
||||
userData := parseJumpCloudUser(user)
|
||||
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData)
|
||||
}
|
||||
@@ -176,6 +203,41 @@ func (jm *JumpCloudManager) GetAllAccounts(_ context.Context) (map[string][]*Use
|
||||
return indexedUsers, nil
|
||||
}
|
||||
|
||||
// searchAllUsers paginates through all system users using limit/skip.
|
||||
func (jm *JumpCloudManager) searchAllUsers(ctx context.Context) ([]jumpCloudUser, error) {
|
||||
var allUsers []jumpCloudUser
|
||||
|
||||
for skip := 0; ; skip += jumpCloudSearchPageSize {
|
||||
searchReq := map[string]int{
|
||||
"limit": jumpCloudSearchPageSize,
|
||||
"skip": skip,
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(searchReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
body, err := jm.doRequest(ctx, http.MethodPost, "/search/systemusers", bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var userList jumpCloudUserList
|
||||
if err = jm.helper.Unmarshal(body, &userList); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
allUsers = append(allUsers, userList.Results...)
|
||||
|
||||
if skip+len(userList.Results) >= userList.TotalCount {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return allUsers, nil
|
||||
}
|
||||
|
||||
// CreateUser creates a new user in JumpCloud Idp and sends an invitation.
|
||||
func (jm *JumpCloudManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) {
|
||||
return nil, fmt.Errorf("method CreateUser not implemented")
|
||||
@@ -183,7 +245,7 @@ func (jm *JumpCloudManager) CreateUser(_ context.Context, _, _, _, _ string) (*U
|
||||
|
||||
// GetUserByEmail searches users with a given email.
|
||||
// If no users have been found, this function returns an empty list.
|
||||
func (jm *JumpCloudManager) GetUserByEmail(_ context.Context, email string) ([]*UserData, error) {
|
||||
func (jm *JumpCloudManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
|
||||
searchFilter := map[string]interface{}{
|
||||
"searchFilter": map[string]interface{}{
|
||||
"filter": []string{email},
|
||||
@@ -191,25 +253,26 @@ func (jm *JumpCloudManager) GetUserByEmail(_ context.Context, email string) ([]*
|
||||
},
|
||||
}
|
||||
|
||||
authCtx := jm.authenticationContext()
|
||||
userList, resp, err := jm.client.SearchApi.SearchSystemusersPost(authCtx, contentType, accept, searchFilter)
|
||||
payload, err := json.Marshal(searchFilter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if jm.appMetrics != nil {
|
||||
jm.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return nil, fmt.Errorf("unable to get user %s, statusCode %d", email, resp.StatusCode)
|
||||
body, err := jm.doRequest(ctx, http.MethodPost, "/search/systemusers", bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if jm.appMetrics != nil {
|
||||
jm.appMetrics.IDPMetrics().CountGetUserByEmail()
|
||||
}
|
||||
|
||||
usersData := make([]*UserData, 0)
|
||||
var userList jumpCloudUserList
|
||||
if err = jm.helper.Unmarshal(body, &userList); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
usersData := make([]*UserData, 0, len(userList.Results))
|
||||
for _, user := range userList.Results {
|
||||
usersData = append(usersData, parseJumpCloudUser(user))
|
||||
}
|
||||
@@ -224,20 +287,11 @@ func (jm *JumpCloudManager) InviteUserByID(_ context.Context, _ string) error {
|
||||
}
|
||||
|
||||
// DeleteUser from jumpCloud directory
|
||||
func (jm *JumpCloudManager) DeleteUser(_ context.Context, userID string) error {
|
||||
authCtx := jm.authenticationContext()
|
||||
_, resp, err := jm.client.SystemusersApi.SystemusersDelete(authCtx, userID, contentType, accept, nil)
|
||||
func (jm *JumpCloudManager) DeleteUser(ctx context.Context, userID string) error {
|
||||
_, err := jm.doRequest(ctx, http.MethodDelete, "/systemusers/"+userID, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if jm.appMetrics != nil {
|
||||
jm.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return fmt.Errorf("unable to delete user, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
if jm.appMetrics != nil {
|
||||
jm.appMetrics.IDPMetrics().CountDeleteUser()
|
||||
@@ -247,11 +301,11 @@ func (jm *JumpCloudManager) DeleteUser(_ context.Context, userID string) error {
|
||||
}
|
||||
|
||||
// parseJumpCloudUser parse JumpCloud system user returned from API V1 to UserData.
|
||||
func parseJumpCloudUser(user v1.Systemuserreturn) *UserData {
|
||||
func parseJumpCloudUser(user jumpCloudUser) *UserData {
|
||||
names := []string{user.Firstname, user.Middlename, user.Lastname}
|
||||
return &UserData{
|
||||
Email: user.Email,
|
||||
Name: strings.Join(names, " "),
|
||||
ID: user.Id,
|
||||
ID: user.ID,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,15 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
@@ -44,3 +51,212 @@ func TestNewJumpCloudManager(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestJumpCloudGetUserDataByID(t *testing.T) {
|
||||
userResponse := jumpCloudUser{
|
||||
ID: "user123",
|
||||
Email: "test@example.com",
|
||||
Firstname: "John",
|
||||
Middlename: "",
|
||||
Lastname: "Doe",
|
||||
}
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "/systemusers/user123", r.URL.Path)
|
||||
assert.Equal(t, http.MethodGet, r.Method)
|
||||
assert.Equal(t, "test-api-key", r.Header.Get("x-api-key"))
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(userResponse)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
manager := newTestJumpCloudManager(t, server.URL)
|
||||
|
||||
userData, err := manager.GetUserDataByID(context.Background(), "user123", AppMetadata{WTAccountID: "acc1"})
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "user123", userData.ID)
|
||||
assert.Equal(t, "test@example.com", userData.Email)
|
||||
assert.Equal(t, "John Doe", userData.Name)
|
||||
assert.Equal(t, "acc1", userData.AppMetadata.WTAccountID)
|
||||
}
|
||||
|
||||
func TestJumpCloudGetAccount(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "/search/systemusers", r.URL.Path)
|
||||
assert.Equal(t, http.MethodPost, r.Method)
|
||||
|
||||
var reqBody map[string]any
|
||||
assert.NoError(t, json.NewDecoder(r.Body).Decode(&reqBody))
|
||||
assert.Contains(t, reqBody, "limit")
|
||||
assert.Contains(t, reqBody, "skip")
|
||||
|
||||
resp := jumpCloudUserList{
|
||||
Results: []jumpCloudUser{
|
||||
{ID: "u1", Email: "a@test.com", Firstname: "Alice", Lastname: "Smith"},
|
||||
{ID: "u2", Email: "b@test.com", Firstname: "Bob", Lastname: "Jones"},
|
||||
},
|
||||
TotalCount: 2,
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
manager := newTestJumpCloudManager(t, server.URL)
|
||||
|
||||
users, err := manager.GetAccount(context.Background(), "testAccount")
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, users, 2)
|
||||
assert.Equal(t, "testAccount", users[0].AppMetadata.WTAccountID)
|
||||
assert.Equal(t, "testAccount", users[1].AppMetadata.WTAccountID)
|
||||
}
|
||||
|
||||
func TestJumpCloudGetAllAccounts(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
resp := jumpCloudUserList{
|
||||
Results: []jumpCloudUser{
|
||||
{ID: "u1", Email: "a@test.com", Firstname: "Alice"},
|
||||
{ID: "u2", Email: "b@test.com", Firstname: "Bob"},
|
||||
},
|
||||
TotalCount: 2,
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
manager := newTestJumpCloudManager(t, server.URL)
|
||||
|
||||
indexedUsers, err := manager.GetAllAccounts(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, indexedUsers[UnsetAccountID], 2)
|
||||
}
|
||||
|
||||
func TestJumpCloudGetAllAccountsPagination(t *testing.T) {
|
||||
totalUsers := 250
|
||||
allUsers := make([]jumpCloudUser, totalUsers)
|
||||
for i := range allUsers {
|
||||
allUsers[i] = jumpCloudUser{
|
||||
ID: fmt.Sprintf("u%d", i),
|
||||
Email: fmt.Sprintf("user%d@test.com", i),
|
||||
Firstname: fmt.Sprintf("User%d", i),
|
||||
}
|
||||
}
|
||||
|
||||
requestCount := 0
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var reqBody map[string]int
|
||||
assert.NoError(t, json.NewDecoder(r.Body).Decode(&reqBody))
|
||||
|
||||
limit := reqBody["limit"]
|
||||
skip := reqBody["skip"]
|
||||
requestCount++
|
||||
|
||||
end := skip + limit
|
||||
if end > totalUsers {
|
||||
end = totalUsers
|
||||
}
|
||||
|
||||
resp := jumpCloudUserList{
|
||||
Results: allUsers[skip:end],
|
||||
TotalCount: totalUsers,
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
manager := newTestJumpCloudManager(t, server.URL)
|
||||
|
||||
indexedUsers, err := manager.GetAllAccounts(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, indexedUsers[UnsetAccountID], totalUsers)
|
||||
assert.Equal(t, 3, requestCount, "should require 3 pages for 250 users at page size 100")
|
||||
}
|
||||
|
||||
func TestJumpCloudGetUserByEmail(t *testing.T) {
|
||||
searchResponse := jumpCloudUserList{
|
||||
Results: []jumpCloudUser{
|
||||
{ID: "u1", Email: "alice@test.com", Firstname: "Alice", Lastname: "Smith"},
|
||||
},
|
||||
TotalCount: 1,
|
||||
}
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "/search/systemusers", r.URL.Path)
|
||||
assert.Equal(t, http.MethodPost, r.Method)
|
||||
|
||||
body, err := io.ReadAll(r.Body)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(body), "alice@test.com")
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(searchResponse)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
manager := newTestJumpCloudManager(t, server.URL)
|
||||
|
||||
users, err := manager.GetUserByEmail(context.Background(), "alice@test.com")
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, users, 1)
|
||||
assert.Equal(t, "alice@test.com", users[0].Email)
|
||||
}
|
||||
|
||||
func TestJumpCloudDeleteUser(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "/systemusers/user123", r.URL.Path)
|
||||
assert.Equal(t, http.MethodDelete, r.Method)
|
||||
assert.Equal(t, "test-api-key", r.Header.Get("x-api-key"))
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{"_id": "user123"})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
manager := newTestJumpCloudManager(t, server.URL)
|
||||
|
||||
err := manager.DeleteUser(context.Background(), "user123")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestJumpCloudAPIError(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
manager := newTestJumpCloudManager(t, server.URL)
|
||||
|
||||
_, err := manager.GetUserDataByID(context.Background(), "user123", AppMetadata{})
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "401")
|
||||
}
|
||||
|
||||
func TestParseJumpCloudUser(t *testing.T) {
|
||||
user := jumpCloudUser{
|
||||
ID: "abc123",
|
||||
Email: "test@example.com",
|
||||
Firstname: "John",
|
||||
Middlename: "M",
|
||||
Lastname: "Doe",
|
||||
}
|
||||
|
||||
userData := parseJumpCloudUser(user)
|
||||
assert.Equal(t, "abc123", userData.ID)
|
||||
assert.Equal(t, "test@example.com", userData.Email)
|
||||
assert.Equal(t, "John M Doe", userData.Name)
|
||||
}
|
||||
|
||||
func newTestJumpCloudManager(t *testing.T, apiBase string) *JumpCloudManager {
|
||||
t.Helper()
|
||||
return &JumpCloudManager{
|
||||
apiBase: apiBase,
|
||||
apiToken: "test-api-key",
|
||||
httpClient: http.DefaultClient,
|
||||
helper: JsonParser{},
|
||||
appMetrics: nil,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -249,7 +249,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
|
||||
if err != nil {
|
||||
newLabel = ""
|
||||
} else {
|
||||
_, err := transaction.GetPeerIdByLabel(ctx, store.LockingStrengthNone, accountID, update.Name)
|
||||
_, err := transaction.GetPeerIdByLabel(ctx, store.LockingStrengthNone, accountID, newLabel)
|
||||
if err == nil {
|
||||
newLabel = ""
|
||||
}
|
||||
|
||||
@@ -37,6 +37,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/job"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
@@ -2738,3 +2739,70 @@ func TestProcessPeerAddAuth(t *testing.T) {
|
||||
assert.Empty(t, config.GroupsToAdd)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUpdatePeer_DnsLabelCollisionWithFQDN(t *testing.T) {
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
// Add first peer with hostname that produces DNS label "netbird1"
|
||||
key1, err := wgtypes.GenerateKey()
|
||||
require.NoError(t, err)
|
||||
peer1, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
Key: key1.PublicKey().String(),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "netbird1.netbird.cloud"},
|
||||
}, false)
|
||||
require.NoError(t, err, "unable to add first peer")
|
||||
assert.Equal(t, "netbird1", peer1.DNSLabel)
|
||||
|
||||
// Add second peer with a different hostname
|
||||
key2, err := wgtypes.GenerateKey()
|
||||
require.NoError(t, err)
|
||||
peer2, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
Key: key2.PublicKey().String(),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "ip-10-29-5-130"},
|
||||
}, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
update := peer2.Copy()
|
||||
update.Name = "netbird1.demo.netbird.cloud"
|
||||
updated, err := manager.UpdatePeer(context.Background(), accountID, userID, update)
|
||||
require.NoError(t, err, "renaming peer should not fail with duplicate DNS label error")
|
||||
assert.Equal(t, "netbird1.demo.netbird.cloud", updated.Name)
|
||||
assert.NotEqual(t, "netbird1", updated.DNSLabel, "DNS label should not collide with existing peer")
|
||||
assert.Contains(t, updated.DNSLabel, "netbird1-", "DNS label should be IP-based fallback")
|
||||
}
|
||||
|
||||
func TestUpdatePeer_DnsLabelUniqueName(t *testing.T) {
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
key1, err := wgtypes.GenerateKey()
|
||||
require.NoError(t, err)
|
||||
peer1, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
Key: key1.PublicKey().String(),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "web-server"},
|
||||
}, false)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "web-server", peer1.DNSLabel)
|
||||
|
||||
// Add second peer and rename it to a unique FQDN whose first label doesn't collide
|
||||
key2, err := wgtypes.GenerateKey()
|
||||
require.NoError(t, err)
|
||||
peer2, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
Key: key2.PublicKey().String(),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "old-name"},
|
||||
}, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
update := peer2.Copy()
|
||||
update.Name = "api-server.example.com"
|
||||
updated, err := manager.UpdatePeer(context.Background(), accountID, userID, update)
|
||||
require.NoError(t, err, "renaming to unique FQDN should succeed")
|
||||
assert.Equal(t, "api-server", updated.DNSLabel, "DNS label should be first label of FQDN")
|
||||
}
|
||||
|
||||
@@ -932,3 +932,71 @@ func TestProtect_HeaderAuth_SubsequentRequestUsesSessionCookie(t *testing.T) {
|
||||
assert.Equal(t, "header-user", capturedData2.GetUserID())
|
||||
assert.Equal(t, "header", capturedData2.GetAuthMethod())
|
||||
}
|
||||
|
||||
// TestProtect_HeaderAuth_MultipleValuesSameHeader verifies that the proxy
|
||||
// correctly handles multiple valid credentials for the same header name.
|
||||
// In production, the mgmt gRPC authenticateHeader iterates all configured
|
||||
// header auths and accepts if any hash matches (OR semantics). The proxy
|
||||
// creates one Header scheme per entry, but a single gRPC call checks all.
|
||||
func TestProtect_HeaderAuth_MultipleValuesSameHeader(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
// Mock simulates mgmt behavior: accepts either token-a or token-b.
|
||||
accepted := map[string]bool{"Bearer token-a": true, "Bearer token-b": true}
|
||||
mock := &mockAuthenticator{fn: func(_ context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
|
||||
ha := req.GetHeaderAuth()
|
||||
if ha != nil && accepted[ha.GetHeaderValue()] {
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "header-user", "example.com", auth.MethodHeader, time.Hour)
|
||||
require.NoError(t, err)
|
||||
return &proto.AuthenticateResponse{Success: true, SessionToken: token}, nil
|
||||
}
|
||||
return &proto.AuthenticateResponse{Success: false}, nil
|
||||
}}
|
||||
|
||||
// Single Header scheme (as if one entry existed), but the mock checks both values.
|
||||
hdr := NewHeader(mock, "svc1", "acc1", "Authorization")
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
|
||||
|
||||
var backendCalled bool
|
||||
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
backendCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
t.Run("first value accepted", func(t *testing.T) {
|
||||
backendCalled = false
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
req.Header.Set("Authorization", "Bearer token-a")
|
||||
req = req.WithContext(proxy.WithCapturedData(req.Context(), proxy.NewCapturedData("")))
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.True(t, backendCalled, "first token should be accepted")
|
||||
})
|
||||
|
||||
t.Run("second value accepted", func(t *testing.T) {
|
||||
backendCalled = false
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
req.Header.Set("Authorization", "Bearer token-b")
|
||||
req = req.WithContext(proxy.WithCapturedData(req.Context(), proxy.NewCapturedData("")))
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.True(t, backendCalled, "second token should be accepted")
|
||||
})
|
||||
|
||||
t.Run("unknown value rejected", func(t *testing.T) {
|
||||
backendCalled = false
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
req.Header.Set("Authorization", "Bearer token-c")
|
||||
req = req.WithContext(proxy.WithCapturedData(req.Context(), proxy.NewCapturedData("")))
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusUnauthorized, rec.Code)
|
||||
assert.False(t, backendCalled, "unknown token should be rejected")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -5,13 +5,12 @@ import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/config"
|
||||
"github.com/aws/aws-sdk-go-v2/credentials"
|
||||
"github.com/aws/aws-sdk-go-v2/service/s3"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/testcontainers/testcontainers-go"
|
||||
"github.com/testcontainers/testcontainers-go/wait"
|
||||
@@ -20,45 +19,55 @@ import (
|
||||
)
|
||||
|
||||
func Test_S3HandlerGetUploadURL(t *testing.T) {
|
||||
if runtime.GOOS != "linux" && os.Getenv("CI") == "true" {
|
||||
t.Skip("Skipping test on non-Linux and CI environment due to docker dependency")
|
||||
}
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("Skipping test on Windows due to potential docker dependency")
|
||||
if runtime.GOOS != "linux" {
|
||||
t.Skip("Skipping test on non-Linux due to docker dependency")
|
||||
}
|
||||
|
||||
awsEndpoint := "http://127.0.0.1:4566"
|
||||
awsRegion := "us-east-1"
|
||||
|
||||
ctx := context.Background()
|
||||
containerRequest := testcontainers.ContainerRequest{
|
||||
Image: "localstack/localstack:s3-latest",
|
||||
ExposedPorts: []string{"4566:4566/tcp"},
|
||||
WaitingFor: wait.ForLog("Ready"),
|
||||
}
|
||||
|
||||
c, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{
|
||||
ContainerRequest: containerRequest,
|
||||
Started: true,
|
||||
ContainerRequest: testcontainers.ContainerRequest{
|
||||
Image: "minio/minio:RELEASE.2025-04-22T22-12-26Z",
|
||||
ExposedPorts: []string{"9000/tcp"},
|
||||
Env: map[string]string{
|
||||
"MINIO_ROOT_USER": "minioadmin",
|
||||
"MINIO_ROOT_PASSWORD": "minioadmin",
|
||||
},
|
||||
Cmd: []string{"server", "/data"},
|
||||
WaitingFor: wait.ForHTTP("/minio/health/ready").WithPort("9000"),
|
||||
},
|
||||
Started: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
defer func(c testcontainers.Container, ctx context.Context) {
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
if err := c.Terminate(ctx); err != nil {
|
||||
t.Log(err)
|
||||
}
|
||||
}(c, ctx)
|
||||
})
|
||||
|
||||
mappedPort, err := c.MappedPort(ctx, "9000")
|
||||
require.NoError(t, err)
|
||||
|
||||
hostIP, err := c.Host(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
awsEndpoint := "http://" + hostIP + ":" + mappedPort.Port()
|
||||
|
||||
t.Setenv("AWS_REGION", awsRegion)
|
||||
t.Setenv("AWS_ENDPOINT_URL", awsEndpoint)
|
||||
t.Setenv("AWS_ACCESS_KEY_ID", "test")
|
||||
t.Setenv("AWS_SECRET_ACCESS_KEY", "test")
|
||||
t.Setenv("AWS_ACCESS_KEY_ID", "minioadmin")
|
||||
t.Setenv("AWS_SECRET_ACCESS_KEY", "minioadmin")
|
||||
t.Setenv("AWS_CONFIG_FILE", "")
|
||||
t.Setenv("AWS_SHARED_CREDENTIALS_FILE", "")
|
||||
t.Setenv("AWS_PROFILE", "")
|
||||
|
||||
cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(awsRegion), config.WithBaseEndpoint(awsEndpoint))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
cfg, err := config.LoadDefaultConfig(ctx,
|
||||
config.WithRegion(awsRegion),
|
||||
config.WithBaseEndpoint(awsEndpoint),
|
||||
config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider("minioadmin", "minioadmin", "")),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
client := s3.NewFromConfig(cfg, func(o *s3.Options) {
|
||||
o.UsePathStyle = true
|
||||
@@ -66,19 +75,16 @@ func Test_S3HandlerGetUploadURL(t *testing.T) {
|
||||
})
|
||||
|
||||
bucketName := "test"
|
||||
if _, err := client.CreateBucket(ctx, &s3.CreateBucketInput{
|
||||
_, err = client.CreateBucket(ctx, &s3.CreateBucketInput{
|
||||
Bucket: &bucketName,
|
||||
}); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
list, err := client.ListBuckets(ctx, &s3.ListBucketsInput{})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, len(list.Buckets), 1)
|
||||
assert.Equal(t, *list.Buckets[0].Name, bucketName)
|
||||
require.Len(t, list.Buckets, 1)
|
||||
require.Equal(t, bucketName, *list.Buckets[0].Name)
|
||||
|
||||
t.Setenv(bucketVar, bucketName)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user