@@ -69,11 +69,12 @@ def user():
69
69
return jsonify(request.oauth.user)
70
70
"""
71
71
72
- def __init__ (self , app = None ):
72
+ def __init__ (self , app = None , validator_class = None ):
73
73
self ._before_request_funcs = []
74
74
self ._after_request_funcs = []
75
75
self ._exception_handler = None
76
76
self ._invalid_response = None
77
+ self ._validator_class = validator_class
77
78
if app :
78
79
self .init_app (app )
79
80
@@ -163,7 +164,10 @@ def validate_client_id(self, client_id):
163
164
if hasattr (self , '_usergetter' ):
164
165
usergetter = self ._usergetter
165
166
166
- validator = OAuth2RequestValidator (
167
+ validator_class = self ._validator_class
168
+ if validator_class is None :
169
+ validator_class = OAuth2RequestValidator
170
+ validator = validator_class (
167
171
clientgetter = self ._clientgetter ,
168
172
tokengetter = self ._tokengetter ,
169
173
grantgetter = self ._grantgetter ,
@@ -430,7 +434,12 @@ def decorated(*args, **kwargs):
430
434
return self ._on_exception (e , e .in_uri (self .error_uri ))
431
435
except oauth2 .OAuth2Error as e :
432
436
log .debug ('OAuth2Error: %r' , e , exc_info = True )
437
+ # on auth error, we should preserve state if it's present according to RFC 6749
438
+ state = request .values .get ('state' )
439
+ if state and not e .state :
440
+ e .state = state # set e.state so e.in_uri() can add the state query parameter to redirect uri
433
441
return self ._on_exception (e , e .in_uri (redirect_uri ))
442
+
434
443
except Exception as e :
435
444
log .exception (e )
436
445
return self ._on_exception (e , add_params_to_uri (
@@ -449,6 +458,10 @@ def decorated(*args, **kwargs):
449
458
return self ._on_exception (e , e .in_uri (self .error_uri ))
450
459
except oauth2 .OAuth2Error as e :
451
460
log .debug ('OAuth2Error: %r' , e , exc_info = True )
461
+ # on auth error, we should preserve state if it's present according to RFC 6749
462
+ state = request .values .get ('state' )
463
+ if state and not e .state :
464
+ e .state = state # set e.state so e.in_uri() can add the state query parameter to redirect uri
452
465
return self ._on_exception (e , e .in_uri (redirect_uri ))
453
466
454
467
if not isinstance (rv , bool ):
@@ -457,8 +470,9 @@ def decorated(*args, **kwargs):
457
470
458
471
if not rv :
459
472
# denied by user
460
- e = oauth2 .AccessDeniedError ()
473
+ e = oauth2 .AccessDeniedError (state = request . values . get ( 'state' ) )
461
474
return self ._on_exception (e , e .in_uri (redirect_uri ))
475
+
462
476
return self .confirm_authorization_request ()
463
477
return decorated
464
478
@@ -488,6 +502,11 @@ def confirm_authorization_request(self):
488
502
return self ._on_exception (e , e .in_uri (self .error_uri ))
489
503
except oauth2 .OAuth2Error as e :
490
504
log .debug ('OAuth2Error: %r' , e , exc_info = True )
505
+
506
+ # on auth error, we should preserve state if it's present according to RFC 6749
507
+ state = request .values .get ('state' )
508
+ if state and not e .state :
509
+ e .state = state # set e.state so e.in_uri() can add the state query parameter to redirect uri
491
510
return self ._on_exception (e , e .in_uri (redirect_uri or self .error_uri ))
492
511
except Exception as e :
493
512
log .exception (e )
0 commit comments