Skip to content

Commit 319de62

Browse files
authored
Merge pull request #1 from delvelabs/exception-handler-callback
Add a method to register a custom exception handler.
2 parents 95451eb + 213cbc0 commit 319de62

File tree

2 files changed

+68
-10
lines changed

2 files changed

+68
-10
lines changed

flask_oauthlib/provider/oauth2.py

+45-9
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def user():
7272
def __init__(self, app=None):
7373
self._before_request_funcs = []
7474
self._after_request_funcs = []
75+
self._exception_handler = None
7576
self._invalid_response = None
7677
if app:
7778
self.init_app(app)
@@ -85,6 +86,13 @@ def init_app(self, app):
8586
app.extensions = getattr(app, 'extensions', {})
8687
app.extensions['oauthlib.provider.oauth2'] = self
8788

89+
def _on_exception(self, error, redirect_content=None):
90+
91+
if self._exception_handler:
92+
return self._exception_handler(error, redirect_content)
93+
else:
94+
return redirect(redirect_content)
95+
8896
@cached_property
8997
def error_uri(self):
9098
"""The error page URI.
@@ -208,6 +216,34 @@ def valid_after_request(valid, oauth):
208216
self._after_request_funcs.append(f)
209217
return f
210218

219+
def exception_handler(self, f):
220+
"""Register a function as custom exception handler.
221+
222+
**As the default error handling is leaking error to the client, it is
223+
STRONGLY RECOMMENDED to implement your own handler to mask
224+
the server side errors in production environment.**
225+
226+
When an error occur during execution, we can
227+
handle the error with with the registered function. The function
228+
accepts two parameters:
229+
- error: the error raised
230+
- redirect_content: the content used in the redirect by default
231+
232+
usage with the flask error handler ::
233+
@oauth.exception_handler
234+
def custom_exception_handler(error, *args):
235+
raise error
236+
237+
@app.errorhandler(Exception)
238+
def all_exception_handler(*args):
239+
# any treatment you need for the error
240+
return "Server error", 500
241+
242+
If no function is registered, it will do a redirect with ``redirect_content`` as content.
243+
"""
244+
self._exception_handler = f
245+
return f
246+
211247
def invalid_response(self, f):
212248
"""Register a function for responsing with invalid request.
213249
@@ -391,13 +427,13 @@ def decorated(*args, **kwargs):
391427
kwargs.update(credentials)
392428
except oauth2.FatalClientError as e:
393429
log.debug('Fatal client error %r', e, exc_info=True)
394-
return redirect(e.in_uri(self.error_uri))
430+
return self._on_exception(e, e.in_uri(self.error_uri))
395431
except oauth2.OAuth2Error as e:
396432
log.debug('OAuth2Error: %r', e, exc_info=True)
397-
return redirect(e.in_uri(redirect_uri))
433+
return self._on_exception(e, e.in_uri(redirect_uri))
398434
except Exception as e:
399435
log.exception(e)
400-
return redirect(add_params_to_uri(
436+
return self._on_exception(e, add_params_to_uri(
401437
self.error_uri, {'error': str(e)}
402438
))
403439

@@ -410,10 +446,10 @@ def decorated(*args, **kwargs):
410446
rv = f(*args, **kwargs)
411447
except oauth2.FatalClientError as e:
412448
log.debug('Fatal client error %r', e, exc_info=True)
413-
return redirect(e.in_uri(self.error_uri))
449+
return self._on_exception(e, e.in_uri(self.error_uri))
414450
except oauth2.OAuth2Error as e:
415451
log.debug('OAuth2Error: %r', e, exc_info=True)
416-
return redirect(e.in_uri(redirect_uri))
452+
return self._on_exception(e, e.in_uri(redirect_uri))
417453

418454
if not isinstance(rv, bool):
419455
# if is a response or redirect
@@ -422,7 +458,7 @@ def decorated(*args, **kwargs):
422458
if not rv:
423459
# denied by user
424460
e = oauth2.AccessDeniedError()
425-
return redirect(e.in_uri(redirect_uri))
461+
return self._on_exception(e, e.in_uri(redirect_uri))
426462
return self.confirm_authorization_request()
427463
return decorated
428464

@@ -449,13 +485,13 @@ def confirm_authorization_request(self):
449485
return create_response(*ret)
450486
except oauth2.FatalClientError as e:
451487
log.debug('Fatal client error %r', e, exc_info=True)
452-
return redirect(e.in_uri(self.error_uri))
488+
return self._on_exception(e, e.in_uri(self.error_uri))
453489
except oauth2.OAuth2Error as e:
454490
log.debug('OAuth2Error: %r', e, exc_info=True)
455-
return redirect(e.in_uri(redirect_uri or self.error_uri))
491+
return self._on_exception(e, e.in_uri(redirect_uri or self.error_uri))
456492
except Exception as e:
457493
log.exception(e)
458-
return redirect(add_params_to_uri(
494+
return self._on_exception(e, add_params_to_uri(
459495
self.error_uri, {'error': str(e)}
460496
))
461497

tests/test_oauth2/test_code.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from datetime import datetime, timedelta
44
from .._base import to_base64
5-
from .base import TestCase
5+
from .base import TestCase, default_provider
66
from .base import create_server, sqlalchemy_provider, cache_provider
77
from .base import db, Client, User, Grant
88

@@ -159,3 +159,25 @@ def test_get_token(self):
159159
url += '&client_secret=' + self.oauth_client.client_secret
160160
rv = self.client.get(url)
161161
assert b'access_token' in rv.data
162+
163+
164+
class TestProviderWithExceptionHandler(TestCase):
165+
166+
def prepare_data(self):
167+
oauth = default_provider(self.app)
168+
169+
@oauth.exception_handler
170+
def custom_exception_handler(error, *args):
171+
raise error
172+
173+
@self.app.errorhandler(Exception)
174+
def all_exception_handler(*args):
175+
return "Testing server error", 500
176+
177+
create_server(self.app, oauth=oauth)
178+
179+
def test_exception_handler(self):
180+
rv = self.client.get('/oauth/authorize')
181+
182+
assert rv.status_code == 500
183+
assert rv.data.decode("utf-8") == "Testing server error"

0 commit comments

Comments
 (0)