From 30c8baa57352bba4c40018e1f104adf868405fae Mon Sep 17 00:00:00 2001 From: Denys Vitali Date: Wed, 3 Feb 2021 15:34:00 +0100 Subject: [PATCH] feat: add Always Encrypted support This commit adds partial support for the Microsoft SQL "Always Encrypted" feature (basically, E2E encryption). The current implementation is to be consider a __preview__ since it might not be perfectly implemented. Supported features: - PFX "keystore" - Seamless encryption Missing features: - Support for Private Keys that are not RSA - Encryption support (only Decryption is possible at the moment) The most probably needs to be improved a bit, but so far it's working for some of the use cases that I needed it for. Feel free to test the feature and open an issue if you find any problem: my goal is to have enough testers to spot eventual bugs. fix: lint issues --- README.md | 80 ++++++ always_encrypted_test.go | 34 +++ buf.go | 38 +++ cek.go | 29 ++ conn_str.go | 78 +++++- conn_str_test.go | 18 +- go.mod | 4 + go.sum | 18 ++ resources/test/always-encrypted/ae-1.pfx | Bin 0 -> 2568 bytes tds.go | 91 ++++-- token.go | 341 +++++++++++++++++++++-- types.go | 8 +- 12 files changed, 697 insertions(+), 42 deletions(-) create mode 100644 always_encrypted_test.go create mode 100644 cek.go create mode 100644 resources/test/always-encrypted/ae-1.pfx diff --git a/README.md b/README.md index 94d87fe0..7e4a5772 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,12 @@ Other supported formats are listed below. * `false` - Data sent between client and server is not encrypted beyond the login packet. (Default) * `true` - Data sent between client and server is encrypted. * `app name` - The application name (default is go-mssqldb) +* `columnEncryption` - Set to "true" if you want to use [Always Encrypted](https://docs.microsoft.com/en-us/sql/relational-databases/security/encryption/always-encrypted-database-engine?view=sql-server-ver15) +* `keyStoreAuthentication` + * `pfx` - Use a PFX file as a key store to authenticate and perform Always Encrypted operations, used when `columnEncryption` is enabled +* `keyStoreLocation` - The location of the key store file (e.g: `./resources/test/always-encrypted/ae-1.pfx`), used when `columnEncryption` is enabled +* `keyStoreSecret` - The password of the key store file provided in `keyStoreLocation`, used when `columnEncryption` is enabled + ### Connection parameters for ODBC and ADO style connection strings: @@ -126,6 +132,80 @@ Where `tokenProvider` is a function that returns a fresh access token or an erro actually trigger the retrieval of a token, this happens when the first statment is issued and a connection is created. + +### Always Encrypted support (preview) + +`go-mssql` supports a client-side decryption of the column encrypted values for those databases +that are using the [Always Encrypted](https://docs.microsoft.com/en-us/sql/relational-databases/security/encryption/always-encrypted-database-engine?view=sql-server-ver15) +feature. + +To start using the feature, you have to use the following parameters in your DSN: + +* `columnEncryption=true` +* `keyStoreAuthentication=pfx` - Only `pfx` is supported at the moment +* `keyStoreLocation=/path/to/your/keystore.pfx` - The location of the key store file (e.g: `./resources/test/always-encrypted/ae-1.pfx`), used when `columnEncryption` is enabled +* `keyStoreSecret=secret` - The password of your keystore (`keyStoreLocation`) + +#### Usage + +Using the Always Encrypted feature should be transparent in the driver: +```go +query := url.Values{} +query.Add("database", "dbname") +query.Add("columnEncryption", "true") +query.Add("keyStoreAuthentication", "pfx") +query.Add("keyStoreLocation", "./resources/test/always-encrypted/ae-1.pfx") +query.Add("keyStoreSecret", "password") + + +hostname := "172.20.0.2" +port:= 1433 + +u := &url.URL{ + Scheme: "sqlserver", + User: url.UserPassword("sa", "superSecurePassword_"), + Host: fmt.Sprintf("%s:%d", hostname, port), + RawQuery: query.Encode(), +} + +db, err := sql.Open("sqlserver", u.String()) +if err != nil { + logrus.Fatalf("unable to open db: %v", err) +} +rows, err := db.Query("SELECT id, ssn FROM [dbo].[cid]") +if err != nil { + logrus.Fatalf("unable to perform query: %v", err) +} + +for ; rows.Next(); { + var dest struct { + Id int + SSN string + } + err = rows.Scan(&dest.Id, &dest.SSN) + if err != nil { + logrus.Fatalf("unable to scan into struct: %v", err) + } + fmt.Printf("%d, %s\n", dest.Id, dest.SSN) +} +``` + +The code above, when used against an Always Encrypted column, returns +the following: + +``` +1, 12345 +2, 00000 +``` + +If `columnEncryption` is set to false, the result will be similar to the following: +``` +1, B��v��3O뗇��a�R��o�l��U� +�iE�#wOS�T횡5�R��1�i_n/Q��oLPBy��kL���8'/� +2, �ކ��?�Y + Ѕ���i_n��-g|����v��2����x�Q)y�p�x��O��9������r��Bt�L�"N����.N]Rc +``` + ## Executing Stored Procedures To run a stored procedure, set the query text to the procedure name: diff --git a/always_encrypted_test.go b/always_encrypted_test.go new file mode 100644 index 00000000..a517a0ca --- /dev/null +++ b/always_encrypted_test.go @@ -0,0 +1,34 @@ +package mssql + +import ( + "fmt" + "testing" + "github.com/stretchr/testify/assert" +) + +func TestAlwaysEncrypted(t *testing.T) { + conn := open(t) + defer conn.Close() + rows, err := conn.Query("SELECT id, ssn FROM [dbo].[cid]") + defer rows.Close() + + if err != nil { + t.Fatalf("unable to query db: %s", err) + } + + var dest struct { + Id int + SSN string + } + + expectedValues := []string{"12345", "00000"} + expectedIdx := 0 + + for ; rows.Next() ; { + err = rows.Scan(&dest.Id, &dest.SSN) + assert.Equal(t, expectedValues[expectedIdx], dest.SSN) + expectedIdx++ + assert.Nil(t, err) + fmt.Printf("col: %v\n", dest) + } +} \ No newline at end of file diff --git a/buf.go b/buf.go index bad2b00d..b69cc987 100644 --- a/buf.go +++ b/buf.go @@ -269,3 +269,41 @@ func (r *tdsBuffer) Read(buf []byte) (copied int, err error) { r.rpos += copied return } + +type sqlIdentifier struct { + serverName string + databaseName string + schemaName string + objectName string +} + +func (r *tdsBuffer) sqlIdentifier() sqlIdentifier { + numParts := int(r.byte()) + if numParts < 1 || numParts >= 5 { + panic("invalid sqlIdentifier: numparts is not between 1 and 4") + } + + parts := make([]string, numParts) + + for i := range parts { + parts[i] = r.UsVarChar() + } + + sqlID := sqlIdentifier{ + objectName: parts[0], + } + + if numParts >= 2 { + sqlID.schemaName = parts[1] + } + + if numParts >= 3{ + sqlID.databaseName = parts[2] + } + + if numParts == 4 { + sqlID.serverName = parts[3] + } + + return sqlID +} diff --git a/cek.go b/cek.go new file mode 100644 index 00000000..9933ba08 --- /dev/null +++ b/cek.go @@ -0,0 +1,29 @@ +package mssql + +type cekTable struct { + entries []cekTableEntry +} + +type encryptionKeyInfo struct { + encryptedKey []byte + databaseID int + cekID int + cekVersion int + cekMdVersion []byte + keyPath string + keyStoreName string + algorithmName string +} + +type cekTableEntry struct { + databaseID int + keyId int + keyVersion int + mdVersion []byte + valueCount int + cekValues []encryptionKeyInfo +} + +func newCekTable(size uint16) cekTable { + return cekTable{entries: make([]cekTableEntry, size)} +} \ No newline at end of file diff --git a/conn_str.go b/conn_str.go index d7d9e06a..66f06a28 100644 --- a/conn_str.go +++ b/conn_str.go @@ -39,11 +39,21 @@ type connectParams struct { packetSize uint16 fedAuthLibrary int fedAuthADALWorkflow byte + columnEncryption bool + keyStoreAuthentication KeyStoreAuthentication + keyStoreLocation string + keyStoreSecret string } // default packet size for TDS buffer const defaultPacketSize = 4096 +type KeyStoreAuthentication string + +const ( + PFXKeystoreAuth = "pfx" +) + func parseConnectParams(dsn string) (connectParams, error) { p := connectParams{ fedAuthLibrary: fedAuthLibraryReserved, @@ -169,6 +179,54 @@ func parseConnectParams(dsn string) (connectParams, error) { } else { p.trustServerCertificate = true } + + columnEncryption, ok := params["columnencryption"] + if ok { + if strings.EqualFold(columnEncryption, "true") { + p.columnEncryption = true + } else { + var err error + p.columnEncryption, err = strconv.ParseBool(columnEncryption) + if err != nil { + f := "invalid columnEncryption '%s': %s" + return p, fmt.Errorf(f, columnEncryption, err.Error()) + } + } + } else { + p.columnEncryption = false + } + + ksAuth, ok := params["keystoreauthentication"] + if ok { + var authMethod KeyStoreAuthentication + switch strings.ToLower(ksAuth) { + case "pfx": + authMethod = PFXKeystoreAuth + default: + return p, fmt.Errorf("invalid keystotreAuthentication '%s'", ksAuth) + } + p.keyStoreAuthentication = authMethod + } + + ksLocation, ok := params["keystorelocation"] + if ok { + if ksLocation == "" { + return p, fmt.Errorf("invalid keystore location provided: '%s'", ksLocation) + } + + _, err := os.Stat(ksLocation) + if err != nil { + return p, fmt.Errorf("unable to find keystore %s: %v", ksLocation, err) + } + + p.keyStoreLocation = ksLocation + } + + ksSecret, ok := params["keystoresecret"] + if ok { + p.keyStoreSecret = ksSecret + } + trust, ok := params["trustservercertificate"] if ok { var err error @@ -248,6 +306,23 @@ func (p connectParams) toUrl() *url.URL { if p.logFlags != 0 { q.Add("log", strconv.FormatUint(p.logFlags, 10)) } + + if p.columnEncryption { + q.Add("columnEncryption", "true") + } + + if p.keyStoreAuthentication != "" { + q.Add("keyStoreAuthentication", string(p.keyStoreAuthentication)) + } + + if p.keyStoreLocation != "" { + q.Add("keyStoreLocation", p.keyStoreLocation) + } + + if p.keyStoreSecret != "" { + q.Add("keyStoreSecret", p.keyStoreSecret) + } + res := url.URL{ Scheme: "sqlserver", Host: p.host, @@ -256,6 +331,7 @@ func (p connectParams) toUrl() *url.URL { if p.instance != "" { res.Path = p.instance } + if len(q) > 0 { res.RawQuery = q.Encode() } @@ -274,7 +350,7 @@ func splitConnectionString(dsn string) (res map[string]string) { if len(name) == 0 { continue } - var value string = "" + var value = "" if len(lst) > 1 { value = strings.TrimSpace(lst[1]) } diff --git a/conn_str_test.go b/conn_str_test.go index bb6e2682..959c06b0 100644 --- a/conn_str_test.go +++ b/conn_str_test.go @@ -2,6 +2,7 @@ package mssql import ( "bufio" + "github.com/stretchr/testify/assert" "io" "os" "reflect" @@ -186,10 +187,10 @@ func testConnParams(t testing.TB) connectParams { } if len(os.Getenv("HOST")) > 0 && len(os.Getenv("DATABASE")) > 0 { return connectParams{ - host: os.Getenv("HOST"), + host: os.Getenv("HOST"), instance: os.Getenv("INSTANCE"), database: os.Getenv("DATABASE"), - user: os.Getenv("SQLUSER"), + user: os.Getenv("SQLUSER"), password: os.Getenv("SQLPASSWORD"), logFlags: logFlags, } @@ -227,3 +228,16 @@ func TestConnParseRoundTripFixed(t *testing.T) { t.Fatal("Parameters do not match after roundtrip", params, rtParams) } } + +func TestConnParseAlwaysEncrypted(t *testing.T) { + connStr := "sqlserver://sa:sa@localhost/instance?database=master&columnEncryption=true&keyStoreAuthentication=pfx&keyStoreLocation=./resources/test/always-encrypted/ae-1.pfx&keyStoreSecret=password" + params, err := parseConnectParams(connStr) + if err != nil { + t.Fatal("Test URL is not valid", err) + } + + assert.True(t, params.columnEncryption) + assert.Equal(t, KeyStoreAuthentication(PFXKeystoreAuth), params.keyStoreAuthentication) + assert.Equal(t, "./resources/test/always-encrypted/ae-1.pfx", params.keyStoreLocation) + assert.Equal(t, "password", params.keyStoreSecret) +} diff --git a/go.mod b/go.mod index ebc02ab8..0d64287d 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,10 @@ module github.com/denisenkom/go-mssqldb go 1.11 require ( + github.com/davecgh/go-spew v1.1.1 // indirect github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe + github.com/stretchr/testify v1.7.0 + github.com/swisscom/mssql-always-encrypted v0.1.0 golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c + golang.org/x/text v0.3.5 ) diff --git a/go.sum b/go.sum index 1887801b..5416d834 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,23 @@ +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe h1:lXe2qZdvpiX5WZkZR4hgp4KJVfY3nMkvmwbVkpv1rVY= github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/swisscom/mssql-always-encrypted v0.1.0 h1:bmYt1My3KgQsYkAJTDXkJt6b5wjRX3rSMrvyYHhK60Y= +github.com/swisscom/mssql-always-encrypted v0.1.0/go.mod h1:FlEWLI3+svdMFq2w7GVMvk7iVhwBEBi7E7llAHb4B20= golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c h1:Vj5n4GlwjmQteupaxJ9+0FNOmBrHfq7vN4btdGoDZgI= golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/text v0.3.5 h1:i6eZZ+zk0SOf0xgBpEpPD18qWcJda6q1sxt3S0kzyUQ= +golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/resources/test/always-encrypted/ae-1.pfx b/resources/test/always-encrypted/ae-1.pfx new file mode 100644 index 0000000000000000000000000000000000000000..3e6edc01d68d81de17ec9861f129424dc9f45aea GIT binary patch literal 2568 zcmb7GdpK0<8eeN>&6t_oGic;~X){gM7)1@HX)xt}Ng;9>M2LnITR9fFyI_O_VL&X5 z5O3j$dP%#J79>D1ummC2LkO`JrG+qp7ymC6?lXkomS6-okIF25-QQl~BoI@AKwn@4 zT81~`N&lwN;F1ubDL8etiD>c@1|#+mLZHVt%5GdTmqf%zwTW(yTAgG@yVc6evKRKc zExxlco+*WU9Xr^^UT>L`vD|DyOYLG`dFDPVlXt1^ov8D0OU`cPuvdVAvTg{IPT zyzNtQt=YOD5dBHF-N(SYHr+ygY9>->z4MUM3y%E5!%by5qKkTavxwAwW;p%XU~Yzc z9Nza#8h5A1`}x_kp_89AM-mF+9|Ya<4g0f~nN~Rqv`KG%&3>ICNdFSho~-uPS(*zB zjEOwimDxR=3GyOGl_JZ~2?R2pk9*Z=Vrr)p5}doWnYVi}orNRulc{*r;L+DQ>^$8v-|GJ+}T-&cuS0T?L@vA zSL7cmpzv0UCGzi`@haMD=;-r~e%F|BMSMMJyarR5Q^K;6E*?!+jp=MzDGjw2@k$rp zX(k-6qu%QCwY)DV6iCqt*s=XlmVPVt&eB>6vPY9&_v+j%u-;03L(=dv&e4Aue)Z_1 zU1tjn?w_ld+f~pt3_%+&m?k|-8M8KKJ}RhIT3r*YR$AxTXR)|qGmmNmb{(p?OH^|; zbI#^(wTfkU?vELn=6ex64?XWGDHg;1`jX(kzT*$mG_GS;w+(Hs^>bl(b{#>uRJZA# zHe#WXR^Lc3^eEz7DKfL*R@)8STPS<_?&NJ{inhg-$!&d`0d-8kB z`yF9Vj`cR=>DaW;#aKr_WpFnA$zVr83i5SHT=Q$p~RU;hxJ@$zVE@uLt9m(_csilkG4 z%ZapdqZCWl^yf7ym^ewn-FgDPv{(;jaf`-i=~>YkqG?iX(Z$z4*(QW-W~fHEgHMxR=+|&s75kO0;9U zI$b75Jr<58AcBPXC(0TnrzK63S`zm?|7D}lWC^f6p?5tNnjmrof_$afkMPK22_AuB zU`(4$2KBWV1yAyr4gIdAtjB)k)^_`+O1pz>tHT3sHeJln-lXZAy2U(dTc4;+N=ajf z*W<9o=B~408h2Umb90t%ViR+<@47EqoYcE+qfockMro(#Icf`G9(=B);BBnPXvhQ) z)QRi6o=WbbSg`~bNdtwwKiq%QGk3UyADV4zl4Q6LDEsj`Q|d&BEKu5(A2Jh8R%L*h z)=LT>x-;*buNHBQKWGk!rxmufWY>1p)t_u!V#O^!8A#>a=N1oSJKGl~OLKEFA51T; zmmIhg>Yotc-}?KA@^cUV5ysE4vi{kc^!QPD{kQ9qNT%Z<2m}BC_WTbiM|V)M10sQN zfCl&jet<8)0E~cb01w!TGJb#$%JNWT2*9XpN6~g50L63^!6*tqZMLI6Fgp677TbVr za1fqI#e)!t7|DQi4;6R776?Lv@PRNO0+j$Xof_Z-I0I~SbOty;Ae!a~`uYPQD8oT5 zqtOLLqPB-n?1SRKzapps4u3_7LT7%cH6KQ>`QN_fYpcK}K?oK>=~Jl0eBTeSsL}}~ z?Y?ybqWo6}Gt^-R%5Q@aj7kX>17a}J`f;BYwAu#~t;W5EdNV#ku*A0{;5r-t)96%K z7msy;u;Lo*kmEkl5o#7kc>L(_NPj;y_OZw?HI6?&GAJO3=M(8qhgB3L4Ct@{Y)nTV z!>#V<)Zn}6{~HbgHNIaEhy$<)CNWvf>2(*)F2fFv_JU#WNTf?N?pM#^r6F5%( zI>I>KZ3iury6=6h?LzJ8#y9GV<^>H@;qC{Af1?1wDO5}<<&%1bvUXdw4`n`d%|XS@ z^2U_&s+QC%^M$^f(NT$Y%kj6e<#p@cyl~mIYm1RdVVwDCZ^}pOh^XT=CbpR;#SoR6 zDqp}(K(fs0)L+Z65FzI^GY!tTkS~#%<^sugK2GxkoyH^+weqg*Gpu0!(KKwNe`2Uv zs2oCR{C&N$VfIv=a(M*rR1fiO3Aw4&B&a#1W`km=?S*?YOc@K;PSwXX=@iN#R>Rez z1Fz<@Yy=gx|2+Asv(NGE*lx=?zv|CpR3DFe3>E~yGAGypRz=T9SsPM>k2j_jQwH{r^1q7)lwk120>U1PVM_$ zN6P@+br48198+TR1H(8!iE?TEH+JC5s-v^1u{*qnO|;YZ=mpcNn1gENJsHdDibM8$ zd6Snm8Qy3cvw0CYGuH8CXTw@-nR9YP|MZ&K09VA^;4!~$C12>WZ!u1^pZ<@mabeKI zdE)JGVIK2o+emZcNM&28V8@Lw&Lu;GRxU2tg{H*7-4ou0J)|u8@v;qx!eP|IdE3u0au1f^R{0$-Z!-Q2~F+7eciN#PLxw80UPgG@2L_R$s*LRt!x97+DpDH0a_tc|Qag}o0 Uw8D?` support column encryption when encrypted data require enclave computations. + */ + return []byte{0x01} +} + +var _ featureExt = &featureExtColumnEncryption{} + type loginHeader struct { Length uint32 TDSVersion uint32 @@ -474,7 +506,7 @@ func ucs22str(s []byte) (string, error) { } func manglePassword(password string) []byte { - var ucs2password []byte = str2ucs2(password) + var ucs2password = str2ucs2(password) for i, ch := range ucs2password { ucs2password[i] = ((ch<<4)&0xff | (ch >> 4)) ^ 0xA5 } @@ -947,7 +979,7 @@ func interpretPreloginResponse(p connectParams, fe *featureExtFedAuth, fields ma // We need to be able to echo the value back to the server fe.FedAuthEcho = fedAuthSupport[0] != 0 } else if fe.FedAuthLibrary != fedAuthLibraryReserved { - return 0, fmt.Errorf("Federated authentication is not supported by the server") + return 0, fmt.Errorf("federated authentication is not supported by the server") } encryptBytes, ok := fields[preloginENCRYPTION] @@ -973,6 +1005,12 @@ func prepareLogin(ctx context.Context, c *Connector, p connectParams, log option AppName: p.appname, TypeFlags: p.typeFlags, } + + if p.columnEncryption { + // Support Always Encrypted + _ = l.FeatureExt.Add(&featureExtColumnEncryption{}) + } + switch { case fe.FedAuthLibrary == fedAuthLibrarySecurityToken: if p.logFlags&logDebug != 0 { @@ -987,14 +1025,14 @@ func prepareLogin(ctx context.Context, c *Connector, p connectParams, log option return nil, err } - l.FeatureExt.Add(fe) + _ = l.FeatureExt.Add(fe) case fe.FedAuthLibrary == fedAuthLibraryADAL: if p.logFlags&logDebug != 0 { log.Println("Starting federated authentication using ADAL") } - l.FeatureExt.Add(fe) + _ = l.FeatureExt.Add(fe) case auth != nil: if p.logFlags&logDebug != 0 { @@ -1203,6 +1241,21 @@ initiate_connection: case loginAckStruct: sess.loginAck = token loginAck = true + case featureExtAck: + for _, v := range token { + switch v:= v.(type) { + case colAckStruct: + if v.Version <= 2 && v.Version > 0 { + sess.alwaysEncrypted = true + sess.alwaysEncryptedSettings = &aeSettings{ + ksSecret: p.keyStoreSecret, + ksLocation: p.keyStoreLocation, + ksAuth: p.keyStoreAuthentication, + } + } + } + } + case doneStruct: if token.isError() { return nil, fmt.Errorf("login error: %s", token.getError()) diff --git a/token.go b/token.go index c9d45256..6650b8cf 100644 --- a/token.go +++ b/token.go @@ -1,12 +1,22 @@ package mssql import ( + "bytes" "context" + "crypto/rsa" + "crypto/sha1" "encoding/binary" "errors" "fmt" + alwaysencrypted "github.com/swisscom/mssql-always-encrypted/pkg" + "github.com/swisscom/mssql-always-encrypted/pkg/algorithms" + "github.com/swisscom/mssql-always-encrypted/pkg/encryption" + "github.com/swisscom/mssql-always-encrypted/pkg/keys" + "golang.org/x/crypto/pkcs12" + "golang.org/x/text/encoding/unicode" "io" "io/ioutil" + "os" "strconv" ) @@ -75,6 +85,10 @@ const ( fedAuthInfoSPN = 0x02 ) +const ( + cipherAlgCustom = 0x00 +) + // COLMETADATA flags // https://msdn.microsoft.com/en-us/library/dd357363.aspx const ( @@ -82,6 +96,9 @@ const ( // TODO implement more flags ) +// UTF-16 Decoder +var utf16Decoder = unicode.UTF16(unicode.LittleEndian, unicode.IgnoreBOM).NewDecoder() + // interface for all tokens type tokenStruct interface{} @@ -467,7 +484,7 @@ func parseFedAuthInfo(r *tdsBuffer) fedAuthInfoStruct { case fedAuthInfoSPN: SPN, err = ucs22str(optData) default: - err = fmt.Errorf("Unexpected fed auth info opt ID %d", int(opts[i].fedAuthInfoID)) + err = fmt.Errorf("unexpected fed auth info opt ID %d", int(opts[i].fedAuthInfoID)) } if err != nil { @@ -510,7 +527,13 @@ type fedAuthAckStruct struct { Signature []byte } -func parseFeatureExtAck(r *tdsBuffer) map[byte]interface{} { +type colAckStruct struct { + Version int +} + +type featureExtAck map[byte]interface{} + +func parseFeatureExtAck(r *tdsBuffer) featureExtAck { ack := map[byte]interface{}{} for feature := r.byte(); feature != featExtTERMINATOR; feature = r.byte() { @@ -532,7 +555,18 @@ func parseFeatureExtAck(r *tdsBuffer) map[byte]interface{} { length -= 32 } ack[feature] = fedAuthAck - + case featExtCOLUMNENCRYPTION: + colAck := colAckStruct{} + colAck.Version = int(r.byte()) + length-- + + if length > 0 { + enclaveLength := r.byte() + var enclaveType = make([]byte, enclaveLength) + r.ReadFull(enclaveType) + length -= uint32(enclaveLength) + } + ack[feature] = colAck } // Skip unprocessed bytes @@ -545,29 +579,302 @@ func parseFeatureExtAck(r *tdsBuffer) map[byte]interface{} { } // http://msdn.microsoft.com/en-us/library/dd357363.aspx -func parseColMetadata72(r *tdsBuffer) (columns []columnStruct) { +func parseColMetadata72(r *tdsBuffer, s *tdsSession) (columns []columnStruct) { count := r.uint16() if count == 0xffff { // no metadata is sent return nil } columns = make([]columnStruct, count) + + var cekTable *cekTable + if s.alwaysEncrypted { + // CEK table + cekTable = readCEKTable(r) + + if s.alwaysEncryptedSettings == nil { + panic("alwaysEncryptedSettings are nil!") + } + + if s.alwaysEncryptedSettings.pKey == nil { + // Load Keystore + f, err := os.Open(s.alwaysEncryptedSettings.ksLocation) + if err != nil { + panic(err) + } + + switch s.alwaysEncryptedSettings.ksAuth { + case PFXKeystoreAuth: + pfxBytes, err := ioutil.ReadAll(f) + if err != nil { + panic(err) + } + + pk, cert, err := pkcs12.Decode(pfxBytes, s.alwaysEncryptedSettings.ksSecret) + if err != nil { + panic(err) + } + + s.alwaysEncryptedSettings.pKey = pk + s.alwaysEncryptedSettings.cert = cert + default: + panic(fmt.Sprintf("ksAuth %v is unimplemented", s.alwaysEncryptedSettings.ksAuth)) + } + } + } + + dec := unicode.UTF16(unicode.LittleEndian, unicode.IgnoreBOM).NewDecoder() + for i := range columns { column := &columns[i] - column.UserType = r.uint32() - column.Flags = r.uint16() + baseTi := getBaseTypeInfo(r, true) + typeInfo := readTypeInfo(r, baseTi.TypeId) + typeInfo.UserType = baseTi.UserType + typeInfo.Flags = baseTi.Flags + typeInfo.TypeId = baseTi.TypeId + + // Table Name + if baseTi.TypeId == typeText || baseTi.TypeId == typeNText || baseTi.TypeId == typeImage { + _ = r.sqlIdentifier() + } + + column.Flags = baseTi.Flags + column.UserType = baseTi.UserType + column.ti = typeInfo + + if column.isEncrypted() && s.alwaysEncrypted { + // Read Crypto Metadata + cryptoMeta := parseCryptoMetadata(r, cekTable) + cryptoMeta.typeInfo.Flags = baseTi.Flags + column.cryptoMeta = cryptoMeta + } - // parsing TYPE_INFO structure - column.ti = readTypeInfo(r) - column.ColName = r.BVarChar() + colNameLen := r.byte() + colNameUtf16 := make([]byte, int(colNameLen)*2) + r.ReadFull(colNameUtf16) + colName, _ := dec.Bytes(colNameUtf16) + column.ColName = string(colName) } return columns } +func getBaseTypeInfo(r *tdsBuffer, parseFlags bool) typeInfo { + userType := r.uint32() + flags := uint16(0) + if parseFlags { + flags = r.uint16() + } + tId := r.byte() + + return typeInfo{ + UserType: userType, + Flags: flags, + TypeId: tId} +} + +type cryptoMetadata struct { + entry *cekTableEntry + ordinal uint16 + algorithmId byte + algorithmName *string + encType byte + normRuleVer byte + typeInfo typeInfo +} + +func parseCryptoMetadata(r *tdsBuffer, cekTable *cekTable) cryptoMetadata { + ordinal := uint16(0) + if cekTable != nil { + ordinal = r.uint16() + } + + typeInfo := getBaseTypeInfo(r, false) + ti := readTypeInfo(r, typeInfo.TypeId) + ti.UserType = typeInfo.UserType + ti.Flags = typeInfo.Flags + ti.TypeId = typeInfo.TypeId + + algorithmId := r.byte() + var algName *string = nil + + if algorithmId == cipherAlgCustom { + // Read the name when a custom algorithm is used + nameLen := int(r.byte()) + var algNameUtf16 = make([]byte, nameLen*2) + r.ReadFull(algNameUtf16) + algNameBytes, _ := unicode.UTF16(unicode.LittleEndian, unicode.IgnoreBOM).NewDecoder().Bytes(algNameUtf16) + mAlgName := string(algNameBytes) + algName = &mAlgName + } + + encType := r.byte() + normRuleVer := r.byte() + + var entry *cekTableEntry = nil + + if cekTable != nil { + if int(ordinal) > len(cekTable.entries)-1 { + panic(fmt.Errorf("invalid ordinal, cekTable only has %d entries", len(cekTable.entries))) + } + entry = &cekTable.entries[ordinal] + } + + return cryptoMetadata{ + entry: entry, + ordinal: ordinal, + algorithmId: algorithmId, + algorithmName: algName, + encType: encType, + normRuleVer: normRuleVer, + typeInfo: ti, + } +} + +func readCEKTable(r *tdsBuffer) *cekTable { + tableSize := r.uint16() + var cekTable *cekTable = nil + + if tableSize != 0 { + mCekTable := newCekTable(tableSize) + for i := uint16(0); i < tableSize; i++ { + mCekTable.entries[i] = readCekTableEntry(r) + } + cekTable = &mCekTable + } + + return cekTable +} + +func readCekTableEntry(r *tdsBuffer) cekTableEntry { + databaseId := r.int32() + cekID := r.int32() + cekVersion := r.int32() + var cekMdVersion = make([]byte, 8) + _, err := r.Read(cekMdVersion) + if err != nil { + panic("unable to read cekMdVersion") + } + + cekValueCount := uint(r.byte()) + enc := unicode.UTF16(unicode.LittleEndian, unicode.IgnoreBOM) + utf16dec := enc.NewDecoder() + cekValues := make([]encryptionKeyInfo, cekValueCount) + + for i := uint(0); i < cekValueCount; i++ { + encryptedCekLength := r.uint16() + encryptedCek := make([]byte, encryptedCekLength) + r.ReadFull(encryptedCek) + + keyStoreLength := r.byte() + keyStoreNameUtf16 := make([]byte, keyStoreLength*2) + r.ReadFull(keyStoreNameUtf16) + keyStoreName, _ := utf16dec.Bytes(keyStoreNameUtf16) + + keyPathLength := r.uint16() + keyPathUtf16 := make([]byte, keyPathLength*2) + r.ReadFull(keyPathUtf16) + keyPath, _ := utf16dec.Bytes(keyPathUtf16) + + algLength := r.byte() + algNameUtf16 := make([]byte, algLength*2) + r.ReadFull(algNameUtf16) + algName, _ := utf16dec.Bytes(algNameUtf16) + + cekValues[i] = encryptionKeyInfo{ + encryptedKey: encryptedCek, + databaseID: int(databaseId), + cekID: int(cekID), + cekVersion: int(cekVersion), + cekMdVersion: cekMdVersion, + keyPath: string(keyPath), + keyStoreName: string(keyStoreName), + algorithmName: string(algName), + } + } + + return cekTableEntry{ + databaseID: int(databaseId), + keyId: int(cekID), + keyVersion: int(cekVersion), + mdVersion: cekMdVersion, + valueCount: int(cekValueCount), + cekValues: cekValues, + } +} + +type RWCBuffer struct { + buffer *bytes.Reader +} + +func (R RWCBuffer) Read(p []byte) (n int, err error) { + return R.buffer.Read(p) +} + +func (R RWCBuffer) Write(p []byte) (n int, err error) { + return 0, nil +} + +func (R RWCBuffer) Close() error { + return nil +} + +var _ io.ReadWriteCloser = RWCBuffer{} + // http://msdn.microsoft.com/en-us/library/dd357254.aspx -func parseRow(r *tdsBuffer, columns []columnStruct, row []interface{}) { +func parseRow(r *tdsBuffer, s *tdsSession, columns []columnStruct, row []interface{}) { for i, column := range columns { - row[i] = column.ti.Reader(&column.ti, r) + columnContent := column.ti.Reader(&column.ti, r) + if column.isEncrypted() && s.alwaysEncrypted { + // Decrypt + cekValue := column.cryptoMeta.entry.cekValues[column.cryptoMeta.ordinal] + algVer := cekValue.cekVersion + encType := encryption.From(column.cryptoMeta.encType) + + // Get pKey + if s.alwaysEncryptedSettings.pKey == nil { + panic("alwaysEncrypted pKey not set: this should never happen") + } + + cekv := alwaysencrypted.LoadCEKV(column.cryptoMeta.entry.cekValues[0].encryptedKey) + if !cekv.Verify(s.alwaysEncryptedSettings.cert) { + panic(fmt.Errorf("invalid certificate being used to decrypt: %v requested but %v provided", + cekv.KeyPath, + fmt.Sprintf("%02x", sha1.Sum(s.alwaysEncryptedSettings.cert.Raw)), + )) + } + + // TODO: Support other private keys + rootKey, err := cekv.Decrypt(s.alwaysEncryptedSettings.pKey.(*rsa.PrivateKey)) + if err != nil { + panic(err) + } + + // Derive Root Key from encryptedKey + k := keys.NewAeadAes256CbcHmac256(rootKey) + alg := algorithms.NewAeadAes256CbcHmac256Algorithm(k, encType, byte(algVer)) + + d, err := alg.Decrypt(columnContent.([]byte)) + if err != nil { + panic(err) + } + + // Dirty workaround to keep compatibility with original types + // TODO: Improve me + var newBuff = make([]byte, 2) + binary.LittleEndian.PutUint16(newBuff, uint16(len(d))) + newBuff = append(newBuff, d...) + + rwc := RWCBuffer{ + buffer: bytes.NewReader(newBuff), + } + + column.cryptoMeta.typeInfo.Buffer = d + buffer := tdsBuffer{rpos: 0, rsize: len(newBuff), rbuf: newBuff, transport: rwc} + + row[i] = column.cryptoMeta.typeInfo.Reader(&column.cryptoMeta.typeInfo, &buffer) + } else { + row[i] = columnContent + } } } @@ -628,10 +935,10 @@ func parseReturnValue(r *tdsBuffer) (nv namedValue) { r.uint16() nv.Name = r.BVarChar() r.byte() - r.uint32() // UserType (uint16 prior to 7.2) - r.uint16() - ti := readTypeInfo(r) - nv.Value = ti.Reader(&ti, r) + + ti := getBaseTypeInfo(r, true) + ti2 := readTypeInfo(r, ti.TypeId) + nv.Value = ti2.Reader(&ti2, r) return } @@ -707,11 +1014,11 @@ func processSingleResponse(sess *tdsSession, ch chan tokenStruct, outs map[strin return } case tokenColMetadata: - columns = parseColMetadata72(sess.buf) + columns = parseColMetadata72(sess.buf, sess) ch <- columns case tokenRow: row := make([]interface{}, len(columns)) - parseRow(sess.buf, columns, row) + parseRow(sess.buf, sess, columns, row) ch <- row case tokenNbcRow: row := make([]interface{}, len(columns)) diff --git a/types.go b/types.go index cae19924..6cdc7be2 100644 --- a/types.go +++ b/types.go @@ -83,6 +83,8 @@ const _TVP_ROW_TOKEN = 0x01 // http://msdn.microsoft.com/en-us/library/dd358284.aspx type typeInfo struct { TypeId uint8 + UserType uint32 + Flags uint16 Size int Scale uint8 Prec uint8 @@ -113,9 +115,9 @@ type xmlInfo struct { XmlSchemaCollection string } -func readTypeInfo(r *tdsBuffer) (res typeInfo) { - res.TypeId = r.byte() - switch res.TypeId { +func readTypeInfo(r *tdsBuffer, typeId byte) (res typeInfo) { + res.TypeId = typeId + switch typeId { case typeNull, typeInt1, typeBit, typeInt2, typeInt4, typeDateTim4, typeFlt4, typeMoney, typeDateTime, typeFlt8, typeMoney4, typeInt8: // those are fixed length types