diff --git a/query_column_add.go b/query_column_add.go index c3c781a1d..5c7e97bd6 100644 --- a/query_column_add.go +++ b/query_column_add.go @@ -137,6 +137,9 @@ func (q *AddColumnQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Res return nil, feature.NewNotSupportError(feature.AlterColumnExists) } + // if a comment is propagated via the context, use it + setCommentFromContext(ctx, q) + queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes()) if err != nil { return nil, err diff --git a/query_column_drop.go b/query_column_drop.go index e66e35b9a..54f8e30b1 100644 --- a/query_column_drop.go +++ b/query_column_drop.go @@ -129,6 +129,9 @@ func (q *DropColumnQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byt //------------------------------------------------------------------------------ func (q *DropColumnQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) { + // if a comment is propagated via the context, use it + setCommentFromContext(ctx, q) + queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes()) if err != nil { return nil, err diff --git a/query_delete.go b/query_delete.go index 99ec37bb7..0a28ac1a9 100644 --- a/query_delete.go +++ b/query_delete.go @@ -321,6 +321,9 @@ func (q *DeleteQuery) scanOrExec( return nil, err } + // if a comment is propagated via the context, use it + setCommentFromContext(ctx, q) + // Generate the query before checking hasReturning. queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes()) if err != nil { diff --git a/query_index_create.go b/query_index_create.go index 4ac4ffd10..76ad515f0 100644 --- a/query_index_create.go +++ b/query_index_create.go @@ -248,6 +248,9 @@ func (q *CreateIndexQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []by //------------------------------------------------------------------------------ func (q *CreateIndexQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) { + // if a comment is propagated via the context, use it + setCommentFromContext(ctx, q) + queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes()) if err != nil { return nil, err diff --git a/query_index_drop.go b/query_index_drop.go index 27c6e7f67..df7892b44 100644 --- a/query_index_drop.go +++ b/query_index_drop.go @@ -115,6 +115,9 @@ func (q *DropIndexQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte //------------------------------------------------------------------------------ func (q *DropIndexQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) { + // if a comment is propagated via the context, use it + setCommentFromContext(ctx, q) + queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes()) if err != nil { return nil, err diff --git a/query_insert.go b/query_insert.go index d2e158d77..0b6e23b5a 100644 --- a/query_insert.go +++ b/query_insert.go @@ -586,6 +586,9 @@ func (q *InsertQuery) scanOrExec( return nil, err } + // if a comment is propagated via the context, use it + setCommentFromContext(ctx, q) + // Generate the query before checking hasReturning. queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes()) if err != nil { diff --git a/query_merge.go b/query_merge.go index 0c172f180..b9524c56b 100644 --- a/query_merge.go +++ b/query_merge.go @@ -243,6 +243,9 @@ func (q *MergeQuery) scanOrExec( return nil, err } + // if a comment is propagated via the context, use it + setCommentFromContext(ctx, q) + // Generate the query before checking hasReturning. queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes()) if err != nil { diff --git a/query_raw.go b/query_raw.go index 308329567..bfc0d3050 100644 --- a/query_raw.go +++ b/query_raw.go @@ -67,6 +67,9 @@ func (q *RawQuery) scanOrExec( } } + // if a comment is propagated via the context, use it + setCommentFromContext(ctx, q) + query := q.db.format(q.query, q.args) var res sql.Result diff --git a/query_select.go b/query_select.go index f5df92854..c10160187 100644 --- a/query_select.go +++ b/query_select.go @@ -791,6 +791,9 @@ func (q *SelectQuery) Rows(ctx context.Context) (*sql.Rows, error) { return nil, err } + // if a comment is propagated via the context, use it + setCommentFromContext(ctx, q) + queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes()) if err != nil { return nil, err @@ -812,6 +815,9 @@ func (q *SelectQuery) Exec(ctx context.Context, dest ...interface{}) (res sql.Re return nil, err } + // if a comment is propagated via the context, use it + setCommentFromContext(ctx, q) + queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes()) if err != nil { return nil, err @@ -872,6 +878,9 @@ func (q *SelectQuery) scanResult(ctx context.Context, dest ...interface{}) (sql. return nil, err } + // if a comment is propagated via the context, use it + setCommentFromContext(ctx, q) + queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes()) if err != nil { return nil, err @@ -924,6 +933,9 @@ func (q *SelectQuery) Count(ctx context.Context) (int, error) { return 0, q.err } + // if a comment is propagated via the context, use it + setCommentFromContext(ctx, q) + qq := countQuery{q} queryBytes, err := qq.AppendQuery(q.db.fmter, nil) @@ -1028,6 +1040,9 @@ func (q *SelectQuery) Exists(ctx context.Context) (bool, error) { } func (q *SelectQuery) selectExists(ctx context.Context) (bool, error) { + // if a comment is propagated via the context, use it + setCommentFromContext(ctx, q) + qq := selectExistsQuery{q} queryBytes, err := qq.AppendQuery(q.db.fmter, nil) @@ -1047,6 +1062,9 @@ func (q *SelectQuery) selectExists(ctx context.Context) (bool, error) { } func (q *SelectQuery) whereExists(ctx context.Context) (bool, error) { + // if a comment is propagated via the context, use it + setCommentFromContext(ctx, q) + qq := whereExistsQuery{q} queryBytes, err := qq.AppendQuery(q.db.fmter, nil) diff --git a/query_table_create.go b/query_table_create.go index d8c4566cb..0ae56bf78 100644 --- a/query_table_create.go +++ b/query_table_create.go @@ -358,6 +358,9 @@ func (q *CreateTableQuery) Exec(ctx context.Context, dest ...interface{}) (sql.R return nil, err } + // if a comment is propagated via the context, use it + setCommentFromContext(ctx, q) + queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes()) if err != nil { return nil, err diff --git a/query_table_drop.go b/query_table_drop.go index 4e7d305a9..78d964d7a 100644 --- a/query_table_drop.go +++ b/query_table_drop.go @@ -123,6 +123,9 @@ func (q *DropTableQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Res } } + // if a comment is propagated via the context, use it + setCommentFromContext(ctx, q) + queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes()) if err != nil { return nil, err diff --git a/query_table_truncate.go b/query_table_truncate.go index 0f30a1d04..8e5e8a70e 100644 --- a/query_table_truncate.go +++ b/query_table_truncate.go @@ -136,6 +136,9 @@ func (q *TruncateTableQuery) AppendQuery( //------------------------------------------------------------------------------ func (q *TruncateTableQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) { + // if a comment is propagated via the context, use it + setCommentFromContext(ctx, q) + queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes()) if err != nil { return nil, err diff --git a/query_update.go b/query_update.go index b700f2180..25b65b05e 100644 --- a/query_update.go +++ b/query_update.go @@ -556,6 +556,9 @@ func (q *UpdateQuery) scanOrExec( return nil, err } + // if a comment is propagated via the context, use it + setCommentFromContext(ctx, q) + // Generate the query before checking hasReturning. queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes()) if err != nil { diff --git a/util.go b/util.go index 97ed9228a..f4896136c 100644 --- a/util.go +++ b/util.go @@ -1,6 +1,7 @@ package bun import ( + "context" "fmt" "reflect" "strings" @@ -86,3 +87,26 @@ func appendComment(b []byte, name string) []byte { name = strings.ReplaceAll(name, `*/`, `*\/`) return append(b, fmt.Sprintf("/* %s */ ", name)...) } + +// queryCommentCtxKey is a context key for setting a query comment on a context instead of calling the Comment("...") API directly +type queryCommentCtxKey struct{} + +// ContextWithComment returns a context that includes a comment that may be included in a query for debugging +// +// If a context with an attached query is used, a comment set by the Comment("...") API will be overwritten. +func ContextWithComment(ctx context.Context, comment string) context.Context { + return context.WithValue(ctx, queryCommentCtxKey{}, comment) +} + +// commenter describes the Comment interface implemented by all of the query types +type commenter[T any] interface { + Comment(string) T +} + +// setCommentFromContext sets the comment on the given query from the supplied context if one is set using the Comment(...) method. +func setCommentFromContext[T any](ctx context.Context, q commenter[T]) { + s, _ := ctx.Value(queryCommentCtxKey{}).(string) + if s != "" { + q.Comment(s) + } +}