Skip to content

Commit

Permalink
Refactor for more correct error reporting on inclomplete tokens and b…
Browse files Browse the repository at this point in the history
…etter performance with multiple verification keys.
  • Loading branch information
pascaldekloe committed Dec 10, 2018
1 parent f2617b7 commit ded3fe6
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 169 deletions.
211 changes: 94 additions & 117 deletions check.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,146 +27,103 @@ var errPart = errors.New("jwt: missing base64 part")
// When the algorithm is not in ECDSAAlgs, then the error is ErrAlgUnk.
// See Valid to complete the verification.
func ECDSACheck(token []byte, key *ecdsa.PublicKey) (*Claims, error) {
firstDot, lastDot, buf, err := scan(token)
if err != nil {
return nil, err
}

var c Claims

// create signature
hash, err := c.parseHeader(ECDSAAlgs, token[:firstDot], buf)
if err != nil {
return nil, err
}
digest := hash.New()
digest.Write(token[:lastDot])

// verify signature
n, err := encoding.Decode(buf, token[lastDot+1:])
if err != nil {
return nil, errors.New("jwt: malformed signature: " + err.Error())
}
r := big.NewInt(0).SetBytes(buf[:n/2])
s := big.NewInt(0).SetBytes(buf[n/2 : n])
if !ecdsa.Verify(key, digest.Sum(buf[:0]), r, s) {
return nil, ErrSigMiss
}

return &c, c.parseClaims(token[firstDot+1:lastDot], buf)
return check(token, ECDSAAlgs, func(content, sig []byte, hash crypto.Hash) error {
r := big.NewInt(0).SetBytes(sig[:len(sig)/2])
s := big.NewInt(0).SetBytes(sig[len(sig)/2:])
digest := hash.New()
digest.Write(content)
if !ecdsa.Verify(key, digest.Sum(sig[:0]), r, s) {
return ErrSigMiss
}
return nil
})
}

// HMACCheck parses a JWT and returns the claims set if, and only if, the
// signature checks out. Note that this excludes unsecured JWTs [ErrUnsecured].
// When the algorithm is not in HMACAlgs, then the error is ErrAlgUnk.
// See Valid to complete the verification.
func HMACCheck(token, secret []byte) (*Claims, error) {
firstDot, lastDot, buf, err := scan(token)
if err != nil {
return nil, err
}

var c Claims

// create signature
hash, err := c.parseHeader(HMACAlgs, token[:firstDot], buf)
if err != nil {
return nil, err
}
digest := hmac.New(hash.New, secret)
digest.Write(token[:lastDot])

// verify signature
n, err := encoding.Decode(buf, token[lastDot+1:])
if err != nil {
return nil, errors.New("jwt: malformed signature: " + err.Error())
}
if !hmac.Equal(buf[:n], digest.Sum(buf[n:n])) {
return nil, ErrSigMiss
}

return &c, c.parseClaims(token[firstDot+1:lastDot], buf)
return check(token, HMACAlgs, func(content, sig []byte, hash crypto.Hash) error {
digest := hmac.New(hash.New, secret)
digest.Write(content)
if !hmac.Equal(sig, digest.Sum(sig[len(sig):])) {
return ErrSigMiss
}
return nil
})
}

// RSACheck parses a JWT and returns the claims set if, and only if, the
// signature checks out. Note that this excludes unsecured JWTs [ErrUnsecured].
// When the algorithm is not in RSAAlgs, then the error is ErrAlgUnk.
// See Valid to complete the verification.
func RSACheck(token []byte, key *rsa.PublicKey) (*Claims, error) {
firstDot, lastDot, buf, err := scan(token)
return check(token, RSAAlgs, func(content, sig []byte, hash crypto.Hash) error {
digest := hash.New()
digest.Write(content)
if err := rsa.VerifyPKCS1v15(key, hash, digest.Sum(sig[len(sig):]), sig); err != nil {
return ErrSigMiss
}
return nil
})
}

