Compare commits

...

1 Commits

Author SHA1 Message Date
Viktor Liu
e7eb0a451b Add local tcp listener 2026-02-06 19:05:58 +08:00
18 changed files with 1009 additions and 45 deletions

View File

@@ -663,6 +663,9 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool {
return true
}
case layers.LayerTypeTCP:
if m.tcpHooksDrop(uint16(d.tcp.DstPort), dstIP, packetData) {
return true
}
// Clamp MSS on all TCP SYN packets, including those from local IPs.
// SNATed routed traffic may appear as local IP but still requires clamping.
if m.mssClampEnabled {
@@ -880,6 +883,38 @@ func (m *Manager) udpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte
return false
}
// tcpHooksDrop checks if any TCP hooks should drop the packet
func (m *Manager) tcpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool {
m.mutex.RLock()
defer m.mutex.RUnlock()
if rules, exists := m.outgoingRules[dstIP]; exists {
for _, rule := range rules {
if rule.tcpHook != nil && portsMatch(rule.dPort, dport) {
return rule.tcpHook(packetData)
}
}
}
if rules, exists := m.outgoingRules[netip.IPv4Unspecified()]; exists {
for _, rule := range rules {
if rule.tcpHook != nil && portsMatch(rule.dPort, dport) {
return rule.tcpHook(packetData)
}
}
}
if rules, exists := m.outgoingRules[netip.IPv6Unspecified()]; exists {
for _, rule := range rules {
if rule.tcpHook != nil && portsMatch(rule.dPort, dport) {
return rule.tcpHook(packetData)
}
}
}
return false
}
// filterInbound implements filtering logic for incoming packets.
// If it returns true, the packet should be dropped.
func (m *Manager) filterInbound(packetData []byte, size int) bool {
@@ -1224,12 +1259,14 @@ func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d
switch payloadLayer {
case layers.LayerTypeTCP:
if rule.tcpHook != nil {
return rule.mgmtId, rule.tcpHook(packetData), true
}
if portsMatch(rule.sPort, uint16(d.tcp.SrcPort)) && portsMatch(rule.dPort, uint16(d.tcp.DstPort)) {
return rule.mgmtId, rule.drop, true
}
case layers.LayerTypeUDP:
// if rule has UDP hook (and if we are here we match this rule)
// we ignore rule.drop and call this hook
if rule.udpHook != nil {
return rule.mgmtId, rule.udpHook(packetData), true
}
@@ -1327,6 +1364,40 @@ func (m *Manager) AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook fu
return r.id
}
// AddTCPPacketHook calls hook when TCP packet from given direction matched
//
// Hook function returns flag which indicates should be the matched package dropped or not
func (m *Manager) AddTCPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string {
r := PeerRule{
id: uuid.New().String(),
ip: ip,
protoLayer: layers.LayerTypeTCP,
dPort: &firewall.Port{Values: []uint16{dPort}},
ipLayer: layers.LayerTypeIPv6,
tcpHook: hook,
}
if ip.Is4() {
r.ipLayer = layers.LayerTypeIPv4
}
m.mutex.Lock()
if in {
if _, ok := m.incomingRules[r.ip]; !ok {
m.incomingRules[r.ip] = make(map[string]PeerRule)
}
m.incomingRules[r.ip][r.id] = r
} else {
if _, ok := m.outgoingRules[r.ip]; !ok {
m.outgoingRules[r.ip] = make(map[string]PeerRule)
}
m.outgoingRules[r.ip][r.id] = r
}
m.mutex.Unlock()
return r.id
}
// RemovePacketHook removes packet hook by given ID
func (m *Manager) RemovePacketHook(hookID string) error {
m.mutex.Lock()

View File

@@ -21,6 +21,7 @@ type PeerRule struct {
drop bool
udpHook func([]byte) bool
tcpHook func([]byte) bool
}
// ID returns the rule id

View File

@@ -21,6 +21,12 @@ type PacketFilter interface {
// Hook function receives raw network packet data as argument.
AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string
// AddTCPPacketHook calls hook when TCP packet from given direction matched
//
// Hook function returns flag which indicates should be the matched package dropped or not.
// Hook function receives raw network packet data as argument.
AddTCPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string
// RemovePacketHook removes hook by ID
RemovePacketHook(hookID string) error
}

View File

@@ -48,6 +48,20 @@ func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3)
}
// AddTCPPacketHook mocks base method.
func (m *MockPacketFilter) AddTCPPacketHook(arg0 bool, arg1 netip.Addr, arg2 uint16, arg3 func([]byte) bool) string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddTCPPacketHook", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(string)
return ret0
}
// AddTCPPacketHook indicates an expected call of AddTCPPacketHook.
func (mr *MockPacketFilterMockRecorder) AddTCPPacketHook(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddTCPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddTCPPacketHook), arg0, arg1, arg2, arg3)
}
// FilterInbound mocks base method.
func (m *MockPacketFilter) FilterInbound(arg0 []byte, arg1 int) bool {
m.ctrl.T.Helper()

View File

@@ -104,3 +104,19 @@ func (r *responseWriter) TsigTimersOnly(bool) {
// After a call to Hijack(), the DNS package will not do anything with the connection.
func (r *responseWriter) Hijack() {
}
// truncationAwareWriter wraps a UDP responseWriter and starts the TCP DNS
// stack when a truncated response is about to be sent. This ensures the
// TCP stack is ready when the client retries over TCP.
type truncationAwareWriter struct {
responseWriter
tcpDNS *tcpDNSServer
}
// WriteMsg checks if the response is truncated and starts the TCP stack if needed.
func (w *truncationAwareWriter) WriteMsg(msg *dns.Msg) error {
if msg.MsgHdr.Truncated && w.tcpDNS != nil {
w.tcpDNS.EnsureRunning()
}
return w.responseWriter.WriteMsg(msg)
}

View File

@@ -116,6 +116,7 @@ type registeredHandlerMap map[types.HandlerID]handlerWrapper
// DefaultServerConfig holds configuration parameters for NewDefaultServer
type DefaultServerConfig struct {
WgInterface WGIface
Firewall DNSFirewall
CustomAddress string
StatusRecorder *peer.Status
StateManager *statemanager.Manager
@@ -137,7 +138,7 @@ func NewDefaultServer(ctx context.Context, config DefaultServerConfig) (*Default
if config.WgInterface.IsUserspaceBind() {
dnsService = NewServiceViaMemory(config.WgInterface)
} else {
dnsService = newServiceViaListener(config.WgInterface, addrPort)
dnsService = newServiceViaListener(config.WgInterface, addrPort, config.Firewall)
}
server := newDefaultServer(ctx, config.WgInterface, dnsService, config.StatusRecorder, config.StateManager, config.DisableSys)

View File

@@ -15,6 +15,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
wgdevice "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
@@ -65,6 +66,10 @@ func (w *mocWGIface) GetDevice() *device.FilteredDevice {
panic("implement me")
}
func (w *mocWGIface) GetWGDevice() *wgdevice.Device {
return nil
}
func (w *mocWGIface) GetInterfaceGUIDString() (string, error) {
panic("implement me")
}

View File

@@ -4,12 +4,22 @@ import (
"net/netip"
"github.com/miekg/dns"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
const (
DefaultPort = 53
)
// DNSFirewall provides DNAT capabilities for DNS port redirection.
// This is used when the DNS server cannot bind port 53 directly
// and needs firewall rules to redirect traffic.
type DNSFirewall interface {
AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error
RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error
}
type service interface {
Listen() error
Stop()

View File

@@ -12,6 +12,7 @@ import (
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/ebpf"
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
)
@@ -30,25 +31,33 @@ type serviceViaListener struct {
dnsMux *dns.ServeMux
customAddr *netip.AddrPort
server *dns.Server
tcpServer *dns.Server
listenIP netip.Addr
listenPort uint16
listenerIsRunning bool
listenerFlagLock sync.Mutex
ebpfService ebpfMgr.Manager
firewall DNSFirewall
tcpDNATConfigured bool
}
func newServiceViaListener(wgIface WGIface, customAddr *netip.AddrPort) *serviceViaListener {
func newServiceViaListener(wgIface WGIface, customAddr *netip.AddrPort, firewall DNSFirewall) *serviceViaListener {
mux := dns.NewServeMux()
s := &serviceViaListener{
wgInterface: wgIface,
dnsMux: mux,
customAddr: customAddr,
firewall: firewall,
server: &dns.Server{
Net: "udp",
Handler: mux,
UDPSize: 65535,
},
tcpServer: &dns.Server{
Net: "tcp",
Handler: mux,
},
}
return s
@@ -69,18 +78,39 @@ func (s *serviceViaListener) Listen() error {
return fmt.Errorf("eval listen address: %w", err)
}
s.listenIP = s.listenIP.Unmap()
s.server.Addr = fmt.Sprintf("%s:%d", s.listenIP, s.listenPort)
log.Debugf("starting dns on %s", s.server.Addr)
addr := fmt.Sprintf("%s:%d", s.listenIP, s.listenPort)
s.server.Addr = addr
s.tcpServer.Addr = addr
log.Debugf("starting dns on %s (UDP + TCP)", addr)
go func() {
s.setListenerStatus(true)
defer s.setListenerStatus(false)
err := s.server.ListenAndServe()
if err != nil {
log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.listenPort, err)
if err := s.server.ListenAndServe(); err != nil {
log.Errorf("dns UDP server on port %d returned an error: %v", s.listenPort, err)
}
}()
go func() {
if err := s.tcpServer.ListenAndServe(); err != nil {
log.Errorf("dns TCP server on port %d returned an error: %v", s.listenPort, err)
}
}()
// When eBPF redirects UDP port 53 to our listen port, TCP still needs
// a DNAT rule because eBPF only handles UDP.
if s.ebpfService != nil && s.firewall != nil && s.listenPort != DefaultPort {
if err := s.firewall.AddInboundDNAT(s.listenIP, firewall.ProtocolTCP, DefaultPort, s.listenPort); err != nil {
log.Warnf("failed to add DNS TCP DNAT rule: %v", err)
} else {
s.tcpDNATConfigured = true
log.Infof("added DNS TCP DNAT rule: %s:%d -> %s:%d", s.listenIP, DefaultPort, s.listenIP, s.listenPort)
}
}
return nil
}
@@ -95,15 +125,24 @@ func (s *serviceViaListener) Stop() {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err := s.server.ShutdownContext(ctx)
if err != nil {
log.Errorf("stopping dns server listener returned an error: %v", err)
if err := s.server.ShutdownContext(ctx); err != nil {
log.Errorf("stopping dns UDP server: %v", err)
}
if err := s.tcpServer.ShutdownContext(ctx); err != nil {
log.Errorf("stopping dns TCP server: %v", err)
}
if s.tcpDNATConfigured && s.firewall != nil {
if err := s.firewall.RemoveInboundDNAT(s.listenIP, firewall.ProtocolTCP, DefaultPort, s.listenPort); err != nil {
log.Errorf("remove DNS TCP DNAT rule: %v", err)
}
s.tcpDNATConfigured = false
}
if s.ebpfService != nil {
err = s.ebpfService.FreeDNSFwd()
if err != nil {
log.Errorf("stopping traffic forwarder returned an error: %v", err)
if err := s.ebpfService.FreeDNSFwd(); err != nil {
log.Errorf("stopping traffic forwarder: %v", err)
}
}
}
@@ -186,18 +225,28 @@ func (s *serviceViaListener) testFreePort(port int) (netip.Addr, bool) {
}
func (s *serviceViaListener) tryToBind(ip netip.Addr, port int) bool {
addrString := fmt.Sprintf("%s:%d", ip, port)
udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString))
probeListener, err := net.ListenUDP("udp", udpAddr)
addrPort := netip.AddrPortFrom(ip, uint16(port))
udpAddr := net.UDPAddrFromAddrPort(addrPort)
udpLn, err := net.ListenUDP("udp", udpAddr)
if err != nil {
log.Warnf("binding dns on %s is not available, error: %s", addrString, err)
log.Warnf("binding dns UDP on %s is not available: %s", addrPort, err)
return false
}
err = probeListener.Close()
if err != nil {
log.Errorf("got an error closing the probe listener, error: %s", err)
if err := udpLn.Close(); err != nil {
log.Debugf("close UDP probe listener: %s", err)
}
tcpAddr := net.TCPAddrFromAddrPort(addrPort)
tcpLn, err := net.ListenTCP("tcp", tcpAddr)
if err != nil {
log.Warnf("binding dns TCP on %s is not available: %s", addrPort, err)
return false
}
if err := tcpLn.Close(); err != nil {
log.Debugf("close TCP probe listener: %s", err)
}
return true
}

View File

@@ -0,0 +1,89 @@
package dns
import (
"fmt"
"net"
"net/netip"
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestServiceViaListener_TCPAndUDP(t *testing.T) {
handler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg)
m.SetReply(r)
m.Answer = append(m.Answer, &dns.A{
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("192.0.2.1"),
})
if err := w.WriteMsg(m); err != nil {
t.Logf("write msg: %v", err)
}
})
// Create a service using a custom address to avoid needing root
svc := newServiceViaListener(nil, nil, nil)
svc.dnsMux.Handle(".", handler)
// Find a free port by binding and releasing
udpAddr := net.UDPAddrFromAddrPort(netip.AddrPortFrom(customIP, 0))
udpLn, err := net.ListenUDP("udp", udpAddr)
if err != nil {
t.Skip("cannot bind to 127.0.0.153, skipping")
}
port := uint16(udpLn.LocalAddr().(*net.UDPAddr).Port)
require.NoError(t, udpLn.Close())
// Check TCP is also available on this port
tcpAddr := net.TCPAddrFromAddrPort(netip.AddrPortFrom(customIP, port))
tcpLn, err := net.ListenTCP("tcp", tcpAddr)
if err != nil {
t.Skip("cannot bind TCP on same port, skipping")
}
require.NoError(t, tcpLn.Close())
addr := fmt.Sprintf("%s:%d", customIP, port)
svc.server.Addr = addr
svc.tcpServer.Addr = addr
svc.listenIP = customIP
svc.listenPort = port
go func() {
if err := svc.server.ListenAndServe(); err != nil {
t.Logf("udp server: %v", err)
}
}()
go func() {
if err := svc.tcpServer.ListenAndServe(); err != nil {
t.Logf("tcp server: %v", err)
}
}()
svc.listenerIsRunning = true
defer svc.Stop()
// Wait for servers to start
time.Sleep(100 * time.Millisecond)
q := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
// Test UDP query
udpClient := &dns.Client{Net: "udp", Timeout: 2 * time.Second}
udpResp, _, err := udpClient.Exchange(q, addr)
require.NoError(t, err, "UDP query should succeed")
require.NotNil(t, udpResp)
require.NotEmpty(t, udpResp.Answer)
assert.Contains(t, udpResp.Answer[0].String(), "192.0.2.1", "UDP response should contain expected IP")
// Test TCP query
tcpClient := &dns.Client{Net: "tcp", Timeout: 2 * time.Second}
tcpResp, _, err := tcpClient.Exchange(q, addr)
require.NoError(t, err, "TCP query should succeed")
require.NotNil(t, tcpResp)
require.NotEmpty(t, tcpResp.Answer)
assert.Contains(t, tcpResp.Answer[0].String(), "192.0.2.1", "TCP response should contain expected IP")
}

View File

@@ -10,6 +10,7 @@ import (
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface"
nbnet "github.com/netbirdio/netbird/client/net"
)
@@ -19,6 +20,8 @@ type ServiceViaMemory struct {
runtimeIP netip.Addr
runtimePort int
udpFilterHookID string
tcpFilterHookID string
tcpDNS *tcpDNSServer
listenerIsRunning bool
listenerFlagLock sync.Mutex
}
@@ -28,14 +31,15 @@ func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory {
if err != nil {
log.Errorf("get last ip from network: %v", err)
}
s := &ServiceViaMemory{
wgInterface: wgIface,
dnsMux: dns.NewServeMux(),
mux := dns.NewServeMux()
return &ServiceViaMemory{
wgInterface: wgIface,
dnsMux: mux,
runtimeIP: lastIP,
runtimePort: DefaultPort,
}
return s
}
func (s *ServiceViaMemory) Listen() error {
@@ -65,8 +69,19 @@ func (s *ServiceViaMemory) Stop() {
return
}
if err := s.wgInterface.GetFilter().RemovePacketHook(s.udpFilterHookID); err != nil {
log.Errorf("unable to remove DNS packet hook: %s", err)
filter := s.wgInterface.GetFilter()
if err := filter.RemovePacketHook(s.udpFilterHookID); err != nil {
log.Errorf("remove DNS UDP packet hook: %s", err)
}
if s.tcpFilterHookID != "" {
if err := filter.RemovePacketHook(s.tcpFilterHookID); err != nil {
log.Errorf("remove DNS TCP packet hook: %s", err)
}
}
if s.tcpDNS != nil {
s.tcpDNS.Stop()
}
s.listenerIsRunning = false
@@ -94,16 +109,22 @@ func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
return "", fmt.Errorf("can't set DNS filter, filter not initialized")
}
// Create TCP DNS server lazily here since the device may not exist at construction time.
if s.tcpDNS == nil {
if dev := s.wgInterface.GetDevice(); dev != nil {
// MTU only affects TCP segment sizing; DNS messages are small so this has no practical impact.
s.tcpDNS = newTCPDNSServer(s.dnsMux, dev.Device, s.runtimeIP, uint16(s.runtimePort), iface.DefaultMTU)
}
}
firstLayerDecoder := layers.LayerTypeIPv4
if s.wgInterface.Address().IP.Is6() {
firstLayerDecoder = layers.LayerTypeIPv6
}
hook := func(packetData []byte) bool {
// Decode the packet
packet := gopacket.NewPacket(packetData, firstLayerDecoder, gopacket.Default)
// Get the UDP layer
udpLayer := packet.Layer(layers.LayerTypeUDP)
udp := udpLayer.(*layers.UDP)
@@ -113,13 +134,27 @@ func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
return true
}
writer := responseWriter{
packet: packet,
device: s.wgInterface.GetDevice().Device,
writer := &truncationAwareWriter{
responseWriter: responseWriter{
packet: packet,
device: s.wgInterface.GetDevice().Device,
},
tcpDNS: s.tcpDNS,
}
go s.dnsMux.ServeDNS(&writer, msg)
go s.dnsMux.ServeDNS(writer, msg)
return true
}
return filter.AddUDPPacketHook(false, s.runtimeIP, uint16(s.runtimePort), hook), nil
udpHookID := filter.AddUDPPacketHook(false, s.runtimeIP, uint16(s.runtimePort), hook)
if s.tcpDNS != nil {
tcpHook := func(packetData []byte) bool {
s.tcpDNS.EnsureRunning()
s.tcpDNS.InjectPacket(packetData)
return true
}
s.tcpFilterHookID = filter.AddTCPPacketHook(false, s.runtimeIP, uint16(s.runtimePort), tcpHook)
}
return udpHookID, nil
}

View File

@@ -0,0 +1,371 @@
package dns
import (
"fmt"
"net"
"net/netip"
"sync"
"sync/atomic"
"time"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/tun"
"gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/waiter"
)
const (
dnsTCPReceiveWindow = 8192
dnsTCPMaxInFlight = 16
dnsTCPIdleTimeout = 30 * time.Second
)
// tcpDNSServer is an on-demand TCP DNS server backed by a minimal gvisor stack.
// It is started lazily when a truncated DNS response is detected and shuts down
// after a period of inactivity to conserve resources.
type tcpDNSServer struct {
mu sync.Mutex
s *stack.Stack
ep *dnsEndpoint
mux *dns.ServeMux
tunDev tun.Device
ip netip.Addr
port uint16
mtu uint16
running bool
timer *time.Timer
}
func newTCPDNSServer(mux *dns.ServeMux, tunDev tun.Device, ip netip.Addr, port uint16, mtu uint16) *tcpDNSServer {
return &tcpDNSServer{
mux: mux,
tunDev: tunDev,
ip: ip,
port: port,
mtu: mtu,
}
}
// EnsureRunning starts the TCP stack if not already running and resets the idle timer.
func (t *tcpDNSServer) EnsureRunning() {
t.mu.Lock()
defer t.mu.Unlock()
if t.running {
t.resetTimerLocked()
return
}
if err := t.startLocked(); err != nil {
log.Errorf("start TCP DNS stack: %v", err)
return
}
t.running = true
t.resetTimerLocked()
log.Debugf("TCP DNS stack started on %s:%d", t.ip, t.port)
}
// InjectPacket delivers a raw IP packet into the gvisor stack for TCP processing.
func (t *tcpDNSServer) InjectPacket(payload []byte) {
t.mu.Lock()
ep := t.ep
t.mu.Unlock()
if ep == nil {
return
}
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(payload),
})
defer pkt.DecRef()
if ep.dispatcher != nil {
ep.dispatcher.DeliverNetworkPacket(ipv4.ProtocolNumber, pkt)
}
}
// Stop tears down the gvisor stack and releases resources.
func (t *tcpDNSServer) Stop() {
t.mu.Lock()
defer t.mu.Unlock()
t.stopLocked()
}
func (t *tcpDNSServer) startLocked() error {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
HandleLocal: false,
})
nicID := tcpip.NICID(1)
ep := &dnsEndpoint{
tunDev: t.tunDev,
}
ep.mtu.Store(uint32(t.mtu))
if err := s.CreateNIC(nicID, ep); err != nil {
return fmt.Errorf("create NIC: %v", err)
}
protoAddr := tcpip.ProtocolAddress{
Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: tcpip.AddrFromSlice(t.ip.AsSlice()),
PrefixLen: 32,
},
}
if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
return fmt.Errorf("add protocol address: %s", err)
}
if err := s.SetPromiscuousMode(nicID, true); err != nil {
return fmt.Errorf("set promiscuous mode: %s", err)
}
if err := s.SetSpoofing(nicID, true); err != nil {
return fmt.Errorf("set spoofing: %s", err)
}
defaultSubnet, err := tcpip.NewSubnet(
tcpip.AddrFrom4([4]byte{0, 0, 0, 0}),
tcpip.MaskFromBytes([]byte{0, 0, 0, 0}),
)
if err != nil {
return fmt.Errorf("create default subnet: %w", err)
}
s.SetRouteTable([]tcpip.Route{
{Destination: defaultSubnet, NIC: nicID},
})
tcpFwd := tcp.NewForwarder(s, dnsTCPReceiveWindow, dnsTCPMaxInFlight, func(r *tcp.ForwarderRequest) {
t.handleTCPDNS(r)
})
s.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpFwd.HandlePacket)
t.s = s
t.ep = ep
return nil
}
func (t *tcpDNSServer) stopLocked() {
if !t.running {
return
}
if t.timer != nil {
t.timer.Stop()
t.timer = nil
}
if t.s != nil {
t.s.Close()
t.s.Wait()
t.s = nil
}
t.ep = nil
t.running = false
log.Debugf("TCP DNS stack stopped")
}
func (t *tcpDNSServer) resetTimerLocked() {
if t.timer != nil {
t.timer.Stop()
}
t.timer = time.AfterFunc(dnsTCPIdleTimeout, func() {
t.mu.Lock()
defer t.mu.Unlock()
t.stopLocked()
})
}
func (t *tcpDNSServer) handleTCPDNS(r *tcp.ForwarderRequest) {
id := r.ID()
wq := waiter.Queue{}
ep, epErr := r.CreateEndpoint(&wq)
if epErr != nil {
log.Debugf("TCP DNS: create endpoint: %v", epErr)
r.Complete(true)
return
}
r.Complete(false)
conn := gonet.NewTCPConn(&wq, ep)
defer func() {
if err := conn.Close(); err != nil {
log.Tracef("TCP DNS: close conn: %v", err)
}
}()
// Reset idle timer on activity
t.mu.Lock()
t.resetTimerLocked()
t.mu.Unlock()
localAddr := &net.TCPAddr{
IP: id.LocalAddress.AsSlice(),
Port: int(id.LocalPort),
}
remoteAddr := &net.TCPAddr{
IP: id.RemoteAddress.AsSlice(),
Port: int(id.RemotePort),
}
for {
if err := conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil {
break
}
msg, err := readTCPDNSMessage(conn)
if err != nil {
break
}
writer := &tcpResponseWriter{
conn: conn,
localAddr: localAddr,
remoteAddr: remoteAddr,
}
t.mux.ServeDNS(writer, msg)
}
}
// dnsEndpoint implements stack.LinkEndpoint for writing packets back via the tun device.
type dnsEndpoint struct {
dispatcher stack.NetworkDispatcher
tunDev tun.Device
mtu atomic.Uint32
}
func (e *dnsEndpoint) Attach(dispatcher stack.NetworkDispatcher) { e.dispatcher = dispatcher }
func (e *dnsEndpoint) IsAttached() bool { return e.dispatcher != nil }
func (e *dnsEndpoint) MTU() uint32 { return e.mtu.Load() }
func (e *dnsEndpoint) Capabilities() stack.LinkEndpointCapabilities { return stack.CapabilityNone }
func (e *dnsEndpoint) MaxHeaderLength() uint16 { return 0 }
func (e *dnsEndpoint) LinkAddress() tcpip.LinkAddress { return "" }
func (e *dnsEndpoint) Wait() {}
func (e *dnsEndpoint) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareNone }
func (e *dnsEndpoint) AddHeader(*stack.PacketBuffer) {}
func (e *dnsEndpoint) ParseHeader(*stack.PacketBuffer) bool { return true }
func (e *dnsEndpoint) Close() {}
func (e *dnsEndpoint) SetLinkAddress(tcpip.LinkAddress) {}
func (e *dnsEndpoint) SetMTU(mtu uint32) { e.mtu.Store(mtu) }
func (e *dnsEndpoint) SetOnCloseAction(func()) {}
const tunPacketOffset = 40
func (e *dnsEndpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) {
var written int
for _, pkt := range pkts.AsSlice() {
data := stack.PayloadSince(pkt.NetworkHeader())
if data == nil {
continue
}
raw := data.AsSlice()
buf := make([]byte, tunPacketOffset, tunPacketOffset+len(raw))
buf = append(buf, raw...)
if _, err := e.tunDev.Write([][]byte{buf}, tunPacketOffset); err != nil {
log.Tracef("TCP DNS endpoint: write packet: %v", err)
continue
}
written++
}
return written, nil
}
// readTCPDNSMessage reads a single DNS message from a TCP connection (length-prefixed).
func readTCPDNSMessage(conn *gonet.TCPConn) (*dns.Msg, error) {
// DNS over TCP uses a 2-byte length prefix
lenBuf := make([]byte, 2)
if _, err := readFull(conn, lenBuf); err != nil {
return nil, fmt.Errorf("read length: %w", err)
}
msgLen := int(lenBuf[0])<<8 | int(lenBuf[1])
if msgLen == 0 || msgLen > 65535 {
return nil, fmt.Errorf("invalid message length: %d", msgLen)
}
msgBuf := make([]byte, msgLen)
if _, err := readFull(conn, msgBuf); err != nil {
return nil, fmt.Errorf("read message: %w", err)
}
msg := new(dns.Msg)
if err := msg.Unpack(msgBuf); err != nil {
return nil, fmt.Errorf("unpack: %w", err)
}
return msg, nil
}
func readFull(conn *gonet.TCPConn, buf []byte) (int, error) {
var total int
for total < len(buf) {
n, err := conn.Read(buf[total:])
total += n
if err != nil {
return total, err
}
}
return total, nil
}
// tcpResponseWriter implements dns.ResponseWriter for TCP DNS connections.
type tcpResponseWriter struct {
conn *gonet.TCPConn
localAddr net.Addr
remoteAddr net.Addr
}
func (w *tcpResponseWriter) LocalAddr() net.Addr {
return w.localAddr
}
func (w *tcpResponseWriter) RemoteAddr() net.Addr {
return w.remoteAddr
}
func (w *tcpResponseWriter) WriteMsg(msg *dns.Msg) error {
data, err := msg.Pack()
if err != nil {
return fmt.Errorf("pack: %w", err)
}
// DNS TCP: 2-byte length prefix + message
buf := make([]byte, 2+len(data))
buf[0] = byte(len(data) >> 8)
buf[1] = byte(len(data))
copy(buf[2:], data)
if _, err = w.conn.Write(buf); err != nil {
return err
}
return nil
}
func (w *tcpResponseWriter) Write(data []byte) (int, error) {
return w.conn.Write(data)
}
func (w *tcpResponseWriter) Close() error {
return w.conn.Close()
}
func (w *tcpResponseWriter) TsigStatus() error { return nil }
func (w *tcpResponseWriter) TsigTimersOnly(bool) {}
func (w *tcpResponseWriter) Hijack() {}

