Skip to content

Commit

Permalink
Merge pull request #50 from MicahParks/response_extractor
Browse files Browse the repository at this point in the history
Add response extractor
  • Loading branch information
MicahParks authored Sep 26, 2022
2 parents 57ba545 + 03d3e04 commit f20aea8
Show file tree
Hide file tree
Showing 12 changed files with 103 additions and 13 deletions.
16 changes: 10 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ jwksURL := os.Getenv("JWKS_URL")

// Confirm the environment variable is not empty.
if jwksURL == "" {
log.Fatalln("JWKS_URL environment variable must be populated.")
log.Fatalln("JWKS_URL environment variable must be populated.")
}
```

Expand All @@ -79,9 +79,9 @@ if jwksURL == "" {
Via HTTP:
```go
// Create the JWKS from the resource at the given URL.
jwks, err := keyfunc.Get(jwksURL, keyfunc.Options{})
jwks, err := keyfunc.Get(jwksURL, keyfunc.Options{}) // See recommended options in the examples directory.
if err != nil {
log.Fatalf("Failed to get the JWKS from the given URL.\nError: %s", err)
log.Fatalf("Failed to get the JWKS from the given URL.\nError: %s", err)
}
```
Via JSON:
Expand All @@ -92,7 +92,7 @@ var jwksJSON = json.RawMessage(`{"keys":[{"kid":"zXew0UJ1h6Q4CCcd_9wxMzvcp5cEBif
// Create the JWKS from the resource at the given URL.
jwks, err := keyfunc.NewJSON(jwksJSON)
if err != nil {
log.Fatalf("Failed to create JWKS from JSON.\nError: %s", err)
log.Fatalf("Failed to create JWKS from JSON.\nError: %s", err)
}
```
Via a given key:
Expand All @@ -103,7 +103,7 @@ uniqueKeyID := "myKeyID"

// Create the JWKS from the HMAC key.
jwks := keyfunc.NewGiven(map[string]keyfunc.GivenKey{
uniqueKeyID: keyfunc.NewGivenHMAC(key),
uniqueKeyID: keyfunc.NewGivenHMAC(key),
})
```

Expand All @@ -117,7 +117,7 @@ features mentioned at the bottom of this `README.md`.
// Parse the JWT.
token, err := jwt.Parse(jwtB64, jwks.Keyfunc)
if err != nil {
return nil, fmt.Errorf("failed to parse token: %w", err)
return nil, fmt.Errorf("failed to parse token: %w", err)
}
```

Expand Down Expand Up @@ -151,6 +151,10 @@ These features can be configured by populating fields in the
* A custom HTTP client can be used.
* A custom HTTP request factory can be provided to create HTTP requests for the remote JWKS resource. For example, an
HTTP header can be added to indicate a User-Agent.
* A custom HTTP response extractor can be provided to get the raw JWKS JSON from the `*http.Response`. For example, the
HTTP response code could be checked. Implementations are responsible for closing the response body. By default, the
response body is read and closed, the status code is ignored. The default behavior is likely to be changed soon.
See https://github.com/MicahParks/keyfunc/issues/48.
* A map of JWT key IDs (`kid`) to keys can be given and used for the `jwt.Keyfunc`. For an example, see
the `examples/given` directory.
* A copy of the latest raw JWKS `[]byte` can be returned.
Expand Down
1 change: 1 addition & 0 deletions examples/aws_cognito/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ func main() {
RefreshRateLimit: time.Minute * 5,
RefreshTimeout: time.Second * 10,
RefreshUnknownKID: true,
ResponseExtractor: keyfunc.ResponseExtractorStatusOK,
}

// Create the JWKS from the resource at the given URL.
Expand Down
1 change: 1 addition & 0 deletions examples/ctx/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ func main() {
RefreshErrorHandler: func(err error) {
log.Printf("There was an error with the jwt.Keyfunc\nError: %s", err.Error())
},
ResponseExtractor: keyfunc.ResponseExtractorStatusOK,
}

// Create the JWKS from the resource at the given URL.
Expand Down
1 change: 1 addition & 0 deletions examples/given/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ func main() {
RefreshRateLimit: time.Minute * 5,
RefreshTimeout: time.Second * 10,
RefreshUnknownKID: true,
ResponseExtractor: keyfunc.ResponseExtractorStatusOK,
}

// Create the JWKS from the resource at the given URL.
Expand Down
1 change: 1 addition & 0 deletions examples/interval/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ func main() {
RefreshErrorHandler: func(err error) {
log.Printf("There was an error with the jwt.Keyfunc\nError: %s", err.Error())
},
ResponseExtractor: keyfunc.ResponseExtractorStatusOK,
}

// Create the JWKS from the resource at the given URL.
Expand Down
1 change: 1 addition & 0 deletions examples/keycloak/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ func main() {
RefreshRateLimit: time.Minute * 5,
RefreshTimeout: time.Second * 10,
RefreshUnknownKID: true,
ResponseExtractor: keyfunc.ResponseExtractorStatusOK,
}

// Create the JWKS from the resource at the given URL.
Expand Down
1 change: 1 addition & 0 deletions examples/recommended_options/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ func main() {
RefreshRateLimit: time.Minute * 5,
RefreshTimeout: time.Second * 10,
RefreshUnknownKID: true,
ResponseExtractor: keyfunc.ResponseExtractorStatusOK,
}

// Create the JWKS from the resource at the given URL.
Expand Down
20 changes: 15 additions & 5 deletions get.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package keyfunc
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"sync"
Expand All @@ -29,6 +31,16 @@ func Get(jwksURL string, options Options) (jwks *JWKS, err error) {
if jwks.requestFactory == nil {
jwks.requestFactory = defaultRequestFactory
}
if jwks.responseExtractor == nil {
jwks.responseExtractor = func(ctx context.Context, resp *http.Response) (json.RawMessage, error) {
// This behavior is likely going to change in favor of checking the response code.
// See https://github.com/MicahParks/keyfunc/issues/48

//goland:noinspection GoUnhandledErrorResult
defer resp.Body.Close()
return io.ReadAll(resp.Body)
}
}
if jwks.refreshTimeout == 0 {
jwks.refreshTimeout = defaultRefreshTimeout
}
Expand Down Expand Up @@ -141,19 +153,17 @@ func (j *JWKS) refresh() (err error) {

req, err := j.requestFactory(ctx, j.jwksURL)
if err != nil {
return err
return fmt.Errorf("failed to create request via factory function: %w", err)
}

resp, err := j.client.Do(req)
if err != nil {
return err
}
//goland:noinspection GoUnhandledErrorResult
defer resp.Body.Close()

jwksBytes, err := io.ReadAll(resp.Body)
jwksBytes, err := j.responseExtractor(ctx, resp)
if err != nil {
return err
return fmt.Errorf("failed to extract response via extractor function: %w", err)
}

// Only reprocess if the JWKS has changed.
Expand Down
1 change: 1 addition & 0 deletions jwks.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ type JWKS struct {
refreshTimeout time.Duration
refreshUnknownKID bool
requestFactory func(ctx context.Context, url string) (*http.Request, error)
responseExtractor func(ctx context.Context, resp *http.Response) (json.RawMessage, error)
}

// rawJWKS represents a JWKS in JSON format.
Expand Down
4 changes: 2 additions & 2 deletions jwks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ func TestRawJWKS(t *testing.T) {
}

raw := jwks.RawJWKS()
if bytes.Compare(raw, []byte(jwksJSON)) != 0 {
if !bytes.Equal(raw, []byte(jwksJSON)) {
t.Fatalf("Raw JWKS does not match remote JWKS resource.")
}

Expand All @@ -367,7 +367,7 @@ func TestRawJWKS(t *testing.T) {
copy(raw, emptySlice)

nextRaw := jwks.RawJWKS()
if bytes.Compare(nextRaw, emptySlice) == 0 {
if bytes.Equal(nextRaw, emptySlice) {
t.Fatalf("Raw JWKS is not a copy.")
}
}
Expand Down
24 changes: 24 additions & 0 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,17 @@ package keyfunc

import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"time"
)

// ErrInvalidHTTPStatusCode indicates that the HTTP status code is invalid.
var ErrInvalidHTTPStatusCode = errors.New("invalid HTTP status code")

// Options represents the configuration options for a JWKS.
//
// If RefreshInterval and or RefreshUnknownKID is not nil, then a background goroutine will be launched to refresh the
Expand Down Expand Up @@ -58,6 +65,22 @@ type Options struct {
// RequestFactory creates HTTP requests for the remote JWKS resource located at the given url. For example, an
// HTTP header could be added to indicate a User-Agent.
RequestFactory func(ctx context.Context, url string) (*http.Request, error)

// ResponseExtractor consumes a *http.Response and produces the raw JSON for the JWKS. By default, the raw JSON is
// expected in the response body and the response's status code is not checked.
//
// The default behavior is likely to change soon. See this relevant GitHub issue:
// https://github.com/MicahParks/keyfunc/issues/48
ResponseExtractor func(ctx context.Context, resp *http.Response) (json.RawMessage, error)
}

// ResponseExtractorStatusOK is meant to be used as the ResponseExtractor field for Options. It confirms that response
// status code is 200 OK and returns the raw JSON from the response body.
func ResponseExtractorStatusOK(ctx context.Context, resp *http.Response) (json.RawMessage, error) {
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("%w: %d", ErrInvalidHTTPStatusCode, resp.StatusCode)
}
return io.ReadAll(resp.Body)
}

// applyOptions applies the given options to the given JWKS.
Expand All @@ -81,4 +104,5 @@ func applyOptions(jwks *JWKS, options Options) {
jwks.refreshTimeout = options.RefreshTimeout
jwks.refreshUnknownKID = options.RefreshUnknownKID
jwks.requestFactory = options.RequestFactory
jwks.responseExtractor = options.ResponseExtractor
}
45 changes: 45 additions & 0 deletions options_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package keyfunc_test

import (
"errors"
"net/http"
"net/http/httptest"
"sync"
"testing"

"github.com/MicahParks/keyfunc"
)

func TestResponseExtractorStatusOK(t *testing.T) {
var mux sync.Mutex
statusCode := http.StatusOK

server := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
mux.Lock()
writer.WriteHeader(statusCode)
mux.Unlock()
_, _ = writer.Write([]byte(jwksJSON))
}))
defer server.Close()

options := keyfunc.Options{
ResponseExtractor: keyfunc.ResponseExtractorStatusOK,
}
jwks, err := keyfunc.Get(server.URL, options)
if err != nil {
t.Fatalf("Failed to get JWK Set from server.\nError: %s", err)
}

if len(jwks.ReadOnlyKeys()) == 0 {
t.Fatalf("Expected JWK Set to have keys.")
}

mux.Lock()
statusCode = http.StatusInternalServerError
mux.Unlock()

_, err = keyfunc.Get(server.URL, options)
if !errors.Is(err, keyfunc.ErrInvalidHTTPStatusCode) {
t.Fatalf("Expected error to be ErrInvalidHTTPStatusCode.\nError: %s", err)
}
}

0 comments on commit f20aea8

Please sign in to comment.