From 6f4a504ce3ad01f5586925de002f5273d4adbfe7 Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Tue, 10 Sep 2024 22:08:38 +0000 Subject: [PATCH 1/5] feat: add key rotation --- coderd/database/dbauthz/dbauthz.go | 20 +++ coderd/database/dbmem/dbmem.go | 40 +++++ coderd/database/dbmetrics/dbmetrics.go | 35 ++++ coderd/database/dump.sql | 11 ++ .../database/migrations/000250_keys.down.sql | 1 + coderd/database/migrations/000250_keys.up.sql | 8 + coderd/database/modelmethods.go | 4 + coderd/database/models.go | 8 + coderd/database/querier.go | 5 + coderd/database/queries.sql.go | 128 ++++++++++++++ coderd/database/queries/keys.sql | 47 +++++ coderd/database/unique_constraint.go | 1 + coderd/keyrotate/rotate.go | 164 ++++++++++++++++++ coderd/keyrotate/rotate_test.go | 41 +++++ 14 files changed, 513 insertions(+) create mode 100644 coderd/database/migrations/000250_keys.down.sql create mode 100644 coderd/database/migrations/000250_keys.up.sql create mode 100644 coderd/database/queries/keys.sql create mode 100644 coderd/keyrotate/rotate.go create mode 100644 coderd/keyrotate/rotate_test.go diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 077d704be1300..8c25b64857976 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -1078,6 +1078,10 @@ func (q *querier) DeleteGroupMemberFromGroup(ctx context.Context, arg database.D return update(q.log, q.auth, fetch, q.db.DeleteGroupMemberFromGroup)(ctx, arg) } +func (q *querier) DeleteKey(ctx context.Context, arg database.DeleteKeyParams) error { + panic("not implemented") +} + func (q *querier) DeleteLicense(ctx context.Context, id int32) (int32, error) { err := deleteQ(q.log, q.auth, q.db.GetLicenseByID, func(ctx context.Context, id int32) error { _, err := q.db.DeleteLicense(ctx, id) @@ -1542,6 +1546,14 @@ func (q *querier) GetJFrogXrayScanByWorkspaceAndAgentID(ctx context.Context, arg return q.db.GetJFrogXrayScanByWorkspaceAndAgentID(ctx, arg) } +func (q *querier) GetKeyByFeatureAndSequence(ctx context.Context, arg database.GetKeyByFeatureAndSequenceParams) (database.Key, error) { + panic("not implemented") +} + +func (q *querier) GetKeys(ctx context.Context) ([]database.Key, error) { + panic("not implemented") +} + func (q *querier) GetLastUpdateCheck(ctx context.Context) (string, error) { if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil { return "", err @@ -2726,6 +2738,10 @@ func (q *querier) InsertGroupMember(ctx context.Context, arg database.InsertGrou return update(q.log, q.auth, fetch, q.db.InsertGroupMember)(ctx, arg) } +func (q *querier) InsertKey(ctx context.Context, arg database.InsertKeyParams) error { + panic("not implemented") +} + func (q *querier) InsertLicense(ctx context.Context, arg database.InsertLicenseParams) (database.License, error) { if err := q.authorizeContext(ctx, policy.ActionCreate, rbac.ResourceLicense); err != nil { return database.License{}, err @@ -3212,6 +3228,10 @@ func (q *querier) UpdateInactiveUsersToDormant(ctx context.Context, lastSeenAfte return q.db.UpdateInactiveUsersToDormant(ctx, lastSeenAfter) } +func (q *querier) UpdateKeyDeletesAt(ctx context.Context, arg database.UpdateKeyDeletesAtParams) error { + panic("not implemented") +} + func (q *querier) UpdateMemberRoles(ctx context.Context, arg database.UpdateMemberRolesParams) (database.OrganizationMember, error) { // Authorized fetch will check that the actor has read access to the org member since the org member is returned. member, err := database.ExpectOne(q.OrganizationMembers(ctx, database.OrganizationMembersParams{ diff --git a/coderd/database/dbmem/dbmem.go b/coderd/database/dbmem/dbmem.go index ed766d48ecd43..f0b1a80325f38 100644 --- a/coderd/database/dbmem/dbmem.go +++ b/coderd/database/dbmem/dbmem.go @@ -1527,6 +1527,15 @@ func (q *FakeQuerier) DeleteGroupMemberFromGroup(_ context.Context, arg database return nil } +func (q *FakeQuerier) DeleteKey(ctx context.Context, arg database.DeleteKeyParams) error { + err := validateDatabaseType(arg) + if err != nil { + return err + } + + panic("not implemented") +} + func (q *FakeQuerier) DeleteLicense(_ context.Context, id int32) (int32, error) { q.mutex.Lock() defer q.mutex.Unlock() @@ -2796,6 +2805,19 @@ func (q *FakeQuerier) GetJFrogXrayScanByWorkspaceAndAgentID(_ context.Context, a return database.JfrogXrayScan{}, sql.ErrNoRows } +func (q *FakeQuerier) GetKeyByFeatureAndSequence(ctx context.Context, arg database.GetKeyByFeatureAndSequenceParams) (database.Key, error) { + err := validateDatabaseType(arg) + if err != nil { + return database.Key{}, err + } + + panic("not implemented") +} + +func (q *FakeQuerier) GetKeys(ctx context.Context) ([]database.Key, error) { + panic("not implemented") +} + func (q *FakeQuerier) GetLastUpdateCheck(_ context.Context) (string, error) { q.mutex.RLock() defer q.mutex.RUnlock() @@ -6492,6 +6514,15 @@ func (q *FakeQuerier) InsertGroupMember(_ context.Context, arg database.InsertGr return nil } +func (q *FakeQuerier) InsertKey(ctx context.Context, arg database.InsertKeyParams) error { + err := validateDatabaseType(arg) + if err != nil { + return err + } + + panic("not implemented") +} + func (q *FakeQuerier) InsertLicense( _ context.Context, arg database.InsertLicenseParams, ) (database.License, error) { @@ -7890,6 +7921,15 @@ func (q *FakeQuerier) UpdateInactiveUsersToDormant(_ context.Context, params dat return updated, nil } +func (q *FakeQuerier) UpdateKeyDeletesAt(ctx context.Context, arg database.UpdateKeyDeletesAtParams) error { + err := validateDatabaseType(arg) + if err != nil { + return err + } + + panic("not implemented") +} + func (q *FakeQuerier) UpdateMemberRoles(_ context.Context, arg database.UpdateMemberRolesParams) (database.OrganizationMember, error) { if err := validateDatabaseType(arg); err != nil { return database.OrganizationMember{}, err diff --git a/coderd/database/dbmetrics/dbmetrics.go b/coderd/database/dbmetrics/dbmetrics.go index 0ec70c1736d43..22ce3b085fed9 100644 --- a/coderd/database/dbmetrics/dbmetrics.go +++ b/coderd/database/dbmetrics/dbmetrics.go @@ -249,6 +249,13 @@ func (m metricsStore) DeleteGroupMemberFromGroup(ctx context.Context, arg databa return err } +func (m metricsStore) DeleteKey(ctx context.Context, arg database.DeleteKeyParams) error { + start := time.Now() + r0 := m.s.DeleteKey(ctx, arg) + m.queryLatencies.WithLabelValues("DeleteKey").Observe(time.Since(start).Seconds()) + return r0 +} + func (m metricsStore) DeleteLicense(ctx context.Context, id int32) (int32, error) { start := time.Now() licenseID, err := m.s.DeleteLicense(ctx, id) @@ -704,6 +711,20 @@ func (m metricsStore) GetJFrogXrayScanByWorkspaceAndAgentID(ctx context.Context, return r0, r1 } +func (m metricsStore) GetKeyByFeatureAndSequence(ctx context.Context, arg database.GetKeyByFeatureAndSequenceParams) (database.Key, error) { + start := time.Now() + r0, r1 := m.s.GetKeyByFeatureAndSequence(ctx, arg) + m.queryLatencies.WithLabelValues("GetKeyByFeatureAndSequence").Observe(time.Since(start).Seconds()) + return r0, r1 +} + +func (m metricsStore) GetKeys(ctx context.Context) ([]database.Key, error) { + start := time.Now() + r0, r1 := m.s.GetKeys(ctx) + m.queryLatencies.WithLabelValues("GetKeys").Observe(time.Since(start).Seconds()) + return r0, r1 +} + func (m metricsStore) GetLastUpdateCheck(ctx context.Context) (string, error) { start := time.Now() version, err := m.s.GetLastUpdateCheck(ctx) @@ -1656,6 +1677,13 @@ func (m metricsStore) InsertGroupMember(ctx context.Context, arg database.Insert return err } +func (m metricsStore) InsertKey(ctx context.Context, arg database.InsertKeyParams) error { + start := time.Now() + r0 := m.s.InsertKey(ctx, arg) + m.queryLatencies.WithLabelValues("InsertKey").Observe(time.Since(start).Seconds()) + return r0 +} + func (m metricsStore) InsertLicense(ctx context.Context, arg database.InsertLicenseParams) (database.License, error) { start := time.Now() license, err := m.s.InsertLicense(ctx, arg) @@ -2027,6 +2055,13 @@ func (m metricsStore) UpdateInactiveUsersToDormant(ctx context.Context, lastSeen return r0, r1 } +func (m metricsStore) UpdateKeyDeletesAt(ctx context.Context, arg database.UpdateKeyDeletesAtParams) error { + start := time.Now() + r0 := m.s.UpdateKeyDeletesAt(ctx, arg) + m.queryLatencies.WithLabelValues("UpdateKeyDeletesAt").Observe(time.Since(start).Seconds()) + return r0 +} + func (m metricsStore) UpdateMemberRoles(ctx context.Context, arg database.UpdateMemberRolesParams) (database.OrganizationMember, error) { start := time.Now() member, err := m.s.UpdateMemberRoles(ctx, arg) diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index 6638d52745ba6..bef8c1f9376af 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -668,6 +668,14 @@ CREATE TABLE jfrog_xray_scans ( results_url text DEFAULT ''::text NOT NULL ); +CREATE TABLE keys ( + feature text NOT NULL, + sequence integer NOT NULL, + secret text, + starts_at timestamp with time zone NOT NULL, + deletes_at timestamp with time zone +); + CREATE TABLE licenses ( id integer NOT NULL, uploaded_at timestamp with time zone NOT NULL, @@ -1676,6 +1684,9 @@ ALTER TABLE ONLY groups ALTER TABLE ONLY jfrog_xray_scans ADD CONSTRAINT jfrog_xray_scans_pkey PRIMARY KEY (agent_id, workspace_id); +ALTER TABLE ONLY keys + ADD CONSTRAINT keys_pkey PRIMARY KEY (feature, sequence); + ALTER TABLE ONLY licenses ADD CONSTRAINT licenses_jwt_key UNIQUE (jwt); diff --git a/coderd/database/migrations/000250_keys.down.sql b/coderd/database/migrations/000250_keys.down.sql new file mode 100644 index 0000000000000..8b0cc3702bcc4 --- /dev/null +++ b/coderd/database/migrations/000250_keys.down.sql @@ -0,0 +1 @@ +DROP TABLE "keys"; diff --git a/coderd/database/migrations/000250_keys.up.sql b/coderd/database/migrations/000250_keys.up.sql new file mode 100644 index 0000000000000..67f6cf8f14a29 --- /dev/null +++ b/coderd/database/migrations/000250_keys.up.sql @@ -0,0 +1,8 @@ +CREATE TABLE "keys" ( + "feature" text NOT NULL, + "sequence" integer NOT NULL, + "secret" text NULL, + "starts_at" timestamptz NOT NULL, + "deletes_at" timestamptz NULL, + PRIMARY KEY ("feature", "sequence") +); diff --git a/coderd/database/modelmethods.go b/coderd/database/modelmethods.go index 816fc4c9214b0..5698fa92cdb27 100644 --- a/coderd/database/modelmethods.go +++ b/coderd/database/modelmethods.go @@ -447,3 +447,7 @@ func (r GetAuthorizationUserRolesRow) RoleNames() ([]rbac.RoleIdentifier, error) } return names, nil } + +func (k Key) ExpiresAt(keyDuration time.Duration) time.Time { + return k.StartsAt.Add(keyDuration) +} diff --git a/coderd/database/models.go b/coderd/database/models.go index 9e0283ba859c1..3642cef53f405 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -2155,6 +2155,14 @@ type JfrogXrayScan struct { ResultsUrl string `db:"results_url" json:"results_url"` } +type Key struct { + Feature string `db:"feature" json:"feature"` + Sequence int32 `db:"sequence" json:"sequence"` + Secret sql.NullString `db:"secret" json:"secret"` + StartsAt time.Time `db:"starts_at" json:"starts_at"` + DeletesAt sql.NullTime `db:"deletes_at" json:"deletes_at"` +} + type License struct { ID int32 `db:"id" json:"id"` UploadedAt time.Time `db:"uploaded_at" json:"uploaded_at"` diff --git a/coderd/database/querier.go b/coderd/database/querier.go index ee9a64f12076d..c342c00581ad4 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -74,6 +74,7 @@ type sqlcQuerier interface { DeleteGitSSHKey(ctx context.Context, userID uuid.UUID) error DeleteGroupByID(ctx context.Context, id uuid.UUID) error DeleteGroupMemberFromGroup(ctx context.Context, arg DeleteGroupMemberFromGroupParams) error + DeleteKey(ctx context.Context, arg DeleteKeyParams) error DeleteLicense(ctx context.Context, id int32) (int32, error) DeleteOAuth2ProviderAppByID(ctx context.Context, id uuid.UUID) error DeleteOAuth2ProviderAppCodeByID(ctx context.Context, id uuid.UUID) error @@ -158,6 +159,8 @@ type sqlcQuerier interface { GetHealthSettings(ctx context.Context) (string, error) GetHungProvisionerJobs(ctx context.Context, updatedAt time.Time) ([]ProvisionerJob, error) GetJFrogXrayScanByWorkspaceAndAgentID(ctx context.Context, arg GetJFrogXrayScanByWorkspaceAndAgentIDParams) (JfrogXrayScan, error) + GetKeyByFeatureAndSequence(ctx context.Context, arg GetKeyByFeatureAndSequenceParams) (Key, error) + GetKeys(ctx context.Context) ([]Key, error) GetLastUpdateCheck(ctx context.Context) (string, error) GetLatestWorkspaceBuildByWorkspaceID(ctx context.Context, workspaceID uuid.UUID) (WorkspaceBuild, error) GetLatestWorkspaceBuilds(ctx context.Context) ([]WorkspaceBuild, error) @@ -346,6 +349,7 @@ type sqlcQuerier interface { InsertGitSSHKey(ctx context.Context, arg InsertGitSSHKeyParams) (GitSSHKey, error) InsertGroup(ctx context.Context, arg InsertGroupParams) (Group, error) InsertGroupMember(ctx context.Context, arg InsertGroupMemberParams) error + InsertKey(ctx context.Context, arg InsertKeyParams) error InsertLicense(ctx context.Context, arg InsertLicenseParams) (License, error) // Inserts any group by name that does not exist. All new groups are given // a random uuid, are inserted into the same organization. They have the default @@ -415,6 +419,7 @@ type sqlcQuerier interface { UpdateGitSSHKey(ctx context.Context, arg UpdateGitSSHKeyParams) (GitSSHKey, error) UpdateGroupByID(ctx context.Context, arg UpdateGroupByIDParams) (Group, error) UpdateInactiveUsersToDormant(ctx context.Context, arg UpdateInactiveUsersToDormantParams) ([]UpdateInactiveUsersToDormantRow, error) + UpdateKeyDeletesAt(ctx context.Context, arg UpdateKeyDeletesAtParams) error UpdateMemberRoles(ctx context.Context, arg UpdateMemberRolesParams) (OrganizationMember, error) UpdateNotificationTemplateMethodByID(ctx context.Context, arg UpdateNotificationTemplateMethodByIDParams) (NotificationTemplate, error) UpdateOAuth2ProviderAppByID(ctx context.Context, arg UpdateOAuth2ProviderAppByIDParams) (OAuth2ProviderApp, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 6831415907b67..de31a01a189aa 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -3164,6 +3164,134 @@ func (q *sqlQuerier) UpsertJFrogXrayScanByWorkspaceAndAgentID(ctx context.Contex return err } +const deleteKey = `-- name: DeleteKey :exec +UPDATE keys +SET secret = NULL +WHERE feature = $1 AND sequence = $2 +` + +type DeleteKeyParams struct { + Feature string `db:"feature" json:"feature"` + Sequence int32 `db:"sequence" json:"sequence"` +} + +func (q *sqlQuerier) DeleteKey(ctx context.Context, arg DeleteKeyParams) error { + _, err := q.db.ExecContext(ctx, deleteKey, arg.Feature, arg.Sequence) + return err +} + +const getKeyByFeatureAndSequence = `-- name: GetKeyByFeatureAndSequence :one +SELECT feature, sequence, secret, starts_at, deletes_at +FROM keys +WHERE feature = $1 + AND sequence = $2 + AND secret IS NOT NULL + AND $3 >= starts_at + AND ($3 < deletes_at OR deletes_at IS NULL) +` + +type GetKeyByFeatureAndSequenceParams struct { + Feature string `db:"feature" json:"feature"` + Sequence int32 `db:"sequence" json:"sequence"` + StartsAt time.Time `db:"starts_at" json:"starts_at"` +} + +func (q *sqlQuerier) GetKeyByFeatureAndSequence(ctx context.Context, arg GetKeyByFeatureAndSequenceParams) (Key, error) { + row := q.db.QueryRowContext(ctx, getKeyByFeatureAndSequence, arg.Feature, arg.Sequence, arg.StartsAt) + var i Key + err := row.Scan( + &i.Feature, + &i.Sequence, + &i.Secret, + &i.StartsAt, + &i.DeletesAt, + ) + return i, err +} + +const getKeys = `-- name: GetKeys :many +SELECT feature, sequence, secret, starts_at, deletes_at +FROM keys +WHERE secret IS NOT NULL +` + +func (q *sqlQuerier) GetKeys(ctx context.Context) ([]Key, error) { + rows, err := q.db.QueryContext(ctx, getKeys) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Key + for rows.Next() { + var i Key + if err := rows.Scan( + &i.Feature, + &i.Sequence, + &i.Secret, + &i.StartsAt, + &i.DeletesAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const insertKey = `-- name: InsertKey :exec +INSERT INTO keys ( + feature, + sequence, + secret, + starts_at +) VALUES ( + $1, + $2, + $3, + $4 +) +` + +type InsertKeyParams struct { + Feature string `db:"feature" json:"feature"` + Sequence int32 `db:"sequence" json:"sequence"` + Secret sql.NullString `db:"secret" json:"secret"` + StartsAt time.Time `db:"starts_at" json:"starts_at"` +} + +func (q *sqlQuerier) InsertKey(ctx context.Context, arg InsertKeyParams) error { + _, err := q.db.ExecContext(ctx, insertKey, + arg.Feature, + arg.Sequence, + arg.Secret, + arg.StartsAt, + ) + return err +} + +const updateKeyDeletesAt = `-- name: UpdateKeyDeletesAt :exec +UPDATE keys +SET deletes_at = $3 +WHERE feature = $1 AND sequence = $2 +` + +type UpdateKeyDeletesAtParams struct { + Feature string `db:"feature" json:"feature"` + Sequence int32 `db:"sequence" json:"sequence"` + DeletesAt sql.NullTime `db:"deletes_at" json:"deletes_at"` +} + +func (q *sqlQuerier) UpdateKeyDeletesAt(ctx context.Context, arg UpdateKeyDeletesAtParams) error { + _, err := q.db.ExecContext(ctx, updateKeyDeletesAt, arg.Feature, arg.Sequence, arg.DeletesAt) + return err +} + const deleteLicense = `-- name: DeleteLicense :one DELETE FROM licenses diff --git a/coderd/database/queries/keys.sql b/coderd/database/queries/keys.sql new file mode 100644 index 0000000000000..d72b0f6132335 --- /dev/null +++ b/coderd/database/queries/keys.sql @@ -0,0 +1,47 @@ +-- name: GetKeys :many +SELECT * +FROM keys +WHERE secret IS NOT NULL; + +-- name: GetKeyByFeatureAndSequence :one +SELECT * +FROM keys +WHERE feature = $1 + AND sequence = $2 + AND secret IS NOT NULL + AND $3 >= starts_at + AND ($3 < deletes_at OR deletes_at IS NULL); + +-- name: DeleteKey :exec +UPDATE keys +SET secret = NULL +WHERE feature = $1 AND sequence = $2; + + +-- name: InsertKey :exec +INSERT INTO keys ( + feature, + sequence, + secret, + starts_at +) VALUES ( + $1, + $2, + $3, + $4 +); + +-- name: UpdateKeyDeletesAt :exec +UPDATE keys +SET deletes_at = $3 +WHERE feature = $1 AND sequence = $2; + + + + + + + + + + diff --git a/coderd/database/unique_constraint.go b/coderd/database/unique_constraint.go index b3bf72f8178b6..cc6d2c6758d9f 100644 --- a/coderd/database/unique_constraint.go +++ b/coderd/database/unique_constraint.go @@ -21,6 +21,7 @@ const ( UniqueGroupsNameOrganizationIDKey UniqueConstraint = "groups_name_organization_id_key" // ALTER TABLE ONLY groups ADD CONSTRAINT groups_name_organization_id_key UNIQUE (name, organization_id); UniqueGroupsPkey UniqueConstraint = "groups_pkey" // ALTER TABLE ONLY groups ADD CONSTRAINT groups_pkey PRIMARY KEY (id); UniqueJfrogXrayScansPkey UniqueConstraint = "jfrog_xray_scans_pkey" // ALTER TABLE ONLY jfrog_xray_scans ADD CONSTRAINT jfrog_xray_scans_pkey PRIMARY KEY (agent_id, workspace_id); + UniqueKeysPkey UniqueConstraint = "keys_pkey" // ALTER TABLE ONLY keys ADD CONSTRAINT keys_pkey PRIMARY KEY (feature, sequence); UniqueLicensesJWTKey UniqueConstraint = "licenses_jwt_key" // ALTER TABLE ONLY licenses ADD CONSTRAINT licenses_jwt_key UNIQUE (jwt); UniqueLicensesPkey UniqueConstraint = "licenses_pkey" // ALTER TABLE ONLY licenses ADD CONSTRAINT licenses_pkey PRIMARY KEY (id); UniqueNotificationMessagesPkey UniqueConstraint = "notification_messages_pkey" // ALTER TABLE ONLY notification_messages ADD CONSTRAINT notification_messages_pkey PRIMARY KEY (id); diff --git a/coderd/keyrotate/rotate.go b/coderd/keyrotate/rotate.go new file mode 100644 index 0000000000000..381cd5d5805e7 --- /dev/null +++ b/coderd/keyrotate/rotate.go @@ -0,0 +1,164 @@ +package keyrotate + +import ( + "context" + "crypto/rand" + "database/sql" + "encoding/hex" + "log/slog" + "time" + + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbtime" + "github.com/coder/quartz" +) + +const ( + WorkspaceAppsTokenDuration = time.Minute + OIDCConvertTokenDuration = time.Minute * 5 + PeerReconnectTokenDuration = time.Hour * 24 +) + +type KeyRotator struct { + DB database.Store + KeyDuration time.Duration + Clock quartz.Clock + Logger slog.Logger + ScanInterval time.Duration +} + +func (kr *KeyRotator) Start(ctx context.Context) { + ticker := kr.Clock.NewTicker(kr.ScanInterval) + defer ticker.Stop() + + for { + err := kr.rotateKeys(ctx) + if err != nil { + kr.Logger.Error("rotate keys", slog.Any("error", err)) + } + + select { + case <-ctx.Done(): + return + case <-ticker.C: + } + } +} + +// rotateKeys checks for keys nearing expiration and rotates them if necessary. +func (kr *KeyRotator) rotateKeys(ctx context.Context) error { + return database.ReadModifyUpdate(kr.DB, func(tx database.Store) error { + keys, err := tx.GetKeys(ctx) + if err != nil { + return xerrors.Errorf("get keys: %w", err) + } + + now := dbtime.Time(kr.Clock.Now()) + for _, key := range keys { + switch { + case shouldDeleteKey(key, now): + err := tx.DeleteKey(ctx, database.DeleteKeyParams{ + Feature: key.Feature, + Sequence: key.Sequence, + }) + if err != nil { + return xerrors.Errorf("delete key: %w", err) + } + case shouldRotateKey(key, kr.KeyDuration, now): + err := kr.rotateKey(ctx, tx, key) + if err != nil { + return xerrors.Errorf("rotate key: %w", err) + } + default: + continue + } + } + return nil + }) +} + +func (kr *KeyRotator) rotateKey(ctx context.Context, tx database.Store, key database.Key) error { + newStartsAt := key.ExpiresAt(kr.KeyDuration) + + secret, err := generateNewSecret(key.Feature) + if err != nil { + return xerrors.Errorf("generate new secret: %w", err) + } + + // Insert new key + err = tx.InsertKey(ctx, database.InsertKeyParams{ + Feature: key.Feature, + Sequence: key.Sequence + 1, + Secret: sql.NullString{ + String: secret, + Valid: true, + }, + StartsAt: newStartsAt, + }) + if err != nil { + return xerrors.Errorf("inserting new key: %w", err) + } + + // Update old key's deletes_at + maxTokenLength := tokenLength(key.Feature) + deletesAt := newStartsAt.Add(time.Hour).Add(maxTokenLength) + + err = tx.UpdateKeyDeletesAt(ctx, database.UpdateKeyDeletesAtParams{ + Feature: key.Feature, + Sequence: key.Sequence, + DeletesAt: sql.NullTime{ + Time: deletesAt, + Valid: true, + }, + }) + if err != nil { + return xerrors.Errorf("update old key's deletes_at: %w", err) + } + + return nil +} + +func generateNewSecret(feature string) (string, error) { + switch feature { + case "workspace_apps": + return generateKey(96) + case "oidc_convert": + return generateKey(32) + case "peer_reconnect": + return generateKey(64) + } + return "", xerrors.Errorf("unknown feature: %s", feature) +} + +func generateKey(length int) (string, error) { + b := make([]byte, length) + _, err := rand.Read(b) + if err != nil { + return "", xerrors.Errorf("rand read: %w", err) + } + return hex.EncodeToString(b), nil +} + +func tokenLength(feature string) time.Duration { + switch feature { + case "workspace_apps": + return WorkspaceAppsTokenDuration + case "oidc_convert": + return OIDCConvertTokenDuration + case "peer_reconnect": + return PeerReconnectTokenDuration + default: + return 0 + } +} + +func shouldDeleteKey(key database.Key, now time.Time) bool { + return key.DeletesAt.Valid && key.DeletesAt.Time.After(now) +} + +func shouldRotateKey(key database.Key, keyDuration time.Duration, now time.Time) bool { + expirationTime := key.ExpiresAt(keyDuration) + return now.Add(time.Hour).After(expirationTime) +} diff --git a/coderd/keyrotate/rotate_test.go b/coderd/keyrotate/rotate_test.go new file mode 100644 index 0000000000000..d7b48d32dc62d --- /dev/null +++ b/coderd/keyrotate/rotate_test.go @@ -0,0 +1,41 @@ +package keyrotate_test + +import ( + "context" + "testing" +) + +func TestKeyRotator(t *testing.T) { + t.Parallel() + + t.Run("RotatesKeysNearExpiration", func(t *testing.T) { + t.Parallel() + + db := database.NewTestDB(t) + kr := &KeyRotator{ + DB: db, + } + + kr.rotateKeys(context.Background()) + }) + + t.Run("DoesNotRotateValidKeys", func(t *testing.T) { + t.Parallel() + // TODO: Implement test to ensure valid keys are not rotated + }) + + t.Run("DeletesExpiredKeys", func(t *testing.T) { + t.Parallel() + // TODO: Implement test for deleting expired keys + }) + + t.Run("HandlesMultipleKeyTypes", func(t *testing.T) { + t.Parallel() + // TODO: Implement test for handling multiple key types + }) + + t.Run("GracefullyHandlesErrors", func(t *testing.T) { + t.Parallel() + // TODO: Implement test for error handling + }) +} From 1c7bfc8da34810836f0190b9ac55e81802d2649d Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Thu, 12 Sep 2024 16:32:55 +0000 Subject: [PATCH 2/5] refactor: replace keys with crypto_keys handling --- coderd/database/dbauthz/dbauthz.go | 44 ++--- coderd/database/dbmem/dbmem.go | 84 ++++----- coderd/database/dbmetrics/dbmetrics.go | 77 ++++---- coderd/database/dbmock/dbmock.go | 90 ++++++++++ coderd/database/dump.sql | 28 +-- ...s.down.sql => 000250_crypto_keys.down.sql} | 0 .../migrations/000250_crypto_keys.up.sql | 15 ++ coderd/database/migrations/000250_keys.up.sql | 8 - coderd/database/modelmethods.go | 2 +- coderd/database/models.go | 77 +++++++- coderd/database/querier.go | 11 +- coderd/database/queries.sql.go | 135 +++++++++----- coderd/database/queries/keys.sql | 35 ++-- coderd/database/unique_constraint.go | 2 +- coderd/keyrotate/rotate.go | 169 +++++++++++++----- coderd/keyrotate/rotate_test.go | 67 ++++++- 16 files changed, 605 insertions(+), 239 deletions(-) rename coderd/database/migrations/{000250_keys.down.sql => 000250_crypto_keys.down.sql} (100%) create mode 100644 coderd/database/migrations/000250_crypto_keys.up.sql delete mode 100644 coderd/database/migrations/000250_keys.up.sql diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 8c25b64857976..d3bbceec14252 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -1041,6 +1041,10 @@ func (q *querier) DeleteCoordinator(ctx context.Context, id uuid.UUID) error { return q.db.DeleteCoordinator(ctx, id) } +func (q *querier) DeleteCryptoKey(ctx context.Context, arg database.DeleteCryptoKeyParams) (database.CryptoKey, error) { + panic("not implemented") +} + func (q *querier) DeleteCustomRole(ctx context.Context, arg database.DeleteCustomRoleParams) error { if arg.OrganizationID.UUID != uuid.Nil { if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceAssignOrgRole.InOrg(arg.OrganizationID.UUID)); err != nil { @@ -1078,10 +1082,6 @@ func (q *querier) DeleteGroupMemberFromGroup(ctx context.Context, arg database.D return update(q.log, q.auth, fetch, q.db.DeleteGroupMemberFromGroup)(ctx, arg) } -func (q *querier) DeleteKey(ctx context.Context, arg database.DeleteKeyParams) error { - panic("not implemented") -} - func (q *querier) DeleteLicense(ctx context.Context, id int32) (int32, error) { err := deleteQ(q.log, q.auth, q.db.GetLicenseByID, func(ctx context.Context, id int32) error { _, err := q.db.DeleteLicense(ctx, id) @@ -1387,6 +1387,14 @@ func (q *querier) GetCoordinatorResumeTokenSigningKey(ctx context.Context) (stri return q.db.GetCoordinatorResumeTokenSigningKey(ctx) } +func (q *querier) GetCryptoKeyByFeatureAndSequence(ctx context.Context, arg database.GetCryptoKeyByFeatureAndSequenceParams) (database.CryptoKey, error) { + panic("not implemented") +} + +func (q *querier) GetCryptoKeys(ctx context.Context) ([]database.CryptoKey, error) { + panic("not implemented") +} + func (q *querier) GetDBCryptKeys(ctx context.Context) ([]database.DBCryptKey, error) { if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil { return nil, err @@ -1546,14 +1554,6 @@ func (q *querier) GetJFrogXrayScanByWorkspaceAndAgentID(ctx context.Context, arg return q.db.GetJFrogXrayScanByWorkspaceAndAgentID(ctx, arg) } -func (q *querier) GetKeyByFeatureAndSequence(ctx context.Context, arg database.GetKeyByFeatureAndSequenceParams) (database.Key, error) { - panic("not implemented") -} - -func (q *querier) GetKeys(ctx context.Context) ([]database.Key, error) { - panic("not implemented") -} - func (q *querier) GetLastUpdateCheck(ctx context.Context) (string, error) { if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil { return "", err @@ -1561,6 +1561,10 @@ func (q *querier) GetLastUpdateCheck(ctx context.Context) (string, error) { return q.db.GetLastUpdateCheck(ctx) } +func (q *querier) GetLatestCryptoKeyByFeature(ctx context.Context, feature database.CryptoKeyFeature) (database.CryptoKey, error) { + panic("not implemented") +} + func (q *querier) GetLatestWorkspaceBuildByWorkspaceID(ctx context.Context, workspaceID uuid.UUID) (database.WorkspaceBuild, error) { if _, err := q.GetWorkspaceByID(ctx, workspaceID); err != nil { return database.WorkspaceBuild{}, err @@ -2666,6 +2670,10 @@ func (q *querier) InsertAuditLog(ctx context.Context, arg database.InsertAuditLo return insert(q.log, q.auth, rbac.ResourceAuditLog, q.db.InsertAuditLog)(ctx, arg) } +func (q *querier) InsertCryptoKey(ctx context.Context, arg database.InsertCryptoKeyParams) (database.CryptoKey, error) { + panic("not implemented") +} + func (q *querier) InsertCustomRole(ctx context.Context, arg database.InsertCustomRoleParams) (database.CustomRole, error) { // Org and site role upsert share the same query. So switch the assertion based on the org uuid. if arg.OrganizationID.UUID != uuid.Nil { @@ -2738,10 +2746,6 @@ func (q *querier) InsertGroupMember(ctx context.Context, arg database.InsertGrou return update(q.log, q.auth, fetch, q.db.InsertGroupMember)(ctx, arg) } -func (q *querier) InsertKey(ctx context.Context, arg database.InsertKeyParams) error { - panic("not implemented") -} - func (q *querier) InsertLicense(ctx context.Context, arg database.InsertLicenseParams) (database.License, error) { if err := q.authorizeContext(ctx, policy.ActionCreate, rbac.ResourceLicense); err != nil { return database.License{}, err @@ -3173,6 +3177,10 @@ func (q *querier) UpdateAPIKeyByID(ctx context.Context, arg database.UpdateAPIKe return update(q.log, q.auth, fetch, q.db.UpdateAPIKeyByID)(ctx, arg) } +func (q *querier) UpdateCryptoKeyDeletesAt(ctx context.Context, arg database.UpdateCryptoKeyDeletesAtParams) (database.CryptoKey, error) { + panic("not implemented") +} + func (q *querier) UpdateCustomRole(ctx context.Context, arg database.UpdateCustomRoleParams) (database.CustomRole, error) { if arg.OrganizationID.UUID != uuid.Nil { if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceAssignOrgRole.InOrg(arg.OrganizationID.UUID)); err != nil { @@ -3228,10 +3236,6 @@ func (q *querier) UpdateInactiveUsersToDormant(ctx context.Context, lastSeenAfte return q.db.UpdateInactiveUsersToDormant(ctx, lastSeenAfter) } -func (q *querier) UpdateKeyDeletesAt(ctx context.Context, arg database.UpdateKeyDeletesAtParams) error { - panic("not implemented") -} - func (q *querier) UpdateMemberRoles(ctx context.Context, arg database.UpdateMemberRolesParams) (database.OrganizationMember, error) { // Authorized fetch will check that the actor has read access to the org member since the org member is returned. member, err := database.ExpectOne(q.OrganizationMembers(ctx, database.OrganizationMembersParams{ diff --git a/coderd/database/dbmem/dbmem.go b/coderd/database/dbmem/dbmem.go index f0b1a80325f38..d3a4342b9c7a8 100644 --- a/coderd/database/dbmem/dbmem.go +++ b/coderd/database/dbmem/dbmem.go @@ -1434,6 +1434,15 @@ func (*FakeQuerier) DeleteCoordinator(context.Context, uuid.UUID) error { return ErrUnimplemented } +func (q *FakeQuerier) DeleteCryptoKey(ctx context.Context, arg database.DeleteCryptoKeyParams) (database.CryptoKey, error) { + err := validateDatabaseType(arg) + if err != nil { + return database.CryptoKey{}, err + } + + panic("not implemented") +} + func (q *FakeQuerier) DeleteCustomRole(_ context.Context, arg database.DeleteCustomRoleParams) error { err := validateDatabaseType(arg) if err != nil { @@ -1527,15 +1536,6 @@ func (q *FakeQuerier) DeleteGroupMemberFromGroup(_ context.Context, arg database return nil } -func (q *FakeQuerier) DeleteKey(ctx context.Context, arg database.DeleteKeyParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - panic("not implemented") -} - func (q *FakeQuerier) DeleteLicense(_ context.Context, id int32) (int32, error) { q.mutex.Lock() defer q.mutex.Unlock() @@ -2318,6 +2318,19 @@ func (q *FakeQuerier) GetCoordinatorResumeTokenSigningKey(_ context.Context) (st return q.coordinatorResumeTokenSigningKey, nil } +func (q *FakeQuerier) GetCryptoKeyByFeatureAndSequence(ctx context.Context, arg database.GetCryptoKeyByFeatureAndSequenceParams) (database.CryptoKey, error) { + err := validateDatabaseType(arg) + if err != nil { + return database.CryptoKey{}, err + } + + panic("not implemented") +} + +func (q *FakeQuerier) GetCryptoKeys(ctx context.Context) ([]database.CryptoKey, error) { + panic("not implemented") +} + func (q *FakeQuerier) GetDBCryptKeys(_ context.Context) ([]database.DBCryptKey, error) { q.mutex.RLock() defer q.mutex.RUnlock() @@ -2805,19 +2818,6 @@ func (q *FakeQuerier) GetJFrogXrayScanByWorkspaceAndAgentID(_ context.Context, a return database.JfrogXrayScan{}, sql.ErrNoRows } -func (q *FakeQuerier) GetKeyByFeatureAndSequence(ctx context.Context, arg database.GetKeyByFeatureAndSequenceParams) (database.Key, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.Key{}, err - } - - panic("not implemented") -} - -func (q *FakeQuerier) GetKeys(ctx context.Context) ([]database.Key, error) { - panic("not implemented") -} - func (q *FakeQuerier) GetLastUpdateCheck(_ context.Context) (string, error) { q.mutex.RLock() defer q.mutex.RUnlock() @@ -2828,6 +2828,10 @@ func (q *FakeQuerier) GetLastUpdateCheck(_ context.Context) (string, error) { return string(q.lastUpdateCheck), nil } +func (q *FakeQuerier) GetLatestCryptoKeyByFeature(ctx context.Context, feature database.CryptoKeyFeature) (database.CryptoKey, error) { + panic("not implemented") +} + func (q *FakeQuerier) GetLatestWorkspaceBuildByWorkspaceID(ctx context.Context, workspaceID uuid.UUID) (database.WorkspaceBuild, error) { q.mutex.RLock() defer q.mutex.RUnlock() @@ -6327,6 +6331,15 @@ func (q *FakeQuerier) InsertAuditLog(_ context.Context, arg database.InsertAudit return alog, nil } +func (q *FakeQuerier) InsertCryptoKey(ctx context.Context, arg database.InsertCryptoKeyParams) (database.CryptoKey, error) { + err := validateDatabaseType(arg) + if err != nil { + return database.CryptoKey{}, err + } + + panic("not implemented") +} + func (q *FakeQuerier) InsertCustomRole(_ context.Context, arg database.InsertCustomRoleParams) (database.CustomRole, error) { err := validateDatabaseType(arg) if err != nil { @@ -6514,15 +6527,6 @@ func (q *FakeQuerier) InsertGroupMember(_ context.Context, arg database.InsertGr return nil } -func (q *FakeQuerier) InsertKey(ctx context.Context, arg database.InsertKeyParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - panic("not implemented") -} - func (q *FakeQuerier) InsertLicense( _ context.Context, arg database.InsertLicenseParams, ) (database.License, error) { @@ -7805,6 +7809,15 @@ func (q *FakeQuerier) UpdateAPIKeyByID(_ context.Context, arg database.UpdateAPI return sql.ErrNoRows } +func (q *FakeQuerier) UpdateCryptoKeyDeletesAt(ctx context.Context, arg database.UpdateCryptoKeyDeletesAtParams) (database.CryptoKey, error) { + err := validateDatabaseType(arg) + if err != nil { + return database.CryptoKey{}, err + } + + panic("not implemented") +} + func (q *FakeQuerier) UpdateCustomRole(_ context.Context, arg database.UpdateCustomRoleParams) (database.CustomRole, error) { err := validateDatabaseType(arg) if err != nil { @@ -7921,15 +7934,6 @@ func (q *FakeQuerier) UpdateInactiveUsersToDormant(_ context.Context, params dat return updated, nil } -func (q *FakeQuerier) UpdateKeyDeletesAt(ctx context.Context, arg database.UpdateKeyDeletesAtParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - panic("not implemented") -} - func (q *FakeQuerier) UpdateMemberRoles(_ context.Context, arg database.UpdateMemberRolesParams) (database.OrganizationMember, error) { if err := validateDatabaseType(arg); err != nil { return database.OrganizationMember{}, err diff --git a/coderd/database/dbmetrics/dbmetrics.go b/coderd/database/dbmetrics/dbmetrics.go index 22ce3b085fed9..bf95ad82896d8 100644 --- a/coderd/database/dbmetrics/dbmetrics.go +++ b/coderd/database/dbmetrics/dbmetrics.go @@ -214,6 +214,13 @@ func (m metricsStore) DeleteCoordinator(ctx context.Context, id uuid.UUID) error return r0 } +func (m metricsStore) DeleteCryptoKey(ctx context.Context, arg database.DeleteCryptoKeyParams) (database.CryptoKey, error) { + start := time.Now() + r0, r1 := m.s.DeleteCryptoKey(ctx, arg) + m.queryLatencies.WithLabelValues("DeleteCryptoKey").Observe(time.Since(start).Seconds()) + return r0, r1 +} + func (m metricsStore) DeleteCustomRole(ctx context.Context, arg database.DeleteCustomRoleParams) error { start := time.Now() r0 := m.s.DeleteCustomRole(ctx, arg) @@ -249,13 +256,6 @@ func (m metricsStore) DeleteGroupMemberFromGroup(ctx context.Context, arg databa return err } -func (m metricsStore) DeleteKey(ctx context.Context, arg database.DeleteKeyParams) error { - start := time.Now() - r0 := m.s.DeleteKey(ctx, arg) - m.queryLatencies.WithLabelValues("DeleteKey").Observe(time.Since(start).Seconds()) - return r0 -} - func (m metricsStore) DeleteLicense(ctx context.Context, id int32) (int32, error) { start := time.Now() licenseID, err := m.s.DeleteLicense(ctx, id) @@ -550,6 +550,20 @@ func (m metricsStore) GetCoordinatorResumeTokenSigningKey(ctx context.Context) ( return r0, r1 } +func (m metricsStore) GetCryptoKeyByFeatureAndSequence(ctx context.Context, arg database.GetCryptoKeyByFeatureAndSequenceParams) (database.CryptoKey, error) { + start := time.Now() + r0, r1 := m.s.GetCryptoKeyByFeatureAndSequence(ctx, arg) + m.queryLatencies.WithLabelValues("GetCryptoKeyByFeatureAndSequence").Observe(time.Since(start).Seconds()) + return r0, r1 +} + +func (m metricsStore) GetCryptoKeys(ctx context.Context) ([]database.CryptoKey, error) { + start := time.Now() + r0, r1 := m.s.GetCryptoKeys(ctx) + m.queryLatencies.WithLabelValues("GetCryptoKeys").Observe(time.Since(start).Seconds()) + return r0, r1 +} + func (m metricsStore) GetDBCryptKeys(ctx context.Context) ([]database.DBCryptKey, error) { start := time.Now() r0, r1 := m.s.GetDBCryptKeys(ctx) @@ -711,20 +725,6 @@ func (m metricsStore) GetJFrogXrayScanByWorkspaceAndAgentID(ctx context.Context, return r0, r1 } -func (m metricsStore) GetKeyByFeatureAndSequence(ctx context.Context, arg database.GetKeyByFeatureAndSequenceParams) (database.Key, error) { - start := time.Now() - r0, r1 := m.s.GetKeyByFeatureAndSequence(ctx, arg) - m.queryLatencies.WithLabelValues("GetKeyByFeatureAndSequence").Observe(time.Since(start).Seconds()) - return r0, r1 -} - -func (m metricsStore) GetKeys(ctx context.Context) ([]database.Key, error) { - start := time.Now() - r0, r1 := m.s.GetKeys(ctx) - m.queryLatencies.WithLabelValues("GetKeys").Observe(time.Since(start).Seconds()) - return r0, r1 -} - func (m metricsStore) GetLastUpdateCheck(ctx context.Context) (string, error) { start := time.Now() version, err := m.s.GetLastUpdateCheck(ctx) @@ -732,6 +732,13 @@ func (m metricsStore) GetLastUpdateCheck(ctx context.Context) (string, error) { return version, err } +func (m metricsStore) GetLatestCryptoKeyByFeature(ctx context.Context, feature database.CryptoKeyFeature) (database.CryptoKey, error) { + start := time.Now() + r0, r1 := m.s.GetLatestCryptoKeyByFeature(ctx, feature) + m.queryLatencies.WithLabelValues("GetLatestCryptoKeyByFeature").Observe(time.Since(start).Seconds()) + return r0, r1 +} + func (m metricsStore) GetLatestWorkspaceBuildByWorkspaceID(ctx context.Context, workspaceID uuid.UUID) (database.WorkspaceBuild, error) { start := time.Now() build, err := m.s.GetLatestWorkspaceBuildByWorkspaceID(ctx, workspaceID) @@ -1614,6 +1621,13 @@ func (m metricsStore) InsertAuditLog(ctx context.Context, arg database.InsertAud return log, err } +func (m metricsStore) InsertCryptoKey(ctx context.Context, arg database.InsertCryptoKeyParams) (database.CryptoKey, error) { + start := time.Now() + key, err := m.s.InsertCryptoKey(ctx, arg) + m.queryLatencies.WithLabelValues("InsertCryptoKey").Observe(time.Since(start).Seconds()) + return key, err +} + func (m metricsStore) InsertCustomRole(ctx context.Context, arg database.InsertCustomRoleParams) (database.CustomRole, error) { start := time.Now() r0, r1 := m.s.InsertCustomRole(ctx, arg) @@ -1677,13 +1691,6 @@ func (m metricsStore) InsertGroupMember(ctx context.Context, arg database.Insert return err } -func (m metricsStore) InsertKey(ctx context.Context, arg database.InsertKeyParams) error { - start := time.Now() - r0 := m.s.InsertKey(ctx, arg) - m.queryLatencies.WithLabelValues("InsertKey").Observe(time.Since(start).Seconds()) - return r0 -} - func (m metricsStore) InsertLicense(ctx context.Context, arg database.InsertLicenseParams) (database.License, error) { start := time.Now() license, err := m.s.InsertLicense(ctx, arg) @@ -2020,6 +2027,13 @@ func (m metricsStore) UpdateAPIKeyByID(ctx context.Context, arg database.UpdateA return err } +func (m metricsStore) UpdateCryptoKeyDeletesAt(ctx context.Context, arg database.UpdateCryptoKeyDeletesAtParams) (database.CryptoKey, error) { + start := time.Now() + key, err := m.s.UpdateCryptoKeyDeletesAt(ctx, arg) + m.queryLatencies.WithLabelValues("UpdateCryptoKeyDeletesAt").Observe(time.Since(start).Seconds()) + return key, err +} + func (m metricsStore) UpdateCustomRole(ctx context.Context, arg database.UpdateCustomRoleParams) (database.CustomRole, error) { start := time.Now() r0, r1 := m.s.UpdateCustomRole(ctx, arg) @@ -2055,13 +2069,6 @@ func (m metricsStore) UpdateInactiveUsersToDormant(ctx context.Context, lastSeen return r0, r1 } -func (m metricsStore) UpdateKeyDeletesAt(ctx context.Context, arg database.UpdateKeyDeletesAtParams) error { - start := time.Now() - r0 := m.s.UpdateKeyDeletesAt(ctx, arg) - m.queryLatencies.WithLabelValues("UpdateKeyDeletesAt").Observe(time.Since(start).Seconds()) - return r0 -} - func (m metricsStore) UpdateMemberRoles(ctx context.Context, arg database.UpdateMemberRolesParams) (database.OrganizationMember, error) { start := time.Now() member, err := m.s.UpdateMemberRoles(ctx, arg) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index c5d579e1c2656..0ab399f573bfe 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -317,6 +317,21 @@ func (mr *MockStoreMockRecorder) DeleteCoordinator(arg0, arg1 any) *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteCoordinator", reflect.TypeOf((*MockStore)(nil).DeleteCoordinator), arg0, arg1) } +// DeleteCryptoKey mocks base method. +func (m *MockStore) DeleteCryptoKey(arg0 context.Context, arg1 database.DeleteCryptoKeyParams) (database.CryptoKey, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteCryptoKey", arg0, arg1) + ret0, _ := ret[0].(database.CryptoKey) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DeleteCryptoKey indicates an expected call of DeleteCryptoKey. +func (mr *MockStoreMockRecorder) DeleteCryptoKey(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteCryptoKey", reflect.TypeOf((*MockStore)(nil).DeleteCryptoKey), arg0, arg1) +} + // DeleteCustomRole mocks base method. func (m *MockStore) DeleteCustomRole(arg0 context.Context, arg1 database.DeleteCustomRoleParams) error { m.ctrl.T.Helper() @@ -1058,6 +1073,36 @@ func (mr *MockStoreMockRecorder) GetCoordinatorResumeTokenSigningKey(arg0 any) * return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCoordinatorResumeTokenSigningKey", reflect.TypeOf((*MockStore)(nil).GetCoordinatorResumeTokenSigningKey), arg0) } +// GetCryptoKeyByFeatureAndSequence mocks base method. +func (m *MockStore) GetCryptoKeyByFeatureAndSequence(arg0 context.Context, arg1 database.GetCryptoKeyByFeatureAndSequenceParams) (database.CryptoKey, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetCryptoKeyByFeatureAndSequence", arg0, arg1) + ret0, _ := ret[0].(database.CryptoKey) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetCryptoKeyByFeatureAndSequence indicates an expected call of GetCryptoKeyByFeatureAndSequence. +func (mr *MockStoreMockRecorder) GetCryptoKeyByFeatureAndSequence(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCryptoKeyByFeatureAndSequence", reflect.TypeOf((*MockStore)(nil).GetCryptoKeyByFeatureAndSequence), arg0, arg1) +} + +// GetCryptoKeys mocks base method. +func (m *MockStore) GetCryptoKeys(arg0 context.Context) ([]database.CryptoKey, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetCryptoKeys", arg0) + ret0, _ := ret[0].([]database.CryptoKey) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetCryptoKeys indicates an expected call of GetCryptoKeys. +func (mr *MockStoreMockRecorder) GetCryptoKeys(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCryptoKeys", reflect.TypeOf((*MockStore)(nil).GetCryptoKeys), arg0) +} + // GetDBCryptKeys mocks base method. func (m *MockStore) GetDBCryptKeys(arg0 context.Context) ([]database.DBCryptKey, error) { m.ctrl.T.Helper() @@ -1418,6 +1463,21 @@ func (mr *MockStoreMockRecorder) GetLastUpdateCheck(arg0 any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLastUpdateCheck", reflect.TypeOf((*MockStore)(nil).GetLastUpdateCheck), arg0) } +// GetLatestCryptoKeyByFeature mocks base method. +func (m *MockStore) GetLatestCryptoKeyByFeature(arg0 context.Context, arg1 database.CryptoKeyFeature) (database.CryptoKey, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetLatestCryptoKeyByFeature", arg0, arg1) + ret0, _ := ret[0].(database.CryptoKey) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetLatestCryptoKeyByFeature indicates an expected call of GetLatestCryptoKeyByFeature. +func (mr *MockStoreMockRecorder) GetLatestCryptoKeyByFeature(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLatestCryptoKeyByFeature", reflect.TypeOf((*MockStore)(nil).GetLatestCryptoKeyByFeature), arg0, arg1) +} + // GetLatestWorkspaceBuildByWorkspaceID mocks base method. func (m *MockStore) GetLatestWorkspaceBuildByWorkspaceID(arg0 context.Context, arg1 uuid.UUID) (database.WorkspaceBuild, error) { m.ctrl.T.Helper() @@ -3352,6 +3412,21 @@ func (mr *MockStoreMockRecorder) InsertAuditLog(arg0, arg1 any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertAuditLog", reflect.TypeOf((*MockStore)(nil).InsertAuditLog), arg0, arg1) } +// InsertCryptoKey mocks base method. +func (m *MockStore) InsertCryptoKey(arg0 context.Context, arg1 database.InsertCryptoKeyParams) (database.CryptoKey, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "InsertCryptoKey", arg0, arg1) + ret0, _ := ret[0].(database.CryptoKey) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// InsertCryptoKey indicates an expected call of InsertCryptoKey. +func (mr *MockStoreMockRecorder) InsertCryptoKey(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertCryptoKey", reflect.TypeOf((*MockStore)(nil).InsertCryptoKey), arg0, arg1) +} + // InsertCustomRole mocks base method. func (m *MockStore) InsertCustomRole(arg0 context.Context, arg1 database.InsertCustomRoleParams) (database.CustomRole, error) { m.ctrl.T.Helper() @@ -4204,6 +4279,21 @@ func (mr *MockStoreMockRecorder) UpdateAPIKeyByID(arg0, arg1 any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAPIKeyByID", reflect.TypeOf((*MockStore)(nil).UpdateAPIKeyByID), arg0, arg1) } +// UpdateCryptoKeyDeletesAt mocks base method. +func (m *MockStore) UpdateCryptoKeyDeletesAt(arg0 context.Context, arg1 database.UpdateCryptoKeyDeletesAtParams) (database.CryptoKey, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateCryptoKeyDeletesAt", arg0, arg1) + ret0, _ := ret[0].(database.CryptoKey) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// UpdateCryptoKeyDeletesAt indicates an expected call of UpdateCryptoKeyDeletesAt. +func (mr *MockStoreMockRecorder) UpdateCryptoKeyDeletesAt(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateCryptoKeyDeletesAt", reflect.TypeOf((*MockStore)(nil).UpdateCryptoKeyDeletesAt), arg0, arg1) +} + // UpdateCustomRole mocks base method. func (m *MockStore) UpdateCustomRole(arg0 context.Context, arg1 database.UpdateCustomRoleParams) (database.CustomRole, error) { m.ctrl.T.Helper() diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index bef8c1f9376af..71f21a3d5a75a 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -36,6 +36,12 @@ CREATE TYPE build_reason AS ENUM ( 'autodelete' ); +CREATE TYPE crypto_key_feature AS ENUM ( + 'workspace_apps', + 'oidc_convert', + 'peer_reconnect' +); + CREATE TYPE display_app AS ENUM ( 'vscode', 'vscode_insiders', @@ -494,6 +500,14 @@ CREATE TABLE audit_logs ( resource_icon text NOT NULL ); +CREATE TABLE crypto_keys ( + feature crypto_key_feature NOT NULL, + sequence integer NOT NULL, + secret text, + starts_at timestamp with time zone NOT NULL, + deletes_at timestamp with time zone +); + CREATE TABLE custom_roles ( name text NOT NULL, display_name text NOT NULL, @@ -668,14 +682,6 @@ CREATE TABLE jfrog_xray_scans ( results_url text DEFAULT ''::text NOT NULL ); -CREATE TABLE keys ( - feature text NOT NULL, - sequence integer NOT NULL, - secret text, - starts_at timestamp with time zone NOT NULL, - deletes_at timestamp with time zone -); - CREATE TABLE licenses ( id integer NOT NULL, uploaded_at timestamp with time zone NOT NULL, @@ -1648,6 +1654,9 @@ ALTER TABLE ONLY api_keys ALTER TABLE ONLY audit_logs ADD CONSTRAINT audit_logs_pkey PRIMARY KEY (id); +ALTER TABLE ONLY crypto_keys + ADD CONSTRAINT crypto_keys_pkey PRIMARY KEY (feature, sequence); + ALTER TABLE ONLY custom_roles ADD CONSTRAINT custom_roles_unique_key UNIQUE (name, organization_id); @@ -1684,9 +1693,6 @@ ALTER TABLE ONLY groups ALTER TABLE ONLY jfrog_xray_scans ADD CONSTRAINT jfrog_xray_scans_pkey PRIMARY KEY (agent_id, workspace_id); -ALTER TABLE ONLY keys - ADD CONSTRAINT keys_pkey PRIMARY KEY (feature, sequence); - ALTER TABLE ONLY licenses ADD CONSTRAINT licenses_jwt_key UNIQUE (jwt); diff --git a/coderd/database/migrations/000250_keys.down.sql b/coderd/database/migrations/000250_crypto_keys.down.sql similarity index 100% rename from coderd/database/migrations/000250_keys.down.sql rename to coderd/database/migrations/000250_crypto_keys.down.sql diff --git a/coderd/database/migrations/000250_crypto_keys.up.sql b/coderd/database/migrations/000250_crypto_keys.up.sql new file mode 100644 index 0000000000000..f6f1457e9eb92 --- /dev/null +++ b/coderd/database/migrations/000250_crypto_keys.up.sql @@ -0,0 +1,15 @@ +CREATE TYPE "crypto_key_feature" AS ENUM ( + 'workspace_apps', + 'oidc_convert', + 'peer_reconnect' +); + +CREATE TABLE "crypto_keys" ( + "feature" "crypto_key_feature" NOT NULL, + "sequence" integer NOT NULL, + "secret" text NULL, + "starts_at" timestamptz NOT NULL, + "deletes_at" timestamptz NULL, + PRIMARY KEY ("feature", "sequence") +); + diff --git a/coderd/database/migrations/000250_keys.up.sql b/coderd/database/migrations/000250_keys.up.sql deleted file mode 100644 index 67f6cf8f14a29..0000000000000 --- a/coderd/database/migrations/000250_keys.up.sql +++ /dev/null @@ -1,8 +0,0 @@ -CREATE TABLE "keys" ( - "feature" text NOT NULL, - "sequence" integer NOT NULL, - "secret" text NULL, - "starts_at" timestamptz NOT NULL, - "deletes_at" timestamptz NULL, - PRIMARY KEY ("feature", "sequence") -); diff --git a/coderd/database/modelmethods.go b/coderd/database/modelmethods.go index 5698fa92cdb27..8b5a18cc90c3a 100644 --- a/coderd/database/modelmethods.go +++ b/coderd/database/modelmethods.go @@ -448,6 +448,6 @@ func (r GetAuthorizationUserRolesRow) RoleNames() ([]rbac.RoleIdentifier, error) return names, nil } -func (k Key) ExpiresAt(keyDuration time.Duration) time.Time { +func (k CryptoKey) ExpiresAt(keyDuration time.Duration) time.Time { return k.StartsAt.Add(keyDuration) } diff --git a/coderd/database/models.go b/coderd/database/models.go index 3642cef53f405..b2217d96f2978 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -339,6 +339,67 @@ func AllBuildReasonValues() []BuildReason { } } +type CryptoKeyFeature string + +const ( + CryptoKeyFeatureWorkspaceApps CryptoKeyFeature = "workspace_apps" + CryptoKeyFeatureOidcConvert CryptoKeyFeature = "oidc_convert" + CryptoKeyFeaturePeerReconnect CryptoKeyFeature = "peer_reconnect" +) + +func (e *CryptoKeyFeature) Scan(src interface{}) error { + switch s := src.(type) { + case []byte: + *e = CryptoKeyFeature(s) + case string: + *e = CryptoKeyFeature(s) + default: + return fmt.Errorf("unsupported scan type for CryptoKeyFeature: %T", src) + } + return nil +} + +type NullCryptoKeyFeature struct { + CryptoKeyFeature CryptoKeyFeature `json:"crypto_key_feature"` + Valid bool `json:"valid"` // Valid is true if CryptoKeyFeature is not NULL +} + +// Scan implements the Scanner interface. +func (ns *NullCryptoKeyFeature) Scan(value interface{}) error { + if value == nil { + ns.CryptoKeyFeature, ns.Valid = "", false + return nil + } + ns.Valid = true + return ns.CryptoKeyFeature.Scan(value) +} + +// Value implements the driver Valuer interface. +func (ns NullCryptoKeyFeature) Value() (driver.Value, error) { + if !ns.Valid { + return nil, nil + } + return string(ns.CryptoKeyFeature), nil +} + +func (e CryptoKeyFeature) Valid() bool { + switch e { + case CryptoKeyFeatureWorkspaceApps, + CryptoKeyFeatureOidcConvert, + CryptoKeyFeaturePeerReconnect: + return true + } + return false +} + +func AllCryptoKeyFeatureValues() []CryptoKeyFeature { + return []CryptoKeyFeature{ + CryptoKeyFeatureWorkspaceApps, + CryptoKeyFeatureOidcConvert, + CryptoKeyFeaturePeerReconnect, + } +} + type DisplayApp string const ( @@ -2043,6 +2104,14 @@ type AuditLog struct { ResourceIcon string `db:"resource_icon" json:"resource_icon"` } +type CryptoKey struct { + Feature CryptoKeyFeature `db:"feature" json:"feature"` + Sequence int32 `db:"sequence" json:"sequence"` + Secret sql.NullString `db:"secret" json:"secret"` + StartsAt time.Time `db:"starts_at" json:"starts_at"` + DeletesAt sql.NullTime `db:"deletes_at" json:"deletes_at"` +} + // Custom roles allow dynamic roles expanded at runtime type CustomRole struct { Name string `db:"name" json:"name"` @@ -2155,14 +2224,6 @@ type JfrogXrayScan struct { ResultsUrl string `db:"results_url" json:"results_url"` } -type Key struct { - Feature string `db:"feature" json:"feature"` - Sequence int32 `db:"sequence" json:"sequence"` - Secret sql.NullString `db:"secret" json:"secret"` - StartsAt time.Time `db:"starts_at" json:"starts_at"` - DeletesAt sql.NullTime `db:"deletes_at" json:"deletes_at"` -} - type License struct { ID int32 `db:"id" json:"id"` UploadedAt time.Time `db:"uploaded_at" json:"uploaded_at"` diff --git a/coderd/database/querier.go b/coderd/database/querier.go index c342c00581ad4..8e8f587d302c8 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -69,12 +69,12 @@ type sqlcQuerier interface { DeleteAllTailnetTunnels(ctx context.Context, arg DeleteAllTailnetTunnelsParams) error DeleteApplicationConnectAPIKeysByUserID(ctx context.Context, userID uuid.UUID) error DeleteCoordinator(ctx context.Context, id uuid.UUID) error + DeleteCryptoKey(ctx context.Context, arg DeleteCryptoKeyParams) (CryptoKey, error) DeleteCustomRole(ctx context.Context, arg DeleteCustomRoleParams) error DeleteExternalAuthLink(ctx context.Context, arg DeleteExternalAuthLinkParams) error DeleteGitSSHKey(ctx context.Context, userID uuid.UUID) error DeleteGroupByID(ctx context.Context, id uuid.UUID) error DeleteGroupMemberFromGroup(ctx context.Context, arg DeleteGroupMemberFromGroupParams) error - DeleteKey(ctx context.Context, arg DeleteKeyParams) error DeleteLicense(ctx context.Context, id int32) (int32, error) DeleteOAuth2ProviderAppByID(ctx context.Context, id uuid.UUID) error DeleteOAuth2ProviderAppCodeByID(ctx context.Context, id uuid.UUID) error @@ -132,6 +132,8 @@ type sqlcQuerier interface { // are included. GetAuthorizationUserRoles(ctx context.Context, userID uuid.UUID) (GetAuthorizationUserRolesRow, error) GetCoordinatorResumeTokenSigningKey(ctx context.Context) (string, error) + GetCryptoKeyByFeatureAndSequence(ctx context.Context, arg GetCryptoKeyByFeatureAndSequenceParams) (CryptoKey, error) + GetCryptoKeys(ctx context.Context) ([]CryptoKey, error) GetDBCryptKeys(ctx context.Context) ([]DBCryptKey, error) GetDERPMeshKey(ctx context.Context) (string, error) GetDefaultOrganization(ctx context.Context) (Organization, error) @@ -159,9 +161,8 @@ type sqlcQuerier interface { GetHealthSettings(ctx context.Context) (string, error) GetHungProvisionerJobs(ctx context.Context, updatedAt time.Time) ([]ProvisionerJob, error) GetJFrogXrayScanByWorkspaceAndAgentID(ctx context.Context, arg GetJFrogXrayScanByWorkspaceAndAgentIDParams) (JfrogXrayScan, error) - GetKeyByFeatureAndSequence(ctx context.Context, arg GetKeyByFeatureAndSequenceParams) (Key, error) - GetKeys(ctx context.Context) ([]Key, error) GetLastUpdateCheck(ctx context.Context) (string, error) + GetLatestCryptoKeyByFeature(ctx context.Context, feature CryptoKeyFeature) (CryptoKey, error) GetLatestWorkspaceBuildByWorkspaceID(ctx context.Context, workspaceID uuid.UUID) (WorkspaceBuild, error) GetLatestWorkspaceBuilds(ctx context.Context) ([]WorkspaceBuild, error) GetLatestWorkspaceBuildsByWorkspaceIDs(ctx context.Context, ids []uuid.UUID) ([]WorkspaceBuild, error) @@ -340,6 +341,7 @@ type sqlcQuerier interface { // every member of the org. InsertAllUsersGroup(ctx context.Context, organizationID uuid.UUID) (Group, error) InsertAuditLog(ctx context.Context, arg InsertAuditLogParams) (AuditLog, error) + InsertCryptoKey(ctx context.Context, arg InsertCryptoKeyParams) (CryptoKey, error) InsertCustomRole(ctx context.Context, arg InsertCustomRoleParams) (CustomRole, error) InsertDBCryptKey(ctx context.Context, arg InsertDBCryptKeyParams) error InsertDERPMeshKey(ctx context.Context, value string) error @@ -349,7 +351,6 @@ type sqlcQuerier interface { InsertGitSSHKey(ctx context.Context, arg InsertGitSSHKeyParams) (GitSSHKey, error) InsertGroup(ctx context.Context, arg InsertGroupParams) (Group, error) InsertGroupMember(ctx context.Context, arg InsertGroupMemberParams) error - InsertKey(ctx context.Context, arg InsertKeyParams) error InsertLicense(ctx context.Context, arg InsertLicenseParams) (License, error) // Inserts any group by name that does not exist. All new groups are given // a random uuid, are inserted into the same organization. They have the default @@ -414,12 +415,12 @@ type sqlcQuerier interface { UnarchiveTemplateVersion(ctx context.Context, arg UnarchiveTemplateVersionParams) error UnfavoriteWorkspace(ctx context.Context, id uuid.UUID) error UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDParams) error + UpdateCryptoKeyDeletesAt(ctx context.Context, arg UpdateCryptoKeyDeletesAtParams) (CryptoKey, error) UpdateCustomRole(ctx context.Context, arg UpdateCustomRoleParams) (CustomRole, error) UpdateExternalAuthLink(ctx context.Context, arg UpdateExternalAuthLinkParams) (ExternalAuthLink, error) UpdateGitSSHKey(ctx context.Context, arg UpdateGitSSHKeyParams) (GitSSHKey, error) UpdateGroupByID(ctx context.Context, arg UpdateGroupByIDParams) (Group, error) UpdateInactiveUsersToDormant(ctx context.Context, arg UpdateInactiveUsersToDormantParams) ([]UpdateInactiveUsersToDormantRow, error) - UpdateKeyDeletesAt(ctx context.Context, arg UpdateKeyDeletesAtParams) error UpdateMemberRoles(ctx context.Context, arg UpdateMemberRolesParams) (OrganizationMember, error) UpdateNotificationTemplateMethodByID(ctx context.Context, arg UpdateNotificationTemplateMethodByIDParams) (NotificationTemplate, error) UpdateOAuth2ProviderAppByID(ctx context.Context, arg UpdateOAuth2ProviderAppByIDParams) (OAuth2ProviderApp, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index de31a01a189aa..19fd2a59354b7 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -3164,25 +3164,33 @@ func (q *sqlQuerier) UpsertJFrogXrayScanByWorkspaceAndAgentID(ctx context.Contex return err } -const deleteKey = `-- name: DeleteKey :exec -UPDATE keys +const deleteCryptoKey = `-- name: DeleteCryptoKey :one +UPDATE crypto_keys SET secret = NULL -WHERE feature = $1 AND sequence = $2 +WHERE feature = $1 AND sequence = $2 RETURNING feature, sequence, secret, starts_at, deletes_at ` -type DeleteKeyParams struct { - Feature string `db:"feature" json:"feature"` - Sequence int32 `db:"sequence" json:"sequence"` +type DeleteCryptoKeyParams struct { + Feature CryptoKeyFeature `db:"feature" json:"feature"` + Sequence int32 `db:"sequence" json:"sequence"` } -func (q *sqlQuerier) DeleteKey(ctx context.Context, arg DeleteKeyParams) error { - _, err := q.db.ExecContext(ctx, deleteKey, arg.Feature, arg.Sequence) - return err +func (q *sqlQuerier) DeleteCryptoKey(ctx context.Context, arg DeleteCryptoKeyParams) (CryptoKey, error) { + row := q.db.QueryRowContext(ctx, deleteCryptoKey, arg.Feature, arg.Sequence) + var i CryptoKey + err := row.Scan( + &i.Feature, + &i.Sequence, + &i.Secret, + &i.StartsAt, + &i.DeletesAt, + ) + return i, err } -const getKeyByFeatureAndSequence = `-- name: GetKeyByFeatureAndSequence :one +const getCryptoKeyByFeatureAndSequence = `-- name: GetCryptoKeyByFeatureAndSequence :one SELECT feature, sequence, secret, starts_at, deletes_at -FROM keys +FROM crypto_keys WHERE feature = $1 AND sequence = $2 AND secret IS NOT NULL @@ -3190,15 +3198,15 @@ WHERE feature = $1 AND ($3 < deletes_at OR deletes_at IS NULL) ` -type GetKeyByFeatureAndSequenceParams struct { - Feature string `db:"feature" json:"feature"` - Sequence int32 `db:"sequence" json:"sequence"` - StartsAt time.Time `db:"starts_at" json:"starts_at"` +type GetCryptoKeyByFeatureAndSequenceParams struct { + Feature CryptoKeyFeature `db:"feature" json:"feature"` + Sequence int32 `db:"sequence" json:"sequence"` + StartsAt time.Time `db:"starts_at" json:"starts_at"` } -func (q *sqlQuerier) GetKeyByFeatureAndSequence(ctx context.Context, arg GetKeyByFeatureAndSequenceParams) (Key, error) { - row := q.db.QueryRowContext(ctx, getKeyByFeatureAndSequence, arg.Feature, arg.Sequence, arg.StartsAt) - var i Key +func (q *sqlQuerier) GetCryptoKeyByFeatureAndSequence(ctx context.Context, arg GetCryptoKeyByFeatureAndSequenceParams) (CryptoKey, error) { + row := q.db.QueryRowContext(ctx, getCryptoKeyByFeatureAndSequence, arg.Feature, arg.Sequence, arg.StartsAt) + var i CryptoKey err := row.Scan( &i.Feature, &i.Sequence, @@ -3209,21 +3217,21 @@ func (q *sqlQuerier) GetKeyByFeatureAndSequence(ctx context.Context, arg GetKeyB return i, err } -const getKeys = `-- name: GetKeys :many +const getCryptoKeys = `-- name: GetCryptoKeys :many SELECT feature, sequence, secret, starts_at, deletes_at -FROM keys +FROM crypto_keys WHERE secret IS NOT NULL ` -func (q *sqlQuerier) GetKeys(ctx context.Context) ([]Key, error) { - rows, err := q.db.QueryContext(ctx, getKeys) +func (q *sqlQuerier) GetCryptoKeys(ctx context.Context) ([]CryptoKey, error) { + rows, err := q.db.QueryContext(ctx, getCryptoKeys) if err != nil { return nil, err } defer rows.Close() - var items []Key + var items []CryptoKey for rows.Next() { - var i Key + var i CryptoKey if err := rows.Scan( &i.Feature, &i.Sequence, @@ -3244,8 +3252,29 @@ func (q *sqlQuerier) GetKeys(ctx context.Context) ([]Key, error) { return items, nil } -const insertKey = `-- name: InsertKey :exec -INSERT INTO keys ( +const getLatestCryptoKeyByFeature = `-- name: GetLatestCryptoKeyByFeature :one +SELECT feature, sequence, secret, starts_at, deletes_at +FROM crypto_keys +WHERE feature = $1 +ORDER BY sequence DESC +LIMIT 1 +` + +func (q *sqlQuerier) GetLatestCryptoKeyByFeature(ctx context.Context, feature CryptoKeyFeature) (CryptoKey, error) { + row := q.db.QueryRowContext(ctx, getLatestCryptoKeyByFeature, feature) + var i CryptoKey + err := row.Scan( + &i.Feature, + &i.Sequence, + &i.Secret, + &i.StartsAt, + &i.DeletesAt, + ) + return i, err +} + +const insertCryptoKey = `-- name: InsertCryptoKey :one +INSERT INTO crypto_keys ( feature, sequence, secret, @@ -3255,41 +3284,57 @@ INSERT INTO keys ( $2, $3, $4 -) +) RETURNING feature, sequence, secret, starts_at, deletes_at ` -type InsertKeyParams struct { - Feature string `db:"feature" json:"feature"` - Sequence int32 `db:"sequence" json:"sequence"` - Secret sql.NullString `db:"secret" json:"secret"` - StartsAt time.Time `db:"starts_at" json:"starts_at"` +type InsertCryptoKeyParams struct { + Feature CryptoKeyFeature `db:"feature" json:"feature"` + Sequence int32 `db:"sequence" json:"sequence"` + Secret sql.NullString `db:"secret" json:"secret"` + StartsAt time.Time `db:"starts_at" json:"starts_at"` } -func (q *sqlQuerier) InsertKey(ctx context.Context, arg InsertKeyParams) error { - _, err := q.db.ExecContext(ctx, insertKey, +func (q *sqlQuerier) InsertCryptoKey(ctx context.Context, arg InsertCryptoKeyParams) (CryptoKey, error) { + row := q.db.QueryRowContext(ctx, insertCryptoKey, arg.Feature, arg.Sequence, arg.Secret, arg.StartsAt, ) - return err + var i CryptoKey + err := row.Scan( + &i.Feature, + &i.Sequence, + &i.Secret, + &i.StartsAt, + &i.DeletesAt, + ) + return i, err } -const updateKeyDeletesAt = `-- name: UpdateKeyDeletesAt :exec -UPDATE keys +const updateCryptoKeyDeletesAt = `-- name: UpdateCryptoKeyDeletesAt :one +UPDATE crypto_keys SET deletes_at = $3 -WHERE feature = $1 AND sequence = $2 +WHERE feature = $1 AND sequence = $2 RETURNING feature, sequence, secret, starts_at, deletes_at ` -type UpdateKeyDeletesAtParams struct { - Feature string `db:"feature" json:"feature"` - Sequence int32 `db:"sequence" json:"sequence"` - DeletesAt sql.NullTime `db:"deletes_at" json:"deletes_at"` +type UpdateCryptoKeyDeletesAtParams struct { + Feature CryptoKeyFeature `db:"feature" json:"feature"` + Sequence int32 `db:"sequence" json:"sequence"` + DeletesAt sql.NullTime `db:"deletes_at" json:"deletes_at"` } -func (q *sqlQuerier) UpdateKeyDeletesAt(ctx context.Context, arg UpdateKeyDeletesAtParams) error { - _, err := q.db.ExecContext(ctx, updateKeyDeletesAt, arg.Feature, arg.Sequence, arg.DeletesAt) - return err +func (q *sqlQuerier) UpdateCryptoKeyDeletesAt(ctx context.Context, arg UpdateCryptoKeyDeletesAtParams) (CryptoKey, error) { + row := q.db.QueryRowContext(ctx, updateCryptoKeyDeletesAt, arg.Feature, arg.Sequence, arg.DeletesAt) + var i CryptoKey + err := row.Scan( + &i.Feature, + &i.Sequence, + &i.Secret, + &i.StartsAt, + &i.DeletesAt, + ) + return i, err } const deleteLicense = `-- name: DeleteLicense :one diff --git a/coderd/database/queries/keys.sql b/coderd/database/queries/keys.sql index d72b0f6132335..49be2a5b71a7f 100644 --- a/coderd/database/queries/keys.sql +++ b/coderd/database/queries/keys.sql @@ -1,25 +1,32 @@ --- name: GetKeys :many +-- name: GetCryptoKeys :many SELECT * -FROM keys +FROM crypto_keys WHERE secret IS NOT NULL; --- name: GetKeyByFeatureAndSequence :one +-- name: GetLatestCryptoKeyByFeature :one SELECT * -FROM keys +FROM crypto_keys +WHERE feature = $1 +ORDER BY sequence DESC +LIMIT 1; + + +-- name: GetCryptoKeyByFeatureAndSequence :one +SELECT * +FROM crypto_keys WHERE feature = $1 AND sequence = $2 AND secret IS NOT NULL AND $3 >= starts_at AND ($3 < deletes_at OR deletes_at IS NULL); --- name: DeleteKey :exec -UPDATE keys +-- name: DeleteCryptoKey :one +UPDATE crypto_keys SET secret = NULL -WHERE feature = $1 AND sequence = $2; - +WHERE feature = $1 AND sequence = $2 RETURNING *; --- name: InsertKey :exec -INSERT INTO keys ( +-- name: InsertCryptoKey :one +INSERT INTO crypto_keys ( feature, sequence, secret, @@ -29,12 +36,12 @@ INSERT INTO keys ( $2, $3, $4 -); +) RETURNING *; --- name: UpdateKeyDeletesAt :exec -UPDATE keys +-- name: UpdateCryptoKeyDeletesAt :one +UPDATE crypto_keys SET deletes_at = $3 -WHERE feature = $1 AND sequence = $2; +WHERE feature = $1 AND sequence = $2 RETURNING *; diff --git a/coderd/database/unique_constraint.go b/coderd/database/unique_constraint.go index cc6d2c6758d9f..01a811af9c5ed 100644 --- a/coderd/database/unique_constraint.go +++ b/coderd/database/unique_constraint.go @@ -9,6 +9,7 @@ const ( UniqueAgentStatsPkey UniqueConstraint = "agent_stats_pkey" // ALTER TABLE ONLY workspace_agent_stats ADD CONSTRAINT agent_stats_pkey PRIMARY KEY (id); UniqueAPIKeysPkey UniqueConstraint = "api_keys_pkey" // ALTER TABLE ONLY api_keys ADD CONSTRAINT api_keys_pkey PRIMARY KEY (id); UniqueAuditLogsPkey UniqueConstraint = "audit_logs_pkey" // ALTER TABLE ONLY audit_logs ADD CONSTRAINT audit_logs_pkey PRIMARY KEY (id); + UniqueCryptoKeysPkey UniqueConstraint = "crypto_keys_pkey" // ALTER TABLE ONLY crypto_keys ADD CONSTRAINT crypto_keys_pkey PRIMARY KEY (feature, sequence); UniqueCustomRolesUniqueKey UniqueConstraint = "custom_roles_unique_key" // ALTER TABLE ONLY custom_roles ADD CONSTRAINT custom_roles_unique_key UNIQUE (name, organization_id); UniqueDbcryptKeysActiveKeyDigestKey UniqueConstraint = "dbcrypt_keys_active_key_digest_key" // ALTER TABLE ONLY dbcrypt_keys ADD CONSTRAINT dbcrypt_keys_active_key_digest_key UNIQUE (active_key_digest); UniqueDbcryptKeysPkey UniqueConstraint = "dbcrypt_keys_pkey" // ALTER TABLE ONLY dbcrypt_keys ADD CONSTRAINT dbcrypt_keys_pkey PRIMARY KEY (number); @@ -21,7 +22,6 @@ const ( UniqueGroupsNameOrganizationIDKey UniqueConstraint = "groups_name_organization_id_key" // ALTER TABLE ONLY groups ADD CONSTRAINT groups_name_organization_id_key UNIQUE (name, organization_id); UniqueGroupsPkey UniqueConstraint = "groups_pkey" // ALTER TABLE ONLY groups ADD CONSTRAINT groups_pkey PRIMARY KEY (id); UniqueJfrogXrayScansPkey UniqueConstraint = "jfrog_xray_scans_pkey" // ALTER TABLE ONLY jfrog_xray_scans ADD CONSTRAINT jfrog_xray_scans_pkey PRIMARY KEY (agent_id, workspace_id); - UniqueKeysPkey UniqueConstraint = "keys_pkey" // ALTER TABLE ONLY keys ADD CONSTRAINT keys_pkey PRIMARY KEY (feature, sequence); UniqueLicensesJWTKey UniqueConstraint = "licenses_jwt_key" // ALTER TABLE ONLY licenses ADD CONSTRAINT licenses_jwt_key UNIQUE (jwt); UniqueLicensesPkey UniqueConstraint = "licenses_pkey" // ALTER TABLE ONLY licenses ADD CONSTRAINT licenses_pkey PRIMARY KEY (id); UniqueNotificationMessagesPkey UniqueConstraint = "notification_messages_pkey" // ALTER TABLE ONLY notification_messages ADD CONSTRAINT notification_messages_pkey PRIMARY KEY (id); diff --git a/coderd/keyrotate/rotate.go b/coderd/keyrotate/rotate.go index 381cd5d5805e7..39cc7975381dd 100644 --- a/coderd/keyrotate/rotate.go +++ b/coderd/keyrotate/rotate.go @@ -5,11 +5,11 @@ import ( "crypto/rand" "database/sql" "encoding/hex" - "log/slog" "time" "golang.org/x/xerrors" + "cdr.dev/slog" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/quartz" @@ -27,16 +27,28 @@ type KeyRotator struct { Clock quartz.Clock Logger slog.Logger ScanInterval time.Duration + ResultsCh chan []database.CryptoKey + features []database.CryptoKeyFeature } -func (kr *KeyRotator) Start(ctx context.Context) { - ticker := kr.Clock.NewTicker(kr.ScanInterval) +func (k *KeyRotator) Start(ctx context.Context) { + ticker := k.Clock.NewTicker(k.ScanInterval) defer ticker.Stop() + if len(k.features) == 0 { + k.features = database.AllCryptoKeyFeatureValues() + } + for { - err := kr.rotateKeys(ctx) + modifiedKeys, err := k.rotateKeys(ctx) if err != nil { - kr.Logger.Error("rotate keys", slog.Any("error", err)) + k.Logger.Error(ctx, "failed to rotate keys", slog.Error(err)) + } + + // This should only be called in test code so we don't + // both to select on the push. + if k.ResultsCh != nil { + k.ResultsCh <- modifiedKeys } select { @@ -48,47 +60,101 @@ func (kr *KeyRotator) Start(ctx context.Context) { } // rotateKeys checks for keys nearing expiration and rotates them if necessary. -func (kr *KeyRotator) rotateKeys(ctx context.Context) error { - return database.ReadModifyUpdate(kr.DB, func(tx database.Store) error { - keys, err := tx.GetKeys(ctx) +func (k *KeyRotator) rotateKeys(ctx context.Context) ([]database.CryptoKey, error) { + var modifiedKeys []database.CryptoKey + return modifiedKeys, database.ReadModifyUpdate(k.DB, func(tx database.Store) error { + // Reset the modified keys slice for each iteration. + modifiedKeys = make([]database.CryptoKey, 0) + keys, err := tx.GetCryptoKeys(ctx) if err != nil { return xerrors.Errorf("get keys: %w", err) } - now := dbtime.Time(kr.Clock.Now()) - for _, key := range keys { - switch { - case shouldDeleteKey(key, now): - err := tx.DeleteKey(ctx, database.DeleteKeyParams{ - Feature: key.Feature, - Sequence: key.Sequence, - }) + // Groups the keys by feature so that we can + // ensure we have at least one key for each feature. + keysByFeature := keysByFeature(keys) + + now := dbtime.Time(k.Clock.Now()) + for feature, keys := range keysByFeature { + // It's possible there are no keys if someone + // has manually deleted all the keys. + if len(keys) == 0 { + k.Logger.Info(ctx, "no valid keys detected, inserting new key", + slog.F("feature", feature), + ) + newKey, err := k.insertNewKey(ctx, tx, feature, now) if err != nil { - return xerrors.Errorf("delete key: %w", err) + return xerrors.Errorf("insert new key: %w", err) } - case shouldRotateKey(key, kr.KeyDuration, now): - err := kr.rotateKey(ctx, tx, key) - if err != nil { - return xerrors.Errorf("rotate key: %w", err) + modifiedKeys = append(modifiedKeys, newKey) + } + + for _, key := range keys { + switch { + case shouldDeleteKey(key, now): + deletedKey, err := tx.DeleteCryptoKey(ctx, database.DeleteCryptoKeyParams{ + Feature: key.Feature, + Sequence: key.Sequence, + }) + if err != nil { + return xerrors.Errorf("delete key: %w", err) + } + modifiedKeys = append(modifiedKeys, deletedKey) + case shouldRotateKey(key, k.KeyDuration, now): + rotatedKeys, err := k.rotateKey(ctx, tx, key) + if err != nil { + return xerrors.Errorf("rotate key: %w", err) + } + modifiedKeys = append(modifiedKeys, rotatedKeys...) + default: + continue } - default: - continue } } return nil }) } -func (kr *KeyRotator) rotateKey(ctx context.Context, tx database.Store, key database.Key) error { - newStartsAt := key.ExpiresAt(kr.KeyDuration) +func (k *KeyRotator) insertNewKey(ctx context.Context, tx database.Store, feature database.CryptoKeyFeature, now time.Time) (database.CryptoKey, error) { + secret, err := generateNewSecret(feature) + if err != nil { + return database.CryptoKey{}, xerrors.Errorf("generate new secret: %w", err) + } + + latestKey, err := tx.GetLatestCryptoKeyByFeature(ctx, feature) + if err != nil && !xerrors.Is(err, sql.ErrNoRows) { + return database.CryptoKey{}, xerrors.Errorf("get latest key: %w", err) + } + + newKey, err := tx.InsertCryptoKey(ctx, database.InsertCryptoKeyParams{ + Feature: feature, + // We'll assume that the first key we insert is 1. + Sequence: latestKey.Sequence + 1, + Secret: sql.NullString{ + String: secret, + Valid: true, + }, + StartsAt: now, + }) + if err != nil { + return database.CryptoKey{}, xerrors.Errorf("inserting new key: %w", err) + } + + k.Logger.Info(ctx, "inserted new key for feature", slog.F("feature", feature)) + return newKey, nil +} + +func (k *KeyRotator) rotateKey(ctx context.Context, tx database.Store, key database.CryptoKey) ([]database.CryptoKey, error) { + // The starts at of the new key is the expiration of the old key. + newStartsAt := key.ExpiresAt(k.KeyDuration) secret, err := generateNewSecret(key.Feature) if err != nil { - return xerrors.Errorf("generate new secret: %w", err) + return nil, xerrors.Errorf("generate new secret: %w", err) } // Insert new key - err = tx.InsertKey(ctx, database.InsertKeyParams{ + newKey, err := tx.InsertCryptoKey(ctx, database.InsertCryptoKeyParams{ Feature: key.Feature, Sequence: key.Sequence + 1, Secret: sql.NullString{ @@ -98,14 +164,13 @@ func (kr *KeyRotator) rotateKey(ctx context.Context, tx database.Store, key data StartsAt: newStartsAt, }) if err != nil { - return xerrors.Errorf("inserting new key: %w", err) + return nil, xerrors.Errorf("inserting new key: %w", err) } - // Update old key's deletes_at - maxTokenLength := tokenLength(key.Feature) - deletesAt := newStartsAt.Add(time.Hour).Add(maxTokenLength) + // Set old key's deletes_at + deletesAt := newStartsAt.Add(time.Hour).Add(tokenDuration(key.Feature)) - err = tx.UpdateKeyDeletesAt(ctx, database.UpdateKeyDeletesAtParams{ + updatedKey, err := tx.UpdateCryptoKeyDeletesAt(ctx, database.UpdateCryptoKeyDeletesAtParams{ Feature: key.Feature, Sequence: key.Sequence, DeletesAt: sql.NullTime{ @@ -114,19 +179,19 @@ func (kr *KeyRotator) rotateKey(ctx context.Context, tx database.Store, key data }, }) if err != nil { - return xerrors.Errorf("update old key's deletes_at: %w", err) + return nil, xerrors.Errorf("update old key's deletes_at: %w", err) } - return nil + return []database.CryptoKey{updatedKey, newKey}, nil } -func generateNewSecret(feature string) (string, error) { +func generateNewSecret(feature database.CryptoKeyFeature) (string, error) { switch feature { - case "workspace_apps": + case database.CryptoKeyFeatureWorkspaceApps: return generateKey(96) - case "oidc_convert": + case database.CryptoKeyFeatureOidcConvert: return generateKey(32) - case "peer_reconnect": + case database.CryptoKeyFeaturePeerReconnect: return generateKey(64) } return "", xerrors.Errorf("unknown feature: %s", feature) @@ -141,24 +206,40 @@ func generateKey(length int) (string, error) { return hex.EncodeToString(b), nil } -func tokenLength(feature string) time.Duration { +func tokenDuration(feature database.CryptoKeyFeature) time.Duration { switch feature { - case "workspace_apps": + case database.CryptoKeyFeatureWorkspaceApps: return WorkspaceAppsTokenDuration - case "oidc_convert": + case database.CryptoKeyFeatureOidcConvert: return OIDCConvertTokenDuration - case "peer_reconnect": + case database.CryptoKeyFeaturePeerReconnect: return PeerReconnectTokenDuration default: return 0 } } -func shouldDeleteKey(key database.Key, now time.Time) bool { +func shouldDeleteKey(key database.CryptoKey, now time.Time) bool { return key.DeletesAt.Valid && key.DeletesAt.Time.After(now) } -func shouldRotateKey(key database.Key, keyDuration time.Duration, now time.Time) bool { +func shouldRotateKey(key database.CryptoKey, keyDuration time.Duration, now time.Time) bool { + // If deletes_at is set, we've already inserted a key. + if key.DeletesAt.Valid { + return false + } expirationTime := key.ExpiresAt(keyDuration) return now.Add(time.Hour).After(expirationTime) } + +func keysByFeature(keys []database.CryptoKey) map[database.CryptoKeyFeature][]database.CryptoKey { + m := map[database.CryptoKeyFeature][]database.CryptoKey{ + database.CryptoKeyFeatureWorkspaceApps: {}, + database.CryptoKeyFeatureOidcConvert: {}, + database.CryptoKeyFeaturePeerReconnect: {}, + } + for _, key := range keys { + m[key.Feature] = append(m[key.Feature], key) + } + return m +} diff --git a/coderd/keyrotate/rotate_test.go b/coderd/keyrotate/rotate_test.go index d7b48d32dc62d..e2392ac27115c 100644 --- a/coderd/keyrotate/rotate_test.go +++ b/coderd/keyrotate/rotate_test.go @@ -1,27 +1,80 @@ -package keyrotate_test +package keyrotate import ( - "context" "testing" + "time" + + "github.com/stretchr/testify/require" + + "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbtestutil" + "github.com/coder/coder/v2/testutil" + "github.com/coder/quartz" ) func TestKeyRotator(t *testing.T) { t.Parallel() + t.Run("NoExistingKeys", func(t *testing.T) { + t.Parallel() + + var ( + db, _ = dbtestutil.NewDB(t) + clock = quartz.NewMock(t) + logger = slogtest.Make(t, nil).Leveled(slog.LevelDebug) + ctx = testutil.Context(t, testutil.WaitShort) + resultsCh = make(chan []database.CryptoKey, 1) + ) + + kr := &KeyRotator{ + DB: db, + KeyDuration: 0, + Clock: clock, + Logger: logger, + ScanInterval: 0, + ResultsCh: resultsCh, + } + + go kr.Start(ctx) + + keys := testutil.RequireRecvCtx(ctx, t, resultsCh) + require.Len(t, keys, len(database.AllCryptoKeyFeatureValues())) + }) + t.Run("RotatesKeysNearExpiration", func(t *testing.T) { t.Parallel() - db := database.NewTestDB(t) + var ( + db, _ = dbtestutil.NewDB(t) + clock = quartz.NewMock(t) + keyDuration = time.Hour * 24 * 7 + logger = slogtest.Make(t, nil).Leveled(slog.LevelDebug) + ctx = testutil.Context(t, testutil.WaitShort) + resultsCh = make(chan []database.CryptoKey, 1) + scanInterval = time.Minute * 10 + ) + kr := &KeyRotator{ - DB: db, + DB: db, + KeyDuration: keyDuration, + Clock: clock, + Logger: logger, + ScanInterval: scanInterval, + ResultsCh: resultsCh, } + keys, err := kr.rotateKeys(ctx) + require.NoError(t, err) + require.Len(t, keys, len(database.AllCryptoKeyFeatureValues())) - kr.rotateKeys(context.Background()) + clock.Advance(keyDuration - time.Minute*59) + keys, err = kr.rotateKeys(ctx) + require.NoError(t, err) + require.Len(t, keys, 2*len(database.AllCryptoKeyFeatureValues())) }) t.Run("DoesNotRotateValidKeys", func(t *testing.T) { - t.Parallel() - // TODO: Implement test to ensure valid keys are not rotated }) t.Run("DeletesExpiredKeys", func(t *testing.T) { From e91fd312792083d1379b98884f5251bb92c675b0 Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Thu, 12 Sep 2024 21:16:39 +0000 Subject: [PATCH 3/5] Add crypto key generation and refactor rotation tests --- coderd/database/dbgen/dbgen.go | 34 +++++ coderd/database/queries.sql.go | 4 +- coderd/database/queries/keys.sql | 4 +- coderd/keyrotate/rotate.go | 11 +- coderd/keyrotate/rotate_internal_test.go | 184 +++++++++++++++++++++++ coderd/keyrotate/rotate_test.go | 94 ------------ 6 files changed, 227 insertions(+), 104 deletions(-) create mode 100644 coderd/keyrotate/rotate_internal_test.go delete mode 100644 coderd/keyrotate/rotate_test.go diff --git a/coderd/database/dbgen/dbgen.go b/coderd/database/dbgen/dbgen.go index 79aee59d97dbe..74aea31767c44 100644 --- a/coderd/database/dbgen/dbgen.go +++ b/coderd/database/dbgen/dbgen.go @@ -893,6 +893,40 @@ func CustomRole(t testing.TB, db database.Store, seed database.CustomRole) datab return role } +func CryptoKey(t testing.TB, db database.Store, seed database.CryptoKey) database.CryptoKey { + t.Helper() + + secret, err := cryptorand.String(10) + require.NoError(t, err, "generate secret") + + sequence, err := cryptorand.Intn(1<<31 - 1) + require.NoError(t, err, "generate sequence") + + feature, err := cryptorand.Element(database.AllCryptoKeyFeatureValues()) + require.NoError(t, err, "generate feature") + + key, err := db.InsertCryptoKey(genCtx, database.InsertCryptoKeyParams{ + Sequence: takeFirst(seed.Sequence, int32(sequence)), + Secret: takeFirst(seed.Secret, sql.NullString{ + String: secret, + Valid: true, + }), + Feature: takeFirst(seed.Feature, feature), + StartsAt: takeFirst(seed.StartsAt, time.Now()), + }) + require.NoError(t, err, "insert crypto key") + + if seed.DeletesAt.Valid { + key, err = db.UpdateCryptoKeyDeletesAt(genCtx, database.UpdateCryptoKeyDeletesAtParams{ + Feature: key.Feature, + Sequence: key.Sequence, + DeletesAt: sql.NullTime{Time: seed.DeletesAt.Time, Valid: true}, + }) + require.NoError(t, err, "update crypto key deletes_at") + } + return key +} + func must[V any](v V, err error) V { if err != nil { panic(err) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 19fd2a59354b7..ef29d829ea3b8 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -3201,11 +3201,11 @@ WHERE feature = $1 type GetCryptoKeyByFeatureAndSequenceParams struct { Feature CryptoKeyFeature `db:"feature" json:"feature"` Sequence int32 `db:"sequence" json:"sequence"` - StartsAt time.Time `db:"starts_at" json:"starts_at"` + Time time.Time `db:"time" json:"time"` } func (q *sqlQuerier) GetCryptoKeyByFeatureAndSequence(ctx context.Context, arg GetCryptoKeyByFeatureAndSequenceParams) (CryptoKey, error) { - row := q.db.QueryRowContext(ctx, getCryptoKeyByFeatureAndSequence, arg.Feature, arg.Sequence, arg.StartsAt) + row := q.db.QueryRowContext(ctx, getCryptoKeyByFeatureAndSequence, arg.Feature, arg.Sequence, arg.Time) var i CryptoKey err := row.Scan( &i.Feature, diff --git a/coderd/database/queries/keys.sql b/coderd/database/queries/keys.sql index 49be2a5b71a7f..8d1c60e637b50 100644 --- a/coderd/database/queries/keys.sql +++ b/coderd/database/queries/keys.sql @@ -17,8 +17,8 @@ FROM crypto_keys WHERE feature = $1 AND sequence = $2 AND secret IS NOT NULL - AND $3 >= starts_at - AND ($3 < deletes_at OR deletes_at IS NULL); + AND @time >= starts_at + AND (@time < deletes_at OR deletes_at IS NULL); -- name: DeleteCryptoKey :one UPDATE crypto_keys diff --git a/coderd/keyrotate/rotate.go b/coderd/keyrotate/rotate.go index 39cc7975381dd..bd1caabadc8df 100644 --- a/coderd/keyrotate/rotate.go +++ b/coderd/keyrotate/rotate.go @@ -72,7 +72,7 @@ func (k *KeyRotator) rotateKeys(ctx context.Context) ([]database.CryptoKey, erro // Groups the keys by feature so that we can // ensure we have at least one key for each feature. - keysByFeature := keysByFeature(keys) + keysByFeature := keysByFeature(keys, k.features) now := dbtime.Time(k.Clock.Now()) for feature, keys := range keysByFeature { @@ -232,11 +232,10 @@ func shouldRotateKey(key database.CryptoKey, keyDuration time.Duration, now time return now.Add(time.Hour).After(expirationTime) } -func keysByFeature(keys []database.CryptoKey) map[database.CryptoKeyFeature][]database.CryptoKey { - m := map[database.CryptoKeyFeature][]database.CryptoKey{ - database.CryptoKeyFeatureWorkspaceApps: {}, - database.CryptoKeyFeatureOidcConvert: {}, - database.CryptoKeyFeaturePeerReconnect: {}, +func keysByFeature(keys []database.CryptoKey, features []database.CryptoKeyFeature) map[database.CryptoKeyFeature][]database.CryptoKey { + m := map[database.CryptoKeyFeature][]database.CryptoKey{} + for _, feature := range features { + m[feature] = []database.CryptoKey{} } for _, key := range keys { m[key.Feature] = append(m[key.Feature], key) diff --git a/coderd/keyrotate/rotate_internal_test.go b/coderd/keyrotate/rotate_internal_test.go new file mode 100644 index 0000000000000..b78e26ad59707 --- /dev/null +++ b/coderd/keyrotate/rotate_internal_test.go @@ -0,0 +1,184 @@ +package keyrotate + +import ( + "database/sql" + "encoding/hex" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbgen" + "github.com/coder/coder/v2/coderd/database/dbtestutil" + "github.com/coder/coder/v2/coderd/database/dbtime" + "github.com/coder/coder/v2/testutil" + "github.com/coder/quartz" +) + +func Test_rotateKeys(t *testing.T) { + t.Parallel() + + t.Run("NoExistingKeys", func(t *testing.T) { + t.Parallel() + + var ( + db, _ = dbtestutil.NewDB(t) + clock = quartz.NewMock(t) + logger = slogtest.Make(t, nil).Leveled(slog.LevelDebug) + ctx = testutil.Context(t, testutil.WaitShort) + resultsCh = make(chan []database.CryptoKey, 1) + ) + + kr := &KeyRotator{ + DB: db, + KeyDuration: 0, + Clock: clock, + Logger: logger, + ScanInterval: 0, + ResultsCh: resultsCh, + } + + now := clock.Now() + keys, err := kr.rotateKeys(ctx) + require.NoError(t, err) + require.Len(t, keys, len(database.AllCryptoKeyFeatureValues())) + + // Fetch the keys from the database and ensure they + // are as expected. + dbkeys, err := db.GetCryptoKeys(ctx) + require.NoError(t, err) + require.Equal(t, keys, dbkeys) + requireContainsAllFeatures(t, keys) + for _, key := range keys { + requireKey(t, key, key.Feature, now, time.Time{}, 1) + } + }) + + t.Run("RotatesKeysNearExpiration", func(t *testing.T) { + t.Parallel() + + var ( + db, _ = dbtestutil.NewDB(t) + clock = quartz.NewMock(t) + keyDuration = time.Hour * 24 * 7 + logger = slogtest.Make(t, nil).Leveled(slog.LevelDebug) + ctx = testutil.Context(t, testutil.WaitShort) + resultsCh = make(chan []database.CryptoKey, 1) + ) + + kr := &KeyRotator{ + DB: db, + KeyDuration: keyDuration, + Clock: clock, + Logger: logger, + ScanInterval: 0, + ResultsCh: resultsCh, + features: []database.CryptoKeyFeature{ + database.CryptoKeyFeatureWorkspaceApps, + }, + } + + now := dbtime.Time(clock.Now()) + + // Seed the database with an existing key. + oldKey := dbgen.CryptoKey(t, db, database.CryptoKey{ + Feature: database.CryptoKeyFeatureWorkspaceApps, + StartsAt: now, + Sequence: 15, + }) + + // Advance the window to just inside rotation time. + _ = clock.Advance(keyDuration - time.Minute*59) + keys, err := kr.rotateKeys(ctx) + require.NoError(t, err) + require.Len(t, keys, 2) + + now = dbtime.Time(clock.Now()) + expectedDeletesAt := now.Add(WorkspaceAppsTokenDuration + time.Hour) + + // Fetch the old key, it should have an expires_at now. + oldKey, err = db.GetCryptoKeyByFeatureAndSequence(ctx, database.GetCryptoKeyByFeatureAndSequenceParams{ + Feature: oldKey.Feature, + Sequence: oldKey.Sequence, + Time: now, + }) + require.NoError(t, err) + require.Equal(t, oldKey.DeletesAt.Time, expectedDeletesAt) + + // The new key should be created and have a starts_at of the old key's expires_at. + newKey, err := db.GetCryptoKeyByFeatureAndSequence(ctx, database.GetCryptoKeyByFeatureAndSequenceParams{ + Feature: database.CryptoKeyFeatureWorkspaceApps, + Sequence: oldKey.Sequence + 1, + }) + require.NoError(t, err) + requireKey(t, newKey, database.CryptoKeyFeatureWorkspaceApps, oldKey.ExpiresAt(keyDuration), expectedDeletesAt, oldKey.Sequence+1) + + clock.Advance(oldKey.DeletesAt.Time.Sub(now) + time.Second) + + keys, err = kr.rotateKeys(ctx) + require.NoError(t, err) + require.Len(t, keys, 1) + + // The old key should be deleted. + _, err = db.GetCryptoKeyByFeatureAndSequence(ctx, database.GetCryptoKeyByFeatureAndSequenceParams{ + Feature: oldKey.Feature, + Sequence: oldKey.Sequence, + }) + require.ErrorIs(t, err, sql.ErrNoRows) + }) + + t.Run("DoesNotRotateValidKeys", func(t *testing.T) { + }) + + t.Run("DeletesExpiredKeys", func(t *testing.T) { + t.Parallel() + // TODO: Implement test for deleting expired keys + }) + + t.Run("HandlesMultipleKeyTypes", func(t *testing.T) { + t.Parallel() + // TODO: Implement test for handling multiple key types + }) + + t.Run("GracefullyHandlesErrors", func(t *testing.T) { + t.Parallel() + // TODO: Implement test for error handling + }) +} + +func requireKey(t *testing.T, key database.CryptoKey, feature database.CryptoKeyFeature, startsAt time.Time, deletesAt time.Time, sequence int32) { + t.Helper() + require.Equal(t, key.Feature, feature) + require.Equal(t, key.StartsAt, startsAt) + require.Equal(t, key.DeletesAt.Time, deletesAt) + require.Equal(t, key.Sequence, sequence) + + secret, err := hex.DecodeString(key.Secret.String) + require.NoError(t, err) + + switch key.Feature { + case database.CryptoKeyFeatureOidcConvert: + require.Len(t, secret, 32) + case database.CryptoKeyFeatureWorkspaceApps: + require.Len(t, secret, 96) + case database.CryptoKeyFeaturePeerReconnect: + require.Len(t, secret, 64) + default: + t.Fatalf("unknown key feature: %s", key.Feature) + } +} + +func requireContainsAllFeatures(t *testing.T, keys []database.CryptoKey) { + t.Helper() + + features := make(map[database.CryptoKeyFeature]bool) + for _, key := range keys { + features[key.Feature] = true + } + require.True(t, features[database.CryptoKeyFeatureOidcConvert]) + require.True(t, features[database.CryptoKeyFeatureWorkspaceApps]) + require.True(t, features[database.CryptoKeyFeaturePeerReconnect]) +} diff --git a/coderd/keyrotate/rotate_test.go b/coderd/keyrotate/rotate_test.go deleted file mode 100644 index e2392ac27115c..0000000000000 --- a/coderd/keyrotate/rotate_test.go +++ /dev/null @@ -1,94 +0,0 @@ -package keyrotate - -import ( - "testing" - "time" - - "github.com/stretchr/testify/require" - - "cdr.dev/slog" - "cdr.dev/slog/sloggers/slogtest" - "github.com/coder/coder/v2/coderd/database" - "github.com/coder/coder/v2/coderd/database/dbtestutil" - "github.com/coder/coder/v2/testutil" - "github.com/coder/quartz" -) - -func TestKeyRotator(t *testing.T) { - t.Parallel() - - t.Run("NoExistingKeys", func(t *testing.T) { - t.Parallel() - - var ( - db, _ = dbtestutil.NewDB(t) - clock = quartz.NewMock(t) - logger = slogtest.Make(t, nil).Leveled(slog.LevelDebug) - ctx = testutil.Context(t, testutil.WaitShort) - resultsCh = make(chan []database.CryptoKey, 1) - ) - - kr := &KeyRotator{ - DB: db, - KeyDuration: 0, - Clock: clock, - Logger: logger, - ScanInterval: 0, - ResultsCh: resultsCh, - } - - go kr.Start(ctx) - - keys := testutil.RequireRecvCtx(ctx, t, resultsCh) - require.Len(t, keys, len(database.AllCryptoKeyFeatureValues())) - }) - - t.Run("RotatesKeysNearExpiration", func(t *testing.T) { - t.Parallel() - - var ( - db, _ = dbtestutil.NewDB(t) - clock = quartz.NewMock(t) - keyDuration = time.Hour * 24 * 7 - logger = slogtest.Make(t, nil).Leveled(slog.LevelDebug) - ctx = testutil.Context(t, testutil.WaitShort) - resultsCh = make(chan []database.CryptoKey, 1) - scanInterval = time.Minute * 10 - ) - - kr := &KeyRotator{ - DB: db, - KeyDuration: keyDuration, - Clock: clock, - Logger: logger, - ScanInterval: scanInterval, - ResultsCh: resultsCh, - } - keys, err := kr.rotateKeys(ctx) - require.NoError(t, err) - require.Len(t, keys, len(database.AllCryptoKeyFeatureValues())) - - clock.Advance(keyDuration - time.Minute*59) - keys, err = kr.rotateKeys(ctx) - require.NoError(t, err) - require.Len(t, keys, 2*len(database.AllCryptoKeyFeatureValues())) - }) - - t.Run("DoesNotRotateValidKeys", func(t *testing.T) { - }) - - t.Run("DeletesExpiredKeys", func(t *testing.T) { - t.Parallel() - // TODO: Implement test for deleting expired keys - }) - - t.Run("HandlesMultipleKeyTypes", func(t *testing.T) { - t.Parallel() - // TODO: Implement test for handling multiple key types - }) - - t.Run("GracefullyHandlesErrors", func(t *testing.T) { - t.Parallel() - // TODO: Implement test for error handling - }) -} From 86195f01a48f8fbc016d8a8341b48bc2704c39cb Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Thu, 12 Sep 2024 23:11:20 +0000 Subject: [PATCH 4/5] tests are passing --- coderd/database/dbgen/dbgen.go | 16 ++-- coderd/database/modelmethods.go | 2 +- coderd/database/queries.sql.go | 5 +- coderd/database/queries/keys.sql | 14 +-- coderd/keyrotate/rotate.go | 13 ++- coderd/keyrotate/rotate_internal_test.go | 110 +++++++++++++---------- coderd/keyrotate/rotate_test.go | 60 +++++++++++++ 7 files changed, 137 insertions(+), 83 deletions(-) create mode 100644 coderd/keyrotate/rotate_test.go diff --git a/coderd/database/dbgen/dbgen.go b/coderd/database/dbgen/dbgen.go index 74aea31767c44..5dc3b0fa3eb73 100644 --- a/coderd/database/dbgen/dbgen.go +++ b/coderd/database/dbgen/dbgen.go @@ -2,6 +2,7 @@ package dbgen import ( "context" + "crypto/rand" "crypto/sha256" "database/sql" "encoding/hex" @@ -896,22 +897,17 @@ func CustomRole(t testing.TB, db database.Store, seed database.CustomRole) datab func CryptoKey(t testing.TB, db database.Store, seed database.CryptoKey) database.CryptoKey { t.Helper() - secret, err := cryptorand.String(10) + b := make([]byte, 96) + _, err := rand.Read(b) require.NoError(t, err, "generate secret") - sequence, err := cryptorand.Intn(1<<31 - 1) - require.NoError(t, err, "generate sequence") - - feature, err := cryptorand.Element(database.AllCryptoKeyFeatureValues()) - require.NoError(t, err, "generate feature") - key, err := db.InsertCryptoKey(genCtx, database.InsertCryptoKeyParams{ - Sequence: takeFirst(seed.Sequence, int32(sequence)), + Sequence: takeFirst(seed.Sequence, 123), Secret: takeFirst(seed.Secret, sql.NullString{ - String: secret, + String: hex.EncodeToString(b), Valid: true, }), - Feature: takeFirst(seed.Feature, feature), + Feature: takeFirst(seed.Feature, database.CryptoKeyFeatureWorkspaceApps), StartsAt: takeFirst(seed.StartsAt, time.Now()), }) require.NoError(t, err, "insert crypto key") diff --git a/coderd/database/modelmethods.go b/coderd/database/modelmethods.go index 8b5a18cc90c3a..82be5e710c058 100644 --- a/coderd/database/modelmethods.go +++ b/coderd/database/modelmethods.go @@ -449,5 +449,5 @@ func (r GetAuthorizationUserRolesRow) RoleNames() ([]rbac.RoleIdentifier, error) } func (k CryptoKey) ExpiresAt(keyDuration time.Duration) time.Time { - return k.StartsAt.Add(keyDuration) + return k.StartsAt.Add(keyDuration).UTC() } diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index ef29d829ea3b8..aecf3c0b3d44c 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -3194,18 +3194,15 @@ FROM crypto_keys WHERE feature = $1 AND sequence = $2 AND secret IS NOT NULL - AND $3 >= starts_at - AND ($3 < deletes_at OR deletes_at IS NULL) ` type GetCryptoKeyByFeatureAndSequenceParams struct { Feature CryptoKeyFeature `db:"feature" json:"feature"` Sequence int32 `db:"sequence" json:"sequence"` - Time time.Time `db:"time" json:"time"` } func (q *sqlQuerier) GetCryptoKeyByFeatureAndSequence(ctx context.Context, arg GetCryptoKeyByFeatureAndSequenceParams) (CryptoKey, error) { - row := q.db.QueryRowContext(ctx, getCryptoKeyByFeatureAndSequence, arg.Feature, arg.Sequence, arg.Time) + row := q.db.QueryRowContext(ctx, getCryptoKeyByFeatureAndSequence, arg.Feature, arg.Sequence) var i CryptoKey err := row.Scan( &i.Feature, diff --git a/coderd/database/queries/keys.sql b/coderd/database/queries/keys.sql index 8d1c60e637b50..64cca798d44d9 100644 --- a/coderd/database/queries/keys.sql +++ b/coderd/database/queries/keys.sql @@ -16,9 +16,7 @@ SELECT * FROM crypto_keys WHERE feature = $1 AND sequence = $2 - AND secret IS NOT NULL - AND @time >= starts_at - AND (@time < deletes_at OR deletes_at IS NULL); + AND secret IS NOT NULL; -- name: DeleteCryptoKey :one UPDATE crypto_keys @@ -42,13 +40,3 @@ INSERT INTO crypto_keys ( UPDATE crypto_keys SET deletes_at = $3 WHERE feature = $1 AND sequence = $2 RETURNING *; - - - - - - - - - - diff --git a/coderd/keyrotate/rotate.go b/coderd/keyrotate/rotate.go index bd1caabadc8df..e9e4305a99aab 100644 --- a/coderd/keyrotate/rotate.go +++ b/coderd/keyrotate/rotate.go @@ -73,8 +73,7 @@ func (k *KeyRotator) rotateKeys(ctx context.Context) ([]database.CryptoKey, erro // Groups the keys by feature so that we can // ensure we have at least one key for each feature. keysByFeature := keysByFeature(keys, k.features) - - now := dbtime.Time(k.Clock.Now()) + now := dbtime.Time(k.Clock.Now().UTC()) for feature, keys := range keysByFeature { // It's possible there are no keys if someone // has manually deleted all the keys. @@ -134,7 +133,7 @@ func (k *KeyRotator) insertNewKey(ctx context.Context, tx database.Store, featur String: secret, Valid: true, }, - StartsAt: now, + StartsAt: now.UTC(), }) if err != nil { return database.CryptoKey{}, xerrors.Errorf("inserting new key: %w", err) @@ -161,7 +160,7 @@ func (k *KeyRotator) rotateKey(ctx context.Context, tx database.Store, key datab String: secret, Valid: true, }, - StartsAt: newStartsAt, + StartsAt: newStartsAt.UTC(), }) if err != nil { return nil, xerrors.Errorf("inserting new key: %w", err) @@ -174,7 +173,7 @@ func (k *KeyRotator) rotateKey(ctx context.Context, tx database.Store, key datab Feature: key.Feature, Sequence: key.Sequence, DeletesAt: sql.NullTime{ - Time: deletesAt, + Time: deletesAt.UTC(), Valid: true, }, }) @@ -220,7 +219,7 @@ func tokenDuration(feature database.CryptoKeyFeature) time.Duration { } func shouldDeleteKey(key database.CryptoKey, now time.Time) bool { - return key.DeletesAt.Valid && key.DeletesAt.Time.After(now) + return key.DeletesAt.Valid && key.DeletesAt.Time.UTC().After(now.UTC()) } func shouldRotateKey(key database.CryptoKey, keyDuration time.Duration, now time.Time) bool { @@ -229,7 +228,7 @@ func shouldRotateKey(key database.CryptoKey, keyDuration time.Duration, now time return false } expirationTime := key.ExpiresAt(keyDuration) - return now.Add(time.Hour).After(expirationTime) + return now.Add(time.Hour).UTC().After(expirationTime.UTC()) } func keysByFeature(keys []database.CryptoKey, features []database.CryptoKeyFeature) map[database.CryptoKeyFeature][]database.CryptoKey { diff --git a/coderd/keyrotate/rotate_internal_test.go b/coderd/keyrotate/rotate_internal_test.go index b78e26ad59707..a827c96b73132 100644 --- a/coderd/keyrotate/rotate_internal_test.go +++ b/coderd/keyrotate/rotate_internal_test.go @@ -21,42 +21,6 @@ import ( func Test_rotateKeys(t *testing.T) { t.Parallel() - t.Run("NoExistingKeys", func(t *testing.T) { - t.Parallel() - - var ( - db, _ = dbtestutil.NewDB(t) - clock = quartz.NewMock(t) - logger = slogtest.Make(t, nil).Leveled(slog.LevelDebug) - ctx = testutil.Context(t, testutil.WaitShort) - resultsCh = make(chan []database.CryptoKey, 1) - ) - - kr := &KeyRotator{ - DB: db, - KeyDuration: 0, - Clock: clock, - Logger: logger, - ScanInterval: 0, - ResultsCh: resultsCh, - } - - now := clock.Now() - keys, err := kr.rotateKeys(ctx) - require.NoError(t, err) - require.Len(t, keys, len(database.AllCryptoKeyFeatureValues())) - - // Fetch the keys from the database and ensure they - // are as expected. - dbkeys, err := db.GetCryptoKeys(ctx) - require.NoError(t, err) - require.Equal(t, keys, dbkeys) - requireContainsAllFeatures(t, keys) - for _, key := range keys { - requireKey(t, key, key.Feature, now, time.Time{}, 1) - } - }) - t.Run("RotatesKeysNearExpiration", func(t *testing.T) { t.Parallel() @@ -81,7 +45,7 @@ func Test_rotateKeys(t *testing.T) { }, } - now := dbtime.Time(clock.Now()) + now := dbnow(clock) // Seed the database with an existing key. oldKey := dbgen.CryptoKey(t, db, database.CryptoKey{ @@ -96,17 +60,16 @@ func Test_rotateKeys(t *testing.T) { require.NoError(t, err) require.Len(t, keys, 2) - now = dbtime.Time(clock.Now()) - expectedDeletesAt := now.Add(WorkspaceAppsTokenDuration + time.Hour) + now = dbnow(clock) + expectedDeletesAt := oldKey.ExpiresAt(keyDuration).Add(WorkspaceAppsTokenDuration + time.Hour) // Fetch the old key, it should have an expires_at now. oldKey, err = db.GetCryptoKeyByFeatureAndSequence(ctx, database.GetCryptoKeyByFeatureAndSequenceParams{ Feature: oldKey.Feature, Sequence: oldKey.Sequence, - Time: now, }) require.NoError(t, err) - require.Equal(t, oldKey.DeletesAt.Time, expectedDeletesAt) + require.Equal(t, oldKey.DeletesAt.Time.UTC(), expectedDeletesAt) // The new key should be created and have a starts_at of the old key's expires_at. newKey, err := db.GetCryptoKeyByFeatureAndSequence(ctx, database.GetCryptoKeyByFeatureAndSequenceParams{ @@ -114,15 +77,17 @@ func Test_rotateKeys(t *testing.T) { Sequence: oldKey.Sequence + 1, }) require.NoError(t, err) - requireKey(t, newKey, database.CryptoKeyFeatureWorkspaceApps, oldKey.ExpiresAt(keyDuration), expectedDeletesAt, oldKey.Sequence+1) + requireKey(t, newKey, database.CryptoKeyFeatureWorkspaceApps, oldKey.ExpiresAt(keyDuration), time.Time{}, oldKey.Sequence+1) - clock.Advance(oldKey.DeletesAt.Time.Sub(now) + time.Second) + // Advance the clock just past the keys delete time. + clock.Advance(oldKey.DeletesAt.Time.UTC().Sub(now) - time.Second) + // We should have deleted the old key. keys, err = kr.rotateKeys(ctx) require.NoError(t, err) require.Len(t, keys, 1) - // The old key should be deleted. + // The old key should be "deleted". _, err = db.GetCryptoKeyByFeatureAndSequence(ctx, database.GetCryptoKeyByFeatureAndSequenceParams{ Feature: oldKey.Feature, Sequence: oldKey.Sequence, @@ -131,6 +96,51 @@ func Test_rotateKeys(t *testing.T) { }) t.Run("DoesNotRotateValidKeys", func(t *testing.T) { + t.Parallel() + + var ( + db, _ = dbtestutil.NewDB(t) + clock = quartz.NewMock(t) + keyDuration = time.Hour * 24 * 7 + logger = slogtest.Make(t, nil).Leveled(slog.LevelDebug) + ctx = testutil.Context(t, testutil.WaitShort) + resultsCh = make(chan []database.CryptoKey, 1) + ) + + kr := &KeyRotator{ + DB: db, + KeyDuration: keyDuration, + Clock: clock, + Logger: logger, + ScanInterval: 0, + ResultsCh: resultsCh, + features: []database.CryptoKeyFeature{ + database.CryptoKeyFeatureWorkspaceApps, + }, + } + + now := dbnow(clock) + + // Seed the database with an existing key + existingKey := dbgen.CryptoKey(t, db, database.CryptoKey{ + Feature: database.CryptoKeyFeatureWorkspaceApps, + StartsAt: now, + Sequence: 1, + }) + + // Advance the clock by 6 days, 23 hours. Once we + // breach the last hour we will insert a new key. + clock.Advance(keyDuration - time.Hour) + + keys, err := kr.rotateKeys(ctx) + require.NoError(t, err) + require.Empty(t, keys) + + // Verify that the existing key is still the only key in the database + dbKeys, err := db.GetCryptoKeys(ctx) + require.NoError(t, err) + require.Len(t, dbKeys, 1) + requireKey(t, dbKeys[0], existingKey.Feature, existingKey.StartsAt.UTC(), existingKey.DeletesAt.Time.UTC(), existingKey.Sequence) }) t.Run("DeletesExpiredKeys", func(t *testing.T) { @@ -149,12 +159,16 @@ func Test_rotateKeys(t *testing.T) { }) } +func dbnow(c quartz.Clock) time.Time { + return dbtime.Time(c.Now().UTC()) +} + func requireKey(t *testing.T, key database.CryptoKey, feature database.CryptoKeyFeature, startsAt time.Time, deletesAt time.Time, sequence int32) { t.Helper() - require.Equal(t, key.Feature, feature) - require.Equal(t, key.StartsAt, startsAt) - require.Equal(t, key.DeletesAt.Time, deletesAt) - require.Equal(t, key.Sequence, sequence) + require.Equal(t, feature, key.Feature) + require.Equal(t, startsAt, key.StartsAt.UTC()) + require.Equal(t, deletesAt, key.DeletesAt.Time.UTC()) + require.Equal(t, sequence, key.Sequence) secret, err := hex.DecodeString(key.Secret.String) require.NoError(t, err) diff --git a/coderd/keyrotate/rotate_test.go b/coderd/keyrotate/rotate_test.go new file mode 100644 index 0000000000000..5f82a0647c99d --- /dev/null +++ b/coderd/keyrotate/rotate_test.go @@ -0,0 +1,60 @@ +package keyrotate_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/coderd/database" +) + +func TestKeyRotator(t *testing.T) { + t.Run("NoExistingKeys", func(t *testing.T) { + // t.Parallel() + + // var ( + // db, _ = dbtestutil.NewDB(t) + // clock = quartz.NewMock(t) + // logger = slogtest.Make(t, nil).Leveled(slog.LevelDebug) + // ctx = testutil.Context(t, testutil.WaitShort) + // resultsCh = make(chan []database.CryptoKey, 1) + // ) + + // kr := &KeyRotator{ + // DB: db, + // KeyDuration: 0, + // Clock: clock, + // Logger: logger, + // ScanInterval: 0, + // ResultsCh: resultsCh, + // } + + // now := dbnow(clock) + // keys, err := kr.rotateKeys(ctx) + // require.NoError(t, err) + // require.Len(t, keys, len(database.AllCryptoKeyFeatureValues())) + + // // Fetch the keys from the database and ensure they + // // are as expected. + // dbkeys, err := db.GetCryptoKeys(ctx) + // require.NoError(t, err) + // require.Equal(t, keys, dbkeys) + // requireContainsAllFeatures(t, keys) + // for _, key := range keys { + // requireKey(t, key, key.Feature, now, time.Time{}, 1) + // } + }) + +} + +func requireContainsAllFeatures(t *testing.T, keys []database.CryptoKey) { + t.Helper() + + features := make(map[database.CryptoKeyFeature]bool) + for _, key := range keys { + features[key.Feature] = true + } + require.True(t, features[database.CryptoKeyFeatureOidcConvert]) + require.True(t, features[database.CryptoKeyFeatureWorkspaceApps]) + require.True(t, features[database.CryptoKeyFeaturePeerReconnect]) +} From 8980d32be53fc67a3fae667d11d87ea2293ac10c Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Fri, 13 Sep 2024 00:36:43 +0000 Subject: [PATCH 5/5] dbcrypt --- coderd/database/dbgen/dbgen.go | 5 +- coderd/database/dbmem/dbmem.go | 35 +- coderd/database/dump.sql | 4 + coderd/database/foreign_key_constraint.go | 1 + .../migrations/000250_crypto_keys.up.sql | 17 +- coderd/database/models.go | 11 +- coderd/database/queries.sql.go | 350 +++++++++--------- .../queries/{keys.sql => crypto_keys.sql} | 6 +- coderd/keyrotate/rotate_internal_test.go | 12 - enterprise/dbcrypt/dbcrypt.go | 25 ++ enterprise/dbcrypt/dbcrypt_internal_test.go | 45 +++ 11 files changed, 308 insertions(+), 203 deletions(-) rename coderd/database/queries/{keys.sql => crypto_keys.sql} (94%) diff --git a/coderd/database/dbgen/dbgen.go b/coderd/database/dbgen/dbgen.go index 5dc3b0fa3eb73..06e40287cff29 100644 --- a/coderd/database/dbgen/dbgen.go +++ b/coderd/database/dbgen/dbgen.go @@ -907,8 +907,9 @@ func CryptoKey(t testing.TB, db database.Store, seed database.CryptoKey) databas String: hex.EncodeToString(b), Valid: true, }), - Feature: takeFirst(seed.Feature, database.CryptoKeyFeatureWorkspaceApps), - StartsAt: takeFirst(seed.StartsAt, time.Now()), + SecretKeyID: takeFirst(seed.SecretKeyID, sql.NullString{}), + Feature: takeFirst(seed.Feature, database.CryptoKeyFeatureWorkspaceApps), + StartsAt: takeFirst(seed.StartsAt, time.Now()), }) require.NoError(t, err, "insert crypto key") diff --git a/coderd/database/dbmem/dbmem.go b/coderd/database/dbmem/dbmem.go index d3a4342b9c7a8..774d9296e51bc 100644 --- a/coderd/database/dbmem/dbmem.go +++ b/coderd/database/dbmem/dbmem.go @@ -153,6 +153,7 @@ type data struct { // New tables workspaceAgentStats []database.WorkspaceAgentStat auditLogs []database.AuditLog + cryptoKeys []database.CryptoKey dbcryptKeys []database.DBCryptKey files []database.File externalAuthLinks []database.ExternalAuthLink @@ -2318,13 +2319,26 @@ func (q *FakeQuerier) GetCoordinatorResumeTokenSigningKey(_ context.Context) (st return q.coordinatorResumeTokenSigningKey, nil } -func (q *FakeQuerier) GetCryptoKeyByFeatureAndSequence(ctx context.Context, arg database.GetCryptoKeyByFeatureAndSequenceParams) (database.CryptoKey, error) { +func (q *FakeQuerier) GetCryptoKeyByFeatureAndSequence(_ context.Context, arg database.GetCryptoKeyByFeatureAndSequenceParams) (database.CryptoKey, error) { err := validateDatabaseType(arg) if err != nil { return database.CryptoKey{}, err } - panic("not implemented") + q.mutex.RLock() + defer q.mutex.RUnlock() + + for _, key := range q.cryptoKeys { + if key.Feature == arg.Feature && key.Sequence == arg.Sequence { + // Keys with NULL secrets are considered deleted. + if key.Secret.Valid { + return key, nil + } + return database.CryptoKey{}, sql.ErrNoRows + } + } + + return database.CryptoKey{}, sql.ErrNoRows } func (q *FakeQuerier) GetCryptoKeys(ctx context.Context) ([]database.CryptoKey, error) { @@ -6331,13 +6345,26 @@ func (q *FakeQuerier) InsertAuditLog(_ context.Context, arg database.InsertAudit return alog, nil } -func (q *FakeQuerier) InsertCryptoKey(ctx context.Context, arg database.InsertCryptoKeyParams) (database.CryptoKey, error) { +func (q *FakeQuerier) InsertCryptoKey(_ context.Context, arg database.InsertCryptoKeyParams) (database.CryptoKey, error) { err := validateDatabaseType(arg) if err != nil { return database.CryptoKey{}, err } - panic("not implemented") + q.mutex.Lock() + defer q.mutex.Unlock() + + key := database.CryptoKey{ + Feature: arg.Feature, + Sequence: arg.Sequence, + Secret: arg.Secret, + SecretKeyID: arg.SecretKeyID, + StartsAt: arg.StartsAt, + } + + q.cryptoKeys = append(q.cryptoKeys, key) + + return key, nil } func (q *FakeQuerier) InsertCustomRole(_ context.Context, arg database.InsertCustomRoleParams) (database.CustomRole, error) { diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index 71f21a3d5a75a..17fd3511442ec 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -504,6 +504,7 @@ CREATE TABLE crypto_keys ( feature crypto_key_feature NOT NULL, sequence integer NOT NULL, secret text, + secret_key_id text, starts_at timestamp with time zone NOT NULL, deletes_at timestamp with time zone ); @@ -2052,6 +2053,9 @@ CREATE TRIGGER update_notification_message_dedupe_hash BEFORE INSERT OR UPDATE O ALTER TABLE ONLY api_keys ADD CONSTRAINT api_keys_user_id_uuid_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE; +ALTER TABLE ONLY crypto_keys + ADD CONSTRAINT crypto_keys_secret_key_id_fkey FOREIGN KEY (secret_key_id) REFERENCES dbcrypt_keys(active_key_digest); + ALTER TABLE ONLY external_auth_links ADD CONSTRAINT git_auth_links_oauth_access_token_key_id_fkey FOREIGN KEY (oauth_access_token_key_id) REFERENCES dbcrypt_keys(active_key_digest); diff --git a/coderd/database/foreign_key_constraint.go b/coderd/database/foreign_key_constraint.go index 0c578255f091c..6046a94e3bcad 100644 --- a/coderd/database/foreign_key_constraint.go +++ b/coderd/database/foreign_key_constraint.go @@ -7,6 +7,7 @@ type ForeignKeyConstraint string // ForeignKeyConstraint enums. const ( ForeignKeyAPIKeysUserIDUUID ForeignKeyConstraint = "api_keys_user_id_uuid_fkey" // ALTER TABLE ONLY api_keys ADD CONSTRAINT api_keys_user_id_uuid_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE; + ForeignKeyCryptoKeysSecretKeyID ForeignKeyConstraint = "crypto_keys_secret_key_id_fkey" // ALTER TABLE ONLY crypto_keys ADD CONSTRAINT crypto_keys_secret_key_id_fkey FOREIGN KEY (secret_key_id) REFERENCES dbcrypt_keys(active_key_digest); ForeignKeyGitAuthLinksOauthAccessTokenKeyID ForeignKeyConstraint = "git_auth_links_oauth_access_token_key_id_fkey" // ALTER TABLE ONLY external_auth_links ADD CONSTRAINT git_auth_links_oauth_access_token_key_id_fkey FOREIGN KEY (oauth_access_token_key_id) REFERENCES dbcrypt_keys(active_key_digest); ForeignKeyGitAuthLinksOauthRefreshTokenKeyID ForeignKeyConstraint = "git_auth_links_oauth_refresh_token_key_id_fkey" // ALTER TABLE ONLY external_auth_links ADD CONSTRAINT git_auth_links_oauth_refresh_token_key_id_fkey FOREIGN KEY (oauth_refresh_token_key_id) REFERENCES dbcrypt_keys(active_key_digest); ForeignKeyGitSSHKeysUserID ForeignKeyConstraint = "gitsshkeys_user_id_fkey" // ALTER TABLE ONLY gitsshkeys ADD CONSTRAINT gitsshkeys_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id); diff --git a/coderd/database/migrations/000250_crypto_keys.up.sql b/coderd/database/migrations/000250_crypto_keys.up.sql index f6f1457e9eb92..7c1aa7888fdd1 100644 --- a/coderd/database/migrations/000250_crypto_keys.up.sql +++ b/coderd/database/migrations/000250_crypto_keys.up.sql @@ -1,15 +1,16 @@ -CREATE TYPE "crypto_key_feature" AS ENUM ( +CREATE TYPE crypto_key_feature AS ENUM ( 'workspace_apps', 'oidc_convert', 'peer_reconnect' ); -CREATE TABLE "crypto_keys" ( - "feature" "crypto_key_feature" NOT NULL, - "sequence" integer NOT NULL, - "secret" text NULL, - "starts_at" timestamptz NOT NULL, - "deletes_at" timestamptz NULL, - PRIMARY KEY ("feature", "sequence") +CREATE TABLE crypto_keys ( + feature crypto_key_feature NOT NULL, + sequence integer NOT NULL, + secret text NULL, + secret_key_id text NULL REFERENCES dbcrypt_keys(active_key_digest), + starts_at timestamptz NOT NULL, + deletes_at timestamptz NULL, + PRIMARY KEY (feature, sequence) ); diff --git a/coderd/database/models.go b/coderd/database/models.go index b2217d96f2978..e9bb8e42b8960 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -2105,11 +2105,12 @@ type AuditLog struct { } type CryptoKey struct { - Feature CryptoKeyFeature `db:"feature" json:"feature"` - Sequence int32 `db:"sequence" json:"sequence"` - Secret sql.NullString `db:"secret" json:"secret"` - StartsAt time.Time `db:"starts_at" json:"starts_at"` - DeletesAt sql.NullTime `db:"deletes_at" json:"deletes_at"` + Feature CryptoKeyFeature `db:"feature" json:"feature"` + Sequence int32 `db:"sequence" json:"sequence"` + Secret sql.NullString `db:"secret" json:"secret"` + SecretKeyID sql.NullString `db:"secret_key_id" json:"secret_key_id"` + StartsAt time.Time `db:"starts_at" json:"starts_at"` + DeletesAt sql.NullTime `db:"deletes_at" json:"deletes_at"` } // Custom roles allow dynamic roles expanded at runtime diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index aecf3c0b3d44c..b5c198c94da42 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -761,6 +761,186 @@ func (q *sqlQuerier) InsertAuditLog(ctx context.Context, arg InsertAuditLogParam return i, err } +const deleteCryptoKey = `-- name: DeleteCryptoKey :one +UPDATE crypto_keys +SET secret = NULL +WHERE feature = $1 AND sequence = $2 RETURNING feature, sequence, secret, secret_key_id, starts_at, deletes_at +` + +type DeleteCryptoKeyParams struct { + Feature CryptoKeyFeature `db:"feature" json:"feature"` + Sequence int32 `db:"sequence" json:"sequence"` +} + +func (q *sqlQuerier) DeleteCryptoKey(ctx context.Context, arg DeleteCryptoKeyParams) (CryptoKey, error) { + row := q.db.QueryRowContext(ctx, deleteCryptoKey, arg.Feature, arg.Sequence) + var i CryptoKey + err := row.Scan( + &i.Feature, + &i.Sequence, + &i.Secret, + &i.SecretKeyID, + &i.StartsAt, + &i.DeletesAt, + ) + return i, err +} + +const getCryptoKeyByFeatureAndSequence = `-- name: GetCryptoKeyByFeatureAndSequence :one +SELECT feature, sequence, secret, secret_key_id, starts_at, deletes_at +FROM crypto_keys +WHERE feature = $1 + AND sequence = $2 + AND secret IS NOT NULL +` + +type GetCryptoKeyByFeatureAndSequenceParams struct { + Feature CryptoKeyFeature `db:"feature" json:"feature"` + Sequence int32 `db:"sequence" json:"sequence"` +} + +func (q *sqlQuerier) GetCryptoKeyByFeatureAndSequence(ctx context.Context, arg GetCryptoKeyByFeatureAndSequenceParams) (CryptoKey, error) { + row := q.db.QueryRowContext(ctx, getCryptoKeyByFeatureAndSequence, arg.Feature, arg.Sequence) + var i CryptoKey + err := row.Scan( + &i.Feature, + &i.Sequence, + &i.Secret, + &i.SecretKeyID, + &i.StartsAt, + &i.DeletesAt, + ) + return i, err +} + +const getCryptoKeys = `-- name: GetCryptoKeys :many +SELECT feature, sequence, secret, secret_key_id, starts_at, deletes_at +FROM crypto_keys +WHERE secret IS NOT NULL +` + +func (q *sqlQuerier) GetCryptoKeys(ctx context.Context) ([]CryptoKey, error) { + rows, err := q.db.QueryContext(ctx, getCryptoKeys) + if err != nil { + return nil, err + } + defer rows.Close() + var items []CryptoKey + for rows.Next() { + var i CryptoKey + if err := rows.Scan( + &i.Feature, + &i.Sequence, + &i.Secret, + &i.SecretKeyID, + &i.StartsAt, + &i.DeletesAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getLatestCryptoKeyByFeature = `-- name: GetLatestCryptoKeyByFeature :one +SELECT feature, sequence, secret, secret_key_id, starts_at, deletes_at +FROM crypto_keys +WHERE feature = $1 +ORDER BY sequence DESC +LIMIT 1 +` + +func (q *sqlQuerier) GetLatestCryptoKeyByFeature(ctx context.Context, feature CryptoKeyFeature) (CryptoKey, error) { + row := q.db.QueryRowContext(ctx, getLatestCryptoKeyByFeature, feature) + var i CryptoKey + err := row.Scan( + &i.Feature, + &i.Sequence, + &i.Secret, + &i.SecretKeyID, + &i.StartsAt, + &i.DeletesAt, + ) + return i, err +} + +const insertCryptoKey = `-- name: InsertCryptoKey :one +INSERT INTO crypto_keys ( + feature, + sequence, + secret, + starts_at, + secret_key_id +) VALUES ( + $1, + $2, + $3, + $4, + $5 +) RETURNING feature, sequence, secret, secret_key_id, starts_at, deletes_at +` + +type InsertCryptoKeyParams struct { + Feature CryptoKeyFeature `db:"feature" json:"feature"` + Sequence int32 `db:"sequence" json:"sequence"` + Secret sql.NullString `db:"secret" json:"secret"` + StartsAt time.Time `db:"starts_at" json:"starts_at"` + SecretKeyID sql.NullString `db:"secret_key_id" json:"secret_key_id"` +} + +func (q *sqlQuerier) InsertCryptoKey(ctx context.Context, arg InsertCryptoKeyParams) (CryptoKey, error) { + row := q.db.QueryRowContext(ctx, insertCryptoKey, + arg.Feature, + arg.Sequence, + arg.Secret, + arg.StartsAt, + arg.SecretKeyID, + ) + var i CryptoKey + err := row.Scan( + &i.Feature, + &i.Sequence, + &i.Secret, + &i.SecretKeyID, + &i.StartsAt, + &i.DeletesAt, + ) + return i, err +} + +const updateCryptoKeyDeletesAt = `-- name: UpdateCryptoKeyDeletesAt :one +UPDATE crypto_keys +SET deletes_at = $3 +WHERE feature = $1 AND sequence = $2 RETURNING feature, sequence, secret, secret_key_id, starts_at, deletes_at +` + +type UpdateCryptoKeyDeletesAtParams struct { + Feature CryptoKeyFeature `db:"feature" json:"feature"` + Sequence int32 `db:"sequence" json:"sequence"` + DeletesAt sql.NullTime `db:"deletes_at" json:"deletes_at"` +} + +func (q *sqlQuerier) UpdateCryptoKeyDeletesAt(ctx context.Context, arg UpdateCryptoKeyDeletesAtParams) (CryptoKey, error) { + row := q.db.QueryRowContext(ctx, updateCryptoKeyDeletesAt, arg.Feature, arg.Sequence, arg.DeletesAt) + var i CryptoKey + err := row.Scan( + &i.Feature, + &i.Sequence, + &i.Secret, + &i.SecretKeyID, + &i.StartsAt, + &i.DeletesAt, + ) + return i, err +} + const getDBCryptKeys = `-- name: GetDBCryptKeys :many SELECT number, active_key_digest, revoked_key_digest, created_at, revoked_at, test FROM dbcrypt_keys ORDER BY number ASC ` @@ -3164,176 +3344,6 @@ func (q *sqlQuerier) UpsertJFrogXrayScanByWorkspaceAndAgentID(ctx context.Contex return err } -const deleteCryptoKey = `-- name: DeleteCryptoKey :one -UPDATE crypto_keys -SET secret = NULL -WHERE feature = $1 AND sequence = $2 RETURNING feature, sequence, secret, starts_at, deletes_at -` - -type DeleteCryptoKeyParams struct { - Feature CryptoKeyFeature `db:"feature" json:"feature"` - Sequence int32 `db:"sequence" json:"sequence"` -} - -func (q *sqlQuerier) DeleteCryptoKey(ctx context.Context, arg DeleteCryptoKeyParams) (CryptoKey, error) { - row := q.db.QueryRowContext(ctx, deleteCryptoKey, arg.Feature, arg.Sequence) - var i CryptoKey - err := row.Scan( - &i.Feature, - &i.Sequence, - &i.Secret, - &i.StartsAt, - &i.DeletesAt, - ) - return i, err -} - -const getCryptoKeyByFeatureAndSequence = `-- name: GetCryptoKeyByFeatureAndSequence :one -SELECT feature, sequence, secret, starts_at, deletes_at -FROM crypto_keys -WHERE feature = $1 - AND sequence = $2 - AND secret IS NOT NULL -` - -type GetCryptoKeyByFeatureAndSequenceParams struct { - Feature CryptoKeyFeature `db:"feature" json:"feature"` - Sequence int32 `db:"sequence" json:"sequence"` -} - -func (q *sqlQuerier) GetCryptoKeyByFeatureAndSequence(ctx context.Context, arg GetCryptoKeyByFeatureAndSequenceParams) (CryptoKey, error) { - row := q.db.QueryRowContext(ctx, getCryptoKeyByFeatureAndSequence, arg.Feature, arg.Sequence) - var i CryptoKey - err := row.Scan( - &i.Feature, - &i.Sequence, - &i.Secret, - &i.StartsAt, - &i.DeletesAt, - ) - return i, err -} - -const getCryptoKeys = `-- name: GetCryptoKeys :many -SELECT feature, sequence, secret, starts_at, deletes_at -FROM crypto_keys -WHERE secret IS NOT NULL -` - -func (q *sqlQuerier) GetCryptoKeys(ctx context.Context) ([]CryptoKey, error) { - rows, err := q.db.QueryContext(ctx, getCryptoKeys) - if err != nil { - return nil, err - } - defer rows.Close() - var items []CryptoKey - for rows.Next() { - var i CryptoKey - if err := rows.Scan( - &i.Feature, - &i.Sequence, - &i.Secret, - &i.StartsAt, - &i.DeletesAt, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Close(); err != nil { - return nil, err - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - -const getLatestCryptoKeyByFeature = `-- name: GetLatestCryptoKeyByFeature :one -SELECT feature, sequence, secret, starts_at, deletes_at -FROM crypto_keys -WHERE feature = $1 -ORDER BY sequence DESC -LIMIT 1 -` - -func (q *sqlQuerier) GetLatestCryptoKeyByFeature(ctx context.Context, feature CryptoKeyFeature) (CryptoKey, error) { - row := q.db.QueryRowContext(ctx, getLatestCryptoKeyByFeature, feature) - var i CryptoKey - err := row.Scan( - &i.Feature, - &i.Sequence, - &i.Secret, - &i.StartsAt, - &i.DeletesAt, - ) - return i, err -} - -const insertCryptoKey = `-- name: InsertCryptoKey :one -INSERT INTO crypto_keys ( - feature, - sequence, - secret, - starts_at -) VALUES ( - $1, - $2, - $3, - $4 -) RETURNING feature, sequence, secret, starts_at, deletes_at -` - -type InsertCryptoKeyParams struct { - Feature CryptoKeyFeature `db:"feature" json:"feature"` - Sequence int32 `db:"sequence" json:"sequence"` - Secret sql.NullString `db:"secret" json:"secret"` - StartsAt time.Time `db:"starts_at" json:"starts_at"` -} - -func (q *sqlQuerier) InsertCryptoKey(ctx context.Context, arg InsertCryptoKeyParams) (CryptoKey, error) { - row := q.db.QueryRowContext(ctx, insertCryptoKey, - arg.Feature, - arg.Sequence, - arg.Secret, - arg.StartsAt, - ) - var i CryptoKey - err := row.Scan( - &i.Feature, - &i.Sequence, - &i.Secret, - &i.StartsAt, - &i.DeletesAt, - ) - return i, err -} - -const updateCryptoKeyDeletesAt = `-- name: UpdateCryptoKeyDeletesAt :one -UPDATE crypto_keys -SET deletes_at = $3 -WHERE feature = $1 AND sequence = $2 RETURNING feature, sequence, secret, starts_at, deletes_at -` - -type UpdateCryptoKeyDeletesAtParams struct { - Feature CryptoKeyFeature `db:"feature" json:"feature"` - Sequence int32 `db:"sequence" json:"sequence"` - DeletesAt sql.NullTime `db:"deletes_at" json:"deletes_at"` -} - -func (q *sqlQuerier) UpdateCryptoKeyDeletesAt(ctx context.Context, arg UpdateCryptoKeyDeletesAtParams) (CryptoKey, error) { - row := q.db.QueryRowContext(ctx, updateCryptoKeyDeletesAt, arg.Feature, arg.Sequence, arg.DeletesAt) - var i CryptoKey - err := row.Scan( - &i.Feature, - &i.Sequence, - &i.Secret, - &i.StartsAt, - &i.DeletesAt, - ) - return i, err -} - const deleteLicense = `-- name: DeleteLicense :one DELETE FROM licenses diff --git a/coderd/database/queries/keys.sql b/coderd/database/queries/crypto_keys.sql similarity index 94% rename from coderd/database/queries/keys.sql rename to coderd/database/queries/crypto_keys.sql index 64cca798d44d9..39dc8175f95ab 100644 --- a/coderd/database/queries/keys.sql +++ b/coderd/database/queries/crypto_keys.sql @@ -28,12 +28,14 @@ INSERT INTO crypto_keys ( feature, sequence, secret, - starts_at + starts_at, + secret_key_id ) VALUES ( $1, $2, $3, - $4 + $4, + $5 ) RETURNING *; -- name: UpdateCryptoKeyDeletesAt :one diff --git a/coderd/keyrotate/rotate_internal_test.go b/coderd/keyrotate/rotate_internal_test.go index a827c96b73132..a0f7e9522507a 100644 --- a/coderd/keyrotate/rotate_internal_test.go +++ b/coderd/keyrotate/rotate_internal_test.go @@ -184,15 +184,3 @@ func requireKey(t *testing.T, key database.CryptoKey, feature database.CryptoKey t.Fatalf("unknown key feature: %s", key.Feature) } } - -func requireContainsAllFeatures(t *testing.T, keys []database.CryptoKey) { - t.Helper() - - features := make(map[database.CryptoKeyFeature]bool) - for _, key := range keys { - features[key.Feature] = true - } - require.True(t, features[database.CryptoKeyFeatureOidcConvert]) - require.True(t, features[database.CryptoKeyFeatureWorkspaceApps]) - require.True(t, features[database.CryptoKeyFeaturePeerReconnect]) -} diff --git a/enterprise/dbcrypt/dbcrypt.go b/enterprise/dbcrypt/dbcrypt.go index ec56a4897a1e3..2717ef1d48188 100644 --- a/enterprise/dbcrypt/dbcrypt.go +++ b/enterprise/dbcrypt/dbcrypt.go @@ -261,6 +261,31 @@ func (db *dbCrypt) UpdateExternalAuthLink(ctx context.Context, params database.U return link, nil } +func (db *dbCrypt) GetCryptoKeyByFeatureAndSequence(ctx context.Context, params database.GetCryptoKeyByFeatureAndSequenceParams) (database.CryptoKey, error) { + key, err := db.Store.GetCryptoKeyByFeatureAndSequence(ctx, params) + if err != nil { + return database.CryptoKey{}, err + } + if err := db.decryptField(&key.Secret.String, key.SecretKeyID); err != nil { + return database.CryptoKey{}, err + } + return key, nil +} + +func (db *dbCrypt) InsertCryptoKey(ctx context.Context, params database.InsertCryptoKeyParams) (database.CryptoKey, error) { + if err := db.encryptField(¶ms.Secret.String, ¶ms.SecretKeyID); err != nil { + return database.CryptoKey{}, err + } + key, err := db.Store.InsertCryptoKey(ctx, params) + if err != nil { + return database.CryptoKey{}, err + } + if err := db.decryptField(&key.Secret.String, key.SecretKeyID); err != nil { + return database.CryptoKey{}, err + } + return key, nil +} + func (db *dbCrypt) encryptField(field *string, digest *sql.NullString) error { // If no cipher is loaded, then we can't encrypt anything! if db.ciphers == nil || db.primaryCipherDigest == "" { diff --git a/enterprise/dbcrypt/dbcrypt_internal_test.go b/enterprise/dbcrypt/dbcrypt_internal_test.go index 37fcc8cae55a3..7dad716b8139b 100644 --- a/enterprise/dbcrypt/dbcrypt_internal_test.go +++ b/enterprise/dbcrypt/dbcrypt_internal_test.go @@ -349,6 +349,51 @@ func TestExternalAuthLinks(t *testing.T) { }) } +func TestCryptoKeys(t *testing.T) { + t.Parallel() + ctx := context.Background() + db, crypt, ciphers := setup(t) + + // We don't write a GetCryptoKeyByFeatureAndSequence test + // because it's basically the same as InsertCryptoKey. + t.Run("InsertCryptoKey", func(t *testing.T) { + t.Parallel() + key := dbgen.CryptoKey(t, crypt, database.CryptoKey{ + Secret: sql.NullString{String: "test", Valid: true}, + }) + require.Equal(t, "test", key.Secret.String) + + key, err := db.GetCryptoKeyByFeatureAndSequence(ctx, database.GetCryptoKeyByFeatureAndSequenceParams{ + Feature: key.Feature, + Sequence: key.Sequence, + }) + require.NoError(t, err) + require.Equal(t, ciphers[0].HexDigest(), key.SecretKeyID.String) + requireEncryptedEquals(t, ciphers[0], key.Secret.String, "test") + }) + t.Run("DecryptErr", func(t *testing.T) { + t.Parallel() + db, crypt, ciphers := setup(t) + key := dbgen.CryptoKey(t, db, database.CryptoKey{ + Secret: sql.NullString{ + String: fakeBase64RandomData(t, 32), + Valid: true, + }, + SecretKeyID: sql.NullString{ + String: ciphers[0].HexDigest(), + Valid: true, + }, + }) + _, err := crypt.GetCryptoKeyByFeatureAndSequence(ctx, database.GetCryptoKeyByFeatureAndSequenceParams{ + Feature: key.Feature, + Sequence: key.Sequence, + }) + require.Error(t, err, "expected an error") + var derr *DecryptFailedError + require.ErrorAs(t, err, &derr, "expected a decrypt error") + }) +} + func TestNew(t *testing.T) { t.Parallel()