Skip to content

Commit 7b09d98

Browse files
authored
chore: add /groups endpoint to filter by organization and/or member (coder#14260)
* chore: merge get groups sql queries into 1 * Add endpoint for fetching groups with filters * remove 2 ways to customizing a fake authorizer
1 parent 83ccdaa commit 7b09d98

File tree

24 files changed

+537
-287
lines changed

24 files changed

+537
-287
lines changed

coderd/apidoc/docs.go

+44
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/apidoc/swagger.json

+40
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/coderdtest/authorize.go

+17-5
Original file line numberDiff line numberDiff line change
@@ -353,16 +353,28 @@ func (s *PreparedRecorder) CompileToSQL(ctx context.Context, cfg regosql.Convert
353353
return s.prepped.CompileToSQL(ctx, cfg)
354354
}
355355

356-
// FakeAuthorizer is an Authorizer that always returns the same error.
356+
// FakeAuthorizer is an Authorizer that will return an error based on the
357+
// "ConditionalReturn" function. By default, **no error** is returned.
358+
// Meaning 'FakeAuthorizer' by default will never return "unauthorized".
357359
type FakeAuthorizer struct {
358-
// AlwaysReturn is the error that will be returned by Authorize.
359-
AlwaysReturn error
360+
ConditionalReturn func(context.Context, rbac.Subject, policy.Action, rbac.Object) error
360361
}
361362

362363
var _ rbac.Authorizer = (*FakeAuthorizer)(nil)
363364

364-
func (d *FakeAuthorizer) Authorize(_ context.Context, _ rbac.Subject, _ policy.Action, _ rbac.Object) error {
365-
return d.AlwaysReturn
365+
// AlwaysReturn is the error that will be returned by Authorize.
366+
func (d *FakeAuthorizer) AlwaysReturn(err error) *FakeAuthorizer {
367+
d.ConditionalReturn = func(_ context.Context, _ rbac.Subject, _ policy.Action, _ rbac.Object) error {
368+
return err
369+
}
370+
return d
371+
}
372+
373+
func (d *FakeAuthorizer) Authorize(ctx context.Context, subject rbac.Subject, action policy.Action, object rbac.Object) error {
374+
if d.ConditionalReturn != nil {
375+
return d.ConditionalReturn(ctx, subject, action, object)
376+
}
377+
return nil
366378
}
367379

368380
func (d *FakeAuthorizer) Prepare(_ context.Context, subject rbac.Subject, action policy.Action, _ string) (rbac.PreparedAuthorized, error) {

coderd/database/dbauthz/dbauthz.go

+8-11
Original file line numberDiff line numberDiff line change
@@ -1491,19 +1491,16 @@ func (q *querier) GetGroupMembersCountByGroupID(ctx context.Context, groupID uui
14911491
return memberCount, nil
14921492
}
14931493

1494-
func (q *querier) GetGroups(ctx context.Context) ([]database.Group, error) {
1495-
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
1496-
return nil, err
1494+
func (q *querier) GetGroups(ctx context.Context, arg database.GetGroupsParams) ([]database.Group, error) {
1495+
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err == nil {
1496+
// Optimize this query for system users as it is used in telemetry.
1497+
// Calling authz on all groups in a deployment for telemetry jobs is
1498+
// excessive. Most user calls should have some filtering applied to reduce
1499+
// the size of the set.
1500+
return q.db.GetGroups(ctx, arg)
14971501
}
1498-
return q.db.GetGroups(ctx)
1499-
}
1500-
1501-
func (q *querier) GetGroupsByOrganizationAndUserID(ctx context.Context, arg database.GetGroupsByOrganizationAndUserIDParams) ([]database.Group, error) {
1502-
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetGroupsByOrganizationAndUserID)(ctx, arg)
1503-
}
15041502

1505-
func (q *querier) GetGroupsByOrganizationID(ctx context.Context, organizationID uuid.UUID) ([]database.Group, error) {
1506-
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetGroupsByOrganizationID)(ctx, organizationID)
1503+
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetGroups)(ctx, arg)
15071504
}
15081505

15091506
func (q *querier) GetHealthSettings(ctx context.Context) (string, error) {

coderd/database/dbauthz/dbauthz_test.go

+19-12
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ func TestInTX(t *testing.T) {
8181

8282
db := dbmem.New()
8383
q := dbauthz.New(db, &coderdtest.RecordingAuthorizer{
84-
Wrapped: &coderdtest.FakeAuthorizer{AlwaysReturn: xerrors.New("custom error")},
84+
Wrapped: (&coderdtest.FakeAuthorizer{}).AlwaysReturn(xerrors.New("custom error")),
8585
}, slog.Make(), coderdtest.AccessControlStorePointer())
8686
actor := rbac.Subject{
8787
ID: uuid.NewString(),
@@ -110,7 +110,7 @@ func TestNew(t *testing.T) {
110110
db = dbmem.New()
111111
exp = dbgen.Workspace(t, db, database.Workspace{})
112112
rec = &coderdtest.RecordingAuthorizer{
113-
Wrapped: &coderdtest.FakeAuthorizer{AlwaysReturn: nil},
113+
Wrapped: &coderdtest.FakeAuthorizer{},
114114
}
115115
subj = rbac.Subject{}
116116
ctx = dbauthz.As(context.Background(), rbac.Subject{})
@@ -135,7 +135,7 @@ func TestNew(t *testing.T) {
135135
func TestDBAuthzRecursive(t *testing.T) {
136136
t.Parallel()
137137
q := dbauthz.New(dbmem.New(), &coderdtest.RecordingAuthorizer{
138-
Wrapped: &coderdtest.FakeAuthorizer{AlwaysReturn: nil},
138+
Wrapped: &coderdtest.FakeAuthorizer{},
139139
}, slog.Make(), coderdtest.AccessControlStorePointer())
140140
actor := rbac.Subject{
141141
ID: uuid.NewString(),
@@ -342,18 +342,21 @@ func (s *MethodTestSuite) TestGroup() {
342342
dbgen.GroupMember(s.T(), db, database.GroupMemberTable{GroupID: g.ID, UserID: u.ID})
343343
check.Asserts(rbac.ResourceSystem, policy.ActionRead)
344344
}))
345-
s.Run("GetGroups", s.Subtest(func(db database.Store, check *expects) {
345+
s.Run("System/GetGroups", s.Subtest(func(db database.Store, check *expects) {
346346
_ = dbgen.Group(s.T(), db, database.Group{})
347-
check.Asserts(rbac.ResourceSystem, policy.ActionRead)
347+
check.Args(database.GetGroupsParams{}).
348+
Asserts(rbac.ResourceSystem, policy.ActionRead)
348349
}))
349-
s.Run("GetGroupsByOrganizationAndUserID", s.Subtest(func(db database.Store, check *expects) {
350+
s.Run("GetGroups", s.Subtest(func(db database.Store, check *expects) {
350351
g := dbgen.Group(s.T(), db, database.Group{})
351352
u := dbgen.User(s.T(), db, database.User{})
352353
gm := dbgen.GroupMember(s.T(), db, database.GroupMemberTable{GroupID: g.ID, UserID: u.ID})
353-
check.Args(database.GetGroupsByOrganizationAndUserIDParams{
354+
check.Args(database.GetGroupsParams{
354355
OrganizationID: g.OrganizationID,
355-
UserID: gm.UserID,
356-
}).Asserts(g, policy.ActionRead)
356+
HasMemberID: gm.UserID,
357+
}).Asserts(rbac.ResourceSystem, policy.ActionRead, g, policy.ActionRead).
358+
// Fail the system resource skip
359+
FailSystemObjectChecks()
357360
}))
358361
s.Run("InsertAllUsersGroup", s.Subtest(func(db database.Store, check *expects) {
359362
o := dbgen.Organization(s.T(), db, database.Organization{})
@@ -597,12 +600,16 @@ func (s *MethodTestSuite) TestLicense() {
597600
}
598601

599602
func (s *MethodTestSuite) TestOrganization() {
600-
s.Run("GetGroupsByOrganizationID", s.Subtest(func(db database.Store, check *expects) {
603+
s.Run("ByOrganization/GetGroups", s.Subtest(func(db database.Store, check *expects) {
601604
o := dbgen.Organization(s.T(), db, database.Organization{})
602605
a := dbgen.Group(s.T(), db, database.Group{OrganizationID: o.ID})
603606
b := dbgen.Group(s.T(), db, database.Group{OrganizationID: o.ID})
604-
check.Args(o.ID).Asserts(a, policy.ActionRead, b, policy.ActionRead).
605-
Returns([]database.Group{a, b})
607+
check.Args(database.GetGroupsParams{
608+
OrganizationID: o.ID,
609+
}).Asserts(rbac.ResourceSystem, policy.ActionRead, a, policy.ActionRead, b, policy.ActionRead).
610+
Returns([]database.Group{a, b}).
611+
// Fail the system check shortcut
612+
FailSystemObjectChecks()
606613
}))
607614
s.Run("GetOrganizationByID", s.Subtest(func(db database.Store, check *expects) {
608615
o := dbgen.Organization(s.T(), db, database.Organization{})

coderd/database/dbauthz/setup_test.go

+27-7
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,7 @@ func (s *MethodTestSuite) Subtest(testCaseF func(db database.Store, check *expec
114114
s.methodAccounting[methodName]++
115115

116116
db := dbmem.New()
117-
fakeAuthorizer := &coderdtest.FakeAuthorizer{
118-
AlwaysReturn: nil,
119-
}
117+
fakeAuthorizer := &coderdtest.FakeAuthorizer{}
120118
rec := &coderdtest.RecordingAuthorizer{
121119
Wrapped: fakeAuthorizer,
122120
}
@@ -174,7 +172,11 @@ func (s *MethodTestSuite) Subtest(testCaseF func(db database.Store, check *expec
174172
// Always run
175173
s.Run("Success", func() {
176174
rec.Reset()
177-
fakeAuthorizer.AlwaysReturn = nil
175+
if testCase.successAuthorizer != nil {
176+
fakeAuthorizer.ConditionalReturn = testCase.successAuthorizer
177+
} else {
178+
fakeAuthorizer.AlwaysReturn(nil)
179+
}
178180

179181
outputs, err := callMethod(ctx)
180182
if testCase.err == nil {
@@ -232,7 +234,7 @@ func (s *MethodTestSuite) NoActorErrorTest(callMethod func(ctx context.Context)
232234
// Asserts that the error returned is a NotAuthorizedError.
233235
func (s *MethodTestSuite) NotAuthorizedErrorTest(ctx context.Context, az *coderdtest.FakeAuthorizer, testCase expects, callMethod func(ctx context.Context) ([]reflect.Value, error)) {
234236
s.Run("NotAuthorized", func() {
235-
az.AlwaysReturn = rbac.ForbiddenWithInternal(xerrors.New("Always fail authz"), rbac.Subject{}, "", rbac.Object{}, nil)
237+
az.AlwaysReturn(rbac.ForbiddenWithInternal(xerrors.New("Always fail authz"), rbac.Subject{}, "", rbac.Object{}, nil))
236238

237239
// If we have assertions, that means the method should FAIL
238240
// if RBAC will disallow the request. The returned error should
@@ -257,8 +259,8 @@ func (s *MethodTestSuite) NotAuthorizedErrorTest(ctx context.Context, az *coderd
257259
// Pass in a canceled context
258260
ctx, cancel := context.WithCancel(ctx)
259261
cancel()
260-
az.AlwaysReturn = rbac.ForbiddenWithInternal(&topdown.Error{Code: topdown.CancelErr},
261-
rbac.Subject{}, "", rbac.Object{}, nil)
262+
az.AlwaysReturn(rbac.ForbiddenWithInternal(&topdown.Error{Code: topdown.CancelErr},
263+
rbac.Subject{}, "", rbac.Object{}, nil))
262264

263265
// If we have assertions, that means the method should FAIL
264266
// if RBAC will disallow the request. The returned error should
@@ -324,6 +326,7 @@ type expects struct {
324326
// instead.
325327
notAuthorizedExpect string
326328
cancelledCtxExpect string
329+
successAuthorizer func(ctx context.Context, subject rbac.Subject, action policy.Action, obj rbac.Object) error
327330
}
328331

329332
// Asserts is required. Asserts the RBAC authorize calls that should be made.
@@ -354,6 +357,23 @@ func (m *expects) Errors(err error) *expects {
354357
return m
355358
}
356359

360+
func (m *expects) FailSystemObjectChecks() *expects {
361+
return m.WithSuccessAuthorizer(func(ctx context.Context, subject rbac.Subject, action policy.Action, obj rbac.Object) error {
362+
if obj.Type == rbac.ResourceSystem.Type {
363+
return xerrors.Errorf("hard coded system authz failed")
364+
}
365+
return nil
366+
})
367+
}
368+
369+
// WithSuccessAuthorizer is helpful when an optimization authz check is made
370+
// to skip some RBAC checks. This check in testing would prevent the ability
371+
// to assert the more nuanced RBAC checks.
372+
func (m *expects) WithSuccessAuthorizer(f func(ctx context.Context, subject rbac.Subject, action policy.Action, obj rbac.Object) error) *expects {
373+
m.successAuthorizer = f
374+
return m
375+
}
376+
357377
func (m *expects) WithNotAuthorized(contains string) *expects {
358378
m.notAuthorizedExpect = contains
359379
return m

coderd/database/dbmem/dbmem.go

+25-30
Original file line numberDiff line numberDiff line change
@@ -2599,51 +2599,46 @@ func (q *FakeQuerier) GetGroupMembersCountByGroupID(ctx context.Context, groupID
25992599
return int64(len(users)), nil
26002600
}
26012601

2602-
func (q *FakeQuerier) GetGroups(_ context.Context) ([]database.Group, error) {
2603-
q.mutex.RLock()
2604-
defer q.mutex.RUnlock()
2605-
2606-
out := make([]database.Group, len(q.groups))
2607-
copy(out, q.groups)
2608-
return out, nil
2609-
}
2610-
2611-
func (q *FakeQuerier) GetGroupsByOrganizationAndUserID(_ context.Context, arg database.GetGroupsByOrganizationAndUserIDParams) ([]database.Group, error) {
2602+
func (q *FakeQuerier) GetGroups(_ context.Context, arg database.GetGroupsParams) ([]database.Group, error) {
26122603
err := validateDatabaseType(arg)
26132604
if err != nil {
26142605
return nil, err
26152606
}
26162607

26172608
q.mutex.RLock()
26182609
defer q.mutex.RUnlock()
2619-
var groupIDs []uuid.UUID
2620-
for _, member := range q.groupMembers {
2621-
if member.UserID == arg.UserID {
2622-
groupIDs = append(groupIDs, member.GroupID)
2610+
2611+
groupIDs := make(map[uuid.UUID]struct{})
2612+
if arg.HasMemberID != uuid.Nil {
2613+
for _, member := range q.groupMembers {
2614+
if member.UserID == arg.HasMemberID {
2615+
groupIDs[member.GroupID] = struct{}{}
2616+
}
26232617
}
2624-
}
2625-
groups := []database.Group{}
2626-
for _, group := range q.groups {
2627-
if slices.Contains(groupIDs, group.ID) && group.OrganizationID == arg.OrganizationID {
2628-
groups = append(groups, group)
2618+
2619+
// Handle the everyone group
2620+
for _, orgMember := range q.organizationMembers {
2621+
if orgMember.UserID == arg.HasMemberID {
2622+
groupIDs[orgMember.OrganizationID] = struct{}{}
2623+
}
26292624
}
26302625
}
26312626

2632-
return groups, nil
2633-
}
2634-
2635-
func (q *FakeQuerier) GetGroupsByOrganizationID(_ context.Context, id uuid.UUID) ([]database.Group, error) {
2636-
q.mutex.RLock()
2637-
defer q.mutex.RUnlock()
2638-
2639-
groups := make([]database.Group, 0, len(q.groups))
2627+
filtered := make([]database.Group, 0)
26402628
for _, group := range q.groups {
2641-
if group.OrganizationID == id {
2642-
groups = append(groups, group)
2629+
if arg.OrganizationID != uuid.Nil && group.OrganizationID != arg.OrganizationID {
2630+
continue
2631+
}
2632+
2633+
_, ok := groupIDs[group.ID]
2634+
if arg.HasMemberID != uuid.Nil && !ok {
2635+
continue
26432636
}
2637+
2638+
filtered = append(filtered, group)
26442639
}
26452640

2646-
return groups, nil
2641+
return filtered, nil
26472642
}
26482643

26492644
func (q *FakeQuerier) GetHealthSettings(_ context.Context) (string, error) {

0 commit comments

Comments
 (0)