Skip to content

Commit

Permalink
Rename pool misnomer.
Browse files Browse the repository at this point in the history
  • Loading branch information
pascaldekloe committed Sep 28, 2018
1 parent dfd9a30 commit 244172a
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 58 deletions.
6 changes: 3 additions & 3 deletions examples_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ func ExampleHandler_deny() {
}

// PEM with password protection.
func ExampleKeyPool_LoadPEM_encrypted() {
func ExampleKeyRegister_LoadPEM_encrypted() {
const key = `-----BEGIN RSA PRIVATE KEY-----
Proc-Type: 4,ENCRYPTED
DEK-Info: AES-128-CBC,65789712555A3E9FECD1D5E235B97B0C
Expand All @@ -159,8 +159,8 @@ xzvC4Vm1r/Oa4TTUbf5tVto7ua/lZvwnu5DIWn2zy5ZUPrtn22r1ymVui7Iuhl0b
SRcADdHh3NgrjDjalhLDB95ho5omG39l7qBKBTlBAYJhDuAk9rIk1FCfCB8upztt
-----END RSA PRIVATE KEY-----`

var p jwt.KeyPool
n, err := p.LoadPEM([]byte(key), []byte("dangerzone"))
var r jwt.KeyRegister
n, err := r.LoadPEM([]byte(key), []byte("dangerzone"))
if err != nil {
fmt.Println("load error:", err)
}
Expand Down
28 changes: 14 additions & 14 deletions pool.go → register.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,21 @@ import (
"fmt"
)

// KeyPool contains a set of recognized credentials.
type KeyPool struct {
// KeyRegister contains a set of recognized credentials.
type KeyRegister struct {
ECDSAs []*ecdsa.PublicKey // ECDSA credentials
RSAs []*rsa.PublicKey // RSA credentials
Secrets [][]byte // HMAC credentials
}

// Check parses a JWT and returns the claims set if, and only if, the signature
// checks out. Note that this excludes unsecured JWTs [ErrUnsecured].
// See Valid to complete the verification.
func (p *KeyPool) Check(token []byte) (*Claims, error) {
// See Claims.Valid to complete the verification.
func (r *KeyRegister) Check(token []byte) (*Claims, error) {
err := ErrAlgUnk
var c *Claims

for _, secret := range p.Secrets {
for _, secret := range r.Secrets {
c, err = HMACCheck(token, secret)
if err == nil {
return c, nil
Expand All @@ -39,7 +39,7 @@ func (p *KeyPool) Check(token []byte) (*Claims, error) {
return nil, err
}

for _, key := range p.RSAs {
for _, key := range r.RSAs {
c, err = RSACheck(token, key)
if err == nil {
return c, nil
Expand All @@ -55,7 +55,7 @@ func (p *KeyPool) Check(token []byte) (*Claims, error) {
return nil, err
}

for _, key := range p.ECDSAs {
for _, key := range r.ECDSAs {
c, err = ECDSACheck(token, key)
if err == nil {
return c, nil
Expand All @@ -72,9 +72,9 @@ func (p *KeyPool) Check(token []byte) (*Claims, error) {

var errUnencryptedPEM = errors.New("jwt: unencrypted PEM rejected due password expectation")

// LoadPEM adds the keys from PEM-encoded data to the pool and returns the
// count. PEM encryption is enforced for non-empty password values.
func (p *KeyPool) LoadPEM(data, password []byte) (n int, err error) {
// LoadPEM adds keys from PEM-encoded data and returns the count.
// PEM encryption is enforced for non-empty password values.
func (r *KeyRegister) LoadPEM(data, password []byte) (n int, err error) {
for {
block, remainder := pem.Decode(data)
if block == nil {
Expand All @@ -99,9 +99,9 @@ func (p *KeyPool) LoadPEM(data, password []byte) (n int, err error) {
}
switch t := key.(type) {
case *ecdsa.PublicKey:
p.ECDSAs = append(p.ECDSAs, t)
r.ECDSAs = append(r.ECDSAs, t)
case *rsa.PublicKey:
p.RSAs = append(p.RSAs, t)
r.RSAs = append(r.RSAs, t)
default:
return n, fmt.Errorf("jwt: unsupported key type %T", t)
}
Expand All @@ -111,14 +111,14 @@ func (p *KeyPool) LoadPEM(data, password []byte) (n int, err error) {
if err != nil {
return n, err
}
p.ECDSAs = append(p.ECDSAs, &key.PublicKey)
r.ECDSAs = append(r.ECDSAs, &key.PublicKey)

case "RSA PRIVATE KEY":
key, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return n, err
}
p.RSAs = append(p.RSAs, &key.PublicKey)
r.RSAs = append(r.RSAs, &key.PublicKey)

default:
return n, fmt.Errorf("jwt: unknown PEM type %q", block.Type)
Expand Down
62 changes: 31 additions & 31 deletions pool_test.go → register_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
)

// Tests the golden cases.
func TestKeyPool(t *testing.T) {
func TestKeyRegister(t *testing.T) {
const fatPEM = `All samples from test_keys.go combined here:
-----BEGIN EC PRIVATE KEY-----
Expand Down Expand Up @@ -130,8 +130,8 @@ EeRpjDtIq46JS/EMcvoetl0Ch8l2tGLC1fpOD4kQsd9TSaTMO3MSy/5WIGg=
`

var p KeyPool
n, err := p.LoadPEM([]byte(fatPEM), nil)
var r KeyRegister
n, err := r.LoadPEM([]byte(fatPEM), nil)
if err != nil {
t.Fatal(err)
}
Expand All @@ -141,11 +141,11 @@ EeRpjDtIq46JS/EMcvoetl0Ch8l2tGLC1fpOD4kQsd9TSaTMO3MSy/5WIGg=

// add the HMAC keys
for _, gold := range goldenHMACs {
p.Secrets = append(p.Secrets, gold.secret)
r.Secrets = append(r.Secrets, gold.secret)
}

for i, gold := range goldenHMACs {
claims, err := p.Check([]byte(gold.token))
claims, err := r.Check([]byte(gold.token))
if err != nil {
t.Errorf("HMAC %d: check error: %s", i, err)
continue
Expand All @@ -156,7 +156,7 @@ EeRpjDtIq46JS/EMcvoetl0Ch8l2tGLC1fpOD4kQsd9TSaTMO3MSy/5WIGg=
}

for i, gold := range goldenECDSAs {
claims, err := p.Check([]byte(gold.token))
claims, err := r.Check([]byte(gold.token))
if err != nil {
t.Errorf("ECDSA %d: check error: %s", i, err)
continue
Expand All @@ -167,7 +167,7 @@ EeRpjDtIq46JS/EMcvoetl0Ch8l2tGLC1fpOD4kQsd9TSaTMO3MSy/5WIGg=
}

for i, gold := range goldenRSAs {
claims, err := p.Check([]byte(gold.token))
claims, err := r.Check([]byte(gold.token))
if err != nil {
t.Errorf("RSA %d: check error: %s", i, err)
continue
Expand All @@ -179,7 +179,7 @@ EeRpjDtIq46JS/EMcvoetl0Ch8l2tGLC1fpOD4kQsd9TSaTMO3MSy/5WIGg=
}

// Includes unsupported key.
func TestKeyPoolLoadPublicKeys(t *testing.T) {
func TestKeyRegisterLoadPublicKeys(t *testing.T) {
const keys = `Tree Public Keys
RSA:
-----BEGIN PUBLIC KEY-----
Expand Down Expand Up @@ -210,24 +210,24 @@ qsa4IOtmJV3zuw==
-----END PUBLIC KEY-----
`

var p KeyPool
n, err := p.LoadPEM([]byte(keys), nil)
var r KeyRegister
n, err := r.LoadPEM([]byte(keys), nil)
if n != 2 {
t.Errorf("loaded %d keys, want 2", n)
}
if want := "jwt: unsupported key type *dsa.PublicKey"; err == nil || err.Error() != want {
t.Errorf("got error %q, want %q", err, want)
}
if len(p.ECDSAs) != 1 {
t.Errorf("got %d ECDSA keys, want 1", len(p.ECDSAs))
if len(r.ECDSAs) != 1 {
t.Errorf("got %d ECDSA keys, want 1", len(r.ECDSAs))
}
if len(p.RSAs) != 1 {
t.Errorf("got %d RSA keys, want 1", len(p.RSAs))
if len(r.RSAs) != 1 {
t.Errorf("got %d RSA keys, want 1", len(r.RSAs))
}
}

func TestKeyPoolLoadUnkownType(t *testing.T) {
n, err := new(KeyPool).LoadPEM([]byte(`
func TestKeyRegisterLoadUnkownType(t *testing.T) {
n, err := new(KeyRegister).LoadPEM([]byte(`
-----BEGIN SPECIAL KEY-----
BLACKTi000000000000000000000000000000000000000000000000000000000
-----END SPECIAL KEY-----
Expand All @@ -240,8 +240,8 @@ BLACKTi000000000000000000000000000000000000000000000000000000000
}
}

func TestKeyPoolLoadPassNotNeeded(t *testing.T) {
n, err := new(KeyPool).LoadPEM([]byte(`
func TestKeyRegisterLoadPassNotNeeded(t *testing.T) {
n, err := new(KeyRegister).LoadPEM([]byte(`
-----BEGIN PUBLIC KEY-----
MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEX0iTLAcGqlWeGIRtIk0G2PRgpf/6
gLxOTyMAdriP4NLRkuu+9Idty3qmEizRC0N81j84E213/LuqLqnsrgfyiw==
Expand All @@ -254,7 +254,7 @@ gLxOTyMAdriP4NLRkuu+9Idty3qmEizRC0N81j84E213/LuqLqnsrgfyiw==
}
}

func TestKeyPoolLoadPassMiss(t *testing.T) {
func TestKeyRegisterLoadPassMiss(t *testing.T) {
const encryptedPEM = `-----BEGIN RSA PRIVATE KEY-----
Proc-Type: 4,ENCRYPTED
DEK-Info: AES-128-CBC,65789712555A3E9FECD1D5E235B97B0C
Expand All @@ -274,7 +274,7 @@ xzvC4Vm1r/Oa4TTUbf5tVto7ua/lZvwnu5DIWn2zy5ZUPrtn22r1ymVui7Iuhl0b
SRcADdHh3NgrjDjalhLDB95ho5omG39l7qBKBTlBAYJhDuAk9rIk1FCfCB8upztt
-----END RSA PRIVATE KEY-----`

n, err := new(KeyPool).LoadPEM([]byte(encryptedPEM), nil)
n, err := new(KeyRegister).LoadPEM([]byte(encryptedPEM), nil)
if n != 0 {
t.Errorf("loaded %d keys, want 0", n)
}
Expand All @@ -283,7 +283,7 @@ SRcADdHh3NgrjDjalhLDB95ho5omG39l7qBKBTlBAYJhDuAk9rIk1FCfCB8upztt
}
}

func TestKeyPoolLoadBroken(t *testing.T) {
func TestKeyRegisterLoadBroken(t *testing.T) {
pems := []string{`
-----BEGIN EC PRIVATE KEY-----
SRcADdHh3NgrjDjalhLDB95ho5omG39l7qBKBTlBAYJhDuAk9rIk1FCfCB8upztt
Expand All @@ -297,14 +297,14 @@ SRcADdHh3NgrjDjalhLDB95ho5omG39l7qBKBTlBAYJhDuAk9rIk1FCfCB8upztt
`}

for _, pem := range pems {
n, err := new(KeyPool).LoadPEM([]byte(pem), nil)
n, err := new(KeyRegister).LoadPEM([]byte(pem), nil)
if n != 0 || err == nil {
t.Errorf("loaded %d keys with error %v", n, err)
}
}
}

func TestKeyPoolCheckMiss(t *testing.T) {
func TestKeyRegisterCheckMiss(t *testing.T) {
const pem = `Unrelated Keys
ECDSA:
-----BEGIN EC PRIVATE KEY-----
Expand All @@ -322,39 +322,39 @@ P9j/1Whc92wzd4Osod3U6Tw9g+C1LuHuHOoLJhj5nUQQcP8UQk6jzKPwr4L4uKAc
-----END PUBLIC KEY-----
`

var p KeyPool
n, err := p.LoadPEM([]byte(pem), nil)
var r KeyRegister
n, err := r.LoadPEM([]byte(pem), nil)
if err != nil {
t.Fatal(err)
}
if n != 2 {
t.Fatalf("got %d keys, want 2", n)
}

p.Secrets = append(p.Secrets, []byte{1, 2})
r.Secrets = append(r.Secrets, []byte{1, 2})

// check unsupported algorithm
const encryptedToken = "eyJhbGciOiJSU0ExXzUiLCJlbmMiOiJBMTI4Q0JDLUhTMjU2In0.QR1Owv2ug2WyPBnbQrRARTeEk9kDO2w8qDcjiHnSJflSdv1iNqhWXaKH4MqAkQtMoNfABIPJaZm0HaA415sv3aeuBWnD8J-Ui7Ah6cWafs3ZwwFKDFUUsWHSK-IPKxLGTkND09XyjORj_CHAgOPJ-Sd8ONQRnJvWn_hXV1BNMHzUjPyYwEsRhDhzjAD26imasOTsgruobpYGoQcXUwFDn7moXPRfDE8-NoQX7N7ZYMmpUDkR-Cx9obNGwJQ3nM52YCitxoQVPzjbl7WBuB7AohdBoZOdZ24WlN1lVIeh8v1K4krB8xgKvRU8kgFrEn_a1rZgN5TiysnmzTROF869lQ.AxY8DCtDaGlsbGljb3RoZQ.MKOle7UQrG6nSxTLX6Mqwt0orbHvAKeWnDYvpIAeZ72deHxz3roJDXQyhxx0wKaMHDjUEOKIwrtkHthpqEanSBNYHZgmNOV7sln1Eu9g3J8.fiK51VwhsxJ-siBMR-YFiA"
_, err = p.Check([]byte(encryptedToken))
_, err = r.Check([]byte(encryptedToken))
if err != ErrAlgUnk {
t.Errorf("encrypted token got error %q, want %q", err, ErrAlgUnk)
}

// check golden cases
for i, gold := range goldenHMACs {
_, err := p.Check([]byte(gold.token))
_, err := r.Check([]byte(gold.token))
if err != ErrSigMiss {
t.Errorf("HMAC %d: got error %q, want %q", i, err, ErrSigMiss)
}
}
for i, gold := range goldenECDSAs {
_, err := p.Check([]byte(gold.token))
_, err := r.Check([]byte(gold.token))
if err != ErrSigMiss {
t.Errorf("ECDSA %d: got error %q, want %q", i, err, ErrSigMiss)
}
}
for i, gold := range goldenRSAs {
_, err := p.Check([]byte(gold.token))
_, err := r.Check([]byte(gold.token))
if err != ErrSigMiss {
t.Errorf("RSA %d: got error %q, want %q", i, err, ErrSigMiss)
}
Expand All @@ -368,7 +368,7 @@ P9j/1Whc92wzd4Osod3U6Tw9g+C1LuHuHOoLJhj5nUQQcP8UQk6jzKPwr4L4uKAc
defer delete(ECDSAAlgs, "EM4")
defer delete(RSAAlgs, "RM4")
for _, header := range []string{"eyJhbGciOiJFTTQifQ", "eyJhbGciOiJITTQifQ", "eyJhbGciOiJSTTQifQ"} {
_, err := p.Check([]byte(header + ".e30."))
_, err := r.Check([]byte(header + ".e30."))
if err != errHashLink {
t.Errorf("header %s got error %q, want %q", header, err, errHashLink)
}
Expand Down
14 changes: 7 additions & 7 deletions web.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,14 @@ func RSACheckHeader(r *http.Request, key *rsa.PublicKey) (*Claims, error) {
return RSACheck(token, key)
}

// CheckHeader applies KeyPool.Check on a HTTP request.
// CheckHeader applies KeyRegister.Check on a HTTP request.
// Specifically it looks for a bearer token in the Authorization header.
func (p *KeyPool) CheckHeader(r *http.Request) (*Claims, error) {
func (reg *KeyRegister) CheckHeader(r *http.Request) (*Claims, error) {
token, err := tokenFromHeader(r)
if err != nil {
return nil, err
}
return p.Check(token)
return reg.Check(token)
}

func tokenFromHeader(r *http.Request) ([]byte, error) {
Expand Down Expand Up @@ -117,8 +117,8 @@ type Handler struct {
ECDSAKey *ecdsa.PublicKey
// RSAKey applies RSAAlgs and disables HMACAlgs when set.
RSAKey *rsa.PublicKey
// KeyPool disables Secret, ECDSAKey and RSAKey when set.
KeyPool *KeyPool
// KeyRegister disables Secret, ECDSAKey and RSAKey when set.
KeyRegister *KeyRegister

// HeaderBinding maps JWT claim names to HTTP header names.
// All requests passed to Target have these headers set. In
Expand All @@ -139,8 +139,8 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// verify claims
var claims *Claims
var err error
if h.KeyPool != nil {
claims, err = h.KeyPool.CheckHeader(r)
if h.KeyRegister != nil {
claims, err = h.KeyRegister.CheckHeader(r)
} else if h.ECDSAKey == nil && h.RSAKey == nil {
claims, err = HMACCheckHeader(r, h.Secret)
} else {
Expand Down
6 changes: 3 additions & 3 deletions web_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ func TestCheckHeaderPresent(t *testing.T) {
if err != ErrNoHeader {
t.Errorf("RSA check got %v, want %v", err, ErrNoHeader)
}
_, err = new(KeyPool).CheckHeader(req)
_, err = new(KeyRegister).CheckHeader(req)
if err != ErrNoHeader {
t.Errorf("KeyPool check got %v, want %v", err, ErrNoHeader)
t.Errorf("KeyRegister check got %v, want %v", err, ErrNoHeader)
}
}

Expand Down Expand Up @@ -79,7 +79,7 @@ func testUnauthorized(t *testing.T, reqHeader string) (body, header string) {
Target: http.HandlerFunc(func(http.ResponseWriter, *http.Request) {
t.Error("handler called")
}),
KeyPool: &KeyPool{
KeyRegister: &KeyRegister{
ECDSAs: []*ecdsa.PublicKey{&testKeyEC256.PublicKey},
},
HeaderBinding: map[string]string{
Expand Down

0 comments on commit 244172a

Please sign in to comment.