Skip to content

Commit f916c53

Browse files
committedAug 3, 2022
fixed SQL Injection for #2
1 parent 6081e00 commit f916c53

File tree

2 files changed

+144
-5
lines changed

2 files changed

+144
-5
lines changed
 

‎sql.go

+35-3
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"encoding/json"
66
"fmt"
77
"reflect"
8+
"regexp"
89
"strconv"
910
"strings"
1011
"time"
@@ -16,13 +17,41 @@ var (
1617
escaper = "'"
1718
nullStr = "NULL"
1819
singleQuoteEscaper = "\\"
20+
escapeRegexp = regexp.MustCompile(`[\0\t\x1a\n\r\"\'\\]`)
21+
22+
//see href='https://dev.mysql.com/doc/refman/8.0/en/string-literals.html#character-escape-sequences'
23+
characterEscapeMap = map[string]string{
24+
"\\0": `\\0`, //ASCII NULL
25+
"\b": `\\b`, //backspace
26+
"\t": `\\t`, //tab
27+
"\x1a": `\\Z`, //ASCII 26 (Control+Z);
28+
"\n": `\\n`, //newline character
29+
"\r": `\\r`, //return character
30+
"\"": `\\"`, //quote (")
31+
"'": `\'`, //quote (')
32+
"\\": `\\\\`, //backslash (\)
33+
// "\\%": `\\%`, //% character
34+
// "\\_": `\\_`, //_ character
35+
}
1936
)
2037

2138
//Escape escape the val for sql
2239
func Escape(val interface{}) string {
2340
return EscapeInLocation(val, time.Local)
2441
}
2542

