Skip to content

Commit

Permalink
Add RouteListFilteredIter API.
Browse files Browse the repository at this point in the history
Allows for listing large numbers of routes without
buffering the whole list in memory at once.

Add benchmarks for RouteListFiltered variants.
  • Loading branch information
fasaxc committed Apr 9, 2024
1 parent 19057e8 commit 166c7ff
Show file tree
Hide file tree
Showing 4 changed files with 262 additions and 57 deletions.
4 changes: 2 additions & 2 deletions netlink_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import (

type tearDownNetlinkTest func()

func skipUnlessRoot(t *testing.T) {
func skipUnlessRoot(t testing.TB) {
t.Helper()

if os.Getuid() != 0 {
Expand Down Expand Up @@ -53,7 +53,7 @@ func skipUnlessKModuleLoaded(t *testing.T, module ...string) {
}
}

func setUpNetlinkTest(t *testing.T) tearDownNetlinkTest {
func setUpNetlinkTest(t testing.TB) tearDownNetlinkTest {
skipUnlessRoot(t)

// new temporary namespace so we don't pollute the host
Expand Down
58 changes: 42 additions & 16 deletions nl/nl_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -488,10 +488,30 @@ func (req *NetlinkRequest) AddRawData(data []byte) {
req.RawData = append(req.RawData, data...)
}

// Execute the request against a the given sockType.
// Execute the request against the given sockType.
// Returns a list of netlink messages in serialized format, optionally filtered
// by resType.
func (req *NetlinkRequest) Execute(sockType int, resType uint16) ([][]byte, error) {
var res [][]byte
err := req.ExecuteIter(sockType, resType, func(msg []byte) bool {
res = append(res, msg)
return true
})
if err != nil {
return nil, err
}
return res, nil
}

// ExecuteIter executes the request against the given sockType.
// Calls the provided callback func once for each netlink message.
// If the callback returns false, it is not called again, but
// the remaining messages are consumed/discarded.
//
// Thread safety: ExecuteIter holds a lock on the socket until
// it finishes iteration so the callback must not call back into
// the netlink API.
func (req *NetlinkRequest) ExecuteIter(sockType int, resType uint16, f func(msg []byte) bool) error {
var (
s *NetlinkSocket
err error
Expand All @@ -508,18 +528,18 @@ func (req *NetlinkRequest) Execute(sockType int, resType uint16) ([][]byte, erro
if s == nil {
s, err = getNetlinkSocket(sockType)
if err != nil {
return nil, err
return err
}

if err := s.SetSendTimeout(&SocketTimeoutTv); err != nil {
return nil, err
return err
}
if err := s.SetReceiveTimeout(&SocketTimeoutTv); err != nil {
return nil, err
return err
}
if EnableErrorMessageReporting {
if err := s.SetExtAck(true); err != nil {
return nil, err
return err
}
}

Expand All @@ -530,38 +550,36 @@ func (req *NetlinkRequest) Execute(sockType int, resType uint16) ([][]byte, erro
}

if err := s.Send(req); err != nil {
return nil, err
return err
}

pid, err := s.GetPid()
if err != nil {
return nil, err
return err
}

var res [][]byte

done:
for {
msgs, from, err := s.Receive()
if err != nil {
return nil, err
return err
}
if from.Pid != PidKernel {
return nil, fmt.Errorf("Wrong sender portid %d, expected %d", from.Pid, PidKernel)
return fmt.Errorf("Wrong sender portid %d, expected %d", from.Pid, PidKernel)
}
for _, m := range msgs {
if m.Header.Seq != req.Seq {
if sharedSocket {
continue
}
return nil, fmt.Errorf("Wrong Seq nr %d, expected %d", m.Header.Seq, req.Seq)
return fmt.Errorf("Wrong Seq nr %d, expected %d", m.Header.Seq, req.Seq)
}
if m.Header.Pid != pid {
continue
}

if m.Header.Flags&unix.NLM_F_DUMP_INTR != 0 {
return nil, syscall.Errno(unix.EINTR)
return syscall.Errno(unix.EINTR)
}

if m.Header.Type == unix.NLMSG_DONE || m.Header.Type == unix.NLMSG_ERROR {
Expand Down Expand Up @@ -600,18 +618,26 @@ done:
}
}

return nil, err
return err
}
if resType != 0 && m.Header.Type != resType {
continue
}
res = append(res, m.Data)
if cont := f(m.Data); !cont {
// Drain the rest of the messages from the kernel but don't
// pass them to the iterator func.
f = dummyMsgIterFunc
}
if m.Header.Flags&unix.NLM_F_MULTI == 0 {
break done
}
}
}
return res, nil
return nil
}

func dummyMsgIterFunc(msg []byte) bool {
return true
}

// Create a new netlink request from proto and flags
Expand Down
Loading

0 comments on commit 166c7ff

Please sign in to comment.