Skip to content

Commit 495167a

Browse files
authoredMar 19, 2025··
feat: implement Interceptor (#3)
1 parent 35ee52a commit 495167a

File tree

1 file changed

+102
-0
lines changed

1 file changed

+102
-0
lines changed
 

‎interceptor.go

+102
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
package queries
2+
3+
import (
4+
"context"
5+
"database/sql/driver"
6+
)
7+
8+
var (
9+
_ driver.Driver = Interceptor{}
10+
_ driver.DriverContext = Interceptor{}
11+
)
12+
13+
type Interceptor struct {
14+
Driver driver.Driver
15+
ExecContext func(ctx context.Context, query string, args []driver.NamedValue, execer driver.ExecerContext) (driver.Result, error)
16+
QueryContext func(ctx context.Context, query string, args []driver.NamedValue, queryer driver.QueryerContext) (driver.Rows, error)
17+
}
18+
19+
// DriverName returns the driver name to pass to [sql.Register] and [sql.Open].
20+
func (i Interceptor) DriverName() string { return "interceptor" }
21+
22+
// Open implements [driver.Driver].
23+
func (i Interceptor) Open(name string) (driver.Conn, error) {
24+
conn, err := i.Driver.Open(name)
25+
if err != nil {
26+
return nil, err
27+
}
28+
return wrappedConn{conn, i}, nil
29+
}
30+
31+
// OpenConnector implements [driver.DriverContext].
32+
func (i Interceptor) OpenConnector(name string) (driver.Connector, error) {
33+
if driver, ok := i.Driver.(driver.DriverContext); ok {
34+
connector, err := driver.OpenConnector(name)
35+
if err != nil {
36+
return nil, err
37+
}
38+
return wrappedConnector{connector, i}, nil
39+
}
40+
connector := dsnConnector{name, i.Driver}
41+
return wrappedConnector{connector, i}, nil
42+
}
43+
44+
var (
45+
_ driver.Conn = wrappedConn{}
46+
_ driver.ExecerContext = wrappedConn{}
47+
_ driver.QueryerContext = wrappedConn{}
48+
)
49+
50+
type wrappedConn struct {
51+
driver.Conn
52+
interceptor Interceptor
53+
}
54+
55+
// ExecContext implements [driver.ExecerContext].
56+
func (c wrappedConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
57+
execer, ok := c.Conn.(driver.ExecerContext)
58+
if !ok {
59+
panic("queries: driver does not implement driver.ExecerContext")
60+
}
61+
if c.interceptor.ExecContext != nil {
62+
return c.interceptor.ExecContext(ctx, query, args, execer)
63+
}
64+
return execer.ExecContext(ctx, query, args)
65+
}
66+
67+
// QueryContext implements [driver.QueryContext].
68+
func (c wrappedConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
69+
queryer, ok := c.Conn.(driver.QueryerContext)
70+
if !ok {
71+
panic("queries: driver does not implement driver.QueryerContext")
72+
}
73+
if c.interceptor.QueryContext != nil {
74+
return c.interceptor.QueryContext(ctx, query, args, queryer)
75+
}
76+
return queryer.QueryContext(ctx, query, args)
77+
}
78+
79+
var _ driver.Connector = wrappedConnector{}
80+
81+
type wrappedConnector struct {
82+
driver.Connector
83+
interceptor Interceptor
84+
}
85+
86+
// Connect implements [driver.Connector].
87+
func (c wrappedConnector) Connect(ctx context.Context) (driver.Conn, error) {
88+
conn, err := c.Connector.Connect(ctx)
89+
if err != nil {
90+
return nil, err
91+
}
92+
return wrappedConn{conn, c.interceptor}, nil
93+
}
94+
95+
// copied from https://go.dev/src/database/sql/sql.go
96+
type dsnConnector struct {
97+
dsn string
98+
driver driver.Driver
99+
}
100+
101+
func (t dsnConnector) Connect(_ context.Context) (driver.Conn, error) { return t.driver.Open(t.dsn) }
102+
func (t dsnConnector) Driver() driver.Driver { return t.driver }

0 commit comments

Comments
 (0)
Please sign in to comment.