Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add iterator-based version of RouteListFilter #940

Merged
merged 1 commit into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions netlink_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import (

type tearDownNetlinkTest func()

func skipUnlessRoot(t *testing.T) {
func skipUnlessRoot(t testing.TB) {
t.Helper()

if os.Getuid() != 0 {
Expand Down Expand Up @@ -53,7 +53,7 @@ func skipUnlessKModuleLoaded(t *testing.T, module ...string) {
}
}

func setUpNetlinkTest(t *testing.T) tearDownNetlinkTest {
func setUpNetlinkTest(t testing.TB) tearDownNetlinkTest {
skipUnlessRoot(t)

// new temporary namespace so we don't pollute the host
Expand Down
58 changes: 42 additions & 16 deletions nl/nl_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -488,10 +488,30 @@ func (req *NetlinkRequest) AddRawData(data []byte) {
req.RawData = append(req.RawData, data...)
}

// Execute the request against a the given sockType.
// Execute the request against the given sockType.
// Returns a list of netlink messages in serialized format, optionally filtered
// by resType.
func (req *NetlinkRequest) Execute(sockType int, resType uint16) ([][]byte, error) {
var res [][]byte
err := req.ExecuteIter(sockType, resType, func(msg []byte) bool {
res = append(res, msg)
return true
})
if err != nil {
return nil, err
}
return res, nil
}

// ExecuteIter executes the request against the given sockType.
// Calls the provided callback func once for each netlink message.
// If the callback returns false, it is not called again, but
// the remaining messages are consumed/discarded.
//
// Thread safety: ExecuteIter holds a lock on the socket until
// it finishes iteration so the callback must not call back into
// the netlink API.
func (req *NetlinkRequest) ExecuteIter(sockType int, resType uint16, f func(msg []byte) bool) error {
var (
s *NetlinkSocket
err error
Expand All @@ -508,18 +528,18 @@ func (req *NetlinkRequest) Execute(sockType int, resType uint16) ([][]byte, erro
if s == nil {
s, err = getNetlinkSocket(sockType)
if err != nil {
return nil, err
return err
}

if err := s.SetSendTimeout(&SocketTimeoutTv); err != nil {
return nil, err
return err
}
if err := s.SetReceiveTimeout(&SocketTimeoutTv); err != nil {
return nil, err
return err
}
if EnableErrorMessageReporting {
if err := s.SetExtAck(true); err != nil {
return nil, err
return err
}
}

Expand All @@ -530,38 +550,36 @@ func (req *NetlinkRequest) Execute(sockType int, resType uint16) ([][]byte, erro
}

if err := s.Send(req); err != nil {
return nil, err
return err
}

pid, err := s.GetPid()
if err != nil {
return nil, err
return err
}

var res [][]byte

done:
for {
msgs, from, err := s.Receive()
if err != nil {
return nil, err
return err
}
if from.Pid != PidKernel {
return nil, fmt.Errorf("Wrong sender portid %d, expected %d", from.Pid, PidKernel)
return fmt.Errorf("Wrong sender portid %d, expected %d", from.Pid, PidKernel)
}
for _, m := range msgs {
if m.Header.Seq != req.Seq {
if sharedSocket {
continue
}
return nil, fmt.Errorf("Wrong Seq nr %d, expected %d", m.Header.Seq, req.Seq)
return fmt.Errorf("Wrong Seq nr %d, expected %d", m.Header.Seq, req.Seq)
}
if m.Header.Pid != pid {
continue
}

if m.Header.Flags&unix.NLM_F_DUMP_INTR != 0 {
return nil, syscall.Errno(unix.EINTR)
return syscall.Errno(unix.EINTR)
}

if m.Header.Type == unix.NLMSG_DONE || m.Header.Type == unix.NLMSG_ERROR {
Expand Down Expand Up @@ -600,18 +618,26 @@ done:
}
}

return nil, err
return err
}
if resType != 0 && m.Header.Type != resType {
continue
}
res = append(res, m.Data)
if cont := f(m.Data); !cont {
// Drain the rest of the messages from the kernel but don't
// pass them to the iterator func.
f = dummyMsgIterFunc
}
if m.Header.Flags&unix.NLM_F_MULTI == 0 {
break done
}
}
}
return res, nil
return nil
}

func dummyMsgIterFunc(msg []byte) bool {
return true
}

// Create a new netlink request from proto and flags
Expand Down
Loading
Loading