Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Update join #761

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
Prev Previous commit
Next Next commit
duplicate image row for update join
  • Loading branch information
lxfeng1997 committed Dec 27, 2024
commit e430bc6dd11702dadf4e9fc31412850614504570
4 changes: 4 additions & 0 deletions pkg/datasource/sql/conn.go
Original file line number Diff line number Diff line change
@@ -244,6 +244,10 @@ func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e
)
}

func (c *Conn) GetDbVersion() string {
return c.res.GetDbVersion()
Copy link
Contributor

@luky116 luky116 Feb 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个有需要支持吗

}

func (c *Conn) GetAutoCommit() bool {
return c.autoCommit
}
2 changes: 2 additions & 0 deletions pkg/datasource/sql/conn_at.go
Original file line number Diff line number Diff line change
@@ -63,6 +63,7 @@ func (c *ATConn) QueryContext(ctx context.Context, query string, args []driver.N
NamedValues: args,
Conn: c.targetConn,
DBName: c.dbName,
DbVersion: c.GetDbVersion(),
IsSupportsSavepoints: true,
IsAutoCommit: c.GetAutoCommit(),
}
@@ -102,6 +103,7 @@ func (c *ATConn) ExecContext(ctx context.Context, query string, args []driver.Na
NamedValues: args,
Conn: c.targetConn,
DBName: c.dbName,
DbVersion: c.GetDbVersion(),
IsSupportsSavepoints: true,
IsAutoCommit: c.GetAutoCommit(),
}
21 changes: 21 additions & 0 deletions pkg/datasource/sql/exec/at/base_executor.go
Original file line number Diff line number Diff line change
@@ -22,6 +22,7 @@ import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"strings"

@@ -350,3 +351,23 @@ func (b *baseExecutor) buildLockKey(records *types.RecordImage, meta types.Table

return lockKeys.String()
}

func (u *updateExecutor) rowsPrepare(ctx context.Context, selectSQL string, selectArgs []driver.NamedValue) (driver.Rows, error) {
var queryer driver.Queryer

queryerContext, ok := u.execContext.Conn.(driver.QueryerContext)
if !ok {
queryer, ok = u.execContext.Conn.(driver.Queryer)
}
if ok {
var err error
rows, err = util.CtxDriverQuery(ctx, queryerContext, queryer, selectSQL, selectArgs)

if err != nil {
return nil, err
}
} else {
return nil, errors.New("target conn should been driver.QueryerContext or driver.Queryer")
}
return rows, nil
}
139 changes: 73 additions & 66 deletions pkg/datasource/sql/exec/at/update_executor.go
Original file line number Diff line number Diff line change
@@ -20,6 +20,7 @@ package at
import (
"context"
"database/sql/driver"
"errors"
"fmt"
"strings"

@@ -31,7 +32,6 @@ import (
"seata.apache.org/seata-go/pkg/datasource/sql/exec"
"seata.apache.org/seata-go/pkg/datasource/sql/types"
"seata.apache.org/seata-go/pkg/datasource/sql/undo"
"seata.apache.org/seata-go/pkg/datasource/sql/util"
"seata.apache.org/seata-go/pkg/util/bytes"
"seata.apache.org/seata-go/pkg/util/log"
)
@@ -90,37 +90,31 @@ func (u *updateExecutor) beforeImage(ctx context.Context) (*types.RecordImage, e
return nil, nil
}

selectSQL, selectArgs, err := u.buildBeforeImageSQL(ctx, u.execContext.NamedValues)
tableName, _ := u.parserCtx.GetTableName()
metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, u.execContext.DBName, tableName)
if err != nil {
return nil, err
}

tableName, _ := u.parserCtx.GetTableName()
metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, u.execContext.DBName, tableName)
selectSQL, selectArgs, err := u.buildBeforeImageSQL(ctx, metaData, u.execContext.NamedValues)
if err != nil {
return nil, err
}

