Skip to content

fix: block incoming endpoints from call-me-maybe #55

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 54 additions & 15 deletions wgengine/magicsock/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"net/netip"
"reflect"
"runtime"
"strconv"
"sync"
"sync/atomic"
"time"
Expand Down Expand Up @@ -62,6 +63,7 @@ type endpoint struct {
lastFullPing mono.Time // last time we pinged all disco endpoints
derpAddr netip.AddrPort // fallback/bootstrap path, if non-zero (non-zero for well-behaved clients)

blockEndpoints bool // if true, all new endpoints are discarded
bestAddr addrLatency // best non-DERP path; zero if none
bestAddrAt mono.Time // time best address re-confirmed
trustBestAddrUntil mono.Time // time when bestAddr expires
Expand Down Expand Up @@ -207,6 +209,30 @@ func (de *endpoint) deleteEndpointLocked(why string, ep netip.AddrPort) {
}
}

func (de *endpoint) setBlockEndpoints(blocked bool) {
de.mu.Lock()
defer de.mu.Unlock()
de.debugUpdates.Add(EndpointChange{
When: time.Now(),
What: "setBlockEndpoints-" + strconv.FormatBool(blocked),
})

de.blockEndpoints = blocked
if blocked {
de.endpointState = map[netip.AddrPort]*endpointState{}
if de.bestAddr.AddrPort.IsValid() {
de.debugUpdates.Add(EndpointChange{
When: time.Now(),
What: "setBlockEndpoints-" + strconv.FormatBool(blocked) + "-bestAddr",
From: de.bestAddr,
})
de.bestAddr = addrLatency{}
}
de.c.logf("magicsock: disco: node %s %s now using DERP only (all endpoints deleted)",
de.publicKey.ShortString(), de.discoShort())
}
}

// initFakeUDPAddr populates fakeWGAddr with a globally unique fake UDPAddr.
// The current implementation just uses the pointer value of de jammed into an IPv6
// address, but it could also be, say, a counter.
Expand Down Expand Up @@ -764,22 +790,29 @@ func (de *endpoint) updateFromNode(n *tailcfg.Node, heartbeatDisabled bool) {
}

