Skip to content

Commit c7edb9a

Browse files
Merge pull request #2945 from actiontech/2928
fix: panic when analysis insert into select stmt
2 parents 2d5c7dc + 290ae86 commit c7edb9a

File tree

2 files changed

+21
-2
lines changed

2 files changed

+21
-2
lines changed

sqle/driver/mysql/analysis.go

+6-1
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,12 @@ func (i *MysqlDriverImpl) ExtractSchemaTableList(sql string) ([]SchemaTable, err
310310
case *ast.InsertStmt:
311311
getMultiTables(stmt.Table.TableRefs)
312312
if stmt.Select != nil {
313-
getMultiTables(stmt.Select.(*ast.SelectStmt).From.TableRefs)
313+
// TODO:INSERT INTO SQLE00115_t1_tmp_employee (id, cname, sex, age, salary) SELECT 4000002, '小张', 0, 25, (SELECT AVG(salary) FROM SQLE00115_t1_employee) 对于这条SQL,解析器无法解析子查询为一个Select语句,而是认为是一个文本
314+
if selectStmt, ok := stmt.Select.(*ast.SelectStmt); ok {
315+
if selectStmt.From != nil{
316+
getMultiTables(selectStmt.From.TableRefs)
317+
}
318+
}
314319
}
315320
case *ast.DeleteStmt:
316321
getMultiTables(stmt.TableRefs.TableRefs)

sqle/driver/mysql/util/util.go

+15-1
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,17 @@ func GetAffectedRowNum(ctx context.Context, originSql string, conn *executor.Exe
5050
// 包含子查询的insert语句,insert into t1 (name) select name from t2
5151
isSelectInsert := stmt.Select != nil && stmt.Lists == nil
5252
if isSelectInsert {
53-
newNode = getSelectNodeFromSelect(stmt.Select.(*ast.SelectStmt))
53+
if selectStmt, ok := stmt.Select.(*ast.SelectStmt); ok {
54+
newNode = getSelectNodeFromSelect(selectStmt)
55+
}
56+
// union语句,无法转换为select count语句
57+
if unionStmt, ok := stmt.Select.(*ast.UnionStmt); ok {
58+
cannotConvert = true
59+
originSql, err = restoreToSqlWithFlag(format.DefaultRestoreFlags, unionStmt)
60+
if err != nil {
61+
return 0, err
62+
}
63+
}
5464
} else if isCommonInsert {
5565
return int64(len(stmt.Lists)), nil
5666
} else {
@@ -73,6 +83,10 @@ func GetAffectedRowNum(ctx context.Context, originSql string, conn *executor.Exe
7383
trimSuffix := strings.TrimRight(originSql, ";")
7484
affectedRowSql = fmt.Sprintf("select count(*) from (%s) as t", trimSuffix)
7585
} else {
86+
if newNode == nil {
87+
log.NewEntry().Errorf("in GetAffectedRowNum, when getting select node from %v failed", originSql)
88+
return 0, fmt.Errorf("get select node from %v failed", originSql)
89+
}
7690
sqlBuilder := new(strings.Builder)
7791
err = newNode.Restore(format.NewRestoreCtx(format.DefaultRestoreFlags, sqlBuilder))
7892
if err != nil {

0 commit comments

Comments
 (0)