43+
//toSqlString escape the string val for sql
44+
func toSqlString(val string) string {
45+
return escapeRegexp.ReplaceAllStringFunc(val, func(s string) string {
46+
47+
mVal, ok := characterEscapeMap[s]
48+
if ok {
49+
return mVal
50+
}
51+
return s
52+
})
53+
}
54+
2655
func timeToString(t time.Time, loc *time.Location) string {
2756
if t.IsZero() {
2857
return escaper + tmFmtZero + escaper
@@ -70,7 +99,7 @@ func EscapeInLocation(val interface{}, loc *time.Location) string {
7099
return fmt.Sprintf("%.6f", v)
71100

72101
case string:
73-
return escaper + strings.Replace(v, escaper, singleQuoteEscaper+escaper, -1) + escaper
102+
return escaper + toSqlString(v) + escaper
74103
default:
75104
refValue := reflect.ValueOf(v)
76105
if v == nil || !refValue.IsValid() {
@@ -94,7 +123,7 @@ func EscapeInLocation(val interface{}, loc *time.Location) string {
94123
if err != nil {
95124
return nullStr
96125
}
97-
return escaper + strings.Replace(string(stringifyData), escaper, singleQuoteEscaper+escaper, -1) + escaper
126+
return escaper + toSqlString(string(stringifyData)) + escaper
98127

99128
}
100129
}
@@ -144,6 +173,9 @@ func FormatInLocation(query string, loc *time.Location, args ...interface{}) str
144173
}
145174

146175
//SetSingleQuoteEscaper set the singleQuoteEscaper
176+
//default:\' , e.g. '' 、 \'
147177
func SetSingleQuoteEscaper(escaper string) {
148-
singleQuoteEscaper = escaper
178+
179+
characterEscapeMap["'"] = escaper
180+
// singleQuoteEscaper = escaper
149181
}

‎sql_test.go

+109-2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,14 @@ func TestNULLEscape(t *testing.T) {
1313
}
1414
}
1515

16+
func Test0Escape(t *testing.T) {
17+
result := Escape(`\0`)
18+
t.Logf("Test0Escape result: %s", result)
19+
if result != `'\\\\0'` {
20+
t.Fatalf("escape error")
21+
}
22+
}
23+
1624
func TestEmptyStringEscape(t *testing.T) {
1725
result := Escape("")
1826
t.Logf("result :%s", result)
@@ -87,9 +95,27 @@ func TestStringEscape(t *testing.T) {
8795
}
8896
}
8997

98+
func TestStringEscape2(t *testing.T) {
99+
s := "hello world"
100+
result := Escape(s)
101+
if result != "'hello world'" {
102+
t.Fatalf("escape string error")
103+
104+
}
105+
106+
s = `hello \' world`
107+
t.Logf("TestStringEscape2 raw:%s", s)
108+
result = Escape(s)
109+
t.Logf("TestStringEscape2 result: %s", result)
110+
if result != `'hello \\\\\' world'` {
111+
t.Fatalf("escape string error")
112+
113+
}
114+
}
115+
90116
func TestStringCustomEscape(t *testing.T) {
91117
s := "hello world"
92-
SetSingleQuoteEscaper("'")
118+
SetSingleQuoteEscaper("''")
93119
result := Escape(s)
94120
if result != "'hello world'" {
95121
t.Fatalf("escape string error")
@@ -103,6 +129,8 @@ func TestStringCustomEscape(t *testing.T) {
103129
t.Fatalf("escape string error")
104130

105131
}
132+
SetSingleQuoteEscaper("\\'")
133+
106134
}
107135

108136
func TestBytesEscape(t *testing.T) {
@@ -210,13 +238,85 @@ func TestOtherEscape(t *testing.T) {
210238
result := Escape(x)
211239
t.Logf("escape reuslt %s", result)
212240

213-
if result != "'{\"key\":\"test\",\"name\":\"asd\\'fsadf\"}'" {
241+
if result != `'{\\"key\\":\\"test\\",\\"name\\":\\"asd\'fsadf\\"}'` {
214242
t.Fatalf("escape map error")
215243

216244
}
217245

218246
}
219247

248+
func TestNewlineEscape(t *testing.T) {
249+
s := "hello\nworld"
250+
result := Escape(s)
251+
t.Logf("escape newline reuslt: %s", result)
252+
253+
if result != "'hello\\\\nworld'" {
254+
t.Fatalf("escape string error")
255+
256+
}
257+
258+
}
259+
260+
func TestReturnEscape(t *testing.T) {
261+
s := "hello\rworld"
262+
result := Escape(s)
263+
t.Logf("escape newline reuslt: %s", result)
264+
265+
if result != "'hello\\\\rworld'" {
266+
t.Fatalf("escape string error")
267+
268+
}
269+
270+
}
271+
272+
func TestTabEscape(t *testing.T) {
273+
s := "hello\tworld"
274+
result := Escape(s)
275+
t.Logf("escape tab reuslt: %s", result)
276+
277+
if result != `'hello\\tworld'` {
278+
t.Fatalf("escape string error")
279+
280+
}
281+
282+
}
283+
284+
func TestDoubleBackslashEscape(t *testing.T) {
285+
s := "hello\\world"
286+
result := Escape(s)
287+
t.Logf("escape tab reuslt: %s", result)
288+
289+
if result != `'hello\\\\world'` {
290+
t.Fatalf("escape string error")
291+
292+
}
293+
294+
}
295+
296+
func TestCtrlZEscape(t *testing.T) {
297+
s := "hello\x1aworld"
298+
result := Escape(s)
299+
t.Logf("escape tab reuslt: %s", result)
300+
301+
if result != `'hello\\Zworld'` {
302+
t.Fatalf("escape string error")
303+
304+
}
305+
306+
}
307+
308+
func TestDoubleQouteEscape(t *testing.T) {
309+
s := "hello \" world"
310+
result := Escape(s)
311+
t.Logf("escape tab reuslt: %s", result)
312+
313+
if result != `'hello \\" world'` {
314+
t.Fatalf("escape string error")
315+
316+
}
317+
318+
}
319+
220320
func TestFormatSql(t *testing.T) {
221321

222322
sql := Format("select * from users where name=? and age=? limit ?,?", "t'est", 10, 10, 10)
@@ -282,4 +382,11 @@ func TestFormatSql(t *testing.T) {
282382
t.Fatalf("escape format time error")
283383

284384
}
385+
386+
sql = Format("select * from users where name=? and age=? limit ?,?", `t\'est`, 10, 10, 10)
387+
388+
if sql != `'select * from users where name='t\\\\\'est' and age=10 limit 10,10'` {
389+
390+
t.Logf("sql: %s\n", sql)
391+
}
285392
}

0 commit comments

Comments
 (0)
Please sign in to comment.