diff --git a/jwt/jws.py b/jwt/jws.py index 9b24baa..dca1680 100644 --- a/jwt/jws.py +++ b/jwt/jws.py @@ -18,6 +18,7 @@ from typing import ( AbstractSet, Tuple, + Callable, ) from .exceptions import ( @@ -49,7 +50,8 @@ def _retrieve_alg(self, alg: str) -> AbstractSigningAlgorithm: raise JWSDecodeError('Unsupported signing algorithm.') def encode(self, message: bytes, key: AbstractJWKBase = None, alg='HS256', - optional_headers: dict = None) -> str: + optional_headers: dict = None, + dumps: Callable = json.dumps) -> str: if alg not in self._supported_algs: # pragma: no cover raise JWSEncodeError('unsupported algorithm: {}'.format(alg)) alg_impl = self._retrieve_alg(alg) @@ -58,7 +60,7 @@ def encode(self, message: bytes, key: AbstractJWKBase = None, alg='HS256', header['alg'] = alg header_b64 = b64encode( - json.dumps(header, separators=(',', ':')).encode('ascii')) + dumps(header, separators=(',', ':')).encode('ascii')) message_b64 = b64encode(message) signing_message = header_b64 + '.' + message_b64 @@ -67,25 +69,27 @@ def encode(self, message: bytes, key: AbstractJWKBase = None, alg='HS256', return signing_message + '.' + signature_b64 - def _decode_segments(self, message: str) -> Tuple[dict, bytes, bytes, str]: + def _decode_segments(self, message: str, + loads: Callable) -> Tuple[dict, bytes, bytes, str]: try: signing_message, signature_b64 = message.rsplit('.', 1) header_b64, message_b64 = signing_message.split('.') except ValueError: raise JWSDecodeError('malformed JWS payload') - header = json.loads(b64decode(header_b64).decode('ascii')) + header = loads(b64decode(header_b64).decode('ascii')) message_bin = b64decode(message_b64) signature = b64decode(signature_b64) return header, message_bin, signature, signing_message def decode(self, message: str, key: AbstractJWKBase = None, - do_verify=True, algorithms: AbstractSet[str]=None) -> bytes: + do_verify=True, algorithms: AbstractSet[str]=None, + loads: Callable = json.loads) -> bytes: if algorithms is None: algorithms = set(supported_signing_algorithms().keys()) header, message_bin, signature, signing_message = \ - self._decode_segments(message) + self._decode_segments(message, loads=loads) alg_value = header['alg'] if alg_value not in algorithms: diff --git a/jwt/jwt.py b/jwt/jwt.py index 5985d02..60282cb 100644 --- a/jwt/jwt.py +++ b/jwt/jwt.py @@ -15,7 +15,7 @@ # limitations under the License. import json -from typing import AbstractSet +from typing import AbstractSet, Callable from .exceptions import ( JWSEncodeError, @@ -33,27 +33,31 @@ def __init__(self): self._jws = JWS() def encode(self, payload: dict, key: AbstractJWKBase = None, alg='HS256', - optional_headers: dict = None) -> str: + optional_headers: dict = None, + dumps: Callable = json.dumps) -> str: try: - message = json.dumps(payload).encode('utf-8') + message = dumps(payload).encode('utf-8') except ValueError as why: raise JWTEncodeError('payload must be able to encode in JSON') optional_headers = optional_headers and optional_headers.copy() or {} optional_headers['typ'] = 'JWT' try: - return self._jws.encode(message, key, alg, optional_headers) + return self._jws.encode(message, key, alg, optional_headers, + dumps=dumps) except JWSEncodeError as why: raise JWTEncodeError('failed to encode to JWT') from why def decode(self, message: str, key: AbstractJWKBase = None, - do_verify=True, algorithms: AbstractSet[str]=None) -> dict: + do_verify=True, algorithms: AbstractSet[str]=None, + loads: Callable = json.loads) -> dict: try: - message_bin = self._jws.decode(message, key, do_verify, algorithms) + message_bin = self._jws.decode(message, key, do_verify, algorithms, + loads=loads) except JWSDecodeError as why: raise JWTDecodeError('failed to decode JWT') from why try: - payload = json.loads(message_bin.decode('utf-8')) + payload = loads(message_bin.decode('utf-8')) return payload except ValueError as why: raise JWTDecodeError(