diff --git a/wgengine/magicsock/endpoint.go b/wgengine/magicsock/endpoint.go index 28983ce9e95c0..616ac16bc870d 100644 --- a/wgengine/magicsock/endpoint.go +++ b/wgengine/magicsock/endpoint.go @@ -15,6 +15,7 @@ import ( "net/netip" "reflect" "runtime" + "strconv" "sync" "sync/atomic" "time" @@ -62,6 +63,7 @@ type endpoint struct { lastFullPing mono.Time // last time we pinged all disco endpoints derpAddr netip.AddrPort // fallback/bootstrap path, if non-zero (non-zero for well-behaved clients) + blockEndpoints bool // if true, all new endpoints are discarded bestAddr addrLatency // best non-DERP path; zero if none bestAddrAt mono.Time // time best address re-confirmed trustBestAddrUntil mono.Time // time when bestAddr expires @@ -207,6 +209,30 @@ func (de *endpoint) deleteEndpointLocked(why string, ep netip.AddrPort) { } } +func (de *endpoint) setBlockEndpoints(blocked bool) { + de.mu.Lock() + defer de.mu.Unlock() + de.debugUpdates.Add(EndpointChange{ + When: time.Now(), + What: "setBlockEndpoints-" + strconv.FormatBool(blocked), + }) + + de.blockEndpoints = blocked + if blocked { + de.endpointState = map[netip.AddrPort]*endpointState{} + if de.bestAddr.AddrPort.IsValid() { + de.debugUpdates.Add(EndpointChange{ + When: time.Now(), + What: "setBlockEndpoints-" + strconv.FormatBool(blocked) + "-bestAddr", + From: de.bestAddr, + }) + de.bestAddr = addrLatency{} + } + de.c.logf("magicsock: disco: node %s %s now using DERP only (all endpoints deleted)", + de.publicKey.ShortString(), de.discoShort()) + } +} + // initFakeUDPAddr populates fakeWGAddr with a globally unique fake UDPAddr. // The current implementation just uses the pointer value of de jammed into an IPv6 // address, but it could also be, say, a counter. @@ -764,22 +790,29 @@ func (de *endpoint) updateFromNode(n *tailcfg.Node, heartbeatDisabled bool) { } var newIpps []netip.AddrPort - for i, epStr := range n.Endpoints { - if i > math.MaxInt16 { - // Seems unlikely. - continue - } - ipp, err := netip.ParseAddrPort(epStr) - if err != nil { - de.c.logf("magicsock: bogus netmap endpoint %q", epStr) - continue - } - if st, ok := de.endpointState[ipp]; ok { - st.index = int16(i) - } else { - de.endpointState[ipp] = &endpointState{index: int16(i)} - newIpps = append(newIpps, ipp) + if !de.blockEndpoints { + for i, epStr := range n.Endpoints { + if i > math.MaxInt16 { + // Seems unlikely. + continue + } + ipp, err := netip.ParseAddrPort(epStr) + if err != nil { + de.c.logf("magicsock: bogus netmap endpoint %q", epStr) + continue + } + if st, ok := de.endpointState[ipp]; ok { + st.index = int16(i) + } else { + de.endpointState[ipp] = &endpointState{index: int16(i)} + newIpps = append(newIpps, ipp) + } } + } else { + de.c.dlogf("[v1] magicsock: disco: updateFromNode: %v received %d endpoints, but endpoints blocked", + de.publicKey.ShortString(), + len(n.Endpoints), + ) } if len(newIpps) > 0 { de.debugUpdates.Add(EndpointChange{ @@ -809,6 +842,12 @@ func (de *endpoint) addCandidateEndpoint(ep netip.AddrPort, forRxPingTxID stun.T de.mu.Lock() defer de.mu.Unlock() + isDERP := ep.Addr() == tailcfg.DerpMagicIPAddr + if isDERP && de.blockEndpoints { + de.c.logf("[unexpected] attempted to add candidate endpoint %v to %v (%v) but endpoints blocked", ep, de.discoShort(), de.publicKey.ShortString()) + return false + } + if st, ok := de.endpointState[ep]; ok { duplicatePing = forRxPingTxID == st.lastGotPingTxID if !duplicatePing { diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index b27736822af30..3c77a11353012 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -217,6 +217,8 @@ type Conn struct { // blockEndpoints is whether to avoid capturing, storing and sending // endpoints gathered from local interfaces or STUN. Only DERP endpoints // will be sent. + // This will also block incoming endpoints received via call-me-maybe disco + // packets. blockEndpoints bool // endpointsUpdateActive indicates that updateEndpoints is // currently running. It's used to deduplicate concurrent endpoint @@ -855,10 +857,10 @@ func (c *Conn) DiscoPublicKey() key.DiscoPublic { return c.discoPublic } -// SetBlockEndpoints sets the blockEndpoints field. If changed, endpoints will -// be updated to apply the new settings. Existing connections may continue to -// use the old setting until they are reestablished. Disabling endpoints does -// not affect the UDP socket or portmapper. +// SetBlockEndpoints sets the blockEndpoints field. If enabled, all peer +// endpoints will be cleared from the peer map and every connection will +// immediately switch to DERP. Disabling endpoints does not affect the UDP +// socket or portmapper. func (c *Conn) SetBlockEndpoints(block bool) { c.mu.Lock() defer c.mu.Unlock() @@ -868,6 +870,7 @@ func (c *Conn) SetBlockEndpoints(block bool) { return } + // Re-gather local endpoints. const why = "SetBlockEndpoints" if c.endpointsUpdateActive { if c.wantEndpointsUpdate != why { @@ -878,6 +881,11 @@ func (c *Conn) SetBlockEndpoints(block bool) { c.endpointsUpdateActive = true go c.updateEndpoints(why) } + + // Update all endpoints to abide by the new setting. + c.peerMap.forEachEndpoint(func(ep *endpoint) { + ep.setBlockEndpoints(block) + }) } // determineEndpoints returns the machine's endpoint addresses. It @@ -1435,6 +1443,12 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netip.AddrPort, derpNodeSrc ke return } + isDERP := src.Addr() == tailcfg.DerpMagicIPAddr + if !isDERP && c.blockEndpoints { + // Ignore disco messages over UDP if endpoints are blocked. + return + } + if !c.peerMap.anyEndpointForDiscoKey(sender) { metricRecvDiscoBadPeer.Add(1) if debugDisco() { @@ -1490,7 +1504,6 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netip.AddrPort, derpNodeSrc ke return } - isDERP := src.Addr() == tailcfg.DerpMagicIPAddr if isDERP { metricRecvDiscoDERP.Add(1) } else { @@ -1535,7 +1548,15 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netip.AddrPort, derpNodeSrc ke c.logf("[unexpected] CallMeMaybe from peer via DERP whose netmap discokey != disco source") return } - c.dlogf("[v1] magicsock: disco: %v<-%v (%v, %v) got call-me-maybe, %d endpoints", + if c.blockEndpoints { + c.dlogf("[v1] magicsock: disco: %v<-%v (%v, %v) got call-me-maybe with %d endpoints, but endpoints blocked", + c.discoShort, epDisco.short, + ep.publicKey.ShortString(), derpStr(src.String()), + len(dm.MyNumber), + ) + return + } + c.dlogf("[v1] magicsock: disco: %v<-%v (%v, %v) got call-me-maybe, %d endpoints", c.discoShort, epDisco.short, ep.publicKey.ShortString(), derpStr(src.String()), len(dm.MyNumber)) @@ -1963,13 +1984,19 @@ func (c *Conn) SetNetworkMap(nm *netmap.NetworkMap) { debugUpdates: ringbuffer.New[EndpointChange](entriesPerBuffer), publicKey: n.Key, publicKeyHex: n.Key.UntypedHexString(), + blockEndpoints: c.blockEndpoints, sentPing: map[stun.TxID]sentPing{}, endpointState: map[netip.AddrPort]*endpointState{}, heartbeatDisabled: heartbeatDisabled, isWireguardOnly: n.IsWireGuardOnly, } - if len(n.Addresses) > 0 { - ep.nodeAddr = n.Addresses[0].Addr() + for _, addr := range n.Addresses { + // Only set nodeAddr if it's a DERP address while endpoints are + // blocked. + if !c.blockEndpoints || addr.Addr() == tailcfg.DerpMagicIPAddr { + ep.nodeAddr = addr.Addr() + break + } } ep.initFakeUDPAddr() if n.DiscoKey.IsZero() { diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index ee5a390f47110..02fc474125a9f 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -3060,30 +3060,27 @@ func TestBlockEndpointsDERPOK(t *testing.T) { logf, closeLogf := logger.LogfCloser(t.Logf) defer closeLogf() - derpMap, cleanup := runDERPAndStun(t, t.Logf, localhostListener{}, netaddr.IPv4(127, 0, 0, 1)) - defer cleanup() + derpMap, cleanupDerp := runDERPAndStun(t, t.Logf, localhostListener{}, netaddr.IPv4(127, 0, 0, 1)) + defer cleanupDerp() ms1 := newMagicStack(t, logger.WithPrefix(logf, "conn1: "), d.m1, derpMap) defer ms1.Close() + ms1.conn.SetDebugLoggingEnabled(true) ms2 := newMagicStack(t, logger.WithPrefix(logf, "conn2: "), d.m2, derpMap) defer ms2.Close() + ms2.conn.SetDebugLoggingEnabled(true) - cleanup = meshStacks(logf, nil, ms1, ms2) - defer cleanup() + cleanupMesh := meshStacks(logf, nil, ms1, ms2) + defer cleanupMesh() m1IP := ms1.IP() m2IP := ms2.IP() logf("IPs: %s %s", m1IP, m2IP) - // SetBlockEndpoints is called later since it's incompatible with the test - // meshStacks implementations. - ms1.conn.SetBlockEndpoints(true) - ms2.conn.SetBlockEndpoints(true) - waitForNoEndpoints(t, ms1.conn) - waitForNoEndpoints(t, ms2.conn) - - cleanup = newPinger(t, logf, ms1, ms2) - defer cleanup() + cleanupPinger1 := newPinger(t, logf, ms1, ms2) + defer cleanupPinger1() + cleanupPinger2 := newPinger(t, logf, ms2, ms1) + defer cleanupPinger2() // Wait for both peers to know about each other. for { @@ -3098,14 +3095,42 @@ func TestBlockEndpointsDERPOK(t *testing.T) { break } - cleanup = newPinger(t, t.Logf, ms1, ms2) - defer cleanup() + waitForEndpoints(t, ms1.conn) + waitForEndpoints(t, ms2.conn) - if len(ms1.conn.activeDerp) == 0 { - t.Errorf("unexpected DERP empty got: %v want: >0", len(ms1.conn.activeDerp)) + // SetBlockEndpoints is called later since it's incompatible with the test + // meshStacks implementations. + // We only set it on ms1, since ms2's endpoints should be ignored by ms1. + ms1.conn.SetBlockEndpoints(true) + + // All endpoints should've been immediately removed from ms1. + ep2, ok := ms1.conn.peerMap.endpointForNodeKey(ms2.Public()) + if !ok { + t.Fatalf("endpoint not found for ms2 in ms1") } - if len(ms2.conn.activeDerp) == 0 { - t.Errorf("unexpected DERP empty got: %v want: >0", len(ms2.conn.activeDerp)) + ep2.mu.Lock() + if !ep2.blockEndpoints { + t.Fatalf("endpoints not blocked on ep2 in ms1") + } + if len(ep2.endpointState) != 0 { + ep2.mu.Unlock() + t.Fatalf("endpoints not removed on ep2 in ms1") + } + ep2.mu.Unlock() + + // Wait for endpoints to finish updating. + waitForNoEndpoints(t, ms1.conn) + + // Give time for another call-me-maybe packet to arrive. I couldn't think of + // a better way than sleeping without making a bunch of changes. + t.Logf("sleeping for call-me-maybe packet to be received and ignored") + time.Sleep(time.Second) + t.Logf("done sleeping") + + ep2.mu.Lock() + defer ep2.mu.Unlock() + for i := range ep2.endpointState { + t.Fatalf("endpoint %q not missing", i.String()) } } @@ -3129,3 +3154,20 @@ func waitForNoEndpoints(t *testing.T, ms *Conn) { } t.Log("endpoints are blocked") } + +func waitForEndpoints(t *testing.T, ms *Conn) { + t.Helper() + for i := 0; i < 50; i++ { + time.Sleep(100 * time.Millisecond) + ms.mu.Lock() + for _, ep := range ms.lastEndpoints { + if ep.Addr.Addr() != tailcfg.DerpMagicIPAddr { + t.Log("endpoint found") + ms.mu.Unlock() + return + } + } + ms.mu.Unlock() + } + t.Fatal("endpoint was not found after 50 attempts") +}