var rowsi driver.Rows
queryerCtx, ok := u.execContext.Conn.(driver.QueryerContext)
var queryer driver.Queryer
if !ok {
queryer, ok = u.execContext.Conn.(driver.Queryer)
if selectSQL == "" {
return nil, errors.New("build select sql by update sourceQuery fail")
}
if ok {
rowsi, err = util.CtxDriverQuery(ctx, queryerCtx, queryer, selectSQL, selectArgs)
defer func() {
if rowsi != nil {
rowsi.Close()

rowsi, err := u.rowsPrepare(ctx, selectSQL, selectArgs)
defer func() {
if rowsi != nil {
if err := rowsi.Close(); err != nil {
log.Errorf("rows close fail, err:%v", err)
return
}
}()
if err != nil {
log.Errorf("ctx driver query: %+v", err)
return nil, err
}
} else {
log.Errorf("target conn should been driver.QueryerContext or driver.Queryer")
return nil, fmt.Errorf("invalid conn")
}()
if err != nil {
return nil, err
}

image, err := u.buildRecordImages(rowsi, metaData, types.SQLTypeUpdate)
@@ -151,26 +145,17 @@ func (u *updateExecutor) afterImage(ctx context.Context, beforeImage types.Recor
}
selectSQL, selectArgs := u.buildAfterImageSQL(beforeImage, metaData)

var rowsi driver.Rows
queryerCtx, ok := u.execContext.Conn.(driver.QueryerContext)
var queryer driver.Queryer
if !ok {
queryer, ok = u.execContext.Conn.(driver.Queryer)
}
if ok {
rowsi, err = util.CtxDriverQuery(ctx, queryerCtx, queryer, selectSQL, selectArgs)
defer func() {
if rowsi != nil {
rowsi.Close()
rowsi, err := u.rowsPrepare(ctx, selectSQL, selectArgs)
defer func() {
if rowsi != nil {
if err := rowsi.Close(); err != nil {
log.Errorf("rows close fail, err:%v", err)
return
}
}()
if err != nil {
log.Errorf("ctx driver query: %+v", err)
return nil, err
}
} else {
log.Errorf("target conn should been driver.QueryerContext or driver.Queryer")
return nil, fmt.Errorf("invalid conn")
}()
if err != nil {
return nil, err
}

afterImage, err := u.buildRecordImages(rowsi, metaData, types.SQLTypeUpdate)
@@ -212,34 +197,74 @@ func (u *updateExecutor) buildAfterImageSQL(beforeImage types.RecordImage, meta
}

// buildAfterImageSQL build the SQL to query before image data
func (u *updateExecutor) buildBeforeImageSQL(ctx context.Context, args []driver.NamedValue) (string, []driver.NamedValue, error) {
func (u *updateExecutor) buildBeforeImageSQL(ctx context.Context, tableMeta *types.TableMeta, args []driver.NamedValue) (string, []driver.NamedValue, error) {
if !u.isAstStmtValid() {
log.Errorf("invalid update stmt")
return "", nil, fmt.Errorf("invalid update stmt")
}

updateStmt := u.parserCtx.UpdateStmt
fields, err := u.buildSelectFields(ctx, tableMeta)
if err != nil {
return "", nil, err
}
if len(fields) == 0 {
return "", nil, err
}

selStmt := ast.SelectStmt{
SelectStmtOpts: &ast.SelectStmtOpts{},
From: updateStmt.TableRefs,
Where: updateStmt.Where,
Fields: &ast.FieldList{Fields: fields},
OrderBy: updateStmt.Order,
Limit: updateStmt.Limit,
TableHints: updateStmt.TableHints,
LockInfo: &ast.SelectLockInfo{
LockType: ast.SelectLockForUpdate,
},
}

b := bytes.NewByteBuffer([]byte{})
_ = selStmt.Restore(format.NewRestoreCtx(format.RestoreKeyWordUppercase, b))
sql := string(b.Bytes())
log.Infof("build select sql by update sourceQuery, sql {%s}", sql)

return sql, u.buildSelectArgs(&selStmt, args), nil
}

func (u *updateExecutor) buildSelectFields(ctx context.Context, tableMeta *types.TableMeta) ([]*ast.SelectField, error) {
updateStmt := u.parserCtx.UpdateStmt
fields := make([]*ast.SelectField, 0, len(updateStmt.List))

lowerTableName := strings.ToLower(tableMeta.TableName)
if undo.UndoConfig.OnlyCareUpdateColumns {
for _, column := range updateStmt.List {
tableName := column.Column.Table.L
if tableName != "" && lowerTableName != tableName {
continue
}

fields = append(fields, &ast.SelectField{
Expr: &ast.ColumnNameExpr{
Name: column.Column,
},
})
}

// select indexes columns
tableName, _ := u.parserCtx.GetTableName()
metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, u.execContext.DBName, tableName)
if err != nil {
return "", nil, err
if len(fields) == 0 {
return fields, nil
}
for _, columnName := range metaData.GetPrimaryKeyOnlyName() {

// select indexes columns
for _, columnName := range tableMeta.GetPrimaryKeyOnlyName() {
fields = append(fields, &ast.SelectField{
Expr: &ast.ColumnNameExpr{
Name: &ast.ColumnName{
Table: model.CIStr{
O: tableMeta.TableName,
L: lowerTableName,
},
Name: model.CIStr{
O: columnName,
L: columnName,
@@ -261,23 +286,5 @@ func (u *updateExecutor) buildBeforeImageSQL(ctx context.Context, args []driver.
})
}

selStmt := ast.SelectStmt{
SelectStmtOpts: &ast.SelectStmtOpts{},
From: updateStmt.TableRefs,
Where: updateStmt.Where,
Fields: &ast.FieldList{Fields: fields},
OrderBy: updateStmt.Order,
Limit: updateStmt.Limit,
TableHints: updateStmt.TableHints,
LockInfo: &ast.SelectLockInfo{
LockType: ast.SelectLockForUpdate,
},
}

b := bytes.NewByteBuffer([]byte{})
_ = selStmt.Restore(format.NewRestoreCtx(format.RestoreKeyWordUppercase, b))
sql := string(b.Bytes())
log.Infof("build select sql by update sourceQuery, sql {%s}", sql)

return sql, u.buildSelectArgs(&selStmt, args), nil
return fields, nil
}
Loading