diff --git a/msdsn/conn_str.go b/msdsn/conn_str.go index 40bedb08..c4268b97 100644 --- a/msdsn/conn_str.go +++ b/msdsn/conn_str.go @@ -38,6 +38,7 @@ const ( ) type Config struct { + Scheme string Port uint64 Host string Instance string @@ -117,7 +118,7 @@ func Parse(dsn string) (Config, map[string]string, error) { return p, params, err } params = parameters - } else if strings.HasPrefix(dsn, "sqlserver://") { + } else if strings.HasPrefix(dsn, "sqlserver://") || strings.HasPrefix(dsn, "azuresql://") { parameters, err := splitConnectionStringURL(dsn) if err != nil { return p, params, err @@ -127,6 +128,11 @@ func Parse(dsn string) (Config, map[string]string, error) { params = splitConnectionString(dsn) } + p.Scheme = "sqlserver" + if strings.HasPrefix(dsn, "azuresql://") { + p.Scheme = "azuresql" + } + strlog, ok := params["log"] if ok { flags, err := strconv.ParseUint(strlog, 10, 64) @@ -342,7 +348,7 @@ func (p Config) URL() *url.URL { } q.Add("disableRetry", fmt.Sprintf("%t", p.DisableRetry)) res := url.URL{ - Scheme: "sqlserver", + Scheme: p.Scheme, Host: host, User: url.UserPassword(p.User, p.Password), } @@ -410,7 +416,7 @@ func splitConnectionStringURL(dsn string) (map[string]string, error) { return res, err } - if u.Scheme != "sqlserver" { + if u.Scheme != "sqlserver" && u.Scheme != "azuresql" { return res, fmt.Errorf("scheme %s is not recognized", u.Scheme) } diff --git a/msdsn/conn_str_test.go b/msdsn/conn_str_test.go index 594b5b3d..a60f4728 100644 --- a/msdsn/conn_str_test.go +++ b/msdsn/conn_str_test.go @@ -136,28 +136,28 @@ func TestValidConnectionString(t *testing.T) { // URL mode {"sqlserver://somehost?connection+timeout=30", func(p Config) bool { - return p.Host == "somehost" && p.Port == 0 && p.Instance == "" && p.ConnTimeout == 30*time.Second + return p.Scheme == "sqlserver" && p.Host == "somehost" && p.Port == 0 && p.Instance == "" && p.ConnTimeout == 30*time.Second }}, {"sqlserver://someuser@somehost?connection+timeout=30", func(p Config) bool { - return p.Host == "somehost" && p.Port == 0 && p.Instance == "" && p.User == "someuser" && p.Password == "" && p.ConnTimeout == 30*time.Second + return p.Scheme == "sqlserver" && p.Host == "somehost" && p.Port == 0 && p.Instance == "" && p.User == "someuser" && p.Password == "" && p.ConnTimeout == 30*time.Second }}, {"sqlserver://someuser:@somehost?connection+timeout=30", func(p Config) bool { - return p.Host == "somehost" && p.Port == 0 && p.Instance == "" && p.User == "someuser" && p.Password == "" && p.ConnTimeout == 30*time.Second + return p.Scheme == "sqlserver" && p.Host == "somehost" && p.Port == 0 && p.Instance == "" && p.User == "someuser" && p.Password == "" && p.ConnTimeout == 30*time.Second }}, {"sqlserver://someuser:foo%3A%2F%5C%21~%40;bar@somehost?connection+timeout=30", func(p Config) bool { - return p.Host == "somehost" && p.Port == 0 && p.Instance == "" && p.User == "someuser" && p.Password == "foo:/\\!~@;bar" && p.ConnTimeout == 30*time.Second + return p.Scheme == "sqlserver" && p.Host == "somehost" && p.Port == 0 && p.Instance == "" && p.User == "someuser" && p.Password == "foo:/\\!~@;bar" && p.ConnTimeout == 30*time.Second }}, {"sqlserver://someuser:foo%3A%2F%5C%21~%40;bar@somehost:1434?connection+timeout=30", func(p Config) bool { - return p.Host == "somehost" && p.Port == 1434 && p.Instance == "" && p.User == "someuser" && p.Password == "foo:/\\!~@;bar" && p.ConnTimeout == 30*time.Second + return p.Scheme == "sqlserver" && p.Host == "somehost" && p.Port == 1434 && p.Instance == "" && p.User == "someuser" && p.Password == "foo:/\\!~@;bar" && p.ConnTimeout == 30*time.Second }}, {"sqlserver://someuser:foo%3A%2F%5C%21~%40;bar@somehost:1434/someinstance?connection+timeout=30", func(p Config) bool { - return p.Host == "somehost" && p.Port == 1434 && p.Instance == "someinstance" && p.User == "someuser" && p.Password == "foo:/\\!~@;bar" && p.ConnTimeout == 30*time.Second + return p.Scheme == "sqlserver" && p.Host == "somehost" && p.Port == 1434 && p.Instance == "someinstance" && p.User == "someuser" && p.Password == "foo:/\\!~@;bar" && p.ConnTimeout == 30*time.Second }}, {"sqlserver://someuser@somehost?disableretry=true", func(p Config) bool { - return p.Host == "somehost" && p.Port == 0 && p.Instance == "" && p.User == "someuser" && p.Password == "" && p.DisableRetry + return p.Scheme == "sqlserver" && p.Host == "somehost" && p.Port == 0 && p.Instance == "" && p.User == "someuser" && p.Password == "" && p.DisableRetry }}, {"sqlserver://someuser@somehost?connection+timeout=30&disableretry=1", func(p Config) bool { - return p.Host == "somehost" && p.Port == 0 && p.Instance == "" && p.User == "someuser" && p.Password == "" && p.ConnTimeout == 30*time.Second && p.DisableRetry + return p.Scheme == "sqlserver" && p.Host == "somehost" && p.Port == 0 && p.Instance == "" && p.User == "someuser" && p.Password == "" && p.ConnTimeout == 30*time.Second && p.DisableRetry }}, } for _, ts := range connStrings { @@ -175,6 +175,41 @@ func TestValidConnectionString(t *testing.T) { } } +func TestValidConnectionStringWithParams(t *testing.T) { + // Like the test above, but strings where we want to have assertions on params being passed through; + // would be very verbose to change the test above + type testStruct struct { + connStr string + check func(Config, map[string]string) bool + } + connStrings := []testStruct{ + {"azuresql://somehost?connection+timeout=30&fedauth=ActiveDirectoryDefault", func(p Config, params map[string]string) bool { + // fedauth is in params, see TestParams + if !(p.Scheme == "azuresql" && p.Host == "somehost" && p.Port == 0 && p.Instance == "" && p.Password == "") { + return false + } + fedauth, ok := params["fedauth"] + if !ok { + return false + } + return fedauth == "ActiveDirectoryDefault" + }}, + } + for _, ts := range connStrings { + p, params, err := Parse(ts.connStr) + if err == nil { + t.Logf("Connection string was parsed successfully %s", ts.connStr) + } else { + t.Errorf("Connection string %s failed to parse with error %s", ts.connStr, err) + continue + } + + if !ts.check(p, params) { + t.Errorf("Check failed on conn str %s", ts.connStr) + } + } +} + func TestSplitConnectionStringURL(t *testing.T) { _, err := splitConnectionStringURL("http://bad") if err == nil {