View File

@@ -45,6 +45,42 @@ const (
const testRecord = "com."
type dnsProtocolKey struct{}
// ContextWithDNSProtocol stores the inbound DNS protocol ("udp" or "tcp") in context.
func ContextWithDNSProtocol(ctx context.Context, network string) context.Context {
return context.WithValue(ctx, dnsProtocolKey{}, network)
}
// DNSProtocolFromContext retrieves the inbound DNS protocol from context.
func DNSProtocolFromContext(ctx context.Context) string {
if v, ok := ctx.Value(dnsProtocolKey{}).(string); ok {
return v
}
return ""
}
type upstreamProtocolKey struct{}
// UpstreamProtocolResult holds the protocol used for the upstream exchange.
// Stored as a pointer in context so the exchange function can set it.
type UpstreamProtocolResult struct {
Protocol string
}
// contextWithUpstreamProtocolResult stores a mutable result holder in the context.
func contextWithUpstreamProtocolResult(ctx context.Context) (context.Context, *UpstreamProtocolResult) {
r := &UpstreamProtocolResult{}
return context.WithValue(ctx, upstreamProtocolKey{}, r), r
}
// setUpstreamProtocol sets the upstream protocol on the result holder in context, if present.
func setUpstreamProtocol(ctx context.Context, protocol string) {
if r, ok := ctx.Value(upstreamProtocolKey{}).(*UpstreamProtocolResult); ok && r != nil {
r.Protocol = protocol
}
}
type upstreamClient interface {
exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
}
@@ -131,7 +167,16 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return
}
ok, failures := u.tryUpstreamServers(w, r, logger)
// Propagate inbound protocol so upstream exchange can use TCP directly
// when the request came in over TCP.
ctx := u.ctx
if addr := w.RemoteAddr(); addr != nil {
network := addr.Network()
ctx = ContextWithDNSProtocol(ctx, network)
resutil.SetMeta(w, "protocol", network)
}
ok, failures := u.tryUpstreamServers(ctx, w, r, logger)
if len(failures) > 0 {
u.logUpstreamFailures(r.Question[0].Name, failures, ok, logger)
}
@@ -146,7 +191,7 @@ func (u *upstreamResolverBase) prepareRequest(r *dns.Msg) {
}
}
func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) (bool, []upstreamFailure) {
func (u *upstreamResolverBase) tryUpstreamServers(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) (bool, []upstreamFailure) {
timeout := u.upstreamTimeout
if len(u.upstreamServers) > 1 {
maxTotal := 5 * time.Second
@@ -161,7 +206,7 @@ func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.M
var failures []upstreamFailure
for _, upstream := range u.upstreamServers {
if failure := u.queryUpstream(w, r, upstream, timeout, logger); failure != nil {
if failure := u.queryUpstream(ctx, w, r, upstream, timeout, logger); failure != nil {
failures = append(failures, *failure)
} else {
return true, failures
@@ -171,15 +216,17 @@ func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.M
}
// queryUpstream queries a single upstream server. Returns nil on success, or failure info to try next upstream.
func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) *upstreamFailure {
func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) *upstreamFailure {
var rm *dns.Msg
var t time.Duration
var err error
var startTime time.Time
var upstreamProto *UpstreamProtocolResult
func() {
ctx, cancel := context.WithTimeout(u.ctx, timeout)
ctx, cancel := context.WithTimeout(parentCtx, timeout)
defer cancel()
ctx, upstreamProto = contextWithUpstreamProtocolResult(ctx)
startTime = time.Now()
rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), r)
}()
@@ -196,7 +243,7 @@ func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, u
return &upstreamFailure{upstream: upstream, reason: dns.RcodeToString[rm.Rcode]}
}
u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, logger)
u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, upstreamProto, logger)
return nil
}
@@ -213,10 +260,13 @@ func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.Add
return &upstreamFailure{upstream: upstream, reason: reason}
}
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, logger *log.Entry) bool {
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, upstreamProto *UpstreamProtocolResult, logger *log.Entry) bool {
u.successCount.Add(1)
resutil.SetMeta(w, "upstream", upstream.String())
if upstreamProto != nil && upstreamProto.Protocol != "" {
resutil.SetMeta(w, "upstream_protocol", upstreamProto.Protocol)
}
// Clear Zero bit from external responses to prevent upstream servers from
// manipulating our internal fallthrough signaling mechanism
@@ -406,8 +456,27 @@ func (u *upstreamResolverBase) testNameserver(server netip.AddrPort, timeout tim
// ExchangeWithFallback exchanges a DNS message with the upstream server.
// It first tries to use UDP, and if it is truncated, it falls back to TCP.
// If the inbound request came over TCP (via context), it skips the UDP attempt.
// If the passed context is nil, this will use Exchange instead of ExchangeContext.
func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, upstream string) (*dns.Msg, time.Duration, error) {
// If the request came in over TCP, go straight to TCP upstream.
if DNSProtocolFromContext(ctx) == "tcp" {
client.Net = "tcp"
var rm *dns.Msg
var t time.Duration
var err error
if ctx == nil {
rm, t, err = client.Exchange(r, upstream)
} else {
rm, t, err = client.ExchangeContext(ctx, r, upstream)
}
if err != nil {
return nil, t, fmt.Errorf("with tcp: %w", err)
}
setUpstreamProtocol(ctx, "tcp")
return rm, t, nil
}
// MTU - ip + udp headers
// Note: this could be sent out on an interface that is not ours, but higher MTU settings could break truncation handling.
client.UDPSize = uint16(currentMTU - (60 + 8))
@@ -429,6 +498,7 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
}
if rm == nil || !rm.MsgHdr.Truncated {
setUpstreamProtocol(ctx, "udp")
return rm, t, nil
}
@@ -447,7 +517,17 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
return nil, t, fmt.Errorf("with tcp: %w", err)
}
// TODO: once TCP is implemented, rm.Truncate() if the request came in over UDP
setUpstreamProtocol(ctx, "tcp")
// Request came in over UDP but response was fetched via TCP.
// Truncate to fit the client's UDP buffer.
maxSize := dns.MinMsgSize
if opt := r.IsEdns0(); opt != nil {
maxSize = int(opt.UDPSize())
}
if rm.Len() > maxSize {
rm.Truncate(maxSize)
}
return rm, t, nil
}
@@ -455,6 +535,16 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
// ExchangeWithNetstack performs a DNS exchange using netstack for dialing.
// This is needed when netstack is enabled to reach peer IPs through the tunnel.
func ExchangeWithNetstack(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upstream string) (*dns.Msg, error) {
// If request came in over TCP, go straight to TCP upstream
if DNSProtocolFromContext(ctx) == "tcp" {
rm, err := netstackExchange(ctx, nsNet, r, upstream, "tcp")
if err != nil {
return nil, err
}
setUpstreamProtocol(ctx, "tcp")
return rm, nil
}
reply, err := netstackExchange(ctx, nsNet, r, upstream, "udp")
if err != nil {
return nil, err
@@ -464,9 +554,29 @@ func ExchangeWithNetstack(ctx context.Context, nsNet *netstack.Net, r *dns.Msg,
if reply != nil && reply.MsgHdr.Truncated {
log.Tracef("udp response for domain=%s type=%v class=%v is truncated, trying TCP",
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
return netstackExchange(ctx, nsNet, r, upstream, "tcp")
rm, err := netstackExchange(ctx, nsNet, r, upstream, "tcp")
if err != nil {
return nil, err
}
setUpstreamProtocol(ctx, "tcp")
// Request came in over UDP but response was fetched via TCP.
// Truncate to fit the client's UDP buffer.
maxSize := dns.MinMsgSize
if opt := r.IsEdns0(); opt != nil {
maxSize = int(opt.UDPSize())
}
if rm.Len() > maxSize {
rm.Truncate(maxSize)
}
return rm, nil
}
setUpstreamProtocol(ctx, "udp")
return reply, nil
}

View File

@@ -12,6 +12,7 @@ import (
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
wgdevice "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/device"
@@ -127,6 +128,7 @@ func (m *mockNetstackProvider) ToInterface() *net.Interface { return nil }
func (m *mockNetstackProvider) IsUserspaceBind() bool { return false }
func (m *mockNetstackProvider) GetFilter() device.PacketFilter { return nil }
func (m *mockNetstackProvider) GetDevice() *device.FilteredDevice { return nil }
func (m *mockNetstackProvider) GetWGDevice() *wgdevice.Device { return nil }
func (m *mockNetstackProvider) GetNet() *netstack.Net { return nil }
func (m *mockNetstackProvider) GetInterfaceGUIDString() (string, error) {
return "", nil
@@ -475,3 +477,180 @@ func TestFormatFailures(t *testing.T) {
})
}
}
func TestDNSProtocolContext(t *testing.T) {
t.Run("roundtrip udp", func(t *testing.T) {
ctx := ContextWithDNSProtocol(context.Background(), "udp")
assert.Equal(t, "udp", DNSProtocolFromContext(ctx))
})
t.Run("roundtrip tcp", func(t *testing.T) {
ctx := ContextWithDNSProtocol(context.Background(), "tcp")
assert.Equal(t, "tcp", DNSProtocolFromContext(ctx))
})
t.Run("missing returns empty", func(t *testing.T) {
assert.Equal(t, "", DNSProtocolFromContext(context.Background()))
})
}
func TestExchangeWithFallback_TCPContext(t *testing.T) {
// Start a local DNS server that responds on TCP only
tcpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg)
m.SetReply(r)
m.Answer = append(m.Answer, &dns.A{
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1"),
})
if err := w.WriteMsg(m); err != nil {
t.Logf("write msg: %v", err)
}
})
tcpServer := &dns.Server{
Addr: "127.0.0.1:0",
Net: "tcp",
Handler: tcpHandler,
}
tcpLn, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
tcpServer.Listener = tcpLn
go func() {
if err := tcpServer.ActivateAndServe(); err != nil {
t.Logf("tcp server: %v", err)
}
}()
defer func() {
_ = tcpServer.Shutdown()
}()
upstream := tcpLn.Addr().String()
// With TCP context, should connect directly via TCP without trying UDP
ctx := ContextWithDNSProtocol(context.Background(), "tcp")
client := &dns.Client{Timeout: 2 * time.Second}
r := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
rm, _, err := ExchangeWithFallback(ctx, client, r, upstream)
require.NoError(t, err)
require.NotNil(t, rm)
require.NotEmpty(t, rm.Answer)
assert.Contains(t, rm.Answer[0].String(), "10.0.0.1")
}
func TestExchangeWithFallback_UDPFallbackToTCP(t *testing.T) {
// Start a server on both UDP and TCP.
// The handler returns a small response that works on both.
handler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg)
m.SetReply(r)
m.Answer = append(m.Answer, &dns.A{
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.3"),
})
if err := w.WriteMsg(m); err != nil {
t.Logf("write msg: %v", err)
}
})
udpPC, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
addr := udpPC.LocalAddr().String()
udpServer := &dns.Server{
PacketConn: udpPC,
Net: "udp",
Handler: handler,
}
tcpLn, err := net.Listen("tcp", addr)
require.NoError(t, err)
tcpServer := &dns.Server{
Listener: tcpLn,
Net: "tcp",
Handler: handler,
}
go func() {
if err := udpServer.ActivateAndServe(); err != nil {
t.Logf("udp server: %v", err)
}
}()
go func() {
if err := tcpServer.ActivateAndServe(); err != nil {
t.Logf("tcp server: %v", err)
}
}()
defer func() {
_ = udpServer.Shutdown()
_ = tcpServer.Shutdown()
}()
// Normal UDP exchange without TCP context should succeed via UDP
ctx := context.Background()
client := &dns.Client{Timeout: 2 * time.Second}
r := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
rm, _, err := ExchangeWithFallback(ctx, client, r, addr)
require.NoError(t, err)
require.NotNil(t, rm)
require.NotEmpty(t, rm.Answer)
assert.Contains(t, rm.Answer[0].String(), "10.0.0.3")
assert.False(t, rm.Truncated, "small response should not be truncated")
}
func TestExchangeWithFallback_TCPContextSkipsUDP(t *testing.T) {
// Start only a TCP server (no UDP). With TCP context it should succeed.
tcpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg)
m.SetReply(r)
m.Answer = append(m.Answer, &dns.A{
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.2"),
})
if err := w.WriteMsg(m); err != nil {
t.Logf("write msg: %v", err)
}
})
tcpLn, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
tcpServer := &dns.Server{
Listener: tcpLn,
Net: "tcp",
Handler: tcpHandler,
}
go func() {
if err := tcpServer.ActivateAndServe(); err != nil {
t.Logf("tcp server: %v", err)
}
}()
defer func() {
_ = tcpServer.Shutdown()
}()
upstream := tcpLn.Addr().String()
// TCP context: should skip UDP entirely and go directly to TCP
ctx := ContextWithDNSProtocol(context.Background(), "tcp")
client := &dns.Client{Timeout: 2 * time.Second}
r := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
rm, _, err := ExchangeWithFallback(ctx, client, r, upstream)
require.NoError(t, err)
require.NotNil(t, rm)
require.NotEmpty(t, rm.Answer)
assert.Contains(t, rm.Answer[0].String(), "10.0.0.2")
// Without TCP context, trying to reach a TCP-only server via UDP should fail
ctx2 := context.Background()
client2 := &dns.Client{Timeout: 500 * time.Millisecond}
_, _, err = ExchangeWithFallback(ctx2, client2, r, upstream)
assert.Error(t, err, "should fail when no UDP server and no TCP context")
}