var newIpps []netip.AddrPort
for i, epStr := range n.Endpoints {
if i > math.MaxInt16 {
// Seems unlikely.
continue
}
ipp, err := netip.ParseAddrPort(epStr)
if err != nil {
de.c.logf("magicsock: bogus netmap endpoint %q", epStr)
continue
}
if st, ok := de.endpointState[ipp]; ok {
st.index = int16(i)
} else {
de.endpointState[ipp] = &endpointState{index: int16(i)}
newIpps = append(newIpps, ipp)
if !de.blockEndpoints {
for i, epStr := range n.Endpoints {
if i > math.MaxInt16 {
// Seems unlikely.
continue
}
ipp, err := netip.ParseAddrPort(epStr)
if err != nil {
de.c.logf("magicsock: bogus netmap endpoint %q", epStr)
continue
}
if st, ok := de.endpointState[ipp]; ok {
st.index = int16(i)
} else {
de.endpointState[ipp] = &endpointState{index: int16(i)}
newIpps = append(newIpps, ipp)
}
}
} else {
de.c.dlogf("[v1] magicsock: disco: updateFromNode: %v received %d endpoints, but endpoints blocked",
de.publicKey.ShortString(),
len(n.Endpoints),
)
}
if len(newIpps) > 0 {
de.debugUpdates.Add(EndpointChange{
Expand Down Expand Up @@ -809,6 +842,12 @@ func (de *endpoint) addCandidateEndpoint(ep netip.AddrPort, forRxPingTxID stun.T
de.mu.Lock()
defer de.mu.Unlock()

isDERP := ep.Addr() == tailcfg.DerpMagicIPAddr
if isDERP && de.blockEndpoints {
de.c.logf("[unexpected] attempted to add candidate endpoint %v to %v (%v) but endpoints blocked", ep, de.discoShort(), de.publicKey.ShortString())
return false
}

if st, ok := de.endpointState[ep]; ok {
duplicatePing = forRxPingTxID == st.lastGotPingTxID
if !duplicatePing {
Expand Down
43 changes: 35 additions & 8 deletions wgengine/magicsock/magicsock.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,8 @@ type Conn struct {
// blockEndpoints is whether to avoid capturing, storing and sending
// endpoints gathered from local interfaces or STUN. Only DERP endpoints
// will be sent.
// This will also block incoming endpoints received via call-me-maybe disco
// packets.
blockEndpoints bool
// endpointsUpdateActive indicates that updateEndpoints is
// currently running. It's used to deduplicate concurrent endpoint
Expand Down Expand Up @@ -855,10 +857,10 @@ func (c *Conn) DiscoPublicKey() key.DiscoPublic {
return c.discoPublic
}

// SetBlockEndpoints sets the blockEndpoints field. If changed, endpoints will
// be updated to apply the new settings. Existing connections may continue to
// use the old setting until they are reestablished. Disabling endpoints does
// not affect the UDP socket or portmapper.
// SetBlockEndpoints sets the blockEndpoints field. If enabled, all peer
// endpoints will be cleared from the peer map and every connection will
// immediately switch to DERP. Disabling endpoints does not affect the UDP
// socket or portmapper.
func (c *Conn) SetBlockEndpoints(block bool) {
c.mu.Lock()
defer c.mu.Unlock()
Expand All @@ -868,6 +870,7 @@ func (c *Conn) SetBlockEndpoints(block bool) {
return
}

// Re-gather local endpoints.
const why = "SetBlockEndpoints"
if c.endpointsUpdateActive {
if c.wantEndpointsUpdate != why {
Expand All @@ -878,6 +881,11 @@ func (c *Conn) SetBlockEndpoints(block bool) {
c.endpointsUpdateActive = true
go c.updateEndpoints(why)
}

// Update all endpoints to abide by the new setting.
c.peerMap.forEachEndpoint(func(ep *endpoint) {
ep.setBlockEndpoints(block)
})
}

// determineEndpoints returns the machine's endpoint addresses. It
Expand Down Expand Up @@ -1435,6 +1443,12 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netip.AddrPort, derpNodeSrc ke
return
}

isDERP := src.Addr() == tailcfg.DerpMagicIPAddr
if !isDERP && c.blockEndpoints {
// Ignore disco messages over UDP if endpoints are blocked.
return
}

if !c.peerMap.anyEndpointForDiscoKey(sender) {
metricRecvDiscoBadPeer.Add(1)
if debugDisco() {
Expand Down Expand Up @@ -1490,7 +1504,6 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netip.AddrPort, derpNodeSrc ke
return
}

isDERP := src.Addr() == tailcfg.DerpMagicIPAddr
if isDERP {
metricRecvDiscoDERP.Add(1)
} else {
Expand Down Expand Up @@ -1535,7 +1548,15 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netip.AddrPort, derpNodeSrc ke
c.logf("[unexpected] CallMeMaybe from peer via DERP whose netmap discokey != disco source")
return
}
c.dlogf("[v1] magicsock: disco: %v<-%v (%v, %v) got call-me-maybe, %d endpoints",
if c.blockEndpoints {
c.dlogf("[v1] magicsock: disco: %v<-%v (%v, %v) got call-me-maybe with %d endpoints, but endpoints blocked",
c.discoShort, epDisco.short,
ep.publicKey.ShortString(), derpStr(src.String()),
len(dm.MyNumber),
)
return
}
c.dlogf("[v1] magicsock: disco: %v<-%v (%v, %v) got call-me-maybe, %d endpoints",
c.discoShort, epDisco.short,
ep.publicKey.ShortString(), derpStr(src.String()),
len(dm.MyNumber))
Expand Down Expand Up @@ -1963,13 +1984,19 @@ func (c *Conn) SetNetworkMap(nm *netmap.NetworkMap) {
debugUpdates: ringbuffer.New[EndpointChange](entriesPerBuffer),
publicKey: n.Key,
publicKeyHex: n.Key.UntypedHexString(),
blockEndpoints: c.blockEndpoints,
sentPing: map[stun.TxID]sentPing{},
endpointState: map[netip.AddrPort]*endpointState{},
heartbeatDisabled: heartbeatDisabled,
isWireguardOnly: n.IsWireGuardOnly,
}
if len(n.Addresses) > 0 {
ep.nodeAddr = n.Addresses[0].Addr()
for _, addr := range n.Addresses {
// Only set nodeAddr if it's a DERP address while endpoints are
// blocked.
if !c.blockEndpoints || addr.Addr() == tailcfg.DerpMagicIPAddr {
ep.nodeAddr = addr.Addr()
break
}
}
ep.initFakeUDPAddr()
if n.DiscoKey.IsZero() {
Expand Down
80 changes: 61 additions & 19 deletions wgengine/magicsock/magicsock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3060,30 +3060,27 @@ func TestBlockEndpointsDERPOK(t *testing.T) {
logf, closeLogf := logger.LogfCloser(t.Logf)
defer closeLogf()

derpMap, cleanup := runDERPAndStun(t, t.Logf, localhostListener{}, netaddr.IPv4(127, 0, 0, 1))
defer cleanup()
derpMap, cleanupDerp := runDERPAndStun(t, t.Logf, localhostListener{}, netaddr.IPv4(127, 0, 0, 1))
defer cleanupDerp()

ms1 := newMagicStack(t, logger.WithPrefix(logf, "conn1: "), d.m1, derpMap)
defer ms1.Close()
ms1.conn.SetDebugLoggingEnabled(true)
ms2 := newMagicStack(t, logger.WithPrefix(logf, "conn2: "), d.m2, derpMap)
defer ms2.Close()
ms2.conn.SetDebugLoggingEnabled(true)

cleanup = meshStacks(logf, nil, ms1, ms2)
defer cleanup()
cleanupMesh := meshStacks(logf, nil, ms1, ms2)
defer cleanupMesh()

m1IP := ms1.IP()
m2IP := ms2.IP()
logf("IPs: %s %s", m1IP, m2IP)

// SetBlockEndpoints is called later since it's incompatible with the test
// meshStacks implementations.
ms1.conn.SetBlockEndpoints(true)
ms2.conn.SetBlockEndpoints(true)
waitForNoEndpoints(t, ms1.conn)
waitForNoEndpoints(t, ms2.conn)

cleanup = newPinger(t, logf, ms1, ms2)
defer cleanup()
cleanupPinger1 := newPinger(t, logf, ms1, ms2)
defer cleanupPinger1()
cleanupPinger2 := newPinger(t, logf, ms2, ms1)
defer cleanupPinger2()

// Wait for both peers to know about each other.
for {
Expand All @@ -3098,14 +3095,42 @@ func TestBlockEndpointsDERPOK(t *testing.T) {
break
}

cleanup = newPinger(t, t.Logf, ms1, ms2)
defer cleanup()
waitForEndpoints(t, ms1.conn)
waitForEndpoints(t, ms2.conn)

if len(ms1.conn.activeDerp) == 0 {
t.Errorf("unexpected DERP empty got: %v want: >0", len(ms1.conn.activeDerp))
// SetBlockEndpoints is called later since it's incompatible with the test
// meshStacks implementations.
// We only set it on ms1, since ms2's endpoints should be ignored by ms1.
ms1.conn.SetBlockEndpoints(true)

// All endpoints should've been immediately removed from ms1.
ep2, ok := ms1.conn.peerMap.endpointForNodeKey(ms2.Public())
if !ok {
t.Fatalf("endpoint not found for ms2 in ms1")
}
if len(ms2.conn.activeDerp) == 0 {
t.Errorf("unexpected DERP empty got: %v want: >0", len(ms2.conn.activeDerp))
ep2.mu.Lock()
if !ep2.blockEndpoints {
t.Fatalf("endpoints not blocked on ep2 in ms1")
}
if len(ep2.endpointState) != 0 {
ep2.mu.Unlock()
t.Fatalf("endpoints not removed on ep2 in ms1")
}
ep2.mu.Unlock()

// Wait for endpoints to finish updating.
waitForNoEndpoints(t, ms1.conn)

// Give time for another call-me-maybe packet to arrive. I couldn't think of
// a better way than sleeping without making a bunch of changes.
t.Logf("sleeping for call-me-maybe packet to be received and ignored")
time.Sleep(time.Second)
t.Logf("done sleeping")

ep2.mu.Lock()
defer ep2.mu.Unlock()
for i := range ep2.endpointState {
t.Fatalf("endpoint %q not missing", i.String())
}
}

Expand All @@ -3129,3 +3154,20 @@ func waitForNoEndpoints(t *testing.T, ms *Conn) {
}
t.Log("endpoints are blocked")
}

func waitForEndpoints(t *testing.T, ms *Conn) {
t.Helper()
for i := 0; i < 50; i++ {
time.Sleep(100 * time.Millisecond)
ms.mu.Lock()
for _, ep := range ms.lastEndpoints {
if ep.Addr.Addr() != tailcfg.DerpMagicIPAddr {
t.Log("endpoint found")
ms.mu.Unlock()
return
}
}
ms.mu.Unlock()
}
t.Fatal("endpoint was not found after 50 attempts")
}
Loading