From 213d90d2c71c6394fb5dc4ad74e4f2cf2d9b3818 Mon Sep 17 00:00:00 2001 From: Shaun Crampton Date: Tue, 9 Apr 2024 10:04:24 +0100 Subject: [PATCH] Add RouteListFilteredIter API. Allows for listing large numbers of routes without buffering the whole list in memory at once. Add benchmarks for RouteListFiltered variants. --- netlink_test.go | 4 +- nl/nl_linux.go | 58 +++++++++++++------ route_linux.go | 114 ++++++++++++++++++++++++------------- route_test.go | 145 ++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 263 insertions(+), 58 deletions(-) diff --git a/netlink_test.go b/netlink_test.go index 0ec5256c..2224ed98 100644 --- a/netlink_test.go +++ b/netlink_test.go @@ -23,7 +23,7 @@ import ( type tearDownNetlinkTest func() -func skipUnlessRoot(t *testing.T) { +func skipUnlessRoot(t testing.TB) { t.Helper() if os.Getuid() != 0 { @@ -53,7 +53,7 @@ func skipUnlessKModuleLoaded(t *testing.T, module ...string) { } } -func setUpNetlinkTest(t *testing.T) tearDownNetlinkTest { +func setUpNetlinkTest(t testing.TB) tearDownNetlinkTest { skipUnlessRoot(t) // new temporary namespace so we don't pollute the host diff --git a/nl/nl_linux.go b/nl/nl_linux.go index c0cef5a1..42d5e6f6 100644 --- a/nl/nl_linux.go +++ b/nl/nl_linux.go @@ -488,10 +488,30 @@ func (req *NetlinkRequest) AddRawData(data []byte) { req.RawData = append(req.RawData, data...) } -// Execute the request against a the given sockType. +// Execute the request against the given sockType. // Returns a list of netlink messages in serialized format, optionally filtered // by resType. func (req *NetlinkRequest) Execute(sockType int, resType uint16) ([][]byte, error) { + var res [][]byte + err := req.ExecuteIter(sockType, resType, func(msg []byte) bool { + res = append(res, msg) + return true + }) + if err != nil { + return nil, err + } + return res, nil +} + +// ExecuteIter executes the request against the given sockType. +// Calls the provided callback func once for each netlink message. +// If the callback returns false, it is not called again, but +// the remaining messages are consumed/discarded. +// +// Thread safety: ExecuteIter holds a lock on the socket until +// it finishes iteration so the callback must not call back into +// the netlink API. +func (req *NetlinkRequest) ExecuteIter(sockType int, resType uint16, f func(msg []byte) bool) error { var ( s *NetlinkSocket err error @@ -508,18 +528,18 @@ func (req *NetlinkRequest) Execute(sockType int, resType uint16) ([][]byte, erro if s == nil { s, err = getNetlinkSocket(sockType) if err != nil { - return nil, err + return err } if err := s.SetSendTimeout(&SocketTimeoutTv); err != nil { - return nil, err + return err } if err := s.SetReceiveTimeout(&SocketTimeoutTv); err != nil { - return nil, err + return err } if EnableErrorMessageReporting { if err := s.SetExtAck(true); err != nil { - return nil, err + return err } } @@ -530,38 +550,36 @@ func (req *NetlinkRequest) Execute(sockType int, resType uint16) ([][]byte, erro } if err := s.Send(req); err != nil { - return nil, err + return err } pid, err := s.GetPid() if err != nil { - return nil, err + return err } - var res [][]byte - done: for { msgs, from, err := s.Receive() if err != nil { - return nil, err + return err } if from.Pid != PidKernel { - return nil, fmt.Errorf("Wrong sender portid %d, expected %d", from.Pid, PidKernel) + return fmt.Errorf("Wrong sender portid %d, expected %d", from.Pid, PidKernel) } for _, m := range msgs { if m.Header.Seq != req.Seq { if sharedSocket { continue } - return nil, fmt.Errorf("Wrong Seq nr %d, expected %d", m.Header.Seq, req.Seq) + return fmt.Errorf("Wrong Seq nr %d, expected %d", m.Header.Seq, req.Seq) } if m.Header.Pid != pid { continue } if m.Header.Flags&unix.NLM_F_DUMP_INTR != 0 { - return nil, syscall.Errno(unix.EINTR) + return syscall.Errno(unix.EINTR) } if m.Header.Type == unix.NLMSG_DONE || m.Header.Type == unix.NLMSG_ERROR { @@ -600,18 +618,26 @@ done: } } - return nil, err + return err } if resType != 0 && m.Header.Type != resType { continue } - res = append(res, m.Data) + if cont := f(m.Data); !cont { + // Drain the rest of the messages from the kernel but don't + // pass them to the iterator func. + f = dummyMsgIterFunc + } if m.Header.Flags&unix.NLM_F_MULTI == 0 { break done } } } - return res, nil + return nil +} + +func dummyMsgIterFunc(msg []byte) bool { + return true } // Create a new netlink request from proto and flags diff --git a/route_linux.go b/route_linux.go index 929738e1..59a6858c 100644 --- a/route_linux.go +++ b/route_linux.go @@ -400,7 +400,7 @@ func (e *SEG6LocalEncap) String() string { } if e.Flags[nl.SEG6_LOCAL_SRH] { segs := make([]string, 0, len(e.Segments)) - //append segment backwards (from n to 0) since seg#0 is the last segment. + // append segment backwards (from n to 0) since seg#0 is the last segment. for i := len(e.Segments); i > 0; i-- { segs = append(segs, e.Segments[i-1].String()) } @@ -835,8 +835,22 @@ func (h *Handle) RouteDel(route *Route) error { } func (h *Handle) routeHandle(route *Route, req *nl.NetlinkRequest, msg *nl.RtMsg) ([][]byte, error) { + if err := h.prepareRouteReq(route, req, msg); err != nil { + return nil, err + } + return req.Execute(unix.NETLINK_ROUTE, 0) +} + +func (h *Handle) routeHandleIter(route *Route, req *nl.NetlinkRequest, msg *nl.RtMsg, f func(msg []byte) bool) error { + if err := h.prepareRouteReq(route, req, msg); err != nil { + return err + } + return req.ExecuteIter(unix.NETLINK_ROUTE, 0, f) +} + +func (h *Handle) prepareRouteReq(route *Route, req *nl.NetlinkRequest, msg *nl.RtMsg) error { if req.NlMsghdr.Type != unix.RTM_GETROUTE && (route.Dst == nil || route.Dst.IP == nil) && route.Src == nil && route.Gw == nil && route.MPLSDst == nil { - return nil, fmt.Errorf("Either Dst.IP, Src.IP or Gw must be set") + return fmt.Errorf("either Dst.IP, Src.IP or Gw must be set") } family := -1 @@ -863,11 +877,11 @@ func (h *Handle) routeHandle(route *Route, req *nl.NetlinkRequest, msg *nl.RtMsg if route.NewDst != nil { if family != -1 && family != route.NewDst.Family() { - return nil, fmt.Errorf("new destination and destination are not the same address family") + return fmt.Errorf("new destination and destination are not the same address family") } buf, err := route.NewDst.Encode() if err != nil { - return nil, err + return err } rtAttrs = append(rtAttrs, nl.NewRtAttr(unix.RTA_NEWDST, buf)) } @@ -878,7 +892,7 @@ func (h *Handle) routeHandle(route *Route, req *nl.NetlinkRequest, msg *nl.RtMsg rtAttrs = append(rtAttrs, nl.NewRtAttr(unix.RTA_ENCAP_TYPE, buf)) buf, err := route.Encap.Encode() if err != nil { - return nil, err + return err } switch route.Encap.Type() { case nl.LWTUNNEL_ENCAP_BPF: @@ -892,7 +906,7 @@ func (h *Handle) routeHandle(route *Route, req *nl.NetlinkRequest, msg *nl.RtMsg if route.Src != nil { srcFamily := nl.GetIPFamily(route.Src) if family != -1 && family != srcFamily { - return nil, fmt.Errorf("source and destination ip are not the same IP family") + return fmt.Errorf("source and destination ip are not the same IP family") } family = srcFamily var srcData []byte @@ -908,7 +922,7 @@ func (h *Handle) routeHandle(route *Route, req *nl.NetlinkRequest, msg *nl.RtMsg if route.Gw != nil { gwFamily := nl.GetIPFamily(route.Gw) if family != -1 && family != gwFamily { - return nil, fmt.Errorf("gateway, source, and destination ip are not the same IP family") + return fmt.Errorf("gateway, source, and destination ip are not the same IP family") } family = gwFamily var gwData []byte @@ -923,7 +937,7 @@ func (h *Handle) routeHandle(route *Route, req *nl.NetlinkRequest, msg *nl.RtMsg if route.Via != nil { buf, err := route.Via.Encode() if err != nil { - return nil, fmt.Errorf("failed to encode RTA_VIA: %v", err) + return fmt.Errorf("failed to encode RTA_VIA: %v", err) } rtAttrs = append(rtAttrs, nl.NewRtAttr(unix.RTA_VIA, buf)) } @@ -942,7 +956,7 @@ func (h *Handle) routeHandle(route *Route, req *nl.NetlinkRequest, msg *nl.RtMsg if nh.Gw != nil { gwFamily := nl.GetIPFamily(nh.Gw) if family != -1 && family != gwFamily { - return nil, fmt.Errorf("gateway, source, and destination ip are not the same IP family") + return fmt.Errorf("gateway, source, and destination ip are not the same IP family") } if gwFamily == FAMILY_V4 { children = append(children, nl.NewRtAttr(unix.RTA_GATEWAY, []byte(nh.Gw.To4()))) @@ -952,11 +966,11 @@ func (h *Handle) routeHandle(route *Route, req *nl.NetlinkRequest, msg *nl.RtMsg } if nh.NewDst != nil { if family != -1 && family != nh.NewDst.Family() { - return nil, fmt.Errorf("new destination and destination are not the same address family") + return fmt.Errorf("new destination and destination are not the same address family") } buf, err := nh.NewDst.Encode() if err != nil { - return nil, err + return err } children = append(children, nl.NewRtAttr(unix.RTA_NEWDST, buf)) } @@ -966,14 +980,14 @@ func (h *Handle) routeHandle(route *Route, req *nl.NetlinkRequest, msg *nl.RtMsg children = append(children, nl.NewRtAttr(unix.RTA_ENCAP_TYPE, buf)) buf, err := nh.Encap.Encode() if err != nil { - return nil, err + return err } children = append(children, nl.NewRtAttr(unix.RTA_ENCAP, buf)) } if nh.Via != nil { buf, err := nh.Via.Encode() if err != nil { - return nil, err + return err } children = append(children, nl.NewRtAttr(unix.RTA_VIA, buf)) } @@ -1104,8 +1118,7 @@ func (h *Handle) routeHandle(route *Route, req *nl.NetlinkRequest, msg *nl.RtMsg native.PutUint32(b, uint32(route.LinkIndex)) req.AddData(nl.NewRtAttr(unix.RTA_OIF, b)) } - - return req.Execute(unix.NETLINK_ROUTE, 0) + return nil } // RouteList gets a list of routes in the system. @@ -1137,73 +1150,94 @@ func RouteListFiltered(family int, filter *Route, filterMask uint64) ([]Route, e // RouteListFiltered gets a list of routes in the system filtered with specified rules. // All rules must be defined in RouteFilter struct func (h *Handle) RouteListFiltered(family int, filter *Route, filterMask uint64) ([]Route, error) { - req := h.newNetlinkRequest(unix.RTM_GETROUTE, unix.NLM_F_DUMP) - rtmsg := &nl.RtMsg{} - rtmsg.Family = uint8(family) - msgs, err := h.routeHandle(filter, req, rtmsg) + var res []Route + err := h.RouteListFilteredIter(family, filter, filterMask, func(route Route) (cont bool) { + res = append(res, route) + return true + }) if err != nil { return nil, err } + return res, nil +} - var res []Route - for _, m := range msgs { +// RouteListFilteredIter passes each route that matches the filter to the given iterator func. Iteration continues +// until all routes are loaded or the func returns false. +func RouteListFilteredIter(family int, filter *Route, filterMask uint64, f func(Route) (cont bool)) error { + return pkgHandle.RouteListFilteredIter(family, filter, filterMask, f) +} + +func (h *Handle) RouteListFilteredIter(family int, filter *Route, filterMask uint64, f func(Route) (cont bool)) error { + req := h.newNetlinkRequest(unix.RTM_GETROUTE, unix.NLM_F_DUMP) + rtmsg := &nl.RtMsg{} + rtmsg.Family = uint8(family) + + var parseErr error + err := h.routeHandleIter(filter, req, rtmsg, func(m []byte) bool { msg := nl.DeserializeRtMsg(m) if family != FAMILY_ALL && msg.Family != uint8(family) { // Ignore routes not matching requested family - continue + return true } if msg.Flags&unix.RTM_F_CLONED != 0 { // Ignore cloned routes - continue + return true } if msg.Table != unix.RT_TABLE_MAIN { if filter == nil || filterMask&RT_FILTER_TABLE == 0 { // Ignore non-main tables - continue + return true } } route, err := deserializeRoute(m) if err != nil { - return nil, err + parseErr = err + return false } if filter != nil { switch { case filterMask&RT_FILTER_TABLE != 0 && filter.Table != unix.RT_TABLE_UNSPEC && route.Table != filter.Table: - continue + return true case filterMask&RT_FILTER_PROTOCOL != 0 && route.Protocol != filter.Protocol: - continue + return true case filterMask&RT_FILTER_SCOPE != 0 && route.Scope != filter.Scope: - continue + return true case filterMask&RT_FILTER_TYPE != 0 && route.Type != filter.Type: - continue + return true case filterMask&RT_FILTER_TOS != 0 && route.Tos != filter.Tos: - continue + return true case filterMask&RT_FILTER_REALM != 0 && route.Realm != filter.Realm: - continue + return true case filterMask&RT_FILTER_OIF != 0 && route.LinkIndex != filter.LinkIndex: - continue + return true case filterMask&RT_FILTER_IIF != 0 && route.ILinkIndex != filter.ILinkIndex: - continue + return true case filterMask&RT_FILTER_GW != 0 && !route.Gw.Equal(filter.Gw): - continue + return true case filterMask&RT_FILTER_SRC != 0 && !route.Src.Equal(filter.Src): - continue + return true case filterMask&RT_FILTER_DST != 0: if filter.MPLSDst == nil || route.MPLSDst == nil || (*filter.MPLSDst) != (*route.MPLSDst) { if filter.Dst == nil { filter.Dst = genZeroIPNet(family) } if !ipNetEqual(route.Dst, filter.Dst) { - continue + return true } } case filterMask&RT_FILTER_HOPLIMIT != 0 && route.Hoplimit != filter.Hoplimit: - continue + return true } } - res = append(res, route) + return f(route) + }) + if err != nil { + return err } - return res, nil + if parseErr != nil { + return parseErr + } + return nil } // deserializeRoute decodes a binary netlink message into a Route struct @@ -1723,7 +1757,7 @@ func (p RouteProtocol) String() string { return "gated" case unix.RTPROT_ISIS: return "isis" - //case unix.RTPROT_KEEPALIVED: + // case unix.RTPROT_KEEPALIVED: // return "keepalived" case unix.RTPROT_KERNEL: return "kernel" diff --git a/route_test.go b/route_test.go index c73205f2..069a7fbf 100644 --- a/route_test.go +++ b/route_test.go @@ -949,6 +949,151 @@ func TestRouteFilterByFamily(t *testing.T) { } } +func TestRouteFilterIterCanStop(t *testing.T) { + tearDown := setUpNetlinkTest(t) + defer tearDown() + + // get loopback interface + link, err := LinkByName("lo") + if err != nil { + t.Fatal(err) + } + // bring the interface up + if err = LinkSetUp(link); err != nil { + t.Fatal(err) + } + + // add a gateway route + dst := &net.IPNet{ + IP: net.IPv4(1, 1, 1, 1), + Mask: net.CIDRMask(32, 32), + } + + for i := 0; i < 3; i++ { + route := Route{ + LinkIndex: link.Attrs().Index, + Dst: dst, + Scope: unix.RT_SCOPE_LINK, + Priority: 1 + i, + Table: 1000, + Type: unix.RTN_UNICAST, + } + if err := RouteAdd(&route); err != nil { + t.Fatal(err) + } + } + + var routes []Route + err = RouteListFilteredIter(FAMILY_V4, &Route{ + Dst: dst, + Scope: unix.RT_SCOPE_LINK, + Table: 1000, + Type: unix.RTN_UNICAST, + }, RT_FILTER_TABLE, func(route Route) (cont bool) { + routes = append(routes, route) + return len(routes) < 2 + }) + if err != nil { + t.Fatal(err) + } + if len(routes) != 2 { + t.Fatal("Unexpected number of iterations") + } + for _, route := range routes { + if route.Scope != unix.RT_SCOPE_LINK { + t.Fatal("Invalid Scope. Route not added properly") + } + if route.Priority < 1 || route.Priority > 3 { + t.Fatal("Priority outside expected range. Route not added properly") + } + if route.Table != 1000 { + t.Fatalf("Invalid Table %d. Route not added properly", route.Table) + } + if route.Type != unix.RTN_UNICAST { + t.Fatal("Invalid Type. Route not added properly") + } + } +} + +func BenchmarkRouteListFilteredNew(b *testing.B) { + tearDown := setUpNetlinkTest(b) + defer tearDown() + + link, err := setUpRoutesBench(b) + + b.ResetTimer() + b.ReportAllocs() + var routes []Route + for i := 0; i < b.N; i++ { + routes, err = pkgHandle.RouteListFiltered(FAMILY_V4, &Route{ + LinkIndex: link.Attrs().Index, + }, RT_FILTER_OIF) + if err != nil { + b.Fatal(err) + } + if len(routes) != 65535 { + b.Fatal("Incorrect number of routes.", len(routes)) + } + } + runtime.KeepAlive(routes) +} + +func BenchmarkRouteListIter(b *testing.B) { + tearDown := setUpNetlinkTest(b) + defer tearDown() + + link, err := setUpRoutesBench(b) + + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + var routes int + err = RouteListFilteredIter(FAMILY_V4, &Route{ + LinkIndex: link.Attrs().Index, + }, RT_FILTER_OIF, func(route Route) (cont bool) { + routes++ + return true + }) + if err != nil { + b.Fatal(err) + } + if routes != 65535 { + b.Fatal("Incorrect number of routes.", routes) + } + } +} + +func setUpRoutesBench(b *testing.B) (Link, error) { + // get loopback interface + link, err := LinkByName("lo") + if err != nil { + b.Fatal(err) + } + // bring the interface up + if err = LinkSetUp(link); err != nil { + b.Fatal(err) + } + + // add a gateway route + for i := 0; i < 65535; i++ { + dst := &net.IPNet{ + IP: net.IPv4(1, 1, byte(i>>8), byte(i&0xff)), + Mask: net.CIDRMask(32, 32), + } + route := Route{ + LinkIndex: link.Attrs().Index, + Dst: dst, + Scope: unix.RT_SCOPE_LINK, + Priority: 10, + Type: unix.RTN_UNICAST, + } + if err := RouteAdd(&route); err != nil { + b.Fatal(err) + } + } + return link, err +} + func tableIDIn(ids []int, id int) bool { for _, v := range ids { if v == id {