View File

@@ -5,6 +5,7 @@ package dns
import (
"net"
wgdevice "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/device"
@@ -19,5 +20,6 @@ type WGIface interface {
IsUserspaceBind() bool
GetFilter() device.PacketFilter
GetDevice() *device.FilteredDevice
GetWGDevice() *wgdevice.Device
GetNet() *netstack.Net
}

View File

@@ -1,6 +1,7 @@
package dns
import (
wgdevice "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/device"
@@ -14,6 +15,7 @@ type WGIface interface {
IsUserspaceBind() bool
GetFilter() device.PacketFilter
GetDevice() *device.FilteredDevice
GetWGDevice() *wgdevice.Device
GetNet() *netstack.Net
GetInterfaceGUIDString() (string, error)
}

View File

@@ -1796,6 +1796,7 @@ func (e *Engine) newDnsServer(dnsConfig *nbdns.Config) (dns.Server, error) {
dnsServer, err := dns.NewDefaultServer(e.ctx, dns.DefaultServerConfig{
WgInterface: e.wgInterface,
Firewall: e.firewall,
CustomAddress: e.config.CustomDNSAddress,
StatusRecorder: e.statusRecorder,
StateManager: e.stateManager,

View File

@@ -4,6 +4,7 @@ import (
"net"
"net/netip"
wgdevice "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/device"
@@ -20,5 +21,6 @@ type wgIfaceBase interface {
IsUserspaceBind() bool
GetFilter() device.PacketFilter
GetDevice() *device.FilteredDevice
GetWGDevice() *wgdevice.Device
GetNet() *netstack.Net
}