diff --git a/pkg/managers/pluginmanager/pluginmanager.go b/pkg/managers/pluginmanager/pluginmanager.go index 15d133382f..79bd66cc99 100644 --- a/pkg/managers/pluginmanager/pluginmanager.go +++ b/pkg/managers/pluginmanager/pluginmanager.go @@ -39,7 +39,7 @@ type PluginManager struct { plugins map[string]plugin.Plugin tel telemetry.Telemetry - watcherManager watchermanager.IWatcherManager + watcherManager watchermanager.Manager } func NewPluginManager(cfg *kcfg.Config, tel telemetry.Telemetry) (*PluginManager, error) { @@ -126,17 +126,19 @@ func (p *PluginManager) Start(ctx context.Context) error { return ErrZeroInterval } - if p.cfg.EnablePodLevel { - p.l.Info("starting watchers") + g, ctx := errgroup.WithContext(ctx) - // Start watcher manager - if err := p.watcherManager.Start(ctx); err != nil { - return errors.Wrap(err, "failed to start watcher manager") - } + if p.cfg.EnablePodLevel { + g.Go(func() error { + p.l.Info("starting watchers") + // Start watcher manager + if err = p.watcherManager.Start(ctx); err != nil { + return errors.Wrap(err, "failed to start watcher manager") + } + return nil + }) } - g, ctx := errgroup.WithContext(ctx) - // run conntrack GC ct, err := conntrack.New() if err != nil { diff --git a/pkg/managers/pluginmanager/pluginmanager_test.go b/pkg/managers/pluginmanager/pluginmanager_test.go index 2203468b5b..9221ad39c3 100644 --- a/pkg/managers/pluginmanager/pluginmanager_test.go +++ b/pkg/managers/pluginmanager/pluginmanager_test.go @@ -38,8 +38,8 @@ var ( } ) -func setupWatcherManagerMock(ctl *gomock.Controller) (m *watchermock.MockIWatcherManager) { - m = watchermock.NewMockIWatcherManager(ctl) +func setupWatcherManagerMock(ctl *gomock.Controller) (m *watchermock.MockManager) { + m = watchermock.NewMockManager(ctl) m.EXPECT().Start(gomock.Any()).Return(nil).AnyTimes() m.EXPECT().Stop(gomock.Any()).Return(nil).AnyTimes() return @@ -456,7 +456,7 @@ func TestWatcherManagerFailure(t *testing.T) { defer ctl.Finish() log.SetupZapLogger(log.GetDefaultLogOpts()) - m := watchermock.NewMockIWatcherManager(ctl) + m := watchermock.NewMockManager(ctl) m.EXPECT().Start(gomock.Any()).Return(errors.New("error")).AnyTimes() cfg := cfgPodLevelEnabled diff --git a/pkg/managers/watchermanager/mocks/mock_types.go b/pkg/managers/watchermanager/mocks/mock_types.go index b848396086..4150e4b7a6 100644 --- a/pkg/managers/watchermanager/mocks/mock_types.go +++ b/pkg/managers/watchermanager/mocks/mock_types.go @@ -16,59 +16,59 @@ import ( gomock "go.uber.org/mock/gomock" ) -// MockIWatcher is a mock of IWatcher interface. -type MockIWatcher struct { +// MockWatcher is a mock of Watcher interface. +type MockWatcher struct { ctrl *gomock.Controller - recorder *MockIWatcherMockRecorder + recorder *MockWatcherMockRecorder } -// MockIWatcherMockRecorder is the mock recorder for MockIWatcher. -type MockIWatcherMockRecorder struct { - mock *MockIWatcher +// MockWatcherMockRecorder is the mock recorder for MockWatcher. +type MockWatcherMockRecorder struct { + mock *MockWatcher } -// NewMockIWatcher creates a new mock instance. -func NewMockIWatcher(ctrl *gomock.Controller) *MockIWatcher { - mock := &MockIWatcher{ctrl: ctrl} - mock.recorder = &MockIWatcherMockRecorder{mock} +// NewMockWatcher creates a new mock instance. +func NewMockWatcher(ctrl *gomock.Controller) *MockWatcher { + mock := &MockWatcher{ctrl: ctrl} + mock.recorder = &MockWatcherMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockIWatcher) EXPECT() *MockIWatcherMockRecorder { +func (m *MockWatcher) EXPECT() *MockWatcherMockRecorder { return m.recorder } -// Init mocks base method. -func (m *MockIWatcher) Init(ctx context.Context) error { +// Name mocks base method. +func (m *MockWatcher) Name() string { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Init", ctx) - ret0, _ := ret[0].(error) + ret := m.ctrl.Call(m, "Name") + ret0, _ := ret[0].(string) return ret0 } -// Init indicates an expected call of Init. -func (mr *MockIWatcherMockRecorder) Init(ctx any) *gomock.Call { +// Name indicates an expected call of Name. +func (mr *MockWatcherMockRecorder) Name() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Init", reflect.TypeOf((*MockIWatcher)(nil).Init), ctx) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Name", reflect.TypeOf((*MockWatcher)(nil).Name)) } -// Refresh mocks base method. -func (m *MockIWatcher) Refresh(ctx context.Context) error { +// Start mocks base method. +func (m *MockWatcher) Start(ctx context.Context) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Refresh", ctx) + ret := m.ctrl.Call(m, "Start", ctx) ret0, _ := ret[0].(error) return ret0 } -// Refresh indicates an expected call of Refresh. -func (mr *MockIWatcherMockRecorder) Refresh(ctx any) *gomock.Call { +// Start indicates an expected call of Start. +func (mr *MockWatcherMockRecorder) Start(ctx any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Refresh", reflect.TypeOf((*MockIWatcher)(nil).Refresh), ctx) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Start", reflect.TypeOf((*MockWatcher)(nil).Start), ctx) } // Stop mocks base method. -func (m *MockIWatcher) Stop(ctx context.Context) error { +func (m *MockWatcher) Stop(ctx context.Context) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Stop", ctx) ret0, _ := ret[0].(error) @@ -76,36 +76,36 @@ func (m *MockIWatcher) Stop(ctx context.Context) error { } // Stop indicates an expected call of Stop. -func (mr *MockIWatcherMockRecorder) Stop(ctx any) *gomock.Call { +func (mr *MockWatcherMockRecorder) Stop(ctx any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stop", reflect.TypeOf((*MockIWatcher)(nil).Stop), ctx) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stop", reflect.TypeOf((*MockWatcher)(nil).Stop), ctx) } -// MockIWatcherManager is a mock of IWatcherManager interface. -type MockIWatcherManager struct { +// MockManager is a mock of Manager interface. +type MockManager struct { ctrl *gomock.Controller - recorder *MockIWatcherManagerMockRecorder + recorder *MockManagerMockRecorder } -// MockIWatcherManagerMockRecorder is the mock recorder for MockIWatcherManager. -type MockIWatcherManagerMockRecorder struct { - mock *MockIWatcherManager +// MockManagerMockRecorder is the mock recorder for MockManager. +type MockManagerMockRecorder struct { + mock *MockManager } -// NewMockIWatcherManager creates a new mock instance. -func NewMockIWatcherManager(ctrl *gomock.Controller) *MockIWatcherManager { - mock := &MockIWatcherManager{ctrl: ctrl} - mock.recorder = &MockIWatcherManagerMockRecorder{mock} +// NewMockManager creates a new mock instance. +func NewMockManager(ctrl *gomock.Controller) *MockManager { + mock := &MockManager{ctrl: ctrl} + mock.recorder = &MockManagerMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockIWatcherManager) EXPECT() *MockIWatcherManagerMockRecorder { +func (m *MockManager) EXPECT() *MockManagerMockRecorder { return m.recorder } // Start mocks base method. -func (m *MockIWatcherManager) Start(ctx context.Context) error { +func (m *MockManager) Start(ctx context.Context) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Start", ctx) ret0, _ := ret[0].(error) @@ -113,13 +113,13 @@ func (m *MockIWatcherManager) Start(ctx context.Context) error { } // Start indicates an expected call of Start. -func (mr *MockIWatcherManagerMockRecorder) Start(ctx any) *gomock.Call { +func (mr *MockManagerMockRecorder) Start(ctx any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Start", reflect.TypeOf((*MockIWatcherManager)(nil).Start), ctx) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Start", reflect.TypeOf((*MockManager)(nil).Start), ctx) } // Stop mocks base method. -func (m *MockIWatcherManager) Stop(ctx context.Context) error { +func (m *MockManager) Stop(ctx context.Context) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Stop", ctx) ret0, _ := ret[0].(error) @@ -127,7 +127,7 @@ func (m *MockIWatcherManager) Stop(ctx context.Context) error { } // Stop indicates an expected call of Stop. -func (mr *MockIWatcherManagerMockRecorder) Stop(ctx any) *gomock.Call { +func (mr *MockManagerMockRecorder) Stop(ctx any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stop", reflect.TypeOf((*MockIWatcherManager)(nil).Stop), ctx) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stop", reflect.TypeOf((*MockManager)(nil).Stop), ctx) } diff --git a/pkg/managers/watchermanager/types.go b/pkg/managers/watchermanager/types.go index 2417650c06..170928eea6 100644 --- a/pkg/managers/watchermanager/types.go +++ b/pkg/managers/watchermanager/types.go @@ -5,29 +5,24 @@ package watchermanager import ( "context" - "sync" - "time" "github.com/microsoft/retina/pkg/log" ) //go:generate go run go.uber.org/mock/mockgen@v0.4.0 -source=types.go -destination=mocks/mock_types.go -package=mocks . -type IWatcher interface { - // Init, Stop, and Refresh should only be called by watchermanager. - Init(ctx context.Context) error +type Watcher interface { + // Start and Stop should only be called by watchermanager. + Start(ctx context.Context) error Stop(ctx context.Context) error - Refresh(ctx context.Context) error + Name() string } -type IWatcherManager interface { +type Manager interface { Start(ctx context.Context) error Stop(ctx context.Context) error } type WatcherManager struct { - Watchers []IWatcher - l *log.ZapLogger - refreshRate time.Duration - cancel context.CancelFunc - wg sync.WaitGroup + Watchers []Watcher + l *log.ZapLogger } diff --git a/pkg/managers/watchermanager/watchermanager.go b/pkg/managers/watchermanager/watchermanager.go index 55617db286..390a283854 100644 --- a/pkg/managers/watchermanager/watchermanager.go +++ b/pkg/managers/watchermanager/watchermanager.go @@ -5,13 +5,14 @@ package watchermanager import ( "context" - "fmt" "time" "github.com/microsoft/retina/pkg/log" "github.com/microsoft/retina/pkg/watchers/apiserver" "github.com/microsoft/retina/pkg/watchers/endpoint" + "github.com/pkg/errors" "go.uber.org/zap" + "golang.org/x/sync/errgroup" ) const ( @@ -21,61 +22,39 @@ const ( func NewWatcherManager() *WatcherManager { return &WatcherManager{ - Watchers: []IWatcher{ - endpoint.Watcher(), - apiserver.Watcher(), + Watchers: []Watcher{ + apiserver.NewWatcher(), + endpoint.NewWatcher(), }, - l: log.Logger().Named("watcher-manager"), - refreshRate: DefaultRefreshRate, + l: log.Logger().Named("watcher-manager"), } } func (wm *WatcherManager) Start(ctx context.Context) error { - newCtx, cancelCtx := context.WithCancel(ctx) - wm.cancel = cancelCtx - + wm.l.Info("starting watcher manager") + // start all watchers + g, ctx := errgroup.WithContext(ctx) for _, w := range wm.Watchers { - if err := w.Init(ctx); err != nil { - wm.l.Error("init failed", zap.String("watcher_type", fmt.Sprintf("%T", w)), zap.Error(err)) - return err - } - wm.wg.Add(1) - go wm.runWatcher(newCtx, w) - wm.l.Info("watcher started", zap.String("watcher_type", fmt.Sprintf("%T", w))) + w := w + g.Go(func() error { + wm.l.Info("starting watcher", zap.String("name", w.Name())) + err := w.Start(ctx) + if err != nil { + wm.l.Error("watcher exited with error", zap.Error(err), zap.String("name", w.Name())) + return errors.Wrap(err, "watcher exited with error") + } + return nil + }) + } + err := g.Wait() + if err != nil { + wm.l.Error("watcher manager exited with error", zap.Error(err)) + return errors.Wrap(err, "watcher manager exited with error") } return nil } func (wm *WatcherManager) Stop(ctx context.Context) error { - if wm.cancel != nil { - wm.cancel() // cancel all runWatcher - } - for _, w := range wm.Watchers { - if err := w.Stop(ctx); err != nil { - wm.l.Error("failed to stop", zap.String("watcher_type", fmt.Sprintf("%T", w)), zap.Error(err)) - return err - } - } - wm.wg.Wait() // wait for all runWatcher to stop wm.l.Info("watcher manager stopped") return nil } - -func (wm *WatcherManager) runWatcher(ctx context.Context, w IWatcher) error { - defer wm.wg.Done() // signal that this runWatcher is done - ticker := time.NewTicker(wm.refreshRate) - defer ticker.Stop() - for { - select { - case <-ctx.Done(): - wm.l.Info("watcher stopping...", zap.String("watcher_type", fmt.Sprintf("%T", w))) - return nil - case <-ticker.C: - err := w.Refresh(ctx) - if err != nil { - wm.l.Error("refresh failed", zap.Error(err)) - return err - } - } - } -} diff --git a/pkg/managers/watchermanager/watchermanager_test.go b/pkg/managers/watchermanager/watchermanager_test.go index 37dcf4809f..a611a29464 100644 --- a/pkg/managers/watchermanager/watchermanager_test.go +++ b/pkg/managers/watchermanager/watchermanager_test.go @@ -4,71 +4,32 @@ package watchermanager import ( "context" - "errors" "testing" "github.com/microsoft/retina/pkg/log" - mock "github.com/microsoft/retina/pkg/managers/watchermanager/mocks" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" "golang.org/x/sync/errgroup" ) -var errInitFailed = errors.New("init failed") - func TestStopWatcherManagerGracefully(t *testing.T) { ctl := gomock.NewController(t) defer ctl.Finish() log.SetupZapLogger(log.GetDefaultLogOpts()) mgr := NewWatcherManager() - mockAPIServerWatcher := mock.NewMockIWatcher(ctl) - mockEndpointWatcher := mock.NewMockIWatcher(ctl) - - mgr.Watchers = []IWatcher{ - mockEndpointWatcher, - mockAPIServerWatcher, - } - - mockAPIServerWatcher.EXPECT().Init(gomock.Any()).Return(nil).AnyTimes() - mockEndpointWatcher.EXPECT().Init(gomock.Any()).Return(nil).AnyTimes() - - mockEndpointWatcher.EXPECT().Stop(gomock.Any()).Return(nil).AnyTimes() - mockAPIServerWatcher.EXPECT().Stop(gomock.Any()).Return(nil).AnyTimes() - - ctx, _ := context.WithCancel(context.Background()) + ctx := context.Background() g, errctx := errgroup.WithContext(ctx) + var err error g.Go(func() error { - return mgr.Start(errctx) + err = mgr.Start(errctx) + return err }) - err := g.Wait() - mgr.Stop(errctx) require.NoError(t, err) } -func TestWatcherInitFailsGracefully(t *testing.T) { - ctl := gomock.NewController(t) - defer ctl.Finish() - log.SetupZapLogger(log.GetDefaultLogOpts()) - - mockAPIServerWatcher := mock.NewMockIWatcher(ctl) - mockEndpointWatcher := mock.NewMockIWatcher(ctl) - - mgr := NewWatcherManager() - mgr.Watchers = []IWatcher{ - mockAPIServerWatcher, - mockEndpointWatcher, - } - - mockAPIServerWatcher.EXPECT().Init(gomock.Any()).Return(errInitFailed).AnyTimes() - mockEndpointWatcher.EXPECT().Init(gomock.Any()).Return(errInitFailed).AnyTimes() - - err := mgr.Start(context.Background()) - require.NotNil(t, err, "Expected error when starting watcher manager") -} - func TestWatcherStopWithoutStart(t *testing.T) { ctl := gomock.NewController(t) defer ctl.Finish() diff --git a/pkg/plugin/packetparser/packetparser_linux.go b/pkg/plugin/packetparser/packetparser_linux.go index fb54c8bb5e..a945ab0b7d 100644 --- a/pkg/plugin/packetparser/packetparser_linux.go +++ b/pkg/plugin/packetparser/packetparser_linux.go @@ -225,7 +225,7 @@ func (p *packetParser) Init() error { } p.tcMap = &sync.Map{} - p.interfaceLockMap = &sync.Map{} + p.interfaceMap = &sync.Map{} return nil } @@ -382,23 +382,26 @@ func (p *packetParser) endpointWatcherCallbackFn(obj interface{}) { iface := event.Obj.(netlink.LinkAttrs) ifaceKey := ifaceToKey(iface) - lockMapVal, _ := p.interfaceLockMap.LoadOrStore(ifaceKey, &sync.Mutex{}) - mu := lockMapVal.(*sync.Mutex) - mu.Lock() - defer mu.Unlock() + _, ifaceExist := p.interfaceMap.LoadOrStore(ifaceKey, struct{}{}) switch event.Type { case endpoint.EndpointCreated: - p.l.Debug("Endpoint created", zap.String("name", iface.Name)) - p.createQdiscAndAttach(iface, Veth) + if !ifaceExist { + p.l.Debug("Endpoint created", zap.String("name", iface.Name)) + p.createQdiscAndAttach(iface, Veth) + } case endpoint.EndpointDeleted: - p.l.Debug("Endpoint deleted", zap.String("name", iface.Name)) - // Clean. - if value, ok := p.tcMap.Load(ifaceKey); ok { - v := value.(*tcValue) - p.clean(v.tc, v.qdisc) - // Delete from map. - p.tcMap.Delete(ifaceKey) + if ifaceExist { + p.l.Debug("Endpoint deleted", zap.String("name", iface.Name)) + // Clean. + if value, ok := p.tcMap.Load(ifaceKey); ok { + v := value.(*tcValue) + p.clean(v.tc, v.qdisc) + // Delete from map. + p.tcMap.Delete(ifaceKey) + } + // Delete from interfaceMap + p.interfaceMap.Delete(ifaceKey) } default: // Unknown. diff --git a/pkg/plugin/packetparser/packetparser_linux_test.go b/pkg/plugin/packetparser/packetparser_linux_test.go index aa27cfa356..afa3811c2a 100644 --- a/pkg/plugin/packetparser/packetparser_linux_test.go +++ b/pkg/plugin/packetparser/packetparser_linux_test.go @@ -163,9 +163,9 @@ func TestEndpointWatcherCallbackFn_EndpointDeleted(t *testing.T) { defer ctrl.Finish() p := &packetParser{ - cfg: cfgPodLevelEnabled, - l: log.Logger().Named("test"), - interfaceLockMap: &sync.Map{}, + cfg: cfgPodLevelEnabled, + l: log.Logger().Named("test"), + interfaceMap: &sync.Map{}, } p.tcMap = &sync.Map{} linkAttr := netlink.LinkAttrs{ @@ -175,6 +175,7 @@ func TestEndpointWatcherCallbackFn_EndpointDeleted(t *testing.T) { } key := ifaceToKey(linkAttr) p.tcMap.Store(key, &tcValue{nil, &tc.Object{}}) + p.interfaceMap.Store(key, struct{}{}) // Create EndpointDeleted event. e := &endpoint.EndpointEvent{ @@ -186,6 +187,8 @@ func TestEndpointWatcherCallbackFn_EndpointDeleted(t *testing.T) { _, ok := p.tcMap.Load(key) assert.False(t, ok) + _, ok = p.interfaceMap.Load(key) + assert.False(t, ok) } func TestCreateQdiscAndAttach(t *testing.T) { @@ -224,10 +227,10 @@ func TestCreateQdiscAndAttach(t *testing.T) { pObj.EndpointEgressFilter = &ebpf.Program{} p := &packetParser{ - cfg: cfgPodLevelEnabled, - l: log.Logger().Named("test"), - objs: pObj, - interfaceLockMap: &sync.Map{}, + cfg: cfgPodLevelEnabled, + l: log.Logger().Named("test"), + objs: pObj, + interfaceMap: &sync.Map{}, endpointIngressInfo: &ebpf.ProgramInfo{ Name: "ingress", }, @@ -412,12 +415,12 @@ func TestStartWithDataAggregationLevelLow(t *testing.T) { pObj.EndpointEgressFilter = &ebpf.Program{} p := &packetParser{ - cfg: cfgDataAggregationLevelLow, - l: log.Logger().Named("test"), - objs: pObj, - reader: mockReader, - recordsChannel: make(chan perf.Record, buffer), - interfaceLockMap: &sync.Map{}, + cfg: cfgDataAggregationLevelLow, + l: log.Logger().Named("test"), + objs: pObj, + reader: mockReader, + recordsChannel: make(chan perf.Record, buffer), + interfaceMap: &sync.Map{}, endpointIngressInfo: &ebpf.ProgramInfo{ Name: "ingress", }, @@ -491,12 +494,12 @@ func TestStartWithDataAggregationLevelHigh(t *testing.T) { pObj.EndpointEgressFilter = &ebpf.Program{} p := &packetParser{ - cfg: cfgDataAggregationLevelHigh, - l: log.Logger().Named("test"), - objs: pObj, - reader: mockReader, - recordsChannel: make(chan perf.Record, buffer), - interfaceLockMap: &sync.Map{}, + cfg: cfgDataAggregationLevelHigh, + l: log.Logger().Named("test"), + objs: pObj, + reader: mockReader, + recordsChannel: make(chan perf.Record, buffer), + interfaceMap: &sync.Map{}, endpointIngressInfo: &ebpf.ProgramInfo{ Name: "ingress", }, diff --git a/pkg/plugin/packetparser/types_linux.go b/pkg/plugin/packetparser/types_linux.go index 4fc06f35c3..ab1d321ae9 100644 --- a/pkg/plugin/packetparser/types_linux.go +++ b/pkg/plugin/packetparser/types_linux.go @@ -114,8 +114,8 @@ type packetParser struct { tcMap *sync.Map reader perfReader enricher enricher.EnricherInterface - // interfaceLockMap is a map of key to *sync.Mutex. - interfaceLockMap *sync.Map + // interfaceMap is a map of exisiting interfaces + interfaceMap *sync.Map endpointIngressInfo *ebpf.ProgramInfo endpointEgressInfo *ebpf.ProgramInfo hostIngressInfo *ebpf.ProgramInfo diff --git a/pkg/watchers/apiserver/apiserver.go b/pkg/watchers/apiserver/apiserver.go index 606c2cde74..55ca6e6cab 100644 --- a/pkg/watchers/apiserver/apiserver.go +++ b/pkg/watchers/apiserver/apiserver.go @@ -9,208 +9,168 @@ import ( "net" "net/url" "strings" + "time" "github.com/microsoft/retina/pkg/common" cc "github.com/microsoft/retina/pkg/controllers/cache" "github.com/microsoft/retina/pkg/log" fm "github.com/microsoft/retina/pkg/managers/filtermanager" "github.com/microsoft/retina/pkg/pubsub" - "github.com/microsoft/retina/pkg/utils" + "github.com/pkg/errors" "go.uber.org/zap" - "k8s.io/client-go/rest" kcfg "sigs.k8s.io/controller-runtime/pkg/client/config" ) -const ( - filterManagerRetries = 3 - hostLookupRetries = 6 // 6 retries for a total of 63 seconds. -) - -type ApiServerWatcher struct { - isRunning bool - l *log.ZapLogger - current cache - new cache - apiServerHostName string - hostResolver IHostResolver - filterManager fm.IFilterManager - restConfig *rest.Config -} - -var a *ApiServerWatcher - -// Watcher creates a new ApiServerWatcher instance. -func Watcher() *ApiServerWatcher { - if a == nil { - a = &ApiServerWatcher{ - isRunning: false, - l: log.Logger().Named("apiserver-watcher"), - current: make(cache), - hostResolver: net.DefaultResolver, - } - } - - return a +func (w *Watcher) Name() string { + return watcherName } -func (a *ApiServerWatcher) Init(ctx context.Context) error { - if a.isRunning { - a.l.Info("apiserver watcher is already running") - return nil - } - - // Get filter manager. - if a.filterManager == nil { - var err error - a.filterManager, err = fm.Init(filterManagerRetries) - if err != nil { - a.l.Error("failed to init filter manager", zap.Error(err)) - return fmt.Errorf("failed to init filter manager: %w", err) +// Start the apiserver watcher. +func (w *Watcher) Start(ctx context.Context) error { + ticker := time.NewTicker(w.refreshRate) + for { + select { + case <-ctx.Done(): + w.l.Info("context done, stopping apiserver watcher") + return nil + case <-ticker.C: + err := w.initNewCache(ctx) + if err != nil { + return err + } + // Compare the new ips with the old ones. + created, deleted := w.diffCache() + + // Publish the new ips. + createdIps := []net.IP{} + deletedIps := []net.IP{} + + for _, v := range created { + w.l.Info("New Apiserver ips:", zap.Any("ip", v)) + ip := net.ParseIP(v.(string)).To4() + createdIps = append(createdIps, ip) + } + + for _, v := range deleted { + w.l.Info("Deleted Apiserver ips:", zap.Any("ip", v)) + ip := net.ParseIP(v.(string)).To4() + deletedIps = append(deletedIps, ip) + } + + if len(createdIps) > 0 { + // Publish the new ips. + w.publish(createdIps, cc.EventTypeAddAPIServerIPs) + // Add ips to filter manager if any. + err := w.filtermanager.AddIPs(createdIps, "apiserver-watcher", fm.RequestMetadata{RuleID: "apiserver-watcher"}) + if err != nil { + w.l.Error("Failed to add ips to filter manager", zap.Error(err)) + } + } + + if len(deletedIps) > 0 { + // Publish the deleted ips. + w.publish(deletedIps, cc.EventTypeDeleteAPIServerIPs) + // Delete ips from filter manager if any. + err := w.filtermanager.DeleteIPs(deletedIps, "apiserver-watcher", fm.RequestMetadata{RuleID: "apiserver-watcher"}) + if err != nil { + w.l.Error("Failed to delete ips from filter manager", zap.Error(err)) + } + } + + // update the current cache and reset the new cache + w.current = w.new.deepcopy() + w.new = nil } } - - // Get kubeconfig. - if a.restConfig == nil { - config, err := kcfg.GetConfig() - if err != nil { - a.l.Error("failed to get kubeconfig", zap.Error(err)) - return fmt.Errorf("failed to get kubeconfig: %w", err) - } - a.restConfig = config - } - - hostName, err := a.getHostName() - if err != nil { - a.l.Error("failed to get host name", zap.Error(err)) - return fmt.Errorf("failed to get host name: %w", err) - } - a.apiServerHostName = hostName - - a.isRunning = true - - return nil } -// Stop stops the ApiServerWatcher. -func (a *ApiServerWatcher) Stop(ctx context.Context) error { - if !a.isRunning { - a.l.Info("apiserver watcher is not running") - return nil - } - a.isRunning = false +// Stop the apiserver watcher. +func (w *Watcher) Stop(_ context.Context) error { + w.l.Info("stopping apiserver watcher") return nil } -func (a *ApiServerWatcher) Refresh(ctx context.Context) error { - err := a.initNewCache(ctx) +func (w *Watcher) initNewCache(ctx context.Context) error { + ips, err := w.getAPIServerIPs(ctx) if err != nil { - a.l.Error("failed to initialize new cache", zap.Error(err)) return err } - // Compare the new IPs with the old ones. - created, deleted := a.diffCache() - - createdIPs := []net.IP{} - deletedIPs := []net.IP{} - - for _, v := range created { - a.l.Info("New Apiserver IPs:", zap.Any("ip", v)) - ip := net.ParseIP(v.(string)).To4() - createdIPs = append(createdIPs, ip) - } - - for _, v := range deleted { - a.l.Info("Deleted Apiserver IPs:", zap.Any("ip", v)) - ip := net.ParseIP(v.(string)).To4() - deletedIPs = append(deletedIPs, ip) + // Reset the new cache. + w.new = make(cache) + for _, ip := range ips { + w.new[ip] = struct{}{} } + return nil +} - if len(createdIPs) > 0 { - a.publish(createdIPs, cc.EventTypeAddAPIServerIPs) - err := a.filterManager.AddIPs(createdIPs, "apiserver-watcher", fm.RequestMetadata{RuleID: "apiserver-watcher"}) - if err != nil { - a.l.Error("Failed to add IPs to filter manager", zap.Error(err)) +func (w *Watcher) diffCache() (created, deleted []interface{}) { + // check if there are new ips + for k := range w.new { + if _, ok := w.current[k]; !ok { + created = append(created, k) } } - if len(deletedIPs) > 0 { - a.publish(deletedIPs, cc.EventTypeDeleteAPIServerIPs) - err := a.filterManager.DeleteIPs(deletedIPs, "apiserver-watcher", fm.RequestMetadata{RuleID: "apiserver-watcher"}) - if err != nil { - a.l.Error("Failed to delete IPs from filter manager", zap.Error(err)) + // check if there are deleted ips + for k := range w.current { + if _, ok := w.new[k]; !ok { + deleted = append(deleted, k) } } - - a.current = a.new.deepcopy() - a.new = nil - - return nil + return } -func (a *ApiServerWatcher) initNewCache(ctx context.Context) error { - ips, err := a.resolveIPs(ctx, a.apiServerHostName) +func (w *Watcher) getAPIServerIPs(ctx context.Context) ([]string, error) { + // Parse the URL + host, err := w.retrieveAPIServerHostname() if err != nil { - return fmt.Errorf("failed to resolve IPs: %w", err) + return nil, err } - // Reset new cache. - a.new = make(cache) - for _, ip := range ips { - a.new[ip] = struct{}{} + // Get the ips for the host + ips, err := w.resolveIPs(ctx, host) + if err != nil { + return nil, err } - return nil + + return ips, nil } -func (a *ApiServerWatcher) diffCache() (created, deleted []interface{}) { - // Check if there are any new IPs. - for k := range a.new { - if _, ok := a.current[k]; !ok { - created = append(created, k) - } +// parse url to extract hostname +func (w *Watcher) retrieveAPIServerHostname() (string, error) { + // Parse the URL + parsedURL, err := url.Parse(w.apiServerURL) + if err != nil { + w.l.Error("failed to parse url", zap.Error(err)) + return "", errors.Wrap(err, "failed to parse url") } - // Check if there are any deleted IPs. - for k := range a.current { - if _, ok := a.new[k]; !ok { - deleted = append(deleted, k) - } + // Remove the scheme (http:// or https://) and port from the host + host := strings.TrimPrefix(parsedURL.Host, "www.") + colonIndex := strings.IndexByte(host, ':') + if colonIndex != -1 { + host = host[:colonIndex] } - return + return host, nil } -func (a *ApiServerWatcher) resolveIPs(ctx context.Context, host string) ([]string, error) { - // perform a DNS lookup for the host URL using the net.DefaultResolver which uses the local resolver. - // Possible errors here are: - // - Canceled context: The context was canceled before the lookup completed. - // -DNS server errors ie NXDOMAIN, SERVFAIL. - // - Network errors ie timeout, unreachable DNS server. - // -Other DNS-related errors encapsulated in a DNSError. - var hostIPs []string - var err error - - retryFunc := func() error { - hostIPs, err = a.hostResolver.LookupHost(ctx, host) - if err != nil { - return fmt.Errorf("APIServer LookupHost failed: %w", err) - } - return nil - } - - // Retry the lookup for hostIPs in case of failure. - err = utils.Retry(retryFunc, hostLookupRetries) +// Resolve the list of ips for the given host +func (w *Watcher) resolveIPs(ctx context.Context, host string) ([]string, error) { + hostIps, err := w.hostResolver.LookupHost(ctx, host) if err != nil { return nil, err } - if len(hostIPs) == 0 { - a.l.Debug("no IPs found for host", zap.String("host", host)) + if len(hostIps) == 0 { + w.l.Error("no ips found for host", zap.String("host", host)) + return nil, fmt.Errorf("no ips found for host %s", host) //nolint:err113 // static err is not necessary } - return hostIPs, nil + return hostIps, nil } -func (a *ApiServerWatcher) publish(netIPs []net.IP, eventType cc.EventType) { +func (w *Watcher) publish(netIPs []net.IP, eventType cc.EventType) { if len(netIPs) == 0 { return } @@ -220,23 +180,30 @@ func (a *ApiServerWatcher) publish(netIPs []net.IP, eventType cc.EventType) { ipsToPublish = append(ipsToPublish, ip.String()) } ps := pubsub.New() - ps.Publish(common.PubSubAPIServer, cc.NewCacheEvent(eventType, common.NewAPIServerObject(ipsToPublish))) - a.l.Debug("Published event", zap.Any("eventType", eventType), zap.Any("netIPs", ipsToPublish)) + ps.Publish(common.PubSubAPIServer, + cc.NewCacheEvent( + eventType, + common.NewAPIServerObject(ipsToPublish), + ), + ) + w.l.Debug("Published event", zap.Any("eventType", eventType), zap.Any("netIPs", ipsToPublish)) } -func (a *ApiServerWatcher) getHostName() (string, error) { - // Parse the host URL. - hostURL := a.restConfig.Host - parsedURL, err := url.ParseRequestURI(hostURL) +// getHostURL returns the host url from the config. +func getHostURL() string { + config, err := kcfg.GetConfig() if err != nil { - log.Logger().Error("failed to parse URL", zap.String("url", hostURL), zap.Error(err)) - return "", fmt.Errorf("failed to parse URL: %w", err) + log.Logger().Error("failed to get config", zap.Error(err)) + return "" } + return config.Host +} - // Extract the host name from the URL. - host := strings.TrimPrefix(parsedURL.Host, "www.") - if colonIndex := strings.IndexByte(host, ':'); colonIndex != -1 { - host = host[:colonIndex] +// Get FilterManager +func (w *Watcher) getFilterManager() *fm.FilterManager { + f, err := fm.Init(filterManagerRetries) + if err != nil { + w.l.Error("failed to init filter manager", zap.Error(err)) } - return host, nil + return f } diff --git a/pkg/watchers/apiserver/apiserver_test.go b/pkg/watchers/apiserver/apiserver_test.go index 04105e85c1..903f0034ae 100644 --- a/pkg/watchers/apiserver/apiserver_test.go +++ b/pkg/watchers/apiserver/apiserver_test.go @@ -18,52 +18,9 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" - "k8s.io/client-go/rest" ) -var errDNS = errors.New("DNS error") - -func TestGetWatcher(t *testing.T) { - log.SetupZapLogger(log.GetDefaultLogOpts()) - - a := Watcher() - assert.NotNil(t, a) - - a_again := Watcher() - assert.Equal(t, a, a_again, "Expected the same veth watcher instance") -} - -func TestAPIServerWatcherStop(t *testing.T) { - log.SetupZapLogger(log.GetDefaultLogOpts()) - ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) - defer cancel() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockedFilterManager := filtermanagermocks.NewMockIFilterManager(ctrl) - - // When apiserver is already stopped. - a := &ApiServerWatcher{ - isRunning: false, - l: log.Logger().Named("apiserver-watcher"), - filterManager: mockedFilterManager, - restConfig: getMockConfig(true), - } - err := a.Stop(ctx) - assert.NoError(t, err, "Expected no error when stopping a stopped apiserver watcher") - assert.Equal(t, false, a.isRunning, "Expected apiserver watcher to be stopped") - - // Start the watcher. - err = a.Init(ctx) - assert.NoError(t, err, "Expected no error when starting a stopped apiserver watcher") - - // Stop the watcher. - err = a.Stop(ctx) - assert.NoError(t, err, "Expected no error when stopping a running apiserver watcher") - assert.Equal(t, false, a.isRunning, "Expected apiserver watcher to be stopped") -} - -func TestRefresh(t *testing.T) { +func TestStart(t *testing.T) { log.SetupZapLogger(log.GetDefaultLogOpts()) ctrl := gomock.NewController(t) defer ctrl.Finish() @@ -74,10 +31,12 @@ func TestRefresh(t *testing.T) { mockedResolver := mocks.NewMockIHostResolver(ctrl) mockedFilterManager := filtermanagermocks.NewMockIFilterManager(ctrl) - a := &ApiServerWatcher{ - l: log.Logger().Named("apiserver-watcher"), + w := &Watcher{ + l: log.Logger().Named(watcherName), + apiServerURL: "https://kubernetes.default.svc.cluster.local:443", hostResolver: mockedResolver, - filterManager: mockedFilterManager, + filtermanager: mockedFilterManager, + refreshRate: 1 * time.Second, } // Return 2 random IPs for the host everytime LookupHost is called. @@ -88,8 +47,8 @@ func TestRefresh(t *testing.T) { mockedFilterManager.EXPECT().AddIPs(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() mockedFilterManager.EXPECT().DeleteIPs(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() - a.Refresh(ctx) - assert.NoError(t, a.Refresh(context.Background()), "Expected no error when refreshing the cache") + err := w.Start(ctx) // watcher will timeout after 20 seconds + require.NoError(t, err, "Expected no error when refreshing the cache") } func TestDiffCache(t *testing.T) { @@ -107,19 +66,21 @@ func TestDiffCache(t *testing.T) { new["192.168.1.2"] = struct{}{} new["192.168.1.3"] = struct{}{} - a := &ApiServerWatcher{ - l: log.Logger().Named("apiserver-watcher"), + w := &Watcher{ + l: log.Logger().Named(watcherName), + apiServerURL: "https://kubernetes.default.svc.cluster.local:443", hostResolver: mockedResolver, current: old, new: new, + refreshRate: 1 * time.Second, } - created, deleted := a.diffCache() + created, deleted := w.diffCache() assert.Equal(t, 1, len(created), "Expected 1 created host") assert.Equal(t, 1, len(deleted), "Expected 1 deleted host") } -func TestRefreshLookUpAlwaysFail(t *testing.T) { +func TestStartError(t *testing.T) { log.SetupZapLogger(log.GetDefaultLogOpts()) ctrl := gomock.NewController(t) defer ctrl.Finish() @@ -128,19 +89,23 @@ func TestRefreshLookUpAlwaysFail(t *testing.T) { defer cancel() mockedResolver := mocks.NewMockIHostResolver(ctrl) + mockedFilterManager := filtermanagermocks.NewMockIFilterManager(ctrl) - a := &ApiServerWatcher{ - l: log.Logger().Named("apiserver-watcher"), - hostResolver: mockedResolver, + w := &Watcher{ + l: log.Logger().Named(watcherName), + apiServerURL: "https://kubernetes.default.svc.cluster.local:443", + hostResolver: mockedResolver, + filtermanager: mockedFilterManager, + refreshRate: 1 * time.Second, } mockedResolver.EXPECT().LookupHost(gomock.Any(), gomock.Any()).Return(nil, errors.New("Error")).AnyTimes() - a.Refresh(ctx) - require.Error(t, a.Refresh(context.Background()), "Expected error when refreshing the cache") + err := w.Start(ctx) + require.Error(t, err, "Expected error when refreshing the cache") } -func TestInitWithIncorrectURL(t *testing.T) { +func TestResolveIPEmpty(t *testing.T) { log.SetupZapLogger(log.GetDefaultLogOpts()) ctrl := gomock.NewController(t) defer ctrl.Finish() @@ -149,62 +114,20 @@ func TestInitWithIncorrectURL(t *testing.T) { defer cancel() mockedResolver := mocks.NewMockIHostResolver(ctrl) - mockedFilterManager := filtermanagermocks.NewMockIFilterManager(ctrl) - a := &ApiServerWatcher{ - l: log.Logger().Named("apiserver-watcher"), - hostResolver: mockedResolver, - restConfig: getMockConfig(false), - filterManager: mockedFilterManager, + w := &Watcher{ + l: log.Logger().Named(watcherName), + apiServerURL: "https://kubernetes.default.svc.cluster.local:443", + hostResolver: mockedResolver, + refreshRate: 1 * time.Second, } mockedResolver.EXPECT().LookupHost(gomock.Any(), gomock.Any()).Return([]string{}, nil).AnyTimes() - require.Error(t, a.Init(ctx), "Expected error during init") + + err := w.Start(ctx) + require.Error(t, err, "Expected error when resolving the IP") } func randomIP() string { return fmt.Sprintf("%d.%d.%d.%d", rand.Intn(256), rand.Intn(256), rand.Intn(256), rand.Intn(256)) } - -// Mock function to simulate getting a Kubernetes config -func getMockConfig(isCorrect bool) *rest.Config { - if isCorrect { - return &rest.Config{ - Host: "https://kubernetes.default.svc.cluster.local:443", - } - } - return &rest.Config{ - Host: "", - } -} - -func TestRefreshFailsFirstFourAttemptsSucceedsOnFifth(t *testing.T) { - _, err := log.SetupZapLogger(log.GetDefaultLogOpts()) - require.NoError(t, err) - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) - defer cancel() - - mockedResolver := mocks.NewMockIHostResolver(ctrl) - mockedFilterManager := filtermanagermocks.NewMockIFilterManager(ctrl) - - a := &ApiServerWatcher{ - l: log.Logger().Named("apiserver-watcher"), - hostResolver: mockedResolver, - filterManager: mockedFilterManager, - } - - // Simulate LookupHost failing the first four times and succeeding on the fifth. - gomock.InOrder( - mockedResolver.EXPECT().LookupHost(gomock.Any(), gomock.Any()).Return(nil, errDNS).Times(4), - mockedResolver.EXPECT().LookupHost(gomock.Any(), gomock.Any()).Return([]string{"127.0.0.1"}, nil).Times(1), - ) - - mockedFilterManager.EXPECT().AddIPs(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() - mockedFilterManager.EXPECT().DeleteIPs(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() - - err = a.Refresh(ctx) - require.NoError(t, err, "Expected no error when refreshing the cache") -} diff --git a/pkg/watchers/apiserver/types.go b/pkg/watchers/apiserver/types.go index bf8952ab13..af0f7cf01c 100644 --- a/pkg/watchers/apiserver/types.go +++ b/pkg/watchers/apiserver/types.go @@ -3,13 +3,49 @@ package apiserver -import "context" +import ( + "context" + "net" + "time" + + "github.com/microsoft/retina/pkg/log" + fm "github.com/microsoft/retina/pkg/managers/filtermanager" +) //go:generate go run go.uber.org/mock/mockgen@v0.4.0 -source=types.go -destination=mocks/mock_types.go -package=mocks . type IHostResolver interface { LookupHost(context context.Context, host string) ([]string, error) } +const ( + watcherName = "apiserver-watcher" + filterManagerRetries = 3 + defaultRefreshRate = 30 * time.Second +) + +type Watcher struct { + l *log.ZapLogger + current cache + new cache + apiServerURL string + hostResolver IHostResolver + filtermanager fm.IFilterManager + refreshRate time.Duration +} + +// NewWatcher creates a new apiserver watcher. +func NewWatcher() *Watcher { + w := &Watcher{ + l: log.Logger().Named(watcherName), + current: make(cache), + apiServerURL: getHostURL(), + hostResolver: net.DefaultResolver, + refreshRate: defaultRefreshRate, + } + w.filtermanager = w.getFilterManager() + return w +} + // define cache as a set type cache map[string]struct{} diff --git a/pkg/watchers/endpoint/endpoint.go b/pkg/watchers/endpoint/endpoint.go deleted file mode 100644 index 1a90ea451f..0000000000 --- a/pkg/watchers/endpoint/endpoint.go +++ /dev/null @@ -1,104 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -package endpoint - -import ( - "context" - - "github.com/microsoft/retina/pkg/common" - "github.com/microsoft/retina/pkg/log" - "github.com/microsoft/retina/pkg/pubsub" - "go.uber.org/zap" -) - -type EndpointWatcher struct { - isRunning bool - l *log.ZapLogger - current cache - new cache - p pubsub.PubSubInterface -} - -var e *EndpointWatcher - -// NewEndpointWatcher creates a new endpoint watcher. -func Watcher() *EndpointWatcher { - if e == nil { - e = &EndpointWatcher{ - isRunning: false, - l: log.Logger().Named("endpoint-watcher"), - p: pubsub.New(), - current: make(cache), - } - } - - return e -} - -func (e *EndpointWatcher) Init(ctx context.Context) error { - if e.isRunning { - e.l.Info("endpoint watcher is already running") - return nil - } - e.isRunning = true - return nil -} - -func (e *EndpointWatcher) Stop(ctx context.Context) error { - if !e.isRunning { - e.l.Info("endpoint watcher is not running") - return nil - } - e.isRunning = false - return nil -} - -func (e *EndpointWatcher) Refresh(ctx context.Context) error { - // initNewCache is OS specific. - // Based on GOOS, will be implemented by either endpoint_linux, or - // endpoint_windows. - err := e.initNewCache() - if err != nil { - return err - } - - // Compare the new veths with the old ones. - created, deleted := e.diffCache() - - // Publish the new veths. - for _, v := range created { - e.l.Debug("Endpoint created", zap.Any("veth", v)) - e.p.Publish(common.PubSubEndpoints, NewEndpointEvent(EndpointCreated, v)) - } - - // Publish the deleted veths. - for _, v := range deleted { - e.l.Debug("Endpoint deleted", zap.Any("veth", v)) - e.p.Publish(common.PubSubEndpoints, NewEndpointEvent(EndpointDeleted, v)) - } - - // Update the cache and reset the new cache. - e.current = e.new.deepcopy() - e.new = nil - - return nil -} - -// Function to differentiate between two caches. -func (e *EndpointWatcher) diffCache() (created, deleted []interface{}) { - // Check if there are any new veths. - for k, v := range e.new { - if _, ok := e.current[k]; !ok { - created = append(created, v) - } - } - - // Check if there are any deleted veths. - for k, v := range e.current { - if _, ok := e.new[k]; !ok { - deleted = append(deleted, v) - } - } - return -} diff --git a/pkg/watchers/endpoint/endpoint_linux.go b/pkg/watchers/endpoint/endpoint_linux.go deleted file mode 100644 index 40a601e1f6..0000000000 --- a/pkg/watchers/endpoint/endpoint_linux.go +++ /dev/null @@ -1,54 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -package endpoint - -import ( - "github.com/vishvananda/netlink" -) - -var showLink = netlink.LinkList - -func (e *EndpointWatcher) initNewCache() error { - veths, err := listVeths() - if err != nil { - return err - } - - // Reset new cache. - e.new = make(cache) - for _, veth := range veths { - k := key{ - name: veth.Attrs().Name, - hardwareAddr: veth.Attrs().HardwareAddr.String(), - netNsID: veth.Attrs().NetNsID, - } - - e.new[k] = *veth.Attrs() - } - - return nil -} - -// Helper functions. - -// Get all the veth interfaces. -// Similar to ip link show type veth -func listVeths() ([]netlink.Link, error) { - links, err := showLink() - if err != nil { - return nil, err - } - - var veths []netlink.Link - for _, link := range links { - // Ref: https://github.com/vishvananda/netlink/blob/ced5aaba43e3f25bb5f04860641d3e3dd04a8544/link.go#L367 - // Unfortunately, there is no type/constant defined for "veth" in the netlink package. - // Version of netlink tested - https://github.com/vishvananda/netlink/tree/v1.2.1-beta.2 - if link.Type() == "veth" { - veths = append(veths, link) - } - } - - return veths, nil -} diff --git a/pkg/watchers/endpoint/endpoint_linux_test.go b/pkg/watchers/endpoint/endpoint_linux_test.go deleted file mode 100644 index a9fbc439e2..0000000000 --- a/pkg/watchers/endpoint/endpoint_linux_test.go +++ /dev/null @@ -1,228 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. -// nolint - -package endpoint - -import ( - "context" - "errors" - "net" - "testing" - - "github.com/microsoft/retina/pkg/log" - "github.com/microsoft/retina/pkg/pubsub" - "github.com/stretchr/testify/assert" - "github.com/vishvananda/netlink" -) - -func TestGetWatcher(t *testing.T) { - log.SetupZapLogger(log.GetDefaultLogOpts()) - - v := Watcher() - assert.NotNil(t, v) - - v_again := Watcher() - assert.Equal(t, v, v_again, "Expected the same veth watcher instance") -} - -func TestEndpointWatcherStart(t *testing.T) { - log.SetupZapLogger(log.GetDefaultLogOpts()) - c := context.Background() - - // When veth is already running. - v := &EndpointWatcher{ - isRunning: true, - l: log.Logger().Named("veth-watcher"), - } - err := v.Init(c) - assert.NoError(t, err, "Expected no error when starting a running veth watcher") - assert.Equal(t, true, v.isRunning, "Expected veth watcher to be running") - - // When veth is not running. - v.isRunning = false - err = v.Init(c) - assert.NoError(t, err, "Expected no error when starting a stopped veth watcher") - assert.Equal(t, true, v.isRunning, "Expected veth watcher to be running") - - // Stop the watcher. - err = v.Stop(c) - assert.NoError(t, err, "Expected no error when stopping a running veth watcher") - - // Restart the watcher. - err = v.Init(c) - assert.NoError(t, err, "Expected no error when starting a stopped veth watcher") - assert.Equal(t, true, v.isRunning, "Expected veth watcher to be running") - - // Stop the watcher. - err = v.Stop(c) - assert.NoError(t, err, "Expected no error when stopping a running veth watcher") -} - -func TestEndpointWatcherStop(t *testing.T) { - log.SetupZapLogger(log.GetDefaultLogOpts()) - c := context.Background() - - // When veth is already stopped. - v := &EndpointWatcher{ - isRunning: false, - l: log.Logger().Named("veth-watcher"), - } - err := v.Stop(c) - assert.NoError(t, err, "Expected no error when stopping a stopped veth watcher") - assert.Equal(t, false, v.isRunning, "Expected veth watcher to be stopped") - - // Start the watcher. - err = v.Init(c) - assert.NoError(t, err, "Expected no error when starting a stopped veth watcher") - - // Stop the watcher. - err = v.Stop(c) - assert.NoError(t, err, "Expected no error when stopping a running veth watcher") - assert.Equal(t, false, v.isRunning, "Expected veth watcher to be stopped") -} - -func TestRun(t *testing.T) { - showLink = func() ([]netlink.Link, error) { - return []netlink.Link{ - &netlink.Veth{ - LinkAttrs: netlink.LinkAttrs{ - Name: "veth0", - }, - }, - &netlink.Vxlan{ - LinkAttrs: netlink.LinkAttrs{ - Name: "eth0", - }, - }, - }, nil - } - - links, err := listVeths() - assert.NoError(t, err, "Expected no error when listing veths") - assert.Equal(t, 1, len(links), "Expected to find 1 veth") - assert.Equal(t, "veth0", links[0].Attrs().Name, "Expected to find veth0") -} - -func TestDiffCache(t *testing.T) { - old := cache{ - key{ - name: "veth0", - hardwareAddr: "00:00:00:00:00:00", - netNsID: 0, - }: netlink.LinkAttrs{ - Name: "veth0", - }, - } - new := cache{ - key{ - name: "veth1", - hardwareAddr: "00:00:00:00:00:FF", - netNsID: 1, - }: netlink.LinkAttrs{ - Name: "veth1", - }, - } - e := &EndpointWatcher{current: old, new: new} - c, d := e.diffCache() - assert.Equal(t, 1, len(c), "Expected to find 1 created veth") - assert.Equal(t, 1, len(d), "Expected to find 1 deleted veth") - assert.Equal(t, "veth1", c[0].(netlink.LinkAttrs).Name, "Expected to find veth1") - assert.Equal(t, "veth0", d[0].(netlink.LinkAttrs).Name, "Expected to find veth0") -} - -func TestRefreshAndCallback(t *testing.T) { - log.SetupZapLogger(log.GetDefaultLogOpts()) - c := context.Background() - - showLink = func() ([]netlink.Link, error) { - return []netlink.Link{ - &netlink.Veth{ - LinkAttrs: netlink.LinkAttrs{ - Name: "veth0", - HardwareAddr: func() net.HardwareAddr { - mac, _ := net.ParseMAC("00:00:00:00:00:00") - return mac - }(), - NetNsID: 0, - }, - }, - &netlink.Veth{ - LinkAttrs: netlink.LinkAttrs{ - Name: "veth1", - HardwareAddr: func() net.HardwareAddr { - mac, _ := net.ParseMAC("00:00:00:00:00:01") - return mac - }(), - NetNsID: 1, - }, - }, - }, nil - } - - cache := make(cache) - cache[key{ - name: "veth2", - hardwareAddr: "00:00:00:00:00:02", - netNsID: 2, - }] = &netlink.Veth{ - LinkAttrs: netlink.LinkAttrs{ - Name: "veth2", - }, - } - - v := &EndpointWatcher{ - isRunning: true, - current: cache, - l: log.Logger().Named("veth-watcher"), - p: pubsub.New(), - } - - // When cache is empty. - assert.Equal(t, 1, len(v.current), "Expected to find 0 veths") - - // Post refresh. - err := v.Refresh(c) - assert.NoError(t, err, "Expected no error when refreshing veth cache") - assert.Equal(t, 2, len(v.current), "Expected to find 2 veths") - assert.Equal(t, "veth0", v.current[key{ - name: "veth0", - hardwareAddr: "00:00:00:00:00:00", - netNsID: 0, - }].(netlink.LinkAttrs).Name, "Expected to find veth0") - assert.Equal(t, "veth1", v.current[key{ - name: "veth1", - hardwareAddr: "00:00:00:00:00:01", - netNsID: 1, - }].(netlink.LinkAttrs).Name, "Expected to find veth1") -} - -func TestRefreshError(t *testing.T) { - log.SetupZapLogger(log.GetDefaultLogOpts()) - c := context.Background() - - showLink = func() ([]netlink.Link, error) { - return nil, errors.New("error") - } - - v := &EndpointWatcher{ - isRunning: true, - current: make(cache), - l: log.Logger().Named("veth-watcher"), - p: pubsub.New(), - } - - err := v.Refresh(c) - assert.Error(t, err, "Expected an error when refreshing veth cache") -} - -func TestListVethsError(t *testing.T) { - log.SetupZapLogger(log.GetDefaultLogOpts()) - - showLink = func() ([]netlink.Link, error) { - return nil, errors.New("error") - } - - _, err := listVeths() - assert.Error(t, err, "Expected an error when listing veths") -} diff --git a/pkg/watchers/endpoint/endpoint_windows.go b/pkg/watchers/endpoint/endpoint_windows.go deleted file mode 100644 index 0a12ff9694..0000000000 --- a/pkg/watchers/endpoint/endpoint_windows.go +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -/* Template */ - -package endpoint - -import ( - "github.com/Microsoft/hcsshim/hcn" -) - -func (e *EndpointWatcher) initNewCache() error { - return nil -} - -func listVeths() ([]hcn.HostComputeEndpoint, error) { - return nil, nil -} diff --git a/pkg/watchers/endpoint/types.go b/pkg/watchers/endpoint/types.go index 693a4c79a7..652f1eb192 100644 --- a/pkg/watchers/endpoint/types.go +++ b/pkg/watchers/endpoint/types.go @@ -3,11 +3,32 @@ package endpoint +import ( + "github.com/microsoft/retina/pkg/log" + "github.com/microsoft/retina/pkg/pubsub" +) + const ( + watcherName = "endpoint-watcher" endpointCreated string = "endpoint_created" endpointDeleted string = "endpoint_deleted" ) +type Watcher struct { + l *log.ZapLogger + p pubsub.PubSubInterface +} + +// NewWatcher creates a new endpoint watcher. +func NewWatcher() *Watcher { + w := &Watcher{ + l: log.Logger().Named(watcherName), + p: pubsub.New(), + } + + return w +} + type key struct { name string hardwareAddr string diff --git a/pkg/watchers/endpoint/watcher_linux.go b/pkg/watchers/endpoint/watcher_linux.go new file mode 100644 index 0000000000..381da9ecb6 --- /dev/null +++ b/pkg/watchers/endpoint/watcher_linux.go @@ -0,0 +1,64 @@ +package endpoint + +import ( + "context" + "syscall" + + "github.com/microsoft/retina/pkg/common" + "github.com/pkg/errors" + "github.com/vishvananda/netlink" + "go.uber.org/zap" +) + +func (w *Watcher) Name() string { + return watcherName +} + +func (w *Watcher) Start(ctx context.Context) error { + w.l.Info("endpoint watcher started") + + // Create a channel to receive netlink events. + netlinkEvCh := make(chan netlink.LinkUpdate) + done := make(chan struct{}) + // Options for subscribing to link updates. We want to list existing links. + opt := netlink.LinkSubscribeOptions{ + ListExisting: true, + } + // Subscribe to link updates. + if err := netlink.LinkSubscribeWithOptions(netlinkEvCh, done, opt); err != nil { + return errors.Wrap(err, "failed to subscribe to link updates") + } + defer close(done) + + for { + select { + case <-ctx.Done(): + w.l.Info("stopping endpoint watcher") + return nil + case ev := <-netlinkEvCh: + // Filter for veth devices. + if ev.Link.Type() == "veth" { + veth := ev.Link.(*netlink.Veth) + switch ev.Header.Type { + case syscall.RTM_NEWLINK: + // Check if the veth device is up. + if veth.Attrs().OperState == netlink.OperUp { + w.l.Info("veth device is up", zap.String("veth", veth.Attrs().Name)) + w.p.Publish(common.PubSubEndpoints, NewEndpointEvent(EndpointCreated, *veth.Attrs())) + } + case syscall.RTM_DELLINK: + // Check if the veth device is down. + if veth.Attrs().OperState == netlink.OperDown { + w.l.Info("veth device is down", zap.String("veth", veth.Attrs().Name)) + w.p.Publish(common.PubSubEndpoints, NewEndpointEvent(EndpointDeleted, *veth.Attrs())) + } + } + } + } + } +} + +func (w *Watcher) Stop(_ context.Context) error { + w.l.Info("stopping veth watcher") + return nil +} diff --git a/pkg/watchers/endpoint/watcher_windows.go b/pkg/watchers/endpoint/watcher_windows.go new file mode 100644 index 0000000000..2116ed52a4 --- /dev/null +++ b/pkg/watchers/endpoint/watcher_windows.go @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +/* Template */ + +package endpoint + +import ( + "context" +) + +func (w *Watcher) Name() string { + return watcherName +} + +func (w *Watcher) Start(_ context.Context) error { + // Not implemented for Windows + return nil +} + +func (w *Watcher) Stop(_ context.Context) error { + // Not implemented for Windows + return nil +} diff --git a/test/managers/filtermanager/main.go b/test/managers/filtermanager/main.go index 7de2a9fd89..8e01b5fd7c 100644 --- a/test/managers/filtermanager/main.go +++ b/test/managers/filtermanager/main.go @@ -19,7 +19,9 @@ import ( "github.com/microsoft/retina/pkg/metrics" "github.com/microsoft/retina/pkg/watchers/apiserver" "github.com/microsoft/retina/pkg/watchers/endpoint" + "github.com/pkg/errors" "go.uber.org/zap" + "golang.org/x/sync/errgroup" ) func main() { @@ -34,17 +36,23 @@ func main() { metrics.InitializeMetrics() - ctx := context.Background() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() // watcher manager wm := watchermanager.NewWatcherManager() - wm.Watchers = []watchermanager.IWatcher{endpoint.Watcher(), apiserver.Watcher()} + wm.Watchers = []watchermanager.Watcher{endpoint.NewWatcher(), apiserver.NewWatcher()} - err := wm.Start(ctx) - if err != nil { - l.Error("Failed to start endpoint watcher", zap.Error(err)) - panic(err) - } + g, ctx := errgroup.WithContext(ctx) + // Start watcher manager + g.Go(func() error { + err := wm.Start(ctx) + if err != nil { + l.Error("watcher manager exited with error", zap.Error(err)) + return errors.Wrap(err, "watcher manager exited with error") + } + return nil + }) defer func() { if err := wm.Stop(ctx); err != nil { l.Error("Failed to stop endpoint watcher", zap.Error(err)) diff --git a/test/plugin/packetparser/main_linux.go b/test/plugin/packetparser/main_linux.go index 2596dd1fe7..1515344875 100644 --- a/test/plugin/packetparser/main_linux.go +++ b/test/plugin/packetparser/main_linux.go @@ -33,7 +33,7 @@ func main() { // watcher manager wm := watchermanager.NewWatcherManager() - wm.Watchers = []watchermanager.IWatcher{endpoint.Watcher()} + wm.Watchers = []watchermanager.Watcher{endpoint.NewWatcher()} err := wm.Start(ctxTimeout) if err != nil { diff --git a/test/watchers/apiserver/main.go b/test/watchers/apiserver/main.go index 4703e12922..fc0c7801be 100644 --- a/test/watchers/apiserver/main.go +++ b/test/watchers/apiserver/main.go @@ -36,7 +36,7 @@ func main() { }() // watcher manager wm := watchermanager.NewWatcherManager() - wm.Watchers = []watchermanager.IWatcher{apiserver.Watcher()} + wm.Watchers = []watchermanager.Watcher{apiserver.NewWatcher()} // apiserver watcher. err = wm.Start(ctx) diff --git a/test/watchers/veth/main.go b/test/watchers/veth/main.go index 50a00c1c01..8a5a5cae38 100644 --- a/test/watchers/veth/main.go +++ b/test/watchers/veth/main.go @@ -24,7 +24,7 @@ func main() { // watcher manager wm := watchermanager.NewWatcherManager() - wm.Watchers = []watchermanager.IWatcher{endpoint.Watcher()} + wm.Watchers = []watchermanager.Watcher{endpoint.NewWatcher()} err := wm.Start(ctx) if err != nil {