Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for custom json serializer #23

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions jwt/jws.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import (
AbstractSet,
Tuple,
Callable,
)

from .exceptions import (
Expand Down Expand Up @@ -49,7 +50,8 @@ def _retrieve_alg(self, alg: str) -> AbstractSigningAlgorithm:
raise JWSDecodeError('Unsupported signing algorithm.')

def encode(self, message: bytes, key: AbstractJWKBase = None, alg='HS256',
optional_headers: dict = None) -> str:
optional_headers: dict = None,
dumps: Callable = json.dumps) -> str:
if alg not in self._supported_algs: # pragma: no cover
raise JWSEncodeError('unsupported algorithm: {}'.format(alg))
alg_impl = self._retrieve_alg(alg)
Expand All @@ -58,7 +60,7 @@ def encode(self, message: bytes, key: AbstractJWKBase = None, alg='HS256',
header['alg'] = alg

header_b64 = b64encode(
json.dumps(header, separators=(',', ':')).encode('ascii'))
dumps(header, separators=(',', ':')).encode('ascii'))
message_b64 = b64encode(message)
signing_message = header_b64 + '.' + message_b64

Expand All @@ -67,25 +69,27 @@ def encode(self, message: bytes, key: AbstractJWKBase = None, alg='HS256',

return signing_message + '.' + signature_b64

def _decode_segments(self, message: str) -> Tuple[dict, bytes, bytes, str]:
def _decode_segments(self, message: str,
loads: Callable) -> Tuple[dict, bytes, bytes, str]:
try:
signing_message, signature_b64 = message.rsplit('.', 1)
header_b64, message_b64 = signing_message.split('.')
except ValueError:
raise JWSDecodeError('malformed JWS payload')

header = json.loads(b64decode(header_b64).decode('ascii'))
header = loads(b64decode(header_b64).decode('ascii'))
message_bin = b64decode(message_b64)
signature = b64decode(signature_b64)
return header, message_bin, signature, signing_message

def decode(self, message: str, key: AbstractJWKBase = None,
do_verify=True, algorithms: AbstractSet[str]=None) -> bytes:
do_verify=True, algorithms: AbstractSet[str]=None,
loads: Callable = json.loads) -> bytes:
if algorithms is None:
algorithms = set(supported_signing_algorithms().keys())

header, message_bin, signature, signing_message = \
self._decode_segments(message)
self._decode_segments(message, loads=loads)

alg_value = header['alg']
if alg_value not in algorithms:
Expand Down
18 changes: 11 additions & 7 deletions jwt/jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# limitations under the License.

import json
from typing import AbstractSet
from typing import AbstractSet, Callable

from .exceptions import (
JWSEncodeError,
Expand All @@ -33,27 +33,31 @@ def __init__(self):
self._jws = JWS()

def encode(self, payload: dict, key: AbstractJWKBase = None, alg='HS256',
optional_headers: dict = None) -> str:
optional_headers: dict = None,
dumps: Callable = json.dumps) -> str:
try:
message = json.dumps(payload).encode('utf-8')
message = dumps(payload).encode('utf-8')
except ValueError as why:
raise JWTEncodeError('payload must be able to encode in JSON')

optional_headers = optional_headers and optional_headers.copy() or {}
optional_headers['typ'] = 'JWT'
try:
return self._jws.encode(message, key, alg, optional_headers)
return self._jws.encode(message, key, alg, optional_headers,
dumps=dumps)
except JWSEncodeError as why:
raise JWTEncodeError('failed to encode to JWT') from why

def decode(self, message: str, key: AbstractJWKBase = None,
do_verify=True, algorithms: AbstractSet[str]=None) -> dict:
do_verify=True, algorithms: AbstractSet[str]=None,
loads: Callable = json.loads) -> dict:
try:
message_bin = self._jws.decode(message, key, do_verify, algorithms)
message_bin = self._jws.decode(message, key, do_verify, algorithms,
loads=loads)
except JWSDecodeError as why:
raise JWTDecodeError('failed to decode JWT') from why
try:
payload = json.loads(message_bin.decode('utf-8'))
payload = loads(message_bin.decode('utf-8'))
return payload
except ValueError as why:
raise JWTDecodeError(
Expand Down