diff --git a/access.go b/access.go index 0add2c1..152db9c 100644 --- a/access.go +++ b/access.go @@ -115,20 +115,17 @@ func (s *Server) HandleAccessRequest(w *Response, r *http.Request) *AccessReques // Only allow GET or POST if r.Method == "GET" { if !s.Config.AllowGetAccessRequest { - w.SetError(E_INVALID_REQUEST, "") - w.InternalError = errors.New("Request must be POST") + s.setErrorAndLog(w, E_INVALID_REQUEST, errors.New("Request must be POST"), "access_request=%s", "GET request not allowed") return nil } } else if r.Method != "POST" { - w.SetError(E_INVALID_REQUEST, "") - w.InternalError = errors.New("Request must be POST") + s.setErrorAndLog(w, E_INVALID_REQUEST, errors.New("Request must be POST"), "access_request=%s", "request must be POST") return nil } err := r.ParseForm() if err != nil { - w.SetError(E_INVALID_REQUEST, "") - w.InternalError = err + s.setErrorAndLog(w, E_INVALID_REQUEST, err, "access_request=%s", "parsing error") return nil } @@ -148,13 +145,13 @@ func (s *Server) HandleAccessRequest(w *Response, r *http.Request) *AccessReques } } - w.SetError(E_UNSUPPORTED_GRANT_TYPE, "") + s.setErrorAndLog(w, E_UNSUPPORTED_GRANT_TYPE, nil, "access_request=%s", "unknown grant type") return nil } func (s *Server) handleAuthorizationCodeRequest(w *Response, r *http.Request) *AccessRequest { // get client authentication - auth := getClientAuth(w, r, s.Config.AllowClientSecretInParams) + auth := s.getClientAuth(w, r, s.Config.AllowClientSecretInParams) if auth == nil { return nil } @@ -172,12 +169,12 @@ func (s *Server) handleAuthorizationCodeRequest(w *Response, r *http.Request) *A // "code" is required if ret.Code == "" { - w.SetError(E_INVALID_GRANT, "") + s.setErrorAndLog(w, E_INVALID_GRANT, nil, "auth_code_request=%s", "code is required") return nil } // must have a valid client - if ret.Client = getClient(auth, w.Storage, w); ret.Client == nil { + if ret.Client = s.getClient(auth, w.Storage, w); ret.Client == nil { return nil } @@ -185,30 +182,29 @@ func (s *Server) handleAuthorizationCodeRequest(w *Response, r *http.Request) *A var err error ret.AuthorizeData, err = w.Storage.LoadAuthorize(ret.Code) if err != nil { - w.SetError(E_INVALID_GRANT, "") - w.InternalError = err + s.setErrorAndLog(w, E_INVALID_GRANT, err, "auth_code_request=%s", "error loading authorize data") return nil } if ret.AuthorizeData == nil { - w.SetError(E_UNAUTHORIZED_CLIENT, "") + s.setErrorAndLog(w, E_UNAUTHORIZED_CLIENT, nil, "auth_code_request=%s", "authorization data is nil") return nil } if ret.AuthorizeData.Client == nil { - w.SetError(E_UNAUTHORIZED_CLIENT, "") + s.setErrorAndLog(w, E_UNAUTHORIZED_CLIENT, nil, "auth_code_request=%s", "authorization client is nil") return nil } if ret.AuthorizeData.Client.GetRedirectUri() == "" { - w.SetError(E_UNAUTHORIZED_CLIENT, "") + s.setErrorAndLog(w, E_UNAUTHORIZED_CLIENT, nil, "auth_code_request=%s", "client redirect uri is empty") return nil } if ret.AuthorizeData.IsExpiredAt(s.Now()) { - w.SetError(E_INVALID_GRANT, "") + s.setErrorAndLog(w, E_INVALID_GRANT, nil, "auth_code_request=%s", "authorization data is expired") return nil } // code must be from the client if ret.AuthorizeData.Client.GetId() != ret.Client.GetId() { - w.SetError(E_INVALID_GRANT, "") + s.setErrorAndLog(w, E_INVALID_GRANT, nil, "auth_code_request=%s", "client code does not match") return nil } @@ -217,13 +213,11 @@ func (s *Server) handleAuthorizationCodeRequest(w *Response, r *http.Request) *A ret.RedirectUri = FirstUri(ret.Client.GetRedirectUri(), s.Config.RedirectUriSeparator) } if err = ValidateUriList(ret.Client.GetRedirectUri(), ret.RedirectUri, s.Config.RedirectUriSeparator); err != nil { - w.SetError(E_INVALID_REQUEST, "") - w.InternalError = err + s.setErrorAndLog(w, E_INVALID_REQUEST, err, "auth_code_request=%s", "error validating client redirect") return nil } if ret.AuthorizeData.RedirectUri != ret.RedirectUri { - w.SetError(E_INVALID_REQUEST, "") - w.InternalError = errors.New("Redirect uri is different") + s.setErrorAndLog(w, E_INVALID_REQUEST, errors.New("Redirect uri is different"), "auth_code_request=%s", "client redirect does not match authorization data") return nil } @@ -231,8 +225,8 @@ func (s *Server) handleAuthorizationCodeRequest(w *Response, r *http.Request) *A if len(ret.AuthorizeData.CodeChallenge) > 0 { // https://tools.ietf.org/html/rfc7636#section-4.1 if matched := pkceMatcher.MatchString(ret.CodeVerifier); !matched { - w.SetError(E_INVALID_REQUEST, "code_verifier invalid (rfc7636)") - w.InternalError = errors.New("code_verifier has invalid format") + s.setErrorAndLog(w, E_INVALID_REQUEST, errors.New("code_verifier has invalid format"), + "auth_code_request=%s", "pkce code challenge verifier does not match") return nil } @@ -245,12 +239,13 @@ func (s *Server) handleAuthorizationCodeRequest(w *Response, r *http.Request) *A hash := sha256.Sum256([]byte(ret.CodeVerifier)) codeVerifier = base64.RawURLEncoding.EncodeToString(hash[:]) default: - w.SetError(E_INVALID_REQUEST, "code_challenge_method transform algorithm not supported (rfc7636)") + s.setErrorAndLog(w, E_INVALID_REQUEST, nil, + "auth_code_request=%s", "pkce transform algorithm not supported (rfc7636)") return nil } if codeVerifier != ret.AuthorizeData.CodeChallenge { - w.SetError(E_INVALID_GRANT, "code_verifier invalid (rfc7636)") - w.InternalError = errors.New("code_verifier failed comparison with code_challenge") + s.setErrorAndLog(w, E_INVALID_GRANT, errors.New("code_verifier failed comparison with code_challenge"), + "auth_code_request=%s", "pkce code verifier does not match challenge") return nil } } @@ -288,7 +283,7 @@ func extraScopes(access_scopes, refresh_scopes string) bool { func (s *Server) handleRefreshTokenRequest(w *Response, r *http.Request) *AccessRequest { // get client authentication - auth := getClientAuth(w, r, s.Config.AllowClientSecretInParams) + auth := s.getClientAuth(w, r, s.Config.AllowClientSecretInParams) if auth == nil { return nil } @@ -305,12 +300,12 @@ func (s *Server) handleRefreshTokenRequest(w *Response, r *http.Request) *Access // "refresh_token" is required if ret.Code == "" { - w.SetError(E_INVALID_GRANT, "") + s.setErrorAndLog(w, E_INVALID_GRANT, nil, "refresh_token=%s", "refresh_token is required") return nil } // must have a valid client - if ret.Client = getClient(auth, w.Storage, w); ret.Client == nil { + if ret.Client = s.getClient(auth, w.Storage, w); ret.Client == nil { return nil } @@ -318,27 +313,25 @@ func (s *Server) handleRefreshTokenRequest(w *Response, r *http.Request) *Access var err error ret.AccessData, err = w.Storage.LoadRefresh(ret.Code) if err != nil { - w.SetError(E_INVALID_GRANT, "") - w.InternalError = err + s.setErrorAndLog(w, E_INVALID_GRANT, err, "refresh_token=%s", "error loading access data") return nil } if ret.AccessData == nil { - w.SetError(E_UNAUTHORIZED_CLIENT, "") + s.setErrorAndLog(w, E_UNAUTHORIZED_CLIENT, nil, "refresh_token=%s", "access data is nil") return nil } if ret.AccessData.Client == nil { - w.SetError(E_UNAUTHORIZED_CLIENT, "") + s.setErrorAndLog(w, E_UNAUTHORIZED_CLIENT, nil, "refresh_token=%s", "access data client is nil") return nil } if ret.AccessData.Client.GetRedirectUri() == "" { - w.SetError(E_UNAUTHORIZED_CLIENT, "") + s.setErrorAndLog(w, E_UNAUTHORIZED_CLIENT, nil, "refresh_token=%s", "access data client redirect uri is empty") return nil } // client must be the same as the previous token if ret.AccessData.Client.GetId() != ret.Client.GetId() { - w.SetError(E_INVALID_CLIENT, "") - w.InternalError = errors.New("Client id must be the same from previous token") + s.setErrorAndLog(w, E_INVALID_CLIENT, errors.New("Client id must be the same from previous token"), "refresh_token=%s, current=%v, previous=%v", "client mismatch", ret.Client.GetId(), ret.AccessData.Client.GetId()) return nil } @@ -351,8 +344,8 @@ func (s *Server) handleRefreshTokenRequest(w *Response, r *http.Request) *Access } if extraScopes(ret.AccessData.Scope, ret.Scope) { - w.SetError(E_ACCESS_DENIED, "") - w.InternalError = errors.New("the requested scope must not include any scope not originally granted by the resource owner") + msg := "the requested scope must not include any scope not originally granted by the resource owner" + s.setErrorAndLog(w, E_ACCESS_DENIED, errors.New(msg), "refresh_token=%s", msg) return nil } @@ -361,7 +354,7 @@ func (s *Server) handleRefreshTokenRequest(w *Response, r *http.Request) *Access func (s *Server) handlePasswordRequest(w *Response, r *http.Request) *AccessRequest { // get client authentication - auth := getClientAuth(w, r, s.Config.AllowClientSecretInParams) + auth := s.getClientAuth(w, r, s.Config.AllowClientSecretInParams) if auth == nil { return nil } @@ -379,12 +372,12 @@ func (s *Server) handlePasswordRequest(w *Response, r *http.Request) *AccessRequ // "username" and "password" is required if ret.Username == "" || ret.Password == "" { - w.SetError(E_INVALID_GRANT, "") + s.setErrorAndLog(w, E_INVALID_GRANT, nil, "handle_password=%s", "username and pass required") return nil } // must have a valid client - if ret.Client = getClient(auth, w.Storage, w); ret.Client == nil { + if ret.Client = s.getClient(auth, w.Storage, w); ret.Client == nil { return nil } @@ -396,7 +389,7 @@ func (s *Server) handlePasswordRequest(w *Response, r *http.Request) *AccessRequ func (s *Server) handleClientCredentialsRequest(w *Response, r *http.Request) *AccessRequest { // get client authentication - auth := getClientAuth(w, r, s.Config.AllowClientSecretInParams) + auth := s.getClientAuth(w, r, s.Config.AllowClientSecretInParams) if auth == nil { return nil } @@ -411,7 +404,7 @@ func (s *Server) handleClientCredentialsRequest(w *Response, r *http.Request) *A } // must have a valid client - if ret.Client = getClient(auth, w.Storage, w); ret.Client == nil { + if ret.Client = s.getClient(auth, w.Storage, w); ret.Client == nil { return nil } @@ -423,7 +416,7 @@ func (s *Server) handleClientCredentialsRequest(w *Response, r *http.Request) *A func (s *Server) handleAssertionRequest(w *Response, r *http.Request) *AccessRequest { // get client authentication - auth := getClientAuth(w, r, s.Config.AllowClientSecretInParams) + auth := s.getClientAuth(w, r, s.Config.AllowClientSecretInParams) if auth == nil { return nil } @@ -441,12 +434,12 @@ func (s *Server) handleAssertionRequest(w *Response, r *http.Request) *AccessReq // "assertion_type" and "assertion" is required if ret.AssertionType == "" || ret.Assertion == "" { - w.SetError(E_INVALID_GRANT, "") + s.setErrorAndLog(w, E_INVALID_GRANT, nil, "handle_assertion_request=%s", "assertion and assertion_type required") return nil } // must have a valid client - if ret.Client = getClient(auth, w.Storage, w); ret.Client == nil { + if ret.Client = s.getClient(auth, w.Storage, w); ret.Client == nil { return nil } @@ -486,8 +479,7 @@ func (s *Server) FinishAccessRequest(w *Response, r *http.Request, ar *AccessReq // generate access token ret.AccessToken, ret.RefreshToken, err = s.AccessTokenGen.GenerateAccessToken(ret, ar.GenerateRefresh) if err != nil { - w.SetError(E_SERVER_ERROR, "") - w.InternalError = err + s.setErrorAndLog(w, E_SERVER_ERROR, err, "finish_access_request=%s", "error generating token") return } } else { @@ -496,8 +488,7 @@ func (s *Server) FinishAccessRequest(w *Response, r *http.Request, ar *AccessReq // save access token if err = w.Storage.SaveAccess(ret); err != nil { - w.SetError(E_SERVER_ERROR, "") - w.InternalError = err + s.setErrorAndLog(w, E_SERVER_ERROR, err, "finish_access_request=%s", "error saving access token") return } @@ -525,7 +516,7 @@ func (s *Server) FinishAccessRequest(w *Response, r *http.Request, ar *AccessReq w.Output["scope"] = ret.Scope } } else { - w.SetError(E_ACCESS_DENIED, "") + s.setErrorAndLog(w, E_ACCESS_DENIED, nil, "finish_access_request=%s", "authorization failed") } } @@ -533,30 +524,39 @@ func (s *Server) FinishAccessRequest(w *Response, r *http.Request, ar *AccessReq // getClient looks up and authenticates the basic auth using the given // storage. Sets an error on the response if auth fails or a server error occurs. -func getClient(auth *BasicAuth, storage Storage, w *Response) Client { +func (s Server) getClient(auth *BasicAuth, storage Storage, w *Response) Client { client, err := storage.GetClient(auth.Username) if err == ErrNotFound { - w.SetError(E_UNAUTHORIZED_CLIENT, "") + s.setErrorAndLog(w, E_UNAUTHORIZED_CLIENT, nil, "get_client=%s", "not found") return nil } if err != nil { - w.SetError(E_SERVER_ERROR, "") - w.InternalError = err + s.setErrorAndLog(w, E_SERVER_ERROR, err, "get_client=%s", "error finding client") return nil } if client == nil { - w.SetError(E_UNAUTHORIZED_CLIENT, "") + s.setErrorAndLog(w, E_UNAUTHORIZED_CLIENT, nil, "get_client=%s", "client is nil") return nil } if !CheckClientSecret(client, auth.Password) { - w.SetError(E_UNAUTHORIZED_CLIENT, "") + s.setErrorAndLog(w, E_UNAUTHORIZED_CLIENT, nil, "get_client=%s, client_id=%v", "client check failed", client.GetId()) return nil } if client.GetRedirectUri() == "" { - w.SetError(E_UNAUTHORIZED_CLIENT, "") + s.setErrorAndLog(w, E_UNAUTHORIZED_CLIENT, nil, "get_client=%s", "client redirect uri is empty") return nil } return client } + +// setErrorAndLog sets the response error and internal error (if non-nil) and logs them along with the provided debug format string and arguments. +func (s Server) setErrorAndLog(w *Response, responseError string, internalError error, debugFormat string, debugArgs ...interface{}) { + format := "error=%v, internal_error=%#v " + debugFormat + + w.InternalError = internalError + w.SetError(responseError, "") + + s.Logger.Printf(format, append([]interface{}{responseError, internalError}, debugArgs...)...) +} \ No newline at end of file diff --git a/access_test.go b/access_test.go index 34bca19..504e39c 100644 --- a/access_test.go +++ b/access_test.go @@ -291,6 +291,8 @@ func TestGetClientWithoutMatcher(t *testing.T) { RedirectUri: "http://www.example.com", } storage := &TestingStorage{clients: map[string]Client{myclient.Id: myclient}} + sconfig := NewServerConfig() + server := NewServer(sconfig, storage) // Ensure bad secret fails { @@ -299,7 +301,7 @@ func TestGetClientWithoutMatcher(t *testing.T) { Password: "invalidsecret", } w := &Response{} - client := getClient(auth, storage, w) + client := server.getClient(auth, storage, w) if client != nil { t.Errorf("Expected error, got client: %v", client) } @@ -320,7 +322,7 @@ func TestGetClientWithoutMatcher(t *testing.T) { Password: "nonexistent", } w := &Response{} - client := getClient(auth, storage, w) + client := server.getClient(auth, storage, w) if client != nil { t.Errorf("Expected error, got client: %v", client) } @@ -341,7 +343,7 @@ func TestGetClientWithoutMatcher(t *testing.T) { Password: "myclientsecret", } w := &Response{} - client := getClient(auth, storage, w) + client := server.getClient(auth, storage, w) if client != myclient { t.Errorf("Expected client, got nil with response: %v", w) } @@ -370,6 +372,8 @@ func TestGetClientSecretMatcher(t *testing.T) { RedirectUri: "http://www.example.com", } storage := &TestingStorage{clients: map[string]Client{myclient.Id: myclient}} + sconfig := NewServerConfig() + server := NewServer(sconfig, storage) // Ensure bad secret fails, but does not panic (doesn't call GetSecret) { @@ -378,7 +382,7 @@ func TestGetClientSecretMatcher(t *testing.T) { Password: "invalidsecret", } w := &Response{} - client := getClient(auth, storage, w) + client := server.getClient(auth, storage, w) if client != nil { t.Errorf("Expected error, got client: %v", client) } @@ -391,7 +395,7 @@ func TestGetClientSecretMatcher(t *testing.T) { Password: "myclientsecret", } w := &Response{} - client := getClient(auth, storage, w) + client := server.getClient(auth, storage, w) if client != myclient { t.Errorf("Expected client, got nil with response: %v", w) } diff --git a/info.go b/info.go index 00aa563..b3c73ca 100644 --- a/info.go +++ b/info.go @@ -17,7 +17,7 @@ func (s *Server) HandleInfoRequest(w *Response, r *http.Request) *InfoRequest { r.ParseForm() bearer := CheckBearerAuth(r) if bearer == nil { - w.SetError(E_INVALID_REQUEST, "") + s.setErrorAndLog(w, E_INVALID_REQUEST, nil, "handle_info_request=%s", "bearer is nil") return nil } @@ -27,7 +27,7 @@ func (s *Server) HandleInfoRequest(w *Response, r *http.Request) *InfoRequest { } if ret.Code == "" { - w.SetError(E_INVALID_REQUEST, "") + s.setErrorAndLog(w, E_INVALID_REQUEST, nil, "handle_info_request=%s", "code is nil") return nil } @@ -36,24 +36,23 @@ func (s *Server) HandleInfoRequest(w *Response, r *http.Request) *InfoRequest { // load access data ret.AccessData, err = w.Storage.LoadAccess(ret.Code) if err != nil { - w.SetError(E_INVALID_REQUEST, "") - w.InternalError = err + s.setErrorAndLog(w, E_INVALID_REQUEST, err, "handle_info_request=%s", "failed to load access data") return nil } if ret.AccessData == nil { - w.SetError(E_INVALID_REQUEST, "") + s.setErrorAndLog(w, E_INVALID_REQUEST, nil, "handle_info_request=%s", "access data is nil") return nil } if ret.AccessData.Client == nil { - w.SetError(E_UNAUTHORIZED_CLIENT, "") + s.setErrorAndLog(w, E_UNAUTHORIZED_CLIENT, nil, "handle_info_request=%s", "access data client is nil") return nil } if ret.AccessData.Client.GetRedirectUri() == "" { - w.SetError(E_UNAUTHORIZED_CLIENT, "") + s.setErrorAndLog(w, E_UNAUTHORIZED_CLIENT, nil, "handle_info_request=%s", "access data client redirect uri is empty") return nil } if ret.AccessData.IsExpiredAt(s.Now()) { - w.SetError(E_INVALID_GRANT, "") + s.setErrorAndLog(w, E_INVALID_GRANT, nil, "handle_info_request=%s", "access data is expired") return nil } diff --git a/log.go b/log.go new file mode 100644 index 0000000..66d4ae3 --- /dev/null +++ b/log.go @@ -0,0 +1,13 @@ +package osin + +// Logger creates a formatted log event. +// NOTE: Log is meant for internal use only and may contain sensitive info. +type Logger interface { + Printf(format string, v ...interface{}) +} + +type LoggerDefault struct { +} + +func (l LoggerDefault) Printf(format string, v ...interface{}) { +} diff --git a/log_test.go b/log_test.go new file mode 100644 index 0000000..b2c875a --- /dev/null +++ b/log_test.go @@ -0,0 +1,41 @@ +package osin + +import ( + "errors" + "fmt" + "reflect" + "testing" +) + +type testLogger struct { + Result string +} + +func (l *testLogger) Printf(format string, v ...interface{}) { + l.Result = fmt.Sprintf(format, v...) +} + +func TestServerErrorLogger(t *testing.T) { + sconfig := NewServerConfig() + server := NewServer(sconfig, NewTestingStorage()) + + tl := &testLogger{} + server.Logger = tl + + r := server.NewResponse() + r.ErrorStatusCode = 404 + + server.setErrorAndLog(r, E_INVALID_GRANT, errors.New("foo"), "foo=%s, bar=%s", "bar", "baz") + + if r.ErrorId != E_INVALID_GRANT { + t.Errorf("expected error to be set to %s", E_INVALID_GRANT) + } + if r.StatusText != deferror.Get(E_INVALID_GRANT) { + t.Errorf("expected status text to be %s, got %s", deferror.Get(E_INVALID_GRANT), r.StatusText) + } + + expectedResult := `error=invalid_grant, internal_error=&errors.errorString{s:"foo"} foo=bar, bar=baz` + if !reflect.DeepEqual(tl.Result, expectedResult) { + t.Errorf("expected %v, got %v", expectedResult, tl.Result) + } +} \ No newline at end of file diff --git a/server.go b/server.go index 57695ae..8b7b31b 100644 --- a/server.go +++ b/server.go @@ -11,6 +11,7 @@ type Server struct { AuthorizeTokenGen AuthorizeTokenGen AccessTokenGen AccessTokenGen Now func() time.Time + Logger Logger } // NewServer creates a new server instance @@ -21,6 +22,7 @@ func NewServer(config *ServerConfig, storage Storage) *Server { AuthorizeTokenGen: &AuthorizeTokenGenDefault{}, AccessTokenGen: &AccessTokenGenDefault{}, Now: time.Now, + Logger: &LoggerDefault{}, } } diff --git a/util.go b/util.go index a86af7b..42c9565 100644 --- a/util.go +++ b/util.go @@ -78,7 +78,7 @@ func CheckBearerAuth(r *http.Request) *BearerAuth { // getClientAuth checks client basic authentication in params if allowed, // otherwise gets it from the header. // Sets an error on the response if no auth is present or a server error occurs. -func getClientAuth(w *Response, r *http.Request, allowQueryParams bool) *BasicAuth { +func (s Server) getClientAuth(w *Response, r *http.Request, allowQueryParams bool) *BasicAuth { if allowQueryParams { // Allow for auth without password @@ -95,13 +95,11 @@ func getClientAuth(w *Response, r *http.Request, allowQueryParams bool) *BasicAu auth, err := CheckBasicAuth(r) if err != nil { - w.SetError(E_INVALID_REQUEST, "") - w.InternalError = err + s.setErrorAndLog(w, E_INVALID_REQUEST, err, "get_client_auth=%s", "check auth error") return nil } if auth == nil { - w.SetError(E_INVALID_REQUEST, "") - w.InternalError = errors.New("Client authentication not sent") + s.setErrorAndLog(w, E_INVALID_REQUEST, errors.New("Client authentication not sent"), "get_client_auth=%s", "client authentication not sent") return nil } return auth diff --git a/util_test.go b/util_test.go index 34ddf6c..234353f 100644 --- a/util_test.go +++ b/util_test.go @@ -54,6 +54,9 @@ func TestGetClientAuth(t *testing.T) { headerOKAuth := make(http.Header) headerOKAuth.Set("Authorization", goodAuthValue) + sconfig := NewServerConfig() + server := NewServer(sconfig, NewTestingStorage()) + var tests = []struct { header http.Header url *url.URL @@ -86,7 +89,7 @@ func TestGetClientAuth(t *testing.T) { w := new(Response) r := &http.Request{Header: tt.header, URL: tt.url} r.ParseForm() - auth := getClientAuth(w, r, tt.allowQueryParams) + auth := server.getClientAuth(w, r, tt.allowQueryParams) if tt.expectAuth && auth == nil { t.Errorf("Auth should not be nil for %v", tt) } else if !tt.expectAuth && auth != nil {