diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 077d704be1300..d3bbceec14252 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -1041,6 +1041,10 @@ func (q *querier) DeleteCoordinator(ctx context.Context, id uuid.UUID) error { return q.db.DeleteCoordinator(ctx, id) } +func (q *querier) DeleteCryptoKey(ctx context.Context, arg database.DeleteCryptoKeyParams) (database.CryptoKey, error) { + panic("not implemented") +} + func (q *querier) DeleteCustomRole(ctx context.Context, arg database.DeleteCustomRoleParams) error { if arg.OrganizationID.UUID != uuid.Nil { if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceAssignOrgRole.InOrg(arg.OrganizationID.UUID)); err != nil { @@ -1383,6 +1387,14 @@ func (q *querier) GetCoordinatorResumeTokenSigningKey(ctx context.Context) (stri return q.db.GetCoordinatorResumeTokenSigningKey(ctx) } +func (q *querier) GetCryptoKeyByFeatureAndSequence(ctx context.Context, arg database.GetCryptoKeyByFeatureAndSequenceParams) (database.CryptoKey, error) { + panic("not implemented") +} + +func (q *querier) GetCryptoKeys(ctx context.Context) ([]database.CryptoKey, error) { + panic("not implemented") +} + func (q *querier) GetDBCryptKeys(ctx context.Context) ([]database.DBCryptKey, error) { if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil { return nil, err @@ -1549,6 +1561,10 @@ func (q *querier) GetLastUpdateCheck(ctx context.Context) (string, error) { return q.db.GetLastUpdateCheck(ctx) } +func (q *querier) GetLatestCryptoKeyByFeature(ctx context.Context, feature database.CryptoKeyFeature) (database.CryptoKey, error) { + panic("not implemented") +} + func (q *querier) GetLatestWorkspaceBuildByWorkspaceID(ctx context.Context, workspaceID uuid.UUID) (database.WorkspaceBuild, error) { if _, err := q.GetWorkspaceByID(ctx, workspaceID); err != nil { return database.WorkspaceBuild{}, err @@ -2654,6 +2670,10 @@ func (q *querier) InsertAuditLog(ctx context.Context, arg database.InsertAuditLo return insert(q.log, q.auth, rbac.ResourceAuditLog, q.db.InsertAuditLog)(ctx, arg) } +func (q *querier) InsertCryptoKey(ctx context.Context, arg database.InsertCryptoKeyParams) (database.CryptoKey, error) { + panic("not implemented") +} + func (q *querier) InsertCustomRole(ctx context.Context, arg database.InsertCustomRoleParams) (database.CustomRole, error) { // Org and site role upsert share the same query. So switch the assertion based on the org uuid. if arg.OrganizationID.UUID != uuid.Nil { @@ -3157,6 +3177,10 @@ func (q *querier) UpdateAPIKeyByID(ctx context.Context, arg database.UpdateAPIKe return update(q.log, q.auth, fetch, q.db.UpdateAPIKeyByID)(ctx, arg) } +func (q *querier) UpdateCryptoKeyDeletesAt(ctx context.Context, arg database.UpdateCryptoKeyDeletesAtParams) (database.CryptoKey, error) { + panic("not implemented") +} + func (q *querier) UpdateCustomRole(ctx context.Context, arg database.UpdateCustomRoleParams) (database.CustomRole, error) { if arg.OrganizationID.UUID != uuid.Nil { if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceAssignOrgRole.InOrg(arg.OrganizationID.UUID)); err != nil { diff --git a/coderd/database/dbgen/dbgen.go b/coderd/database/dbgen/dbgen.go index 79aee59d97dbe..06e40287cff29 100644 --- a/coderd/database/dbgen/dbgen.go +++ b/coderd/database/dbgen/dbgen.go @@ -2,6 +2,7 @@ package dbgen import ( "context" + "crypto/rand" "crypto/sha256" "database/sql" "encoding/hex" @@ -893,6 +894,36 @@ func CustomRole(t testing.TB, db database.Store, seed database.CustomRole) datab return role } +func CryptoKey(t testing.TB, db database.Store, seed database.CryptoKey) database.CryptoKey { + t.Helper() + + b := make([]byte, 96) + _, err := rand.Read(b) + require.NoError(t, err, "generate secret") + + key, err := db.InsertCryptoKey(genCtx, database.InsertCryptoKeyParams{ + Sequence: takeFirst(seed.Sequence, 123), + Secret: takeFirst(seed.Secret, sql.NullString{ + String: hex.EncodeToString(b), + Valid: true, + }), + SecretKeyID: takeFirst(seed.SecretKeyID, sql.NullString{}), + Feature: takeFirst(seed.Feature, database.CryptoKeyFeatureWorkspaceApps), + StartsAt: takeFirst(seed.StartsAt, time.Now()), + }) + require.NoError(t, err, "insert crypto key") + + if seed.DeletesAt.Valid { + key, err = db.UpdateCryptoKeyDeletesAt(genCtx, database.UpdateCryptoKeyDeletesAtParams{ + Feature: key.Feature, + Sequence: key.Sequence, + DeletesAt: sql.NullTime{Time: seed.DeletesAt.Time, Valid: true}, + }) + require.NoError(t, err, "update crypto key deletes_at") + } + return key +} + func must[V any](v V, err error) V { if err != nil { panic(err) diff --git a/coderd/database/dbmem/dbmem.go b/coderd/database/dbmem/dbmem.go index ed766d48ecd43..774d9296e51bc 100644 --- a/coderd/database/dbmem/dbmem.go +++ b/coderd/database/dbmem/dbmem.go @@ -153,6 +153,7 @@ type data struct { // New tables workspaceAgentStats []database.WorkspaceAgentStat auditLogs []database.AuditLog + cryptoKeys []database.CryptoKey dbcryptKeys []database.DBCryptKey files []database.File externalAuthLinks []database.ExternalAuthLink @@ -1434,6 +1435,15 @@ func (*FakeQuerier) DeleteCoordinator(context.Context, uuid.UUID) error { return ErrUnimplemented } +func (q *FakeQuerier) DeleteCryptoKey(ctx context.Context, arg database.DeleteCryptoKeyParams) (database.CryptoKey, error) { + err := validateDatabaseType(arg) + if err != nil { + return database.CryptoKey{}, err + } + + panic("not implemented") +} + func (q *FakeQuerier) DeleteCustomRole(_ context.Context, arg database.DeleteCustomRoleParams) error { err := validateDatabaseType(arg) if err != nil { @@ -2309,6 +2319,32 @@ func (q *FakeQuerier) GetCoordinatorResumeTokenSigningKey(_ context.Context) (st return q.coordinatorResumeTokenSigningKey, nil } +func (q *FakeQuerier) GetCryptoKeyByFeatureAndSequence(_ context.Context, arg database.GetCryptoKeyByFeatureAndSequenceParams) (database.CryptoKey, error) { + err := validateDatabaseType(arg) + if err != nil { + return database.CryptoKey{}, err + } + + q.mutex.RLock() + defer q.mutex.RUnlock() + + for _, key := range q.cryptoKeys { + if key.Feature == arg.Feature && key.Sequence == arg.Sequence { + // Keys with NULL secrets are considered deleted. + if key.Secret.Valid { + return key, nil + } + return database.CryptoKey{}, sql.ErrNoRows + } + } + + return database.CryptoKey{}, sql.ErrNoRows +} + +func (q *FakeQuerier) GetCryptoKeys(ctx context.Context) ([]database.CryptoKey, error) { + panic("not implemented") +} + func (q *FakeQuerier) GetDBCryptKeys(_ context.Context) ([]database.DBCryptKey, error) { q.mutex.RLock() defer q.mutex.RUnlock() @@ -2806,6 +2842,10 @@ func (q *FakeQuerier) GetLastUpdateCheck(_ context.Context) (string, error) { return string(q.lastUpdateCheck), nil } +func (q *FakeQuerier) GetLatestCryptoKeyByFeature(ctx context.Context, feature database.CryptoKeyFeature) (database.CryptoKey, error) { + panic("not implemented") +} + func (q *FakeQuerier) GetLatestWorkspaceBuildByWorkspaceID(ctx context.Context, workspaceID uuid.UUID) (database.WorkspaceBuild, error) { q.mutex.RLock() defer q.mutex.RUnlock() @@ -6305,6 +6345,28 @@ func (q *FakeQuerier) InsertAuditLog(_ context.Context, arg database.InsertAudit return alog, nil } +func (q *FakeQuerier) InsertCryptoKey(_ context.Context, arg database.InsertCryptoKeyParams) (database.CryptoKey, error) { + err := validateDatabaseType(arg) + if err != nil { + return database.CryptoKey{}, err + } + + q.mutex.Lock() + defer q.mutex.Unlock() + + key := database.CryptoKey{ + Feature: arg.Feature, + Sequence: arg.Sequence, + Secret: arg.Secret, + SecretKeyID: arg.SecretKeyID, + StartsAt: arg.StartsAt, + } + + q.cryptoKeys = append(q.cryptoKeys, key) + + return key, nil +} + func (q *FakeQuerier) InsertCustomRole(_ context.Context, arg database.InsertCustomRoleParams) (database.CustomRole, error) { err := validateDatabaseType(arg) if err != nil { @@ -7774,6 +7836,15 @@ func (q *FakeQuerier) UpdateAPIKeyByID(_ context.Context, arg database.UpdateAPI return sql.ErrNoRows } +func (q *FakeQuerier) UpdateCryptoKeyDeletesAt(ctx context.Context, arg database.UpdateCryptoKeyDeletesAtParams) (database.CryptoKey, error) { + err := validateDatabaseType(arg) + if err != nil { + return database.CryptoKey{}, err + } + + panic("not implemented") +} + func (q *FakeQuerier) UpdateCustomRole(_ context.Context, arg database.UpdateCustomRoleParams) (database.CustomRole, error) { err := validateDatabaseType(arg) if err != nil { diff --git a/coderd/database/dbmetrics/dbmetrics.go b/coderd/database/dbmetrics/dbmetrics.go index 0ec70c1736d43..bf95ad82896d8 100644 --- a/coderd/database/dbmetrics/dbmetrics.go +++ b/coderd/database/dbmetrics/dbmetrics.go @@ -214,6 +214,13 @@ func (m metricsStore) DeleteCoordinator(ctx context.Context, id uuid.UUID) error return r0 } +func (m metricsStore) DeleteCryptoKey(ctx context.Context, arg database.DeleteCryptoKeyParams) (database.CryptoKey, error) { + start := time.Now() + r0, r1 := m.s.DeleteCryptoKey(ctx, arg) + m.queryLatencies.WithLabelValues("DeleteCryptoKey").Observe(time.Since(start).Seconds()) + return r0, r1 +} + func (m metricsStore) DeleteCustomRole(ctx context.Context, arg database.DeleteCustomRoleParams) error { start := time.Now() r0 := m.s.DeleteCustomRole(ctx, arg) @@ -543,6 +550,20 @@ func (m metricsStore) GetCoordinatorResumeTokenSigningKey(ctx context.Context) ( return r0, r1 } +func (m metricsStore) GetCryptoKeyByFeatureAndSequence(ctx context.Context, arg database.GetCryptoKeyByFeatureAndSequenceParams) (database.CryptoKey, error) { + start := time.Now() + r0, r1 := m.s.GetCryptoKeyByFeatureAndSequence(ctx, arg) + m.queryLatencies.WithLabelValues("GetCryptoKeyByFeatureAndSequence").Observe(time.Since(start).Seconds()) + return r0, r1 +} + +func (m metricsStore) GetCryptoKeys(ctx context.Context) ([]database.CryptoKey, error) { + start := time.Now() + r0, r1 := m.s.GetCryptoKeys(ctx) + m.queryLatencies.WithLabelValues("GetCryptoKeys").Observe(time.Since(start).Seconds()) + return r0, r1 +} + func (m metricsStore) GetDBCryptKeys(ctx context.Context) ([]database.DBCryptKey, error) { start := time.Now() r0, r1 := m.s.GetDBCryptKeys(ctx) @@ -711,6 +732,13 @@ func (m metricsStore) GetLastUpdateCheck(ctx context.Context) (string, error) { return version, err } +func (m metricsStore) GetLatestCryptoKeyByFeature(ctx context.Context, feature database.CryptoKeyFeature) (database.CryptoKey, error) { + start := time.Now() + r0, r1 := m.s.GetLatestCryptoKeyByFeature(ctx, feature) + m.queryLatencies.WithLabelValues("GetLatestCryptoKeyByFeature").Observe(time.Since(start).Seconds()) + return r0, r1 +} + func (m metricsStore) GetLatestWorkspaceBuildByWorkspaceID(ctx context.Context, workspaceID uuid.UUID) (database.WorkspaceBuild, error) { start := time.Now() build, err := m.s.GetLatestWorkspaceBuildByWorkspaceID(ctx, workspaceID) @@ -1593,6 +1621,13 @@ func (m metricsStore) InsertAuditLog(ctx context.Context, arg database.InsertAud return log, err } +func (m metricsStore) InsertCryptoKey(ctx context.Context, arg database.InsertCryptoKeyParams) (database.CryptoKey, error) { + start := time.Now() + key, err := m.s.InsertCryptoKey(ctx, arg) + m.queryLatencies.WithLabelValues("InsertCryptoKey").Observe(time.Since(start).Seconds()) + return key, err +} + func (m metricsStore) InsertCustomRole(ctx context.Context, arg database.InsertCustomRoleParams) (database.CustomRole, error) { start := time.Now() r0, r1 := m.s.InsertCustomRole(ctx, arg) @@ -1992,6 +2027,13 @@ func (m metricsStore) UpdateAPIKeyByID(ctx context.Context, arg database.UpdateA return err } +func (m metricsStore) UpdateCryptoKeyDeletesAt(ctx context.Context, arg database.UpdateCryptoKeyDeletesAtParams) (database.CryptoKey, error) { + start := time.Now() + key, err := m.s.UpdateCryptoKeyDeletesAt(ctx, arg) + m.queryLatencies.WithLabelValues("UpdateCryptoKeyDeletesAt").Observe(time.Since(start).Seconds()) + return key, err +} + func (m metricsStore) UpdateCustomRole(ctx context.Context, arg database.UpdateCustomRoleParams) (database.CustomRole, error) { start := time.Now() r0, r1 := m.s.UpdateCustomRole(ctx, arg) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index c5d579e1c2656..0ab399f573bfe 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -317,6 +317,21 @@ func (mr *MockStoreMockRecorder) DeleteCoordinator(arg0, arg1 any) *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteCoordinator", reflect.TypeOf((*MockStore)(nil).DeleteCoordinator), arg0, arg1) } +// DeleteCryptoKey mocks base method. +func (m *MockStore) DeleteCryptoKey(arg0 context.Context, arg1 database.DeleteCryptoKeyParams) (database.CryptoKey, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteCryptoKey", arg0, arg1) + ret0, _ := ret[0].(database.CryptoKey) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DeleteCryptoKey indicates an expected call of DeleteCryptoKey. +func (mr *MockStoreMockRecorder) DeleteCryptoKey(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteCryptoKey", reflect.TypeOf((*MockStore)(nil).DeleteCryptoKey), arg0, arg1) +} + // DeleteCustomRole mocks base method. func (m *MockStore) DeleteCustomRole(arg0 context.Context, arg1 database.DeleteCustomRoleParams) error { m.ctrl.T.Helper() @@ -1058,6 +1073,36 @@ func (mr *MockStoreMockRecorder) GetCoordinatorResumeTokenSigningKey(arg0 any) * return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCoordinatorResumeTokenSigningKey", reflect.TypeOf((*MockStore)(nil).GetCoordinatorResumeTokenSigningKey), arg0) } +// GetCryptoKeyByFeatureAndSequence mocks base method. +func (m *MockStore) GetCryptoKeyByFeatureAndSequence(arg0 context.Context, arg1 database.GetCryptoKeyByFeatureAndSequenceParams) (database.CryptoKey, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetCryptoKeyByFeatureAndSequence", arg0, arg1) + ret0, _ := ret[0].(database.CryptoKey) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetCryptoKeyByFeatureAndSequence indicates an expected call of GetCryptoKeyByFeatureAndSequence. +func (mr *MockStoreMockRecorder) GetCryptoKeyByFeatureAndSequence(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCryptoKeyByFeatureAndSequence", reflect.TypeOf((*MockStore)(nil).GetCryptoKeyByFeatureAndSequence), arg0, arg1) +} + +// GetCryptoKeys mocks base method. +func (m *MockStore) GetCryptoKeys(arg0 context.Context) ([]database.CryptoKey, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetCryptoKeys", arg0) + ret0, _ := ret[0].([]database.CryptoKey) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetCryptoKeys indicates an expected call of GetCryptoKeys. +func (mr *MockStoreMockRecorder) GetCryptoKeys(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCryptoKeys", reflect.TypeOf((*MockStore)(nil).GetCryptoKeys), arg0) +} + // GetDBCryptKeys mocks base method. func (m *MockStore) GetDBCryptKeys(arg0 context.Context) ([]database.DBCryptKey, error) { m.ctrl.T.Helper() @@ -1418,6 +1463,21 @@ func (mr *MockStoreMockRecorder) GetLastUpdateCheck(arg0 any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLastUpdateCheck", reflect.TypeOf((*MockStore)(nil).GetLastUpdateCheck), arg0) } +// GetLatestCryptoKeyByFeature mocks base method. +func (m *MockStore) GetLatestCryptoKeyByFeature(arg0 context.Context, arg1 database.CryptoKeyFeature) (database.CryptoKey, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetLatestCryptoKeyByFeature", arg0, arg1) + ret0, _ := ret[0].(database.CryptoKey) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetLatestCryptoKeyByFeature indicates an expected call of GetLatestCryptoKeyByFeature. +func (mr *MockStoreMockRecorder) GetLatestCryptoKeyByFeature(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLatestCryptoKeyByFeature", reflect.TypeOf((*MockStore)(nil).GetLatestCryptoKeyByFeature), arg0, arg1) +} + // GetLatestWorkspaceBuildByWorkspaceID mocks base method. func (m *MockStore) GetLatestWorkspaceBuildByWorkspaceID(arg0 context.Context, arg1 uuid.UUID) (database.WorkspaceBuild, error) { m.ctrl.T.Helper() @@ -3352,6 +3412,21 @@ func (mr *MockStoreMockRecorder) InsertAuditLog(arg0, arg1 any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertAuditLog", reflect.TypeOf((*MockStore)(nil).InsertAuditLog), arg0, arg1) } +// InsertCryptoKey mocks base method. +func (m *MockStore) InsertCryptoKey(arg0 context.Context, arg1 database.InsertCryptoKeyParams) (database.CryptoKey, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "InsertCryptoKey", arg0, arg1) + ret0, _ := ret[0].(database.CryptoKey) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// InsertCryptoKey indicates an expected call of InsertCryptoKey. +func (mr *MockStoreMockRecorder) InsertCryptoKey(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertCryptoKey", reflect.TypeOf((*MockStore)(nil).InsertCryptoKey), arg0, arg1) +} + // InsertCustomRole mocks base method. func (m *MockStore) InsertCustomRole(arg0 context.Context, arg1 database.InsertCustomRoleParams) (database.CustomRole, error) { m.ctrl.T.Helper() @@ -4204,6 +4279,21 @@ func (mr *MockStoreMockRecorder) UpdateAPIKeyByID(arg0, arg1 any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAPIKeyByID", reflect.TypeOf((*MockStore)(nil).UpdateAPIKeyByID), arg0, arg1) } +// UpdateCryptoKeyDeletesAt mocks base method. +func (m *MockStore) UpdateCryptoKeyDeletesAt(arg0 context.Context, arg1 database.UpdateCryptoKeyDeletesAtParams) (database.CryptoKey, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateCryptoKeyDeletesAt", arg0, arg1) + ret0, _ := ret[0].(database.CryptoKey) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// UpdateCryptoKeyDeletesAt indicates an expected call of UpdateCryptoKeyDeletesAt. +func (mr *MockStoreMockRecorder) UpdateCryptoKeyDeletesAt(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateCryptoKeyDeletesAt", reflect.TypeOf((*MockStore)(nil).UpdateCryptoKeyDeletesAt), arg0, arg1) +} + // UpdateCustomRole mocks base method. func (m *MockStore) UpdateCustomRole(arg0 context.Context, arg1 database.UpdateCustomRoleParams) (database.CustomRole, error) { m.ctrl.T.Helper() diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index 6638d52745ba6..17fd3511442ec 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -36,6 +36,12 @@ CREATE TYPE build_reason AS ENUM ( 'autodelete' ); +CREATE TYPE crypto_key_feature AS ENUM ( + 'workspace_apps', + 'oidc_convert', + 'peer_reconnect' +); + CREATE TYPE display_app AS ENUM ( 'vscode', 'vscode_insiders', @@ -494,6 +500,15 @@ CREATE TABLE audit_logs ( resource_icon text NOT NULL ); +CREATE TABLE crypto_keys ( + feature crypto_key_feature NOT NULL, + sequence integer NOT NULL, + secret text, + secret_key_id text, + starts_at timestamp with time zone NOT NULL, + deletes_at timestamp with time zone +); + CREATE TABLE custom_roles ( name text NOT NULL, display_name text NOT NULL, @@ -1640,6 +1655,9 @@ ALTER TABLE ONLY api_keys ALTER TABLE ONLY audit_logs ADD CONSTRAINT audit_logs_pkey PRIMARY KEY (id); +ALTER TABLE ONLY crypto_keys + ADD CONSTRAINT crypto_keys_pkey PRIMARY KEY (feature, sequence); + ALTER TABLE ONLY custom_roles ADD CONSTRAINT custom_roles_unique_key UNIQUE (name, organization_id); @@ -2035,6 +2053,9 @@ CREATE TRIGGER update_notification_message_dedupe_hash BEFORE INSERT OR UPDATE O ALTER TABLE ONLY api_keys ADD CONSTRAINT api_keys_user_id_uuid_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE; +ALTER TABLE ONLY crypto_keys + ADD CONSTRAINT crypto_keys_secret_key_id_fkey FOREIGN KEY (secret_key_id) REFERENCES dbcrypt_keys(active_key_digest); + ALTER TABLE ONLY external_auth_links ADD CONSTRAINT git_auth_links_oauth_access_token_key_id_fkey FOREIGN KEY (oauth_access_token_key_id) REFERENCES dbcrypt_keys(active_key_digest); diff --git a/coderd/database/foreign_key_constraint.go b/coderd/database/foreign_key_constraint.go index 0c578255f091c..6046a94e3bcad 100644 --- a/coderd/database/foreign_key_constraint.go +++ b/coderd/database/foreign_key_constraint.go @@ -7,6 +7,7 @@ type ForeignKeyConstraint string // ForeignKeyConstraint enums. const ( ForeignKeyAPIKeysUserIDUUID ForeignKeyConstraint = "api_keys_user_id_uuid_fkey" // ALTER TABLE ONLY api_keys ADD CONSTRAINT api_keys_user_id_uuid_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE; + ForeignKeyCryptoKeysSecretKeyID ForeignKeyConstraint = "crypto_keys_secret_key_id_fkey" // ALTER TABLE ONLY crypto_keys ADD CONSTRAINT crypto_keys_secret_key_id_fkey FOREIGN KEY (secret_key_id) REFERENCES dbcrypt_keys(active_key_digest); ForeignKeyGitAuthLinksOauthAccessTokenKeyID ForeignKeyConstraint = "git_auth_links_oauth_access_token_key_id_fkey" // ALTER TABLE ONLY external_auth_links ADD CONSTRAINT git_auth_links_oauth_access_token_key_id_fkey FOREIGN KEY (oauth_access_token_key_id) REFERENCES dbcrypt_keys(active_key_digest); ForeignKeyGitAuthLinksOauthRefreshTokenKeyID ForeignKeyConstraint = "git_auth_links_oauth_refresh_token_key_id_fkey" // ALTER TABLE ONLY external_auth_links ADD CONSTRAINT git_auth_links_oauth_refresh_token_key_id_fkey FOREIGN KEY (oauth_refresh_token_key_id) REFERENCES dbcrypt_keys(active_key_digest); ForeignKeyGitSSHKeysUserID ForeignKeyConstraint = "gitsshkeys_user_id_fkey" // ALTER TABLE ONLY gitsshkeys ADD CONSTRAINT gitsshkeys_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id); diff --git a/coderd/database/migrations/000250_crypto_keys.down.sql b/coderd/database/migrations/000250_crypto_keys.down.sql new file mode 100644 index 0000000000000..8b0cc3702bcc4 --- /dev/null +++ b/coderd/database/migrations/000250_crypto_keys.down.sql @@ -0,0 +1 @@ +DROP TABLE "keys"; diff --git a/coderd/database/migrations/000250_crypto_keys.up.sql b/coderd/database/migrations/000250_crypto_keys.up.sql new file mode 100644 index 0000000000000..7c1aa7888fdd1 --- /dev/null +++ b/coderd/database/migrations/000250_crypto_keys.up.sql @@ -0,0 +1,16 @@ +CREATE TYPE crypto_key_feature AS ENUM ( + 'workspace_apps', + 'oidc_convert', + 'peer_reconnect' +); + +CREATE TABLE crypto_keys ( + feature crypto_key_feature NOT NULL, + sequence integer NOT NULL, + secret text NULL, + secret_key_id text NULL REFERENCES dbcrypt_keys(active_key_digest), + starts_at timestamptz NOT NULL, + deletes_at timestamptz NULL, + PRIMARY KEY (feature, sequence) +); + diff --git a/coderd/database/modelmethods.go b/coderd/database/modelmethods.go index 816fc4c9214b0..82be5e710c058 100644 --- a/coderd/database/modelmethods.go +++ b/coderd/database/modelmethods.go @@ -447,3 +447,7 @@ func (r GetAuthorizationUserRolesRow) RoleNames() ([]rbac.RoleIdentifier, error) } return names, nil } + +func (k CryptoKey) ExpiresAt(keyDuration time.Duration) time.Time { + return k.StartsAt.Add(keyDuration).UTC() +} diff --git a/coderd/database/models.go b/coderd/database/models.go index 9e0283ba859c1..e9bb8e42b8960 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -339,6 +339,67 @@ func AllBuildReasonValues() []BuildReason { } } +type CryptoKeyFeature string + +const ( + CryptoKeyFeatureWorkspaceApps CryptoKeyFeature = "workspace_apps" + CryptoKeyFeatureOidcConvert CryptoKeyFeature = "oidc_convert" + CryptoKeyFeaturePeerReconnect CryptoKeyFeature = "peer_reconnect" +) + +func (e *CryptoKeyFeature) Scan(src interface{}) error { + switch s := src.(type) { + case []byte: + *e = CryptoKeyFeature(s) + case string: + *e = CryptoKeyFeature(s) + default: + return fmt.Errorf("unsupported scan type for CryptoKeyFeature: %T", src) + } + return nil +} + +type NullCryptoKeyFeature struct { + CryptoKeyFeature CryptoKeyFeature `json:"crypto_key_feature"` + Valid bool `json:"valid"` // Valid is true if CryptoKeyFeature is not NULL +} + +// Scan implements the Scanner interface. +func (ns *NullCryptoKeyFeature) Scan(value interface{}) error { + if value == nil { + ns.CryptoKeyFeature, ns.Valid = "", false + return nil + } + ns.Valid = true + return ns.CryptoKeyFeature.Scan(value) +} + +// Value implements the driver Valuer interface. +func (ns NullCryptoKeyFeature) Value() (driver.Value, error) { + if !ns.Valid { + return nil, nil + } + return string(ns.CryptoKeyFeature), nil +} + +func (e CryptoKeyFeature) Valid() bool { + switch e { + case CryptoKeyFeatureWorkspaceApps, + CryptoKeyFeatureOidcConvert, + CryptoKeyFeaturePeerReconnect: + return true + } + return false +} + +func AllCryptoKeyFeatureValues() []CryptoKeyFeature { + return []CryptoKeyFeature{ + CryptoKeyFeatureWorkspaceApps, + CryptoKeyFeatureOidcConvert, + CryptoKeyFeaturePeerReconnect, + } +} + type DisplayApp string const ( @@ -2043,6 +2104,15 @@ type AuditLog struct { ResourceIcon string `db:"resource_icon" json:"resource_icon"` } +type CryptoKey struct { + Feature CryptoKeyFeature `db:"feature" json:"feature"` + Sequence int32 `db:"sequence" json:"sequence"` + Secret sql.NullString `db:"secret" json:"secret"` + SecretKeyID sql.NullString `db:"secret_key_id" json:"secret_key_id"` + StartsAt time.Time `db:"starts_at" json:"starts_at"` + DeletesAt sql.NullTime `db:"deletes_at" json:"deletes_at"` +} + // Custom roles allow dynamic roles expanded at runtime type CustomRole struct { Name string `db:"name" json:"name"` diff --git a/coderd/database/querier.go b/coderd/database/querier.go index ee9a64f12076d..8e8f587d302c8 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -69,6 +69,7 @@ type sqlcQuerier interface { DeleteAllTailnetTunnels(ctx context.Context, arg DeleteAllTailnetTunnelsParams) error DeleteApplicationConnectAPIKeysByUserID(ctx context.Context, userID uuid.UUID) error DeleteCoordinator(ctx context.Context, id uuid.UUID) error + DeleteCryptoKey(ctx context.Context, arg DeleteCryptoKeyParams) (CryptoKey, error) DeleteCustomRole(ctx context.Context, arg DeleteCustomRoleParams) error DeleteExternalAuthLink(ctx context.Context, arg DeleteExternalAuthLinkParams) error DeleteGitSSHKey(ctx context.Context, userID uuid.UUID) error @@ -131,6 +132,8 @@ type sqlcQuerier interface { // are included. GetAuthorizationUserRoles(ctx context.Context, userID uuid.UUID) (GetAuthorizationUserRolesRow, error) GetCoordinatorResumeTokenSigningKey(ctx context.Context) (string, error) + GetCryptoKeyByFeatureAndSequence(ctx context.Context, arg GetCryptoKeyByFeatureAndSequenceParams) (CryptoKey, error) + GetCryptoKeys(ctx context.Context) ([]CryptoKey, error) GetDBCryptKeys(ctx context.Context) ([]DBCryptKey, error) GetDERPMeshKey(ctx context.Context) (string, error) GetDefaultOrganization(ctx context.Context) (Organization, error) @@ -159,6 +162,7 @@ type sqlcQuerier interface { GetHungProvisionerJobs(ctx context.Context, updatedAt time.Time) ([]ProvisionerJob, error) GetJFrogXrayScanByWorkspaceAndAgentID(ctx context.Context, arg GetJFrogXrayScanByWorkspaceAndAgentIDParams) (JfrogXrayScan, error) GetLastUpdateCheck(ctx context.Context) (string, error) + GetLatestCryptoKeyByFeature(ctx context.Context, feature CryptoKeyFeature) (CryptoKey, error) GetLatestWorkspaceBuildByWorkspaceID(ctx context.Context, workspaceID uuid.UUID) (WorkspaceBuild, error) GetLatestWorkspaceBuilds(ctx context.Context) ([]WorkspaceBuild, error) GetLatestWorkspaceBuildsByWorkspaceIDs(ctx context.Context, ids []uuid.UUID) ([]WorkspaceBuild, error) @@ -337,6 +341,7 @@ type sqlcQuerier interface { // every member of the org. InsertAllUsersGroup(ctx context.Context, organizationID uuid.UUID) (Group, error) InsertAuditLog(ctx context.Context, arg InsertAuditLogParams) (AuditLog, error) + InsertCryptoKey(ctx context.Context, arg InsertCryptoKeyParams) (CryptoKey, error) InsertCustomRole(ctx context.Context, arg InsertCustomRoleParams) (CustomRole, error) InsertDBCryptKey(ctx context.Context, arg InsertDBCryptKeyParams) error InsertDERPMeshKey(ctx context.Context, value string) error @@ -410,6 +415,7 @@ type sqlcQuerier interface { UnarchiveTemplateVersion(ctx context.Context, arg UnarchiveTemplateVersionParams) error UnfavoriteWorkspace(ctx context.Context, id uuid.UUID) error UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDParams) error + UpdateCryptoKeyDeletesAt(ctx context.Context, arg UpdateCryptoKeyDeletesAtParams) (CryptoKey, error) UpdateCustomRole(ctx context.Context, arg UpdateCustomRoleParams) (CustomRole, error) UpdateExternalAuthLink(ctx context.Context, arg UpdateExternalAuthLinkParams) (ExternalAuthLink, error) UpdateGitSSHKey(ctx context.Context, arg UpdateGitSSHKeyParams) (GitSSHKey, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 6831415907b67..b5c198c94da42 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -761,6 +761,186 @@ func (q *sqlQuerier) InsertAuditLog(ctx context.Context, arg InsertAuditLogParam return i, err } +const deleteCryptoKey = `-- name: DeleteCryptoKey :one +UPDATE crypto_keys +SET secret = NULL +WHERE feature = $1 AND sequence = $2 RETURNING feature, sequence, secret, secret_key_id, starts_at, deletes_at +` + +type DeleteCryptoKeyParams struct { + Feature CryptoKeyFeature `db:"feature" json:"feature"` + Sequence int32 `db:"sequence" json:"sequence"` +} + +func (q *sqlQuerier) DeleteCryptoKey(ctx context.Context, arg DeleteCryptoKeyParams) (CryptoKey, error) { + row := q.db.QueryRowContext(ctx, deleteCryptoKey, arg.Feature, arg.Sequence) + var i CryptoKey + err := row.Scan( + &i.Feature, + &i.Sequence, + &i.Secret, + &i.SecretKeyID, + &i.StartsAt, + &i.DeletesAt, + ) + return i, err +} + +const getCryptoKeyByFeatureAndSequence = `-- name: GetCryptoKeyByFeatureAndSequence :one +SELECT feature, sequence, secret, secret_key_id, starts_at, deletes_at +FROM crypto_keys +WHERE feature = $1 + AND sequence = $2 + AND secret IS NOT NULL +` + +type GetCryptoKeyByFeatureAndSequenceParams struct { + Feature CryptoKeyFeature `db:"feature" json:"feature"` + Sequence int32 `db:"sequence" json:"sequence"` +} + +func (q *sqlQuerier) GetCryptoKeyByFeatureAndSequence(ctx context.Context, arg GetCryptoKeyByFeatureAndSequenceParams) (CryptoKey, error) { + row := q.db.QueryRowContext(ctx, getCryptoKeyByFeatureAndSequence, arg.Feature, arg.Sequence) + var i CryptoKey + err := row.Scan( + &i.Feature, + &i.Sequence, + &i.Secret, + &i.SecretKeyID, + &i.StartsAt, + &i.DeletesAt, + ) + return i, err +} + +const getCryptoKeys = `-- name: GetCryptoKeys :many +SELECT feature, sequence, secret, secret_key_id, starts_at, deletes_at +FROM crypto_keys +WHERE secret IS NOT NULL +` + +func (q *sqlQuerier) GetCryptoKeys(ctx context.Context) ([]CryptoKey, error) { + rows, err := q.db.QueryContext(ctx, getCryptoKeys) + if err != nil { + return nil, err + } + defer rows.Close() + var items []CryptoKey + for rows.Next() { + var i CryptoKey + if err := rows.Scan( + &i.Feature, + &i.Sequence, + &i.Secret, + &i.SecretKeyID, + &i.StartsAt, + &i.DeletesAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getLatestCryptoKeyByFeature = `-- name: GetLatestCryptoKeyByFeature :one +SELECT feature, sequence, secret, secret_key_id, starts_at, deletes_at +FROM crypto_keys +WHERE feature = $1 +ORDER BY sequence DESC +LIMIT 1 +` + +func (q *sqlQuerier) GetLatestCryptoKeyByFeature(ctx context.Context, feature CryptoKeyFeature) (CryptoKey, error) { + row := q.db.QueryRowContext(ctx, getLatestCryptoKeyByFeature, feature) + var i CryptoKey + err := row.Scan( + &i.Feature, + &i.Sequence, + &i.Secret, + &i.SecretKeyID, + &i.StartsAt, + &i.DeletesAt, + ) + return i, err +} + +const insertCryptoKey = `-- name: InsertCryptoKey :one +INSERT INTO crypto_keys ( + feature, + sequence, + secret, + starts_at, + secret_key_id +) VALUES ( + $1, + $2, + $3, + $4, + $5 +) RETURNING feature, sequence, secret, secret_key_id, starts_at, deletes_at +` + +type InsertCryptoKeyParams struct { + Feature CryptoKeyFeature `db:"feature" json:"feature"` + Sequence int32 `db:"sequence" json:"sequence"` + Secret sql.NullString `db:"secret" json:"secret"` + StartsAt time.Time `db:"starts_at" json:"starts_at"` + SecretKeyID sql.NullString `db:"secret_key_id" json:"secret_key_id"` +} + +func (q *sqlQuerier) InsertCryptoKey(ctx context.Context, arg InsertCryptoKeyParams) (CryptoKey, error) { + row := q.db.QueryRowContext(ctx, insertCryptoKey, + arg.Feature, + arg.Sequence, + arg.Secret, + arg.StartsAt, + arg.SecretKeyID, + ) + var i CryptoKey + err := row.Scan( + &i.Feature, + &i.Sequence, + &i.Secret, + &i.SecretKeyID, + &i.StartsAt, + &i.DeletesAt, + ) + return i, err +} + +const updateCryptoKeyDeletesAt = `-- name: UpdateCryptoKeyDeletesAt :one +UPDATE crypto_keys +SET deletes_at = $3 +WHERE feature = $1 AND sequence = $2 RETURNING feature, sequence, secret, secret_key_id, starts_at, deletes_at +` + +type UpdateCryptoKeyDeletesAtParams struct { + Feature CryptoKeyFeature `db:"feature" json:"feature"` + Sequence int32 `db:"sequence" json:"sequence"` + DeletesAt sql.NullTime `db:"deletes_at" json:"deletes_at"` +} + +func (q *sqlQuerier) UpdateCryptoKeyDeletesAt(ctx context.Context, arg UpdateCryptoKeyDeletesAtParams) (CryptoKey, error) { + row := q.db.QueryRowContext(ctx, updateCryptoKeyDeletesAt, arg.Feature, arg.Sequence, arg.DeletesAt) + var i CryptoKey + err := row.Scan( + &i.Feature, + &i.Sequence, + &i.Secret, + &i.SecretKeyID, + &i.StartsAt, + &i.DeletesAt, + ) + return i, err +} + const getDBCryptKeys = `-- name: GetDBCryptKeys :many SELECT number, active_key_digest, revoked_key_digest, created_at, revoked_at, test FROM dbcrypt_keys ORDER BY number ASC ` diff --git a/coderd/database/queries/crypto_keys.sql b/coderd/database/queries/crypto_keys.sql new file mode 100644 index 0000000000000..39dc8175f95ab --- /dev/null +++ b/coderd/database/queries/crypto_keys.sql @@ -0,0 +1,44 @@ +-- name: GetCryptoKeys :many +SELECT * +FROM crypto_keys +WHERE secret IS NOT NULL; + +-- name: GetLatestCryptoKeyByFeature :one +SELECT * +FROM crypto_keys +WHERE feature = $1 +ORDER BY sequence DESC +LIMIT 1; + + +-- name: GetCryptoKeyByFeatureAndSequence :one +SELECT * +FROM crypto_keys +WHERE feature = $1 + AND sequence = $2 + AND secret IS NOT NULL; + +-- name: DeleteCryptoKey :one +UPDATE crypto_keys +SET secret = NULL +WHERE feature = $1 AND sequence = $2 RETURNING *; + +-- name: InsertCryptoKey :one +INSERT INTO crypto_keys ( + feature, + sequence, + secret, + starts_at, + secret_key_id +) VALUES ( + $1, + $2, + $3, + $4, + $5 +) RETURNING *; + +-- name: UpdateCryptoKeyDeletesAt :one +UPDATE crypto_keys +SET deletes_at = $3 +WHERE feature = $1 AND sequence = $2 RETURNING *; diff --git a/coderd/database/unique_constraint.go b/coderd/database/unique_constraint.go index b3bf72f8178b6..01a811af9c5ed 100644 --- a/coderd/database/unique_constraint.go +++ b/coderd/database/unique_constraint.go @@ -9,6 +9,7 @@ const ( UniqueAgentStatsPkey UniqueConstraint = "agent_stats_pkey" // ALTER TABLE ONLY workspace_agent_stats ADD CONSTRAINT agent_stats_pkey PRIMARY KEY (id); UniqueAPIKeysPkey UniqueConstraint = "api_keys_pkey" // ALTER TABLE ONLY api_keys ADD CONSTRAINT api_keys_pkey PRIMARY KEY (id); UniqueAuditLogsPkey UniqueConstraint = "audit_logs_pkey" // ALTER TABLE ONLY audit_logs ADD CONSTRAINT audit_logs_pkey PRIMARY KEY (id); + UniqueCryptoKeysPkey UniqueConstraint = "crypto_keys_pkey" // ALTER TABLE ONLY crypto_keys ADD CONSTRAINT crypto_keys_pkey PRIMARY KEY (feature, sequence); UniqueCustomRolesUniqueKey UniqueConstraint = "custom_roles_unique_key" // ALTER TABLE ONLY custom_roles ADD CONSTRAINT custom_roles_unique_key UNIQUE (name, organization_id); UniqueDbcryptKeysActiveKeyDigestKey UniqueConstraint = "dbcrypt_keys_active_key_digest_key" // ALTER TABLE ONLY dbcrypt_keys ADD CONSTRAINT dbcrypt_keys_active_key_digest_key UNIQUE (active_key_digest); UniqueDbcryptKeysPkey UniqueConstraint = "dbcrypt_keys_pkey" // ALTER TABLE ONLY dbcrypt_keys ADD CONSTRAINT dbcrypt_keys_pkey PRIMARY KEY (number); diff --git a/coderd/keyrotate/rotate.go b/coderd/keyrotate/rotate.go new file mode 100644 index 0000000000000..e9e4305a99aab --- /dev/null +++ b/coderd/keyrotate/rotate.go @@ -0,0 +1,243 @@ +package keyrotate + +import ( + "context" + "crypto/rand" + "database/sql" + "encoding/hex" + "time" + + "golang.org/x/xerrors" + + "cdr.dev/slog" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbtime" + "github.com/coder/quartz" +) + +const ( + WorkspaceAppsTokenDuration = time.Minute + OIDCConvertTokenDuration = time.Minute * 5 + PeerReconnectTokenDuration = time.Hour * 24 +) + +type KeyRotator struct { + DB database.Store + KeyDuration time.Duration + Clock quartz.Clock + Logger slog.Logger + ScanInterval time.Duration + ResultsCh chan []database.CryptoKey + features []database.CryptoKeyFeature +} + +func (k *KeyRotator) Start(ctx context.Context) { + ticker := k.Clock.NewTicker(k.ScanInterval) + defer ticker.Stop() + + if len(k.features) == 0 { + k.features = database.AllCryptoKeyFeatureValues() + } + + for { + modifiedKeys, err := k.rotateKeys(ctx) + if err != nil { + k.Logger.Error(ctx, "failed to rotate keys", slog.Error(err)) + } + + // This should only be called in test code so we don't + // both to select on the push. + if k.ResultsCh != nil { + k.ResultsCh <- modifiedKeys + } + + select { + case <-ctx.Done(): + return + case <-ticker.C: + } + } +} + +// rotateKeys checks for keys nearing expiration and rotates them if necessary. +func (k *KeyRotator) rotateKeys(ctx context.Context) ([]database.CryptoKey, error) { + var modifiedKeys []database.CryptoKey + return modifiedKeys, database.ReadModifyUpdate(k.DB, func(tx database.Store) error { + // Reset the modified keys slice for each iteration. + modifiedKeys = make([]database.CryptoKey, 0) + keys, err := tx.GetCryptoKeys(ctx) + if err != nil { + return xerrors.Errorf("get keys: %w", err) + } + + // Groups the keys by feature so that we can + // ensure we have at least one key for each feature. + keysByFeature := keysByFeature(keys, k.features) + now := dbtime.Time(k.Clock.Now().UTC()) + for feature, keys := range keysByFeature { + // It's possible there are no keys if someone + // has manually deleted all the keys. + if len(keys) == 0 { + k.Logger.Info(ctx, "no valid keys detected, inserting new key", + slog.F("feature", feature), + ) + newKey, err := k.insertNewKey(ctx, tx, feature, now) + if err != nil { + return xerrors.Errorf("insert new key: %w", err) + } + modifiedKeys = append(modifiedKeys, newKey) + } + + for _, key := range keys { + switch { + case shouldDeleteKey(key, now): + deletedKey, err := tx.DeleteCryptoKey(ctx, database.DeleteCryptoKeyParams{ + Feature: key.Feature, + Sequence: key.Sequence, + }) + if err != nil { + return xerrors.Errorf("delete key: %w", err) + } + modifiedKeys = append(modifiedKeys, deletedKey) + case shouldRotateKey(key, k.KeyDuration, now): + rotatedKeys, err := k.rotateKey(ctx, tx, key) + if err != nil { + return xerrors.Errorf("rotate key: %w", err) + } + modifiedKeys = append(modifiedKeys, rotatedKeys...) + default: + continue + } + } + } + return nil + }) +} + +func (k *KeyRotator) insertNewKey(ctx context.Context, tx database.Store, feature database.CryptoKeyFeature, now time.Time) (database.CryptoKey, error) { + secret, err := generateNewSecret(feature) + if err != nil { + return database.CryptoKey{}, xerrors.Errorf("generate new secret: %w", err) + } + + latestKey, err := tx.GetLatestCryptoKeyByFeature(ctx, feature) + if err != nil && !xerrors.Is(err, sql.ErrNoRows) { + return database.CryptoKey{}, xerrors.Errorf("get latest key: %w", err) + } + + newKey, err := tx.InsertCryptoKey(ctx, database.InsertCryptoKeyParams{ + Feature: feature, + // We'll assume that the first key we insert is 1. + Sequence: latestKey.Sequence + 1, + Secret: sql.NullString{ + String: secret, + Valid: true, + }, + StartsAt: now.UTC(), + }) + if err != nil { + return database.CryptoKey{}, xerrors.Errorf("inserting new key: %w", err) + } + + k.Logger.Info(ctx, "inserted new key for feature", slog.F("feature", feature)) + return newKey, nil +} + +func (k *KeyRotator) rotateKey(ctx context.Context, tx database.Store, key database.CryptoKey) ([]database.CryptoKey, error) { + // The starts at of the new key is the expiration of the old key. + newStartsAt := key.ExpiresAt(k.KeyDuration) + + secret, err := generateNewSecret(key.Feature) + if err != nil { + return nil, xerrors.Errorf("generate new secret: %w", err) + } + + // Insert new key + newKey, err := tx.InsertCryptoKey(ctx, database.InsertCryptoKeyParams{ + Feature: key.Feature, + Sequence: key.Sequence + 1, + Secret: sql.NullString{ + String: secret, + Valid: true, + }, + StartsAt: newStartsAt.UTC(), + }) + if err != nil { + return nil, xerrors.Errorf("inserting new key: %w", err) + } + + // Set old key's deletes_at + deletesAt := newStartsAt.Add(time.Hour).Add(tokenDuration(key.Feature)) + + updatedKey, err := tx.UpdateCryptoKeyDeletesAt(ctx, database.UpdateCryptoKeyDeletesAtParams{ + Feature: key.Feature, + Sequence: key.Sequence, + DeletesAt: sql.NullTime{ + Time: deletesAt.UTC(), + Valid: true, + }, + }) + if err != nil { + return nil, xerrors.Errorf("update old key's deletes_at: %w", err) + } + + return []database.CryptoKey{updatedKey, newKey}, nil +} + +func generateNewSecret(feature database.CryptoKeyFeature) (string, error) { + switch feature { + case database.CryptoKeyFeatureWorkspaceApps: + return generateKey(96) + case database.CryptoKeyFeatureOidcConvert: + return generateKey(32) + case database.CryptoKeyFeaturePeerReconnect: + return generateKey(64) + } + return "", xerrors.Errorf("unknown feature: %s", feature) +} + +func generateKey(length int) (string, error) { + b := make([]byte, length) + _, err := rand.Read(b) + if err != nil { + return "", xerrors.Errorf("rand read: %w", err) + } + return hex.EncodeToString(b), nil +} + +func tokenDuration(feature database.CryptoKeyFeature) time.Duration { + switch feature { + case database.CryptoKeyFeatureWorkspaceApps: + return WorkspaceAppsTokenDuration + case database.CryptoKeyFeatureOidcConvert: + return OIDCConvertTokenDuration + case database.CryptoKeyFeaturePeerReconnect: + return PeerReconnectTokenDuration + default: + return 0 + } +} + +func shouldDeleteKey(key database.CryptoKey, now time.Time) bool { + return key.DeletesAt.Valid && key.DeletesAt.Time.UTC().After(now.UTC()) +} + +func shouldRotateKey(key database.CryptoKey, keyDuration time.Duration, now time.Time) bool { + // If deletes_at is set, we've already inserted a key. + if key.DeletesAt.Valid { + return false + } + expirationTime := key.ExpiresAt(keyDuration) + return now.Add(time.Hour).UTC().After(expirationTime.UTC()) +} + +func keysByFeature(keys []database.CryptoKey, features []database.CryptoKeyFeature) map[database.CryptoKeyFeature][]database.CryptoKey { + m := map[database.CryptoKeyFeature][]database.CryptoKey{} + for _, feature := range features { + m[feature] = []database.CryptoKey{} + } + for _, key := range keys { + m[key.Feature] = append(m[key.Feature], key) + } + return m +} diff --git a/coderd/keyrotate/rotate_internal_test.go b/coderd/keyrotate/rotate_internal_test.go new file mode 100644 index 0000000000000..a0f7e9522507a --- /dev/null +++ b/coderd/keyrotate/rotate_internal_test.go @@ -0,0 +1,186 @@ +package keyrotate + +import ( + "database/sql" + "encoding/hex" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbgen" + "github.com/coder/coder/v2/coderd/database/dbtestutil" + "github.com/coder/coder/v2/coderd/database/dbtime" + "github.com/coder/coder/v2/testutil" + "github.com/coder/quartz" +) + +func Test_rotateKeys(t *testing.T) { + t.Parallel() + + t.Run("RotatesKeysNearExpiration", func(t *testing.T) { + t.Parallel() + + var ( + db, _ = dbtestutil.NewDB(t) + clock = quartz.NewMock(t) + keyDuration = time.Hour * 24 * 7 + logger = slogtest.Make(t, nil).Leveled(slog.LevelDebug) + ctx = testutil.Context(t, testutil.WaitShort) + resultsCh = make(chan []database.CryptoKey, 1) + ) + + kr := &KeyRotator{ + DB: db, + KeyDuration: keyDuration, + Clock: clock, + Logger: logger, + ScanInterval: 0, + ResultsCh: resultsCh, + features: []database.CryptoKeyFeature{ + database.CryptoKeyFeatureWorkspaceApps, + }, + } + + now := dbnow(clock) + + // Seed the database with an existing key. + oldKey := dbgen.CryptoKey(t, db, database.CryptoKey{ + Feature: database.CryptoKeyFeatureWorkspaceApps, + StartsAt: now, + Sequence: 15, + }) + + // Advance the window to just inside rotation time. + _ = clock.Advance(keyDuration - time.Minute*59) + keys, err := kr.rotateKeys(ctx) + require.NoError(t, err) + require.Len(t, keys, 2) + + now = dbnow(clock) + expectedDeletesAt := oldKey.ExpiresAt(keyDuration).Add(WorkspaceAppsTokenDuration + time.Hour) + + // Fetch the old key, it should have an expires_at now. + oldKey, err = db.GetCryptoKeyByFeatureAndSequence(ctx, database.GetCryptoKeyByFeatureAndSequenceParams{ + Feature: oldKey.Feature, + Sequence: oldKey.Sequence, + }) + require.NoError(t, err) + require.Equal(t, oldKey.DeletesAt.Time.UTC(), expectedDeletesAt) + + // The new key should be created and have a starts_at of the old key's expires_at. + newKey, err := db.GetCryptoKeyByFeatureAndSequence(ctx, database.GetCryptoKeyByFeatureAndSequenceParams{ + Feature: database.CryptoKeyFeatureWorkspaceApps, + Sequence: oldKey.Sequence + 1, + }) + require.NoError(t, err) + requireKey(t, newKey, database.CryptoKeyFeatureWorkspaceApps, oldKey.ExpiresAt(keyDuration), time.Time{}, oldKey.Sequence+1) + + // Advance the clock just past the keys delete time. + clock.Advance(oldKey.DeletesAt.Time.UTC().Sub(now) - time.Second) + + // We should have deleted the old key. + keys, err = kr.rotateKeys(ctx) + require.NoError(t, err) + require.Len(t, keys, 1) + + // The old key should be "deleted". + _, err = db.GetCryptoKeyByFeatureAndSequence(ctx, database.GetCryptoKeyByFeatureAndSequenceParams{ + Feature: oldKey.Feature, + Sequence: oldKey.Sequence, + }) + require.ErrorIs(t, err, sql.ErrNoRows) + }) + + t.Run("DoesNotRotateValidKeys", func(t *testing.T) { + t.Parallel() + + var ( + db, _ = dbtestutil.NewDB(t) + clock = quartz.NewMock(t) + keyDuration = time.Hour * 24 * 7 + logger = slogtest.Make(t, nil).Leveled(slog.LevelDebug) + ctx = testutil.Context(t, testutil.WaitShort) + resultsCh = make(chan []database.CryptoKey, 1) + ) + + kr := &KeyRotator{ + DB: db, + KeyDuration: keyDuration, + Clock: clock, + Logger: logger, + ScanInterval: 0, + ResultsCh: resultsCh, + features: []database.CryptoKeyFeature{ + database.CryptoKeyFeatureWorkspaceApps, + }, + } + + now := dbnow(clock) + + // Seed the database with an existing key + existingKey := dbgen.CryptoKey(t, db, database.CryptoKey{ + Feature: database.CryptoKeyFeatureWorkspaceApps, + StartsAt: now, + Sequence: 1, + }) + + // Advance the clock by 6 days, 23 hours. Once we + // breach the last hour we will insert a new key. + clock.Advance(keyDuration - time.Hour) + + keys, err := kr.rotateKeys(ctx) + require.NoError(t, err) + require.Empty(t, keys) + + // Verify that the existing key is still the only key in the database + dbKeys, err := db.GetCryptoKeys(ctx) + require.NoError(t, err) + require.Len(t, dbKeys, 1) + requireKey(t, dbKeys[0], existingKey.Feature, existingKey.StartsAt.UTC(), existingKey.DeletesAt.Time.UTC(), existingKey.Sequence) + }) + + t.Run("DeletesExpiredKeys", func(t *testing.T) { + t.Parallel() + // TODO: Implement test for deleting expired keys + }) + + t.Run("HandlesMultipleKeyTypes", func(t *testing.T) { + t.Parallel() + // TODO: Implement test for handling multiple key types + }) + + t.Run("GracefullyHandlesErrors", func(t *testing.T) { + t.Parallel() + // TODO: Implement test for error handling + }) +} + +func dbnow(c quartz.Clock) time.Time { + return dbtime.Time(c.Now().UTC()) +} + +func requireKey(t *testing.T, key database.CryptoKey, feature database.CryptoKeyFeature, startsAt time.Time, deletesAt time.Time, sequence int32) { + t.Helper() + require.Equal(t, feature, key.Feature) + require.Equal(t, startsAt, key.StartsAt.UTC()) + require.Equal(t, deletesAt, key.DeletesAt.Time.UTC()) + require.Equal(t, sequence, key.Sequence) + + secret, err := hex.DecodeString(key.Secret.String) + require.NoError(t, err) + + switch key.Feature { + case database.CryptoKeyFeatureOidcConvert: + require.Len(t, secret, 32) + case database.CryptoKeyFeatureWorkspaceApps: + require.Len(t, secret, 96) + case database.CryptoKeyFeaturePeerReconnect: + require.Len(t, secret, 64) + default: + t.Fatalf("unknown key feature: %s", key.Feature) + } +} diff --git a/coderd/keyrotate/rotate_test.go b/coderd/keyrotate/rotate_test.go new file mode 100644 index 0000000000000..5f82a0647c99d --- /dev/null +++ b/coderd/keyrotate/rotate_test.go @@ -0,0 +1,60 @@ +package keyrotate_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/coderd/database" +) + +func TestKeyRotator(t *testing.T) { + t.Run("NoExistingKeys", func(t *testing.T) { + // t.Parallel() + + // var ( + // db, _ = dbtestutil.NewDB(t) + // clock = quartz.NewMock(t) + // logger = slogtest.Make(t, nil).Leveled(slog.LevelDebug) + // ctx = testutil.Context(t, testutil.WaitShort) + // resultsCh = make(chan []database.CryptoKey, 1) + // ) + + // kr := &KeyRotator{ + // DB: db, + // KeyDuration: 0, + // Clock: clock, + // Logger: logger, + // ScanInterval: 0, + // ResultsCh: resultsCh, + // } + + // now := dbnow(clock) + // keys, err := kr.rotateKeys(ctx) + // require.NoError(t, err) + // require.Len(t, keys, len(database.AllCryptoKeyFeatureValues())) + + // // Fetch the keys from the database and ensure they + // // are as expected. + // dbkeys, err := db.GetCryptoKeys(ctx) + // require.NoError(t, err) + // require.Equal(t, keys, dbkeys) + // requireContainsAllFeatures(t, keys) + // for _, key := range keys { + // requireKey(t, key, key.Feature, now, time.Time{}, 1) + // } + }) + +} + +func requireContainsAllFeatures(t *testing.T, keys []database.CryptoKey) { + t.Helper() + + features := make(map[database.CryptoKeyFeature]bool) + for _, key := range keys { + features[key.Feature] = true + } + require.True(t, features[database.CryptoKeyFeatureOidcConvert]) + require.True(t, features[database.CryptoKeyFeatureWorkspaceApps]) + require.True(t, features[database.CryptoKeyFeaturePeerReconnect]) +} diff --git a/enterprise/dbcrypt/dbcrypt.go b/enterprise/dbcrypt/dbcrypt.go index ec56a4897a1e3..2717ef1d48188 100644 --- a/enterprise/dbcrypt/dbcrypt.go +++ b/enterprise/dbcrypt/dbcrypt.go @@ -261,6 +261,31 @@ func (db *dbCrypt) UpdateExternalAuthLink(ctx context.Context, params database.U return link, nil } +func (db *dbCrypt) GetCryptoKeyByFeatureAndSequence(ctx context.Context, params database.GetCryptoKeyByFeatureAndSequenceParams) (database.CryptoKey, error) { + key, err := db.Store.GetCryptoKeyByFeatureAndSequence(ctx, params) + if err != nil { + return database.CryptoKey{}, err + } + if err := db.decryptField(&key.Secret.String, key.SecretKeyID); err != nil { + return database.CryptoKey{}, err + } + return key, nil +} + +func (db *dbCrypt) InsertCryptoKey(ctx context.Context, params database.InsertCryptoKeyParams) (database.CryptoKey, error) { + if err := db.encryptField(¶ms.Secret.String, ¶ms.SecretKeyID); err != nil { + return database.CryptoKey{}, err + } + key, err := db.Store.InsertCryptoKey(ctx, params) + if err != nil { + return database.CryptoKey{}, err + } + if err := db.decryptField(&key.Secret.String, key.SecretKeyID); err != nil { + return database.CryptoKey{}, err + } + return key, nil +} + func (db *dbCrypt) encryptField(field *string, digest *sql.NullString) error { // If no cipher is loaded, then we can't encrypt anything! if db.ciphers == nil || db.primaryCipherDigest == "" { diff --git a/enterprise/dbcrypt/dbcrypt_internal_test.go b/enterprise/dbcrypt/dbcrypt_internal_test.go index 37fcc8cae55a3..7dad716b8139b 100644 --- a/enterprise/dbcrypt/dbcrypt_internal_test.go +++ b/enterprise/dbcrypt/dbcrypt_internal_test.go @@ -349,6 +349,51 @@ func TestExternalAuthLinks(t *testing.T) { }) } +func TestCryptoKeys(t *testing.T) { + t.Parallel() + ctx := context.Background() + db, crypt, ciphers := setup(t) + + // We don't write a GetCryptoKeyByFeatureAndSequence test + // because it's basically the same as InsertCryptoKey. + t.Run("InsertCryptoKey", func(t *testing.T) { + t.Parallel() + key := dbgen.CryptoKey(t, crypt, database.CryptoKey{ + Secret: sql.NullString{String: "test", Valid: true}, + }) + require.Equal(t, "test", key.Secret.String) + + key, err := db.GetCryptoKeyByFeatureAndSequence(ctx, database.GetCryptoKeyByFeatureAndSequenceParams{ + Feature: key.Feature, + Sequence: key.Sequence, + }) + require.NoError(t, err) + require.Equal(t, ciphers[0].HexDigest(), key.SecretKeyID.String) + requireEncryptedEquals(t, ciphers[0], key.Secret.String, "test") + }) + t.Run("DecryptErr", func(t *testing.T) { + t.Parallel() + db, crypt, ciphers := setup(t) + key := dbgen.CryptoKey(t, db, database.CryptoKey{ + Secret: sql.NullString{ + String: fakeBase64RandomData(t, 32), + Valid: true, + }, + SecretKeyID: sql.NullString{ + String: ciphers[0].HexDigest(), + Valid: true, + }, + }) + _, err := crypt.GetCryptoKeyByFeatureAndSequence(ctx, database.GetCryptoKeyByFeatureAndSequenceParams{ + Feature: key.Feature, + Sequence: key.Sequence, + }) + require.Error(t, err, "expected an error") + var derr *DecryptFailedError + require.ErrorAs(t, err, &derr, "expected a decrypt error") + }) +} + func TestNew(t *testing.T) { t.Parallel()