func check(token []byte, algs map[string]crypto.Hash, verifySig func(content, sig []byte, hash crypto.Hash) error) (*Claims, error) {
header, buf, err := parseHeader(token)
if err != nil {
return nil, err
}

var c Claims

// create signature
hash, err := c.parseHeader(RSAAlgs, token[:firstDot], buf)
hash, err := header.match(algs)
if err != nil {
return nil, err
}
digest := hash.New()
digest.Write(token[:lastDot])

// verify signature
n, err := encoding.Decode(buf, token[lastDot+1:])
claims, err := verifyAndParseClaims(token, buf, hash, verifySig)
if err != nil {
return nil, errors.New("jwt: malformed signature: " + err.Error())
}
if err := rsa.VerifyPKCS1v15(key, hash, digest.Sum(buf[n:n]), buf[:n]); err != nil {
return nil, ErrSigMiss
return nil, err
}

return &c, c.parseClaims(token[firstDot+1:lastDot], buf)
claims.KeyID = header.Kid
return claims, nil
}

// Scan detects the 3 base64 chunks and allocates matching buffer.
func scan(token []byte) (firstDot, lastDot int, buf []byte, err error) {
firstDot = bytes.IndexByte(token, '.')
lastDot = bytes.LastIndexByte(token, '.')
if lastDot <= firstDot {
// zero or one dot
return 0, 0, nil, errPart
}

// buffer must fit largest base64 chunk
// start with signature
max := len(token) - lastDot
// compare with payload
if l := lastDot - firstDot; l > max {
max = l
}
// compare with header
if firstDot > max {
max = firstDot
}
buf = make([]byte, encoding.DecodedLen(max))
return
// Header is a critical subset of the registered “JOSE Header Parameter Names”.
type header struct {
Alg string // algorithm
Kid string // key identifier
Crit []string // extensions which must be understood and processed
}

// ParseHeader decodes the enc(oded) “JOSE Header” and validates the applicability.
func (c *Claims) parseHeader(algs map[string]crypto.Hash, enc, buf []byte) (crypto.Hash, error) {
// parse critical subset of the registered “JOSE Header Parameter Names”
var header struct {
Alg string // algorithm
Kid string // key identifier
Crit []string // extensions which must be understood and processed.
// ParseHeader decodes the “JOSE Header” and allocates a matching buffer.
func parseHeader(token []byte) (h *header, buf []byte, err error) {
buf = make([]byte, encoding.DecodedLen(len(token)))

end := bytes.IndexByte(token, '.')
if end < 0 {
end = len(token)
}
n, err := encoding.Decode(buf, enc)
n, err := encoding.Decode(buf, token[:end])
if err != nil {
return 0, errors.New("jwt: malformed header: " + err.Error())
return nil, nil, errors.New("jwt: malformed header: " + err.Error())
}
if err := json.Unmarshal(buf[:n], &header); err != nil {
return 0, errors.New("jwt: malformed header: " + err.Error())

h = new(header)
if err := json.Unmarshal(buf[:n], h); err != nil {
return nil, nil, errors.New("jwt: malformed header: " + err.Error())
}
return
}

func (h *header) match(algs map[string]crypto.Hash) (crypto.Hash, error) {
// why would anyone do this?
if header.Alg == "none" {
if h.Alg == "none" {
return 0, ErrUnsecured
}

// availability check
hash, ok := algs[header.Alg]
hash, ok := algs[h.Alg]
if !ok {
return 0, ErrAlgUnk
}
Expand All @@ -177,33 +134,55 @@ func (c *Claims) parseHeader(algs map[string]crypto.Hash, enc, buf []byte) (cryp
// “If any of the listed extension Header Parameters are not understood
// and supported by the recipient, then the JWS is invalid.”
// — “JSON Web Signature (JWS)” RFC 7515, subsection 4.1.11
if len(header.Crit) != 0 {
return 0, fmt.Errorf("jwt: unsupported critical extension in JOSE header: %q", header.Crit)
if len(h.Crit) != 0 {
return 0, fmt.Errorf("jwt: unsupported critical extension in JOSE header: %q", h.Crit)
}

c.KeyID = header.Kid

return hash, nil
}

// ParseClaims unmarshals the payload from enc.
// Buf remains in use (by the Raw field)!
func (c *Claims) parseClaims(enc, buf []byte) error {
func verifyAndParseClaims(token, buf []byte, hash crypto.Hash, verifySig func(content, sig []byte, hash crypto.Hash) error) (*Claims, error) {
firstDot := bytes.IndexByte(token, '.')
lastDot := bytes.LastIndexByte(token, '.')
if lastDot <= firstDot {
// zero or one dot
return nil, errPart
}

// verify signature
n, err := encoding.Decode(buf, token[lastDot+1:])
if err != nil {
return nil, errors.New("jwt: malformed signature: " + err.Error())
}
err = verifySig(token[:lastDot], buf[:n], hash)
if err != nil {
return nil, err
}

// decode payload
n, err := encoding.Decode(buf, enc)
n, err = encoding.Decode(buf, token[firstDot+1:lastDot])
if err != nil {
return errors.New("jwt: malformed payload: " + err.Error())
return nil, errors.New("jwt: malformed payload: " + err.Error())
}
buf = buf[:n]
c.Raw = json.RawMessage(buf)

m := make(map[string]interface{})
c.Set = m
if err = json.Unmarshal(buf, &m); err != nil {
return errors.New("jwt: malformed payload: " + err.Error())
// construct result
c := &Claims{
Raw: json.RawMessage(buf),
Set: make(map[string]interface{}),
}
if err = json.Unmarshal(buf, &c.Set); err != nil {
return nil, errors.New("jwt: malformed payload: " + err.Error())
}

c.extractRegistered()
return c, nil
}

// map registered claims on type match
// move from Set to Registered on type match
func (c *Claims) extractRegistered() {
m := c.Set
if s, ok := m[issuer].(string); ok {
delete(m, issuer)
c.Issuer = s
Expand Down Expand Up @@ -255,6 +234,4 @@ func (c *Claims) parseClaims(enc, buf []byte) error {
delete(m, id)
c.ID = s
}

return nil
}
23 changes: 10 additions & 13 deletions check_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,13 +132,6 @@ func TestCheckMiss(t *testing.T) {
}
}

func TestErrUnsecured(t *testing.T) {
_, err := HMACCheck([]byte("eyJhbGciOiJub25lIn0.e30."), nil)
if err != ErrUnsecured {
t.Errorf("got error %v, want %v", err, ErrUnsecured)
}
}

func TestCheckAlgWrong(t *testing.T) {
_, err := ECDSACheck([]byte(goldenRSAs[0].token), nil)
if err != ErrAlgUnk {
Expand Down Expand Up @@ -180,17 +173,21 @@ func TestCheckIncomplete(t *testing.T) {
// header only
_, err := ECDSACheck([]byte("eyJhbGciOiJFUzI1NiJ9"), &testKeyEC256.PublicKey)
if err != errPart {
t.Errorf("one base64 chunk got error %v, want %v", err, errPart)
t.Errorf("header only got error %v, want %v", err, errPart)
}
_, err = RSACheck([]byte("eyJhbGciOiJub25lIn0"), &testKeyRSA1024.PublicKey)
if err != errPart {
t.Errorf("one base64 chunk got error %v, want %v", err, errPart)
if err != ErrUnsecured {
t.Errorf("unsecured header only got error %v, want %v", err, errPart)
}

// header + body; missing signature
_, err = HMACCheck([]byte("eyJhbGciOiJub25lIn0.e30"), nil)
// header + claims; no signature
_, err = ECDSACheck([]byte("eyJhbGciOiJFUzI1NiJ9.e30"), &testKeyEC384.PublicKey)
if err != errPart {
t.Errorf("two base64 chunks got error %v, want %v", err, errPart)
t.Errorf("missing signature got error %v, want %v", err, errPart)
}
_, err = HMACCheck([]byte("eyJhbGciOiJub25lIn0.e30"), nil)
if err != ErrUnsecured {
t.Errorf("unsecured got error %v, want %v", err, errPart)
}
}

Expand Down
Loading

0 comments on commit ded3fe6

Please sign in to comment.