@@ -97,68 +97,136 @@ func (u *MySQLInsertOnDuplicateUndoLogBuilder) buildBeforeImageSQL(insertStmt *a
97
97
if err := checkDuplicateKeyUpdate (insertStmt , metaData ); err != nil {
98
98
return "" , nil , err
99
99
}
100
- var selectArgs []driver.Value
100
+
101
+ // Reset primary keys map
102
+ u .BeforeImageSqlPrimaryKeys = make (map [string ]bool )
103
+
101
104
pkIndexMap := u .getPkIndex (insertStmt , metaData )
102
105
var pkIndexArray []int
103
106
for _ , val := range pkIndexMap {
104
- tmpVal := val
105
- pkIndexArray = append (pkIndexArray , tmpVal )
107
+ pkIndexArray = append (pkIndexArray , val )
106
108
}
107
109
insertRows , err := getInsertRows (insertStmt , pkIndexArray )
108
110
if err != nil {
109
111
return "" , nil , err
110
112
}
111
- insertNum := len ( insertRows )
113
+
112
114
paramMap , err := u .buildImageParameters (insertStmt , args , insertRows )
113
115
if err != nil {
114
116
return "" , nil , err
115
117
}
116
118
117
- sql := strings.Builder {}
118
- sql .WriteString ("SELECT * FROM " + metaData .TableName + " " )
119
+ // 如果没有参数或没有主键索引,直接返回空
120
+ if len (paramMap ) == 0 || len (metaData .Indexs ) == 0 {
121
+ return "" , nil , nil
122
+ }
123
+
124
+ // 检查是否有主键
125
+ hasPK := false
126
+ for _ , index := range metaData .Indexs {
127
+ if strings .EqualFold ("PRIMARY" , index .Name ) {
128
+ hasPK = true
129
+ break
130
+ }
131
+ }
132
+ if ! hasPK {
133
+ return "" , nil , nil
134
+ }
135
+
136
+ var sql strings.Builder
137
+ sql .WriteString ("SELECT * FROM " + metaData .TableName + " " )
138
+
139
+ var selectArgs []driver.Value
119
140
isContainWhere := false
120
- for i := 0 ; i < insertNum ; i ++ {
121
- finalI := i
122
- paramAppenderTempList := make ([]driver.Value , 0 )
141
+ hasConditions := false
142
+
143
+ for i := 0 ; i < len (insertRows ); i ++ {
144
+ var rowConditions []string
145
+ var rowArgs []driver.Value
146
+ usedParams := make (map [string ]bool )
147
+
148
+ // First try unique indexes
123
149
for _ , index := range metaData .Indexs {
124
- //unique index
125
- if index .NonUnique || isIndexValueNotNull (index , paramMap , finalI ) == false {
150
+ if index .NonUnique || strings .EqualFold ("PRIMARY" , index .Name ) {
126
151
continue
127
152
}
128
- columnIsNull := true
129
- uniqueList := make ([]string , 0 )
130
- for _ , columnMeta := range index .Columns {
131
- columnName := columnMeta .ColumnName
132
- imageParameters , ok := paramMap [columnName ]
133
- if ! ok && columnMeta .ColumnDef != nil {
134
- if strings .EqualFold ("PRIMARY" , index .Name ) {
135
- u .BeforeImageSqlPrimaryKeys [columnName ] = true
136
- }
137
- uniqueList = append (uniqueList , columnName + " = DEFAULT(" + columnName + ") " )
138
- columnIsNull = false
139
- continue
153
+
154
+ if ! isIndexValueNotNull (index , paramMap , i ) {
155
+ continue
156
+ }
157
+
158
+ var indexConditions []string
159
+ var indexArgs []driver.Value
160
+ allColumnsPresent := true
161
+ for _ , colMeta := range index .Columns {
162
+ columnName := colMeta .ColumnName
163
+ if params , ok := paramMap [columnName ]; ok && len (params ) > i && params [i ] != nil {
164
+ indexConditions = append (indexConditions , columnName + " = ? " )
165
+ indexArgs = append (indexArgs , params [i ])
166
+ usedParams [columnName ] = true
167
+ } else if colMeta .ColumnDef != nil {
168
+ indexConditions = append (indexConditions , columnName + " = DEFAULT(" + columnName + ")" )
169
+ } else {
170
+ allColumnsPresent = false
171
+ break
140
172
}
141
- if strings .EqualFold ("PRIMARY" , index .Name ) {
142
- u .BeforeImageSqlPrimaryKeys [columnName ] = true
173
+ }
174
+
175
+ if allColumnsPresent && len (indexConditions ) > 0 {
176
+ rowConditions = append (rowConditions , "(" + strings .Join (indexConditions , " and " )+ ")" )
177
+ rowArgs = append (rowArgs , indexArgs ... )
178
+ hasConditions = true
179
+ }
180
+ }
181
+
182
+ // Then try primary key
183
+ for _ , index := range metaData .Indexs {
184
+ if ! strings .EqualFold ("PRIMARY" , index .Name ) {
185
+ continue
186
+ }
187
+
188
+ var pkConditions []string
189
+ var pkArgs []driver.Value
190
+ for _ , colMeta := range index .Columns {
191
+ columnName := colMeta .ColumnName
192
+ u .BeforeImageSqlPrimaryKeys [columnName ] = true
193
+ if params , ok := paramMap [columnName ]; ok && len (params ) > i && params [i ] != nil && ! usedParams [columnName ] {
194
+ pkConditions = append (pkConditions , columnName + " = ? " )
195
+ pkArgs = append (pkArgs , params [i ])
143
196
}
144
- columnIsNull = false
145
- uniqueList = append (uniqueList , columnName + " = ? " )
146
- paramAppenderTempList = append (paramAppenderTempList , imageParameters [finalI ])
147
197
}
148
198
149
- if ! columnIsNull {
150
- if isContainWhere {
151
- sql .WriteString (" OR (" + strings .Join (uniqueList , " and " ) + ") " )
152
- } else {
153
- sql .WriteString (" WHERE (" + strings .Join (uniqueList , " and " ) + ") " )
154
- isContainWhere = true
199
+ if len (pkConditions ) > 0 {
200
+ rowConditions = append (rowConditions , "(" + strings .Join (pkConditions , " and " )+ ")" )
201
+ rowArgs = append (rowArgs , pkArgs ... )
202
+ hasConditions = true
203
+ }
204
+ }
205
+
206
+ if len (rowConditions ) > 0 {
207
+ if ! isContainWhere {
208
+ sql .WriteString ("WHERE " )
209
+ isContainWhere = true
210
+ } else {
211
+ sql .WriteString (" OR " )
212
+ }
213
+ for j , condition := range rowConditions {
214
+ if j > 0 {
215
+ sql .WriteString (" OR " )
155
216
}
217
+ sql .WriteString (condition + " " )
156
218
}
219
+ selectArgs = append (selectArgs , rowArgs ... )
157
220
}
158
- selectArgs = append (selectArgs , paramAppenderTempList ... )
159
221
}
160
- log .Infof ("build select sql by insert on duplicate sourceQuery, sql {}" , sql .String ())
161
- return sql .String (), selectArgs , nil
222
+
223
+ if ! hasConditions {
224
+ return "" , nil , nil
225
+ }
226
+
227
+ sqlStr := sql .String ()
228
+ log .Infof ("build select sql by insert on duplicate sourceQuery, sql: %s" , sqlStr )
229
+ return sqlStr , selectArgs , nil
162
230
}
163
231
164
232
func (u * MySQLInsertOnDuplicateUndoLogBuilder ) AfterImage (ctx context.Context , execCtx * types.ExecContext , beforeImages []* types.RecordImage ) ([]* types.RecordImage , error ) {
@@ -168,18 +236,22 @@ func (u *MySQLInsertOnDuplicateUndoLogBuilder) AfterImage(ctx context.Context, e
168
236
log .Errorf ("build prepare stmt: %+v" , err )
169
237
return nil , err
170
238
}
239
+ defer stmt .Close ()
240
+
241
+ tableName := execCtx .ParseContext .InsertStmt .Table .TableRefs .Left .(* ast.TableSource ).Source .(* ast.TableName ).Name .O
242
+ metaData := execCtx .MetaDataMap [tableName ]
171
243
172
244
rows , err := stmt .Query (selectArgs )
173
245
if err != nil {
174
- log .Errorf ("stmt query: %+v" , err )
175
246
return nil , err
176
247
}
177
- tableName := execCtx . ParseContext . InsertStmt . Table . TableRefs . Left .( * ast. TableSource ). Source .( * ast. TableName ). Name . O
178
- metaData := execCtx . MetaDataMap [ tableName ]
248
+ defer rows . Close ()
249
+
179
250
image , err := u .buildRecordImages (rows , & metaData )
180
251
if err != nil {
181
252
return nil , err
182
253
}
254
+
183
255
return []* types.RecordImage {image }, nil
184
256
}
185
257
@@ -190,6 +262,13 @@ func (u *MySQLInsertOnDuplicateUndoLogBuilder) buildAfterImageSQL(ctx context.Co
190
262
if len (beforeImages ) > 0 {
191
263
beforeImage = beforeImages [0 ]
192
264
}
265
+
266
+ // 如果没有before image,直接返回原始SQL和参数
267
+ if beforeImage == nil || len (beforeImage .Rows ) == 0 {
268
+ return selectSQL , selectArgs
269
+ }
270
+
271
+ // 收集主键值
193
272
primaryValueMap := make (map [string ][]interface {})
194
273
for _ , row := range beforeImage .Rows {
195
274
for _ , col := range row .Columns {
@@ -200,23 +279,53 @@ func (u *MySQLInsertOnDuplicateUndoLogBuilder) buildAfterImageSQL(ctx context.Co
200
279
}
201
280
202
281
var afterImageSql strings.Builder
203
- var primaryValues []driver.Value
204
282
afterImageSql .WriteString (selectSQL )
205
- for i := 0 ; i < len (beforeImage .Rows ); i ++ {
206
- wherePrimaryList := make ([]string , 0 )
207
- for name , value := range primaryValueMap {
208
- if ! u .BeforeImageSqlPrimaryKeys [name ] {
209
- wherePrimaryList = append (wherePrimaryList , name + " = ? " )
210
- primaryValues = append (primaryValues , value [i ])
283
+
284
+ // 如果原始SQL已经包含了所有需要的条件,直接返回
285
+ if len (primaryValueMap ) == 0 || len (selectArgs ) == len (beforeImage .Rows )* len (primaryValueMap ) {
286
+ return selectSQL , selectArgs
287
+ }
288
+
289
+ // 添加主键条件
290
+ var primaryValues []driver.Value
291
+ usedPrimaryKeys := make (map [string ]bool )
292
+
293
+ for name := range primaryValueMap {
294
+ if ! u .BeforeImageSqlPrimaryKeys [name ] {
295
+ usedPrimaryKeys [name ] = true
296
+ for i := 0 ; i < len (beforeImage .Rows ); i ++ {
297
+ if value := primaryValueMap [name ][i ]; value != nil {
298
+ if dv , ok := value .(driver.Value ); ok {
299
+ primaryValues = append (primaryValues , dv )
300
+ } else {
301
+ primaryValues = append (primaryValues , value )
302
+ }
303
+ }
211
304
}
212
305
}
213
- if len (wherePrimaryList ) != 0 {
214
- afterImageSql .WriteString (" OR (" + strings .Join (wherePrimaryList , " and " ) + ") " )
306
+ }
307
+
308
+ if len (primaryValues ) > 0 {
309
+ afterImageSql .WriteString (" OR (" + strings .Join (u .buildPrimaryKeyConditions (primaryValueMap , usedPrimaryKeys ), " and " ) + ") " )
310
+ }
311
+
312
+ finalArgs := make ([]driver.Value , len (selectArgs )+ len (primaryValues ))
313
+ copy (finalArgs , selectArgs )
314
+ copy (finalArgs [len (selectArgs ):], primaryValues )
315
+
316
+ sqlStr := afterImageSql .String ()
317
+ log .Infof ("build after select sql by insert on duplicate sourceQuery, sql %s" , sqlStr )
318
+ return sqlStr , finalArgs
319
+ }
320
+
321
+ func (u * MySQLInsertOnDuplicateUndoLogBuilder ) buildPrimaryKeyConditions (primaryValueMap map [string ][]interface {}, usedPrimaryKeys map [string ]bool ) []string {
322
+ var conditions []string
323
+ for name := range primaryValueMap {
324
+ if ! usedPrimaryKeys [name ] {
325
+ conditions = append (conditions , name + " = ? " )
215
326
}
216
327
}
217
- selectArgs = append (selectArgs , primaryValues ... )
218
- log .Infof ("build after select sql by insert on duplicate sourceQuery, sql {}" , afterImageSql .String ())
219
- return afterImageSql .String (), selectArgs
328
+ return conditions
220
329
}
221
330
222
331
func checkDuplicateKeyUpdate (insert * ast.InsertStmt , metaData types.TableMeta ) error {
@@ -243,11 +352,10 @@ func checkDuplicateKeyUpdate(insert *ast.InsertStmt, metaData types.TableMeta) e
243
352
244
353
// build sql params
245
354
func (u * MySQLInsertOnDuplicateUndoLogBuilder ) buildImageParameters (insert * ast.InsertStmt , args []driver.Value , insertRows [][]interface {}) (map [string ][]driver.Value , error ) {
246
- var (
247
- parameterMap = make (map [string ][]driver.Value )
248
- )
355
+ parameterMap := make (map [string ][]driver.Value )
249
356
insertColumns := getInsertColumns (insert )
250
- var placeHolderIndex = 0
357
+ placeHolderIndex := 0
358
+
251
359
for _ , row := range insertRows {
252
360
if len (row ) != len (insertColumns ) {
253
361
log .Errorf ("insert row's column size not equal to insert column size" )
@@ -256,13 +364,14 @@ func (u *MySQLInsertOnDuplicateUndoLogBuilder) buildImageParameters(insert *ast.
256
364
for i , col := range insertColumns {
257
365
columnName := executor .DelEscape (col , types .DBTypeMySQL )
258
366
val := row [i ]
259
- rStr , ok := val .(string )
260
- if ok && strings .EqualFold (rStr , SqlPlaceholder ) {
261
- objects := args [placeHolderIndex ]
262
- parameterMap [columnName ] = append (parameterMap [col ], objects )
367
+ if str , ok := val .(string ); ok && strings .EqualFold (str , SqlPlaceholder ) {
368
+ if placeHolderIndex >= len (args ) {
369
+ return nil , fmt .Errorf ("not enough parameters for placeholders" )
370
+ }
371
+ parameterMap [columnName ] = append (parameterMap [columnName ], args [placeHolderIndex ])
263
372
placeHolderIndex ++
264
373
} else {
265
- parameterMap [columnName ] = append (parameterMap [col ], val )
374
+ parameterMap [columnName ] = append (parameterMap [columnName ], val )
266
375
}
267
376
}
268
377
}
0 commit comments