diff --git a/README.md b/README.md index 94d87fe0..7e4a5772 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,12 @@ Other supported formats are listed below. * `false` - Data sent between client and server is not encrypted beyond the login packet. (Default) * `true` - Data sent between client and server is encrypted. * `app name` - The application name (default is go-mssqldb) +* `columnEncryption` - Set to "true" if you want to use [Always Encrypted](https://docs.microsoft.com/en-us/sql/relational-databases/security/encryption/always-encrypted-database-engine?view=sql-server-ver15) +* `keyStoreAuthentication` + * `pfx` - Use a PFX file as a key store to authenticate and perform Always Encrypted operations, used when `columnEncryption` is enabled +* `keyStoreLocation` - The location of the key store file (e.g: `./resources/test/always-encrypted/ae-1.pfx`), used when `columnEncryption` is enabled +* `keyStoreSecret` - The password of the key store file provided in `keyStoreLocation`, used when `columnEncryption` is enabled + ### Connection parameters for ODBC and ADO style connection strings: @@ -126,6 +132,80 @@ Where `tokenProvider` is a function that returns a fresh access token or an erro actually trigger the retrieval of a token, this happens when the first statment is issued and a connection is created. + +### Always Encrypted support (preview) + +`go-mssql` supports a client-side decryption of the column encrypted values for those databases +that are using the [Always Encrypted](https://docs.microsoft.com/en-us/sql/relational-databases/security/encryption/always-encrypted-database-engine?view=sql-server-ver15) +feature. + +To start using the feature, you have to use the following parameters in your DSN: + +* `columnEncryption=true` +* `keyStoreAuthentication=pfx` - Only `pfx` is supported at the moment +* `keyStoreLocation=/path/to/your/keystore.pfx` - The location of the key store file (e.g: `./resources/test/always-encrypted/ae-1.pfx`), used when `columnEncryption` is enabled +* `keyStoreSecret=secret` - The password of your keystore (`keyStoreLocation`) + +#### Usage + +Using the Always Encrypted feature should be transparent in the driver: +```go +query := url.Values{} +query.Add("database", "dbname") +query.Add("columnEncryption", "true") +query.Add("keyStoreAuthentication", "pfx") +query.Add("keyStoreLocation", "./resources/test/always-encrypted/ae-1.pfx") +query.Add("keyStoreSecret", "password") + + +hostname := "172.20.0.2" +port:= 1433 + +u := &url.URL{ + Scheme: "sqlserver", + User: url.UserPassword("sa", "superSecurePassword_"), + Host: fmt.Sprintf("%s:%d", hostname, port), + RawQuery: query.Encode(), +} + +db, err := sql.Open("sqlserver", u.String()) +if err != nil { + logrus.Fatalf("unable to open db: %v", err) +} +rows, err := db.Query("SELECT id, ssn FROM [dbo].[cid]") +if err != nil { + logrus.Fatalf("unable to perform query: %v", err) +} + +for ; rows.Next(); { + var dest struct { + Id int + SSN string + } + err = rows.Scan(&dest.Id, &dest.SSN) + if err != nil { + logrus.Fatalf("unable to scan into struct: %v", err) + } + fmt.Printf("%d, %s\n", dest.Id, dest.SSN) +} +``` + +The code above, when used against an Always Encrypted column, returns +the following: + +``` +1, 12345 +2, 00000 +``` + +If `columnEncryption` is set to false, the result will be similar to the following: +``` +1, B��v��3O뗇��a�R��o�l��U� +�iE�#wOS�T횡5�R��1�i_n/Q��oLPBy��kL���8'/� +2, �ކ��?�Y + Ѕ���i_n��-g|����v��2����x�Q)y�p�x��O��9������r��Bt�L�"N����.N]Rc +``` + ## Executing Stored Procedures To run a stored procedure, set the query text to the procedure name: diff --git a/always_encrypted_test.go b/always_encrypted_test.go new file mode 100644 index 00000000..a517a0ca --- /dev/null +++ b/always_encrypted_test.go @@ -0,0 +1,34 @@ +package mssql + +import ( + "fmt" + "testing" + "github.com/stretchr/testify/assert" +) + +func TestAlwaysEncrypted(t *testing.T) { + conn := open(t) + defer conn.Close() + rows, err := conn.Query("SELECT id, ssn FROM [dbo].[cid]") + defer rows.Close() + + if err != nil { + t.Fatalf("unable to query db: %s", err) + } + + var dest struct { + Id int + SSN string + } + + expectedValues := []string{"12345", "00000"} + expectedIdx := 0 + + for ; rows.Next() ; { + err = rows.Scan(&dest.Id, &dest.SSN) + assert.Equal(t, expectedValues[expectedIdx], dest.SSN) + expectedIdx++ + assert.Nil(t, err) + fmt.Printf("col: %v\n", dest) + } +} \ No newline at end of file diff --git a/buf.go b/buf.go index bad2b00d..b69cc987 100644 --- a/buf.go +++ b/buf.go @@ -269,3 +269,41 @@ func (r *tdsBuffer) Read(buf []byte) (copied int, err error) { r.rpos += copied return } + +type sqlIdentifier struct { + serverName string + databaseName string + schemaName string + objectName string +} + +func (r *tdsBuffer) sqlIdentifier() sqlIdentifier { + numParts := int(r.byte()) + if numParts < 1 || numParts >= 5 { + panic("invalid sqlIdentifier: numparts is not between 1 and 4") + } + + parts := make([]string, numParts) + + for i := range parts { + parts[i] = r.UsVarChar() + } + + sqlID := sqlIdentifier{ + objectName: parts[0], + } + + if numParts >= 2 { + sqlID.schemaName = parts[1] + } + + if numParts >= 3{ + sqlID.databaseName = parts[2] + } + + if numParts == 4 { + sqlID.serverName = parts[3] + } + + return sqlID +} diff --git a/cek.go b/cek.go new file mode 100644 index 00000000..9933ba08 --- /dev/null +++ b/cek.go @@ -0,0 +1,29 @@ +package mssql + +type cekTable struct { + entries []cekTableEntry +} + +type encryptionKeyInfo struct { + encryptedKey []byte + databaseID int + cekID int + cekVersion int + cekMdVersion []byte + keyPath string + keyStoreName string + algorithmName string +} + +type cekTableEntry struct { + databaseID int + keyId int + keyVersion int + mdVersion []byte + valueCount int + cekValues []encryptionKeyInfo +} + +func newCekTable(size uint16) cekTable { + return cekTable{entries: make([]cekTableEntry, size)} +} \ No newline at end of file diff --git a/conn_str.go b/conn_str.go index d7d9e06a..66f06a28 100644 --- a/conn_str.go +++ b/conn_str.go @@ -39,11 +39,21 @@ type connectParams struct { packetSize uint16 fedAuthLibrary int fedAuthADALWorkflow byte + columnEncryption bool + keyStoreAuthentication KeyStoreAuthentication + keyStoreLocation string + keyStoreSecret string } // default packet size for TDS buffer const defaultPacketSize = 4096 +type KeyStoreAuthentication string + +const ( + PFXKeystoreAuth = "pfx" +) + func parseConnectParams(dsn string) (connectParams, error) { p := connectParams{ fedAuthLibrary: fedAuthLibraryReserved, @@ -169,6 +179,54 @@ func parseConnectParams(dsn string) (connectParams, error) { } else { p.trustServerCertificate = true } + + columnEncryption, ok := params["columnencryption"] + if ok { + if strings.EqualFold(columnEncryption, "true") { + p.columnEncryption = true + } else { + var err error + p.columnEncryption, err = strconv.ParseBool(columnEncryption) + if err != nil { + f := "invalid columnEncryption '%s': %s" + return p, fmt.Errorf(f, columnEncryption, err.Error()) + } + } + } else { + p.columnEncryption = false + } + + ksAuth, ok := params["keystoreauthentication"] + if ok { + var authMethod KeyStoreAuthentication + switch strings.ToLower(ksAuth) { + case "pfx": + authMethod = PFXKeystoreAuth + default: + return p, fmt.Errorf("invalid keystotreAuthentication '%s'", ksAuth) + } + p.keyStoreAuthentication = authMethod + } + + ksLocation, ok := params["keystorelocation"] + if ok { + if ksLocation == "" { + return p, fmt.Errorf("invalid keystore location provided: '%s'", ksLocation) + } + + _, err := os.Stat(ksLocation) + if err != nil { + return p, fmt.Errorf("unable to find keystore %s: %v", ksLocation, err) + } + + p.keyStoreLocation = ksLocation + } + + ksSecret, ok := params["keystoresecret"] + if ok { + p.keyStoreSecret = ksSecret + } + trust, ok := params["trustservercertificate"] if ok { var err error @@ -248,6 +306,23 @@ func (p connectParams) toUrl() *url.URL { if p.logFlags != 0 { q.Add("log", strconv.FormatUint(p.logFlags, 10)) } + + if p.columnEncryption { + q.Add("columnEncryption", "true") + } + + if p.keyStoreAuthentication != "" { + q.Add("keyStoreAuthentication", string(p.keyStoreAuthentication)) + } + + if p.keyStoreLocation != "" { + q.Add("keyStoreLocation", p.keyStoreLocation) + } + + if p.keyStoreSecret != "" { + q.Add("keyStoreSecret", p.keyStoreSecret) + } + res := url.URL{ Scheme: "sqlserver", Host: p.host, @@ -256,6 +331,7 @@ func (p connectParams) toUrl() *url.URL { if p.instance != "" { res.Path = p.instance } + if len(q) > 0 { res.RawQuery = q.Encode() } @@ -274,7 +350,7 @@ func splitConnectionString(dsn string) (res map[string]string) { if len(name) == 0 { continue } - var value string = "" + var value = "" if len(lst) > 1 { value = strings.TrimSpace(lst[1]) } diff --git a/conn_str_test.go b/conn_str_test.go index bb6e2682..959c06b0 100644 --- a/conn_str_test.go +++ b/conn_str_test.go @@ -2,6 +2,7 @@ package mssql import ( "bufio" + "github.com/stretchr/testify/assert" "io" "os" "reflect" @@ -186,10 +187,10 @@ func testConnParams(t testing.TB) connectParams { } if len(os.Getenv("HOST")) > 0 && len(os.Getenv("DATABASE")) > 0 { return connectParams{ - host: os.Getenv("HOST"), + host: os.Getenv("HOST"), instance: os.Getenv("INSTANCE"), database: os.Getenv("DATABASE"), - user: os.Getenv("SQLUSER"), + user: os.Getenv("SQLUSER"), password: os.Getenv("SQLPASSWORD"), logFlags: logFlags, } @@ -227,3 +228,16 @@ func TestConnParseRoundTripFixed(t *testing.T) { t.Fatal("Parameters do not match after roundtrip", params, rtParams) } } + +func TestConnParseAlwaysEncrypted(t *testing.T) { + connStr := "sqlserver://sa:sa@localhost/instance?database=master&columnEncryption=true&keyStoreAuthentication=pfx&keyStoreLocation=./resources/test/always-encrypted/ae-1.pfx&keyStoreSecret=password" + params, err := parseConnectParams(connStr) + if err != nil { + t.Fatal("Test URL is not valid", err) + } + + assert.True(t, params.columnEncryption) + assert.Equal(t, KeyStoreAuthentication(PFXKeystoreAuth), params.keyStoreAuthentication) + assert.Equal(t, "./resources/test/always-encrypted/ae-1.pfx", params.keyStoreLocation) + assert.Equal(t, "password", params.keyStoreSecret) +} diff --git a/go.mod b/go.mod index ebc02ab8..0d64287d 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,10 @@ module github.com/denisenkom/go-mssqldb go 1.11 require ( + github.com/davecgh/go-spew v1.1.1 // indirect github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe + github.com/stretchr/testify v1.7.0 + github.com/swisscom/mssql-always-encrypted v0.1.0 golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c + golang.org/x/text v0.3.5 ) diff --git a/go.sum b/go.sum index 1887801b..5416d834 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,23 @@ +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe h1:lXe2qZdvpiX5WZkZR4hgp4KJVfY3nMkvmwbVkpv1rVY= github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/swisscom/mssql-always-encrypted v0.1.0 h1:bmYt1My3KgQsYkAJTDXkJt6b5wjRX3rSMrvyYHhK60Y= +github.com/swisscom/mssql-always-encrypted v0.1.0/go.mod h1:FlEWLI3+svdMFq2w7GVMvk7iVhwBEBi7E7llAHb4B20= golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c h1:Vj5n4GlwjmQteupaxJ9+0FNOmBrHfq7vN4btdGoDZgI= golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/text v0.3.5 h1:i6eZZ+zk0SOf0xgBpEpPD18qWcJda6q1sxt3S0kzyUQ= +golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/resources/test/always-encrypted/ae-1.pfx b/resources/test/always-encrypted/ae-1.pfx new file mode 100644 index 00000000..3e6edc01 Binary files /dev/null and b/resources/test/always-encrypted/ae-1.pfx differ diff --git a/tds.go b/tds.go index e1b63300..e4c6f594 100644 --- a/tds.go +++ b/tds.go @@ -131,17 +131,27 @@ const ( ) type tdsSession struct { - buf *tdsBuffer - loginAck loginAckStruct - database string - partner string - columns []columnStruct - tranid uint64 - logFlags uint64 - log optionalLogger - routedServer string - routedPort uint16 - returnStatus *ReturnStatus + buf *tdsBuffer + loginAck loginAckStruct + alwaysEncrypted bool + alwaysEncryptedSettings *aeSettings + database string + partner string + columns []columnStruct + tranid uint64 + logFlags uint64 + log optionalLogger + routedServer string + routedPort uint16 + returnStatus *ReturnStatus +} + +type aeSettings struct { + ksLocation string + ksSecret string + ksAuth KeyStoreAuthentication + pKey interface{} + cert *x509.Certificate } const ( @@ -155,10 +165,15 @@ const ( ) type columnStruct struct { - UserType uint32 - Flags uint16 - ColName string - ti typeInfo + UserType uint32 + Flags uint16 + ColName string + ti typeInfo + cryptoMeta cryptoMetadata +} + +func (c columnStruct) isEncrypted() bool { + return 0x0800 == (c.Flags & 0x0800) } type keySlice []uint8 @@ -408,6 +423,23 @@ func (e *featureExtFedAuth) toBytes() []byte { return d } +type featureExtColumnEncryption struct { +} + +func (f *featureExtColumnEncryption) featureID() byte { + return featExtCOLUMNENCRYPTION +} + +func (f *featureExtColumnEncryption) toBytes() []byte { + /* + 1 = The client supports column encryption without enclave computations. + 2 = The client SHOULD<25> support column encryption when encrypted data require enclave computations. + */ + return []byte{0x01} +} + +var _ featureExt = &featureExtColumnEncryption{} + type loginHeader struct { Length uint32 TDSVersion uint32 @@ -474,7 +506,7 @@ func ucs22str(s []byte) (string, error) { } func manglePassword(password string) []byte { - var ucs2password []byte = str2ucs2(password) + var ucs2password = str2ucs2(password) for i, ch := range ucs2password { ucs2password[i] = ((ch<<4)&0xff | (ch >> 4)) ^ 0xA5 } @@ -947,7 +979,7 @@ func interpretPreloginResponse(p connectParams, fe *featureExtFedAuth, fields ma // We need to be able to echo the value back to the server fe.FedAuthEcho = fedAuthSupport[0] != 0 } else if fe.FedAuthLibrary != fedAuthLibraryReserved { - return 0, fmt.Errorf("Federated authentication is not supported by the server") + return 0, fmt.Errorf("federated authentication is not supported by the server") } encryptBytes, ok := fields[preloginENCRYPTION] @@ -973,6 +1005,12 @@ func prepareLogin(ctx context.Context, c *Connector, p connectParams, log option AppName: p.appname, TypeFlags: p.typeFlags, } + + if p.columnEncryption { + // Support Always Encrypted + _ = l.FeatureExt.Add(&featureExtColumnEncryption{}) + } + switch { case fe.FedAuthLibrary == fedAuthLibrarySecurityToken: if p.logFlags&logDebug != 0 { @@ -987,14 +1025,14 @@ func prepareLogin(ctx context.Context, c *Connector, p connectParams, log option return nil, err } - l.FeatureExt.Add(fe) + _ = l.FeatureExt.Add(fe) case fe.FedAuthLibrary == fedAuthLibraryADAL: if p.logFlags&logDebug != 0 { log.Println("Starting federated authentication using ADAL") } - l.FeatureExt.Add(fe) + _ = l.FeatureExt.Add(fe) case auth != nil: if p.logFlags&logDebug != 0 { @@ -1203,6 +1241,21 @@ initiate_connection: case loginAckStruct: sess.loginAck = token loginAck = true + case featureExtAck: + for _, v := range token { + switch v:= v.(type) { + case colAckStruct: + if v.Version <= 2 && v.Version > 0 { + sess.alwaysEncrypted = true + sess.alwaysEncryptedSettings = &aeSettings{ + ksSecret: p.keyStoreSecret, + ksLocation: p.keyStoreLocation, + ksAuth: p.keyStoreAuthentication, + } + } + } + } + case doneStruct: if token.isError() { return nil, fmt.Errorf("login error: %s", token.getError()) diff --git a/token.go b/token.go index c9d45256..6650b8cf 100644 --- a/token.go +++ b/token.go @@ -1,12 +1,22 @@ package mssql import ( + "bytes" "context" + "crypto/rsa" + "crypto/sha1" "encoding/binary" "errors" "fmt" + alwaysencrypted "github.com/swisscom/mssql-always-encrypted/pkg" + "github.com/swisscom/mssql-always-encrypted/pkg/algorithms" + "github.com/swisscom/mssql-always-encrypted/pkg/encryption" + "github.com/swisscom/mssql-always-encrypted/pkg/keys" + "golang.org/x/crypto/pkcs12" + "golang.org/x/text/encoding/unicode" "io" "io/ioutil" + "os" "strconv" ) @@ -75,6 +85,10 @@ const ( fedAuthInfoSPN = 0x02 ) +const ( + cipherAlgCustom = 0x00 +) + // COLMETADATA flags // https://msdn.microsoft.com/en-us/library/dd357363.aspx const ( @@ -82,6 +96,9 @@ const ( // TODO implement more flags ) +// UTF-16 Decoder +var utf16Decoder = unicode.UTF16(unicode.LittleEndian, unicode.IgnoreBOM).NewDecoder() + // interface for all tokens type tokenStruct interface{} @@ -467,7 +484,7 @@ func parseFedAuthInfo(r *tdsBuffer) fedAuthInfoStruct { case fedAuthInfoSPN: SPN, err = ucs22str(optData) default: - err = fmt.Errorf("Unexpected fed auth info opt ID %d", int(opts[i].fedAuthInfoID)) + err = fmt.Errorf("unexpected fed auth info opt ID %d", int(opts[i].fedAuthInfoID)) } if err != nil { @@ -510,7 +527,13 @@ type fedAuthAckStruct struct { Signature []byte } -func parseFeatureExtAck(r *tdsBuffer) map[byte]interface{} { +type colAckStruct struct { + Version int +} + +type featureExtAck map[byte]interface{} + +func parseFeatureExtAck(r *tdsBuffer) featureExtAck { ack := map[byte]interface{}{} for feature := r.byte(); feature != featExtTERMINATOR; feature = r.byte() { @@ -532,7 +555,18 @@ func parseFeatureExtAck(r *tdsBuffer) map[byte]interface{} { length -= 32 } ack[feature] = fedAuthAck - + case featExtCOLUMNENCRYPTION: + colAck := colAckStruct{} + colAck.Version = int(r.byte()) + length-- + + if length > 0 { + enclaveLength := r.byte() + var enclaveType = make([]byte, enclaveLength) + r.ReadFull(enclaveType) + length -= uint32(enclaveLength) + } + ack[feature] = colAck } // Skip unprocessed bytes @@ -545,29 +579,302 @@ func parseFeatureExtAck(r *tdsBuffer) map[byte]interface{} { } // http://msdn.microsoft.com/en-us/library/dd357363.aspx -func parseColMetadata72(r *tdsBuffer) (columns []columnStruct) { +func parseColMetadata72(r *tdsBuffer, s *tdsSession) (columns []columnStruct) { count := r.uint16() if count == 0xffff { // no metadata is sent return nil } columns = make([]columnStruct, count) + + var cekTable *cekTable + if s.alwaysEncrypted { + // CEK table + cekTable = readCEKTable(r) + + if s.alwaysEncryptedSettings == nil { + panic("alwaysEncryptedSettings are nil!") + } + + if s.alwaysEncryptedSettings.pKey == nil { + // Load Keystore + f, err := os.Open(s.alwaysEncryptedSettings.ksLocation) + if err != nil { + panic(err) + } + + switch s.alwaysEncryptedSettings.ksAuth { + case PFXKeystoreAuth: + pfxBytes, err := ioutil.ReadAll(f) + if err != nil { + panic(err) + } + + pk, cert, err := pkcs12.Decode(pfxBytes, s.alwaysEncryptedSettings.ksSecret) + if err != nil { + panic(err) + } + + s.alwaysEncryptedSettings.pKey = pk + s.alwaysEncryptedSettings.cert = cert + default: + panic(fmt.Sprintf("ksAuth %v is unimplemented", s.alwaysEncryptedSettings.ksAuth)) + } + } + } + + dec := unicode.UTF16(unicode.LittleEndian, unicode.IgnoreBOM).NewDecoder() + for i := range columns { column := &columns[i] - column.UserType = r.uint32() - column.Flags = r.uint16() + baseTi := getBaseTypeInfo(r, true) + typeInfo := readTypeInfo(r, baseTi.TypeId) + typeInfo.UserType = baseTi.UserType + typeInfo.Flags = baseTi.Flags + typeInfo.TypeId = baseTi.TypeId + + // Table Name + if baseTi.TypeId == typeText || baseTi.TypeId == typeNText || baseTi.TypeId == typeImage { + _ = r.sqlIdentifier() + } + + column.Flags = baseTi.Flags + column.UserType = baseTi.UserType + column.ti = typeInfo + + if column.isEncrypted() && s.alwaysEncrypted { + // Read Crypto Metadata + cryptoMeta := parseCryptoMetadata(r, cekTable) + cryptoMeta.typeInfo.Flags = baseTi.Flags + column.cryptoMeta = cryptoMeta + } - // parsing TYPE_INFO structure - column.ti = readTypeInfo(r) - column.ColName = r.BVarChar() + colNameLen := r.byte() + colNameUtf16 := make([]byte, int(colNameLen)*2) + r.ReadFull(colNameUtf16) + colName, _ := dec.Bytes(colNameUtf16) + column.ColName = string(colName) } return columns } +func getBaseTypeInfo(r *tdsBuffer, parseFlags bool) typeInfo { + userType := r.uint32() + flags := uint16(0) + if parseFlags { + flags = r.uint16() + } + tId := r.byte() + + return typeInfo{ + UserType: userType, + Flags: flags, + TypeId: tId} +} + +type cryptoMetadata struct { + entry *cekTableEntry + ordinal uint16 + algorithmId byte + algorithmName *string + encType byte + normRuleVer byte + typeInfo typeInfo +} + +func parseCryptoMetadata(r *tdsBuffer, cekTable *cekTable) cryptoMetadata { + ordinal := uint16(0) + if cekTable != nil { + ordinal = r.uint16() + } + + typeInfo := getBaseTypeInfo(r, false) + ti := readTypeInfo(r, typeInfo.TypeId) + ti.UserType = typeInfo.UserType + ti.Flags = typeInfo.Flags + ti.TypeId = typeInfo.TypeId + + algorithmId := r.byte() + var algName *string = nil + + if algorithmId == cipherAlgCustom { + // Read the name when a custom algorithm is used + nameLen := int(r.byte()) + var algNameUtf16 = make([]byte, nameLen*2) + r.ReadFull(algNameUtf16) + algNameBytes, _ := unicode.UTF16(unicode.LittleEndian, unicode.IgnoreBOM).NewDecoder().Bytes(algNameUtf16) + mAlgName := string(algNameBytes) + algName = &mAlgName + } + + encType := r.byte() + normRuleVer := r.byte() + + var entry *cekTableEntry = nil + + if cekTable != nil { + if int(ordinal) > len(cekTable.entries)-1 { + panic(fmt.Errorf("invalid ordinal, cekTable only has %d entries", len(cekTable.entries))) + } + entry = &cekTable.entries[ordinal] + } + + return cryptoMetadata{ + entry: entry, + ordinal: ordinal, + algorithmId: algorithmId, + algorithmName: algName, + encType: encType, + normRuleVer: normRuleVer, + typeInfo: ti, + } +} + +func readCEKTable(r *tdsBuffer) *cekTable { + tableSize := r.uint16() + var cekTable *cekTable = nil + + if tableSize != 0 { + mCekTable := newCekTable(tableSize) + for i := uint16(0); i < tableSize; i++ { + mCekTable.entries[i] = readCekTableEntry(r) + } + cekTable = &mCekTable + } + + return cekTable +} + +func readCekTableEntry(r *tdsBuffer) cekTableEntry { + databaseId := r.int32() + cekID := r.int32() + cekVersion := r.int32() + var cekMdVersion = make([]byte, 8) + _, err := r.Read(cekMdVersion) + if err != nil { + panic("unable to read cekMdVersion") + } + + cekValueCount := uint(r.byte()) + enc := unicode.UTF16(unicode.LittleEndian, unicode.IgnoreBOM) + utf16dec := enc.NewDecoder() + cekValues := make([]encryptionKeyInfo, cekValueCount) + + for i := uint(0); i < cekValueCount; i++ { + encryptedCekLength := r.uint16() + encryptedCek := make([]byte, encryptedCekLength) + r.ReadFull(encryptedCek) + + keyStoreLength := r.byte() + keyStoreNameUtf16 := make([]byte, keyStoreLength*2) + r.ReadFull(keyStoreNameUtf16) + keyStoreName, _ := utf16dec.Bytes(keyStoreNameUtf16) + + keyPathLength := r.uint16() + keyPathUtf16 := make([]byte, keyPathLength*2) + r.ReadFull(keyPathUtf16) + keyPath, _ := utf16dec.Bytes(keyPathUtf16) + + algLength := r.byte() + algNameUtf16 := make([]byte, algLength*2) + r.ReadFull(algNameUtf16) + algName, _ := utf16dec.Bytes(algNameUtf16) + + cekValues[i] = encryptionKeyInfo{ + encryptedKey: encryptedCek, + databaseID: int(databaseId), + cekID: int(cekID), + cekVersion: int(cekVersion), + cekMdVersion: cekMdVersion, + keyPath: string(keyPath), + keyStoreName: string(keyStoreName), + algorithmName: string(algName), + } + } + + return cekTableEntry{ + databaseID: int(databaseId), + keyId: int(cekID), + keyVersion: int(cekVersion), + mdVersion: cekMdVersion, + valueCount: int(cekValueCount), + cekValues: cekValues, + } +} + +type RWCBuffer struct { + buffer *bytes.Reader +} + +func (R RWCBuffer) Read(p []byte) (n int, err error) { + return R.buffer.Read(p) +} + +func (R RWCBuffer) Write(p []byte) (n int, err error) { + return 0, nil +} + +func (R RWCBuffer) Close() error { + return nil +} + +var _ io.ReadWriteCloser = RWCBuffer{} + // http://msdn.microsoft.com/en-us/library/dd357254.aspx -func parseRow(r *tdsBuffer, columns []columnStruct, row []interface{}) { +func parseRow(r *tdsBuffer, s *tdsSession, columns []columnStruct, row []interface{}) { for i, column := range columns { - row[i] = column.ti.Reader(&column.ti, r) + columnContent := column.ti.Reader(&column.ti, r) + if column.isEncrypted() && s.alwaysEncrypted { + // Decrypt + cekValue := column.cryptoMeta.entry.cekValues[column.cryptoMeta.ordinal] + algVer := cekValue.cekVersion + encType := encryption.From(column.cryptoMeta.encType) + + // Get pKey + if s.alwaysEncryptedSettings.pKey == nil { + panic("alwaysEncrypted pKey not set: this should never happen") + } + + cekv := alwaysencrypted.LoadCEKV(column.cryptoMeta.entry.cekValues[0].encryptedKey) + if !cekv.Verify(s.alwaysEncryptedSettings.cert) { + panic(fmt.Errorf("invalid certificate being used to decrypt: %v requested but %v provided", + cekv.KeyPath, + fmt.Sprintf("%02x", sha1.Sum(s.alwaysEncryptedSettings.cert.Raw)), + )) + } + + // TODO: Support other private keys + rootKey, err := cekv.Decrypt(s.alwaysEncryptedSettings.pKey.(*rsa.PrivateKey)) + if err != nil { + panic(err) + } + + // Derive Root Key from encryptedKey + k := keys.NewAeadAes256CbcHmac256(rootKey) + alg := algorithms.NewAeadAes256CbcHmac256Algorithm(k, encType, byte(algVer)) + + d, err := alg.Decrypt(columnContent.([]byte)) + if err != nil { + panic(err) + } + + // Dirty workaround to keep compatibility with original types + // TODO: Improve me + var newBuff = make([]byte, 2) + binary.LittleEndian.PutUint16(newBuff, uint16(len(d))) + newBuff = append(newBuff, d...) + + rwc := RWCBuffer{ + buffer: bytes.NewReader(newBuff), + } + + column.cryptoMeta.typeInfo.Buffer = d + buffer := tdsBuffer{rpos: 0, rsize: len(newBuff), rbuf: newBuff, transport: rwc} + + row[i] = column.cryptoMeta.typeInfo.Reader(&column.cryptoMeta.typeInfo, &buffer) + } else { + row[i] = columnContent + } } } @@ -628,10 +935,10 @@ func parseReturnValue(r *tdsBuffer) (nv namedValue) { r.uint16() nv.Name = r.BVarChar() r.byte() - r.uint32() // UserType (uint16 prior to 7.2) - r.uint16() - ti := readTypeInfo(r) - nv.Value = ti.Reader(&ti, r) + + ti := getBaseTypeInfo(r, true) + ti2 := readTypeInfo(r, ti.TypeId) + nv.Value = ti2.Reader(&ti2, r) return } @@ -707,11 +1014,11 @@ func processSingleResponse(sess *tdsSession, ch chan tokenStruct, outs map[strin return } case tokenColMetadata: - columns = parseColMetadata72(sess.buf) + columns = parseColMetadata72(sess.buf, sess) ch <- columns case tokenRow: row := make([]interface{}, len(columns)) - parseRow(sess.buf, columns, row) + parseRow(sess.buf, sess, columns, row) ch <- row case tokenNbcRow: row := make([]interface{}, len(columns)) diff --git a/types.go b/types.go index cae19924..6cdc7be2 100644 --- a/types.go +++ b/types.go @@ -83,6 +83,8 @@ const _TVP_ROW_TOKEN = 0x01 // http://msdn.microsoft.com/en-us/library/dd358284.aspx type typeInfo struct { TypeId uint8 + UserType uint32 + Flags uint16 Size int Scale uint8 Prec uint8 @@ -113,9 +115,9 @@ type xmlInfo struct { XmlSchemaCollection string } -func readTypeInfo(r *tdsBuffer) (res typeInfo) { - res.TypeId = r.byte() - switch res.TypeId { +func readTypeInfo(r *tdsBuffer, typeId byte) (res typeInfo) { + res.TypeId = typeId + switch typeId { case typeNull, typeInt1, typeBit, typeInt2, typeInt4, typeDateTim4, typeFlt4, typeMoney, typeDateTime, typeFlt8, typeMoney4, typeInt8: // those are fixed length types