Skip to content

Commit abdd938

Browse files
committed
bugfix apache#704
1 parent ad092d5 commit abdd938

File tree

2 files changed

+177
-65
lines changed

2 files changed

+177
-65
lines changed

pkg/datasource/sql/types/image.go

+7-4
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package types
1919

2020
import (
21+
"database/sql/driver"
2122
"encoding/base64"
2223
"encoding/json"
2324
"reflect"
@@ -117,14 +118,16 @@ type RecordImage struct {
117118
// Rows data row
118119
Rows []RowImage `json:"rows"`
119120
// TableMeta table information schema
120-
TableMeta *TableMeta `json:"-"`
121+
TableMeta *TableMeta `json:"-"`
122+
PrimaryKeyMap map[string][]driver.Value `json:"primaryKeyMap,omitempty"`
121123
}
122124

123125
func NewEmptyRecordImage(tableMeta *TableMeta, sqlType SQLType) *RecordImage {
124126
return &RecordImage{
125-
TableName: tableMeta.TableName,
126-
TableMeta: tableMeta,
127-
SQLType: sqlType,
127+
TableName: tableMeta.TableName,
128+
TableMeta: tableMeta,
129+
SQLType: sqlType,
130+
PrimaryKeyMap: make(map[string][]driver.Value),
128131
}
129132
}
130133

pkg/datasource/sql/undo/builder/mysql_insertonduplicate_update_undo_log_builder.go

+170-61
Original file line numberDiff line numberDiff line change
@@ -97,68 +97,136 @@ func (u *MySQLInsertOnDuplicateUndoLogBuilder) buildBeforeImageSQL(insertStmt *a
9797
if err := checkDuplicateKeyUpdate(insertStmt, metaData); err != nil {
9898
return "", nil, err
9999
}
100-
var selectArgs []driver.Value
100+
101+
// Reset primary keys map
102+
u.BeforeImageSqlPrimaryKeys = make(map[string]bool)
103+
101104
pkIndexMap := u.getPkIndex(insertStmt, metaData)
102105
var pkIndexArray []int
103106
for _, val := range pkIndexMap {
104-
tmpVal := val
105-
pkIndexArray = append(pkIndexArray, tmpVal)
107+
pkIndexArray = append(pkIndexArray, val)
106108
}
107109
insertRows, err := getInsertRows(insertStmt, pkIndexArray)
108110
if err != nil {
109111
return "", nil, err
110112
}
111-
insertNum := len(insertRows)
113+
112114
paramMap, err := u.buildImageParameters(insertStmt, args, insertRows)
113115
if err != nil {
114116
return "", nil, err
115117
}
116118

117-
sql := strings.Builder{}
118-
sql.WriteString("SELECT * FROM " + metaData.TableName + " ")
119+
// 如果没有参数或没有主键索引,直接返回空
120+
if len(paramMap) == 0 || len(metaData.Indexs) == 0 {
121+
return "", nil, nil
122+
}
123+
124+
// 检查是否有主键
125+
hasPK := false
126+
for _, index := range metaData.Indexs {
127+
if strings.EqualFold("PRIMARY", index.Name) {
128+
hasPK = true
129+
break
130+
}
131+
}
132+
if !hasPK {
133+
return "", nil, nil
134+
}
135+
136+
var sql strings.Builder
137+
sql.WriteString("SELECT * FROM " + metaData.TableName + " ")
138+
139+
var selectArgs []driver.Value
119140
isContainWhere := false
120-
for i := 0; i < insertNum; i++ {
121-
finalI := i
122-
paramAppenderTempList := make([]driver.Value, 0)
141+
hasConditions := false
142+
143+
for i := 0; i < len(insertRows); i++ {
144+
var rowConditions []string
145+
var rowArgs []driver.Value
146+
usedParams := make(map[string]bool)
147+
148+
// First try unique indexes
123149
for _, index := range metaData.Indexs {
124-
//unique index
125-
if index.NonUnique || isIndexValueNotNull(index, paramMap, finalI) == false {
150+
if index.NonUnique || strings.EqualFold("PRIMARY", index.Name) {
126151
continue
127152
}
128-
columnIsNull := true
129-
uniqueList := make([]string, 0)
130-
for _, columnMeta := range index.Columns {
131-
columnName := columnMeta.ColumnName
132-
imageParameters, ok := paramMap[columnName]
133-
if !ok && columnMeta.ColumnDef != nil {
134-
if strings.EqualFold("PRIMARY", index.Name) {
135-
u.BeforeImageSqlPrimaryKeys[columnName] = true
136-
}
137-
uniqueList = append(uniqueList, columnName+" = DEFAULT("+columnName+") ")
138-
columnIsNull = false
139-
continue
153+
154+
if !isIndexValueNotNull(index, paramMap, i) {
155+
continue
156+
}
157+
158+
var indexConditions []string
159+
var indexArgs []driver.Value
160+
allColumnsPresent := true
161+
for _, colMeta := range index.Columns {
162+
columnName := colMeta.ColumnName
163+
if params, ok := paramMap[columnName]; ok && len(params) > i && params[i] != nil {
164+
indexConditions = append(indexConditions, columnName+" = ? ")
165+
indexArgs = append(indexArgs, params[i])
166+
usedParams[columnName] = true
167+
} else if colMeta.ColumnDef != nil {
168+
indexConditions = append(indexConditions, columnName+" = DEFAULT("+columnName+")")
169+
} else {
170+
allColumnsPresent = false
171+
break
140172
}
141-
if strings.EqualFold("PRIMARY", index.Name) {
142-
u.BeforeImageSqlPrimaryKeys[columnName] = true
173+
}
174+
175+
if allColumnsPresent && len(indexConditions) > 0 {
176+
rowConditions = append(rowConditions, "("+strings.Join(indexConditions, " and ")+")")
177+
rowArgs = append(rowArgs, indexArgs...)
178+
hasConditions = true
179+
}
180+
}
181+
182+
// Then try primary key
183+
for _, index := range metaData.Indexs {
184+
if !strings.EqualFold("PRIMARY", index.Name) {
185+
continue
186+
}
187+
188+
var pkConditions []string
189+
var pkArgs []driver.Value
190+
for _, colMeta := range index.Columns {
191+
columnName := colMeta.ColumnName
192+
u.BeforeImageSqlPrimaryKeys[columnName] = true
193+
if params, ok := paramMap[columnName]; ok && len(params) > i && params[i] != nil && !usedParams[columnName] {
194+
pkConditions = append(pkConditions, columnName+" = ? ")
195+
pkArgs = append(pkArgs, params[i])
143196
}
144-
columnIsNull = false
145-
uniqueList = append(uniqueList, columnName+" = ? ")
146-
paramAppenderTempList = append(paramAppenderTempList, imageParameters[finalI])
147197
}
148198

149-
if !columnIsNull {
150-
if isContainWhere {
151-
sql.WriteString(" OR (" + strings.Join(uniqueList, " and ") + ") ")
152-
} else {
153-
sql.WriteString(" WHERE (" + strings.Join(uniqueList, " and ") + ") ")
154-
isContainWhere = true
199+
if len(pkConditions) > 0 {
200+
rowConditions = append(rowConditions, "("+strings.Join(pkConditions, " and ")+")")
201+
rowArgs = append(rowArgs, pkArgs...)
202+
hasConditions = true
203+
}
204+
}
205+
206+
if len(rowConditions) > 0 {
207+
if !isContainWhere {
208+
sql.WriteString("WHERE ")
209+
isContainWhere = true
210+
} else {
211+
sql.WriteString(" OR ")
212+
}
213+
for j, condition := range rowConditions {
214+
if j > 0 {
215+
sql.WriteString(" OR ")
155216
}
217+
sql.WriteString(condition + " ")
156218
}
219+
selectArgs = append(selectArgs, rowArgs...)
157220
}
158-
selectArgs = append(selectArgs, paramAppenderTempList...)
159221
}
160-
log.Infof("build select sql by insert on duplicate sourceQuery, sql {}", sql.String())
161-
return sql.String(), selectArgs, nil
222+
223+
if !hasConditions {
224+
return "", nil, nil
225+
}
226+
227+
sqlStr := sql.String()
228+
log.Infof("build select sql by insert on duplicate sourceQuery, sql: %s", sqlStr)
229+
return sqlStr, selectArgs, nil
162230
}
163231

164232
func (u *MySQLInsertOnDuplicateUndoLogBuilder) AfterImage(ctx context.Context, execCtx *types.ExecContext, beforeImages []*types.RecordImage) ([]*types.RecordImage, error) {
@@ -168,18 +236,22 @@ func (u *MySQLInsertOnDuplicateUndoLogBuilder) AfterImage(ctx context.Context, e
168236
log.Errorf("build prepare stmt: %+v", err)
169237
return nil, err
170238
}
239+
defer stmt.Close()
240+
241+
tableName := execCtx.ParseContext.InsertStmt.Table.TableRefs.Left.(*ast.TableSource).Source.(*ast.TableName).Name.O
242+
metaData := execCtx.MetaDataMap[tableName]
171243

172244
rows, err := stmt.Query(selectArgs)
173245
if err != nil {
174-
log.Errorf("stmt query: %+v", err)
175246
return nil, err
176247
}
177-
tableName := execCtx.ParseContext.InsertStmt.Table.TableRefs.Left.(*ast.TableSource).Source.(*ast.TableName).Name.O
178-
metaData := execCtx.MetaDataMap[tableName]
248+
defer rows.Close()
249+
179250
image, err := u.buildRecordImages(rows, &metaData)
180251
if err != nil {
181252
return nil, err
182253
}
254+
183255
return []*types.RecordImage{image}, nil
184256
}
185257

@@ -190,6 +262,13 @@ func (u *MySQLInsertOnDuplicateUndoLogBuilder) buildAfterImageSQL(ctx context.Co
190262
if len(beforeImages) > 0 {
191263
beforeImage = beforeImages[0]
192264
}
265+
266+
// 如果没有before image,直接返回原始SQL和参数
267+
if beforeImage == nil || len(beforeImage.Rows) == 0 {
268+
return selectSQL, selectArgs
269+
}
270+
271+
// 收集主键值
193272
primaryValueMap := make(map[string][]interface{})
194273
for _, row := range beforeImage.Rows {
195274
for _, col := range row.Columns {
@@ -200,23 +279,53 @@ func (u *MySQLInsertOnDuplicateUndoLogBuilder) buildAfterImageSQL(ctx context.Co
200279
}
201280

202281
var afterImageSql strings.Builder
203-
var primaryValues []driver.Value
204282
afterImageSql.WriteString(selectSQL)
205-
for i := 0; i < len(beforeImage.Rows); i++ {
206-
wherePrimaryList := make([]string, 0)
207-
for name, value := range primaryValueMap {
208-
if !u.BeforeImageSqlPrimaryKeys[name] {
209-
wherePrimaryList = append(wherePrimaryList, name+" = ? ")
210-
primaryValues = append(primaryValues, value[i])
283+
284+
// 如果原始SQL已经包含了所有需要的条件,直接返回
285+
if len(primaryValueMap) == 0 || len(selectArgs) == len(beforeImage.Rows)*len(primaryValueMap) {
286+
return selectSQL, selectArgs
287+
}
288+
289+
// 添加主键条件
290+
var primaryValues []driver.Value
291+
usedPrimaryKeys := make(map[string]bool)
292+
293+
for name := range primaryValueMap {
294+
if !u.BeforeImageSqlPrimaryKeys[name] {
295+
usedPrimaryKeys[name] = true
296+
for i := 0; i < len(beforeImage.Rows); i++ {
297+
if value := primaryValueMap[name][i]; value != nil {
298+
if dv, ok := value.(driver.Value); ok {
299+
primaryValues = append(primaryValues, dv)
300+
} else {
301+
primaryValues = append(primaryValues, value)
302+
}
303+
}
211304
}
212305
}
213-
if len(wherePrimaryList) != 0 {
214-
afterImageSql.WriteString(" OR (" + strings.Join(wherePrimaryList, " and ") + ") ")
306+
}
307+
308+
if len(primaryValues) > 0 {
309+
afterImageSql.WriteString(" OR (" + strings.Join(u.buildPrimaryKeyConditions(primaryValueMap, usedPrimaryKeys), " and ") + ") ")
310+
}
311+
312+
finalArgs := make([]driver.Value, len(selectArgs)+len(primaryValues))
313+
copy(finalArgs, selectArgs)
314+
copy(finalArgs[len(selectArgs):], primaryValues)
315+
316+
sqlStr := afterImageSql.String()
317+
log.Infof("build after select sql by insert on duplicate sourceQuery, sql %s", sqlStr)
318+
return sqlStr, finalArgs
319+
}
320+
321+
func (u *MySQLInsertOnDuplicateUndoLogBuilder) buildPrimaryKeyConditions(primaryValueMap map[string][]interface{}, usedPrimaryKeys map[string]bool) []string {
322+
var conditions []string
323+
for name := range primaryValueMap {
324+
if !usedPrimaryKeys[name] {
325+
conditions = append(conditions, name+" = ? ")
215326
}
216327
}
217-
selectArgs = append(selectArgs, primaryValues...)
218-
log.Infof("build after select sql by insert on duplicate sourceQuery, sql {}", afterImageSql.String())
219-
return afterImageSql.String(), selectArgs
328+
return conditions
220329
}
221330

222331
func checkDuplicateKeyUpdate(insert *ast.InsertStmt, metaData types.TableMeta) error {
@@ -243,11 +352,10 @@ func checkDuplicateKeyUpdate(insert *ast.InsertStmt, metaData types.TableMeta) e
243352

244353
// build sql params
245354
func (u *MySQLInsertOnDuplicateUndoLogBuilder) buildImageParameters(insert *ast.InsertStmt, args []driver.Value, insertRows [][]interface{}) (map[string][]driver.Value, error) {
246-
var (
247-
parameterMap = make(map[string][]driver.Value)
248-
)
355+
parameterMap := make(map[string][]driver.Value)
249356
insertColumns := getInsertColumns(insert)
250-
var placeHolderIndex = 0
357+
placeHolderIndex := 0
358+
251359
for _, row := range insertRows {
252360
if len(row) != len(insertColumns) {
253361
log.Errorf("insert row's column size not equal to insert column size")
@@ -256,13 +364,14 @@ func (u *MySQLInsertOnDuplicateUndoLogBuilder) buildImageParameters(insert *ast.
256364
for i, col := range insertColumns {
257365
columnName := executor.DelEscape(col, types.DBTypeMySQL)
258366
val := row[i]
259-
rStr, ok := val.(string)
260-
if ok && strings.EqualFold(rStr, SqlPlaceholder) {
261-
objects := args[placeHolderIndex]
262-
parameterMap[columnName] = append(parameterMap[col], objects)
367+
if str, ok := val.(string); ok && strings.EqualFold(str, SqlPlaceholder) {
368+
if placeHolderIndex >= len(args) {
369+
return nil, fmt.Errorf("not enough parameters for placeholders")
370+
}
371+
parameterMap[columnName] = append(parameterMap[columnName], args[placeHolderIndex])
263372
placeHolderIndex++
264373
} else {
265-
parameterMap[columnName] = append(parameterMap[col], val)
374+
parameterMap[columnName] = append(parameterMap[columnName], val)
266375
}
267376
}
268377
}

0 commit comments

Comments
 (0)