From b5ceeab13ebc3f75c994ff8f500a407849d3a7db Mon Sep 17 00:00:00 2001 From: Leszek Hanusz <leszek.hanusz@gmail.com> Date: Thu, 8 Feb 2024 14:43:59 +0100 Subject: [PATCH] Adding json_unserialize attribute the httpx transport --- gql/transport/httpx.py | 10 +++++++- tests/test_httpx_async.py | 50 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 1 deletion(-) diff --git a/gql/transport/httpx.py b/gql/transport/httpx.py index cfc25dc9..a328127f 100644 --- a/gql/transport/httpx.py +++ b/gql/transport/httpx.py @@ -38,6 +38,7 @@ def __init__( self, url: Union[str, httpx.URL], json_serialize: Callable = json.dumps, + json_unserialize: Callable = json.loads, **kwargs, ): """Initialize the transport with the given httpx parameters. @@ -45,10 +46,13 @@ def __init__( :param url: The GraphQL server URL. Example: 'https://server.com:PORT/path'. :param json_serialize: Json serializer callable. By default json.dumps() function. + :param json_unserialize: Json unserializer callable. + By default json.loads() function. :param kwargs: Extra args passed to the `httpx` client. """ self.url = url self.json_serialize = json_serialize + self.json_unserialize = json_unserialize self.kwargs = kwargs def _prepare_request( @@ -145,7 +149,11 @@ def _prepare_result(self, response: httpx.Response) -> ExecutionResult: log.debug("<<< %s", response.text) try: - result: Dict[str, Any] = response.json() + result: Dict[str, Any] + if self.json_unserialize == json.loads: + result = response.json() + else: + result = self.json_unserialize(response.content) except Exception: self._raise_response_error(response, "Not a JSON answer") diff --git a/tests/test_httpx_async.py b/tests/test_httpx_async.py index e5be73ec..e8350ad1 100644 --- a/tests/test_httpx_async.py +++ b/tests/test_httpx_async.py @@ -1389,3 +1389,53 @@ async def handler(request): # Checking that there is no space after the colon in the log expected_log = '"query":"query getContinents' assert expected_log in caplog.text + + +query_float_str = """ + query getPi { + pi + } +""" + +query_float_server_answer_data = '{"pi": 3.141592653589793238462643383279502884197}' + +query_float_server_answer = f'{{"data":{query_float_server_answer_data}}}' + + +@pytest.mark.asyncio +async def test_httpx_json_unserializer(event_loop, aiohttp_server): + from aiohttp import web + from decimal import Decimal + from functools import partial + from gql.transport.httpx import HTTPXAsyncTransport + + async def handler(request): + return web.Response( + text=query_float_server_answer, + content_type="application/json", + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + json_loads = partial(json.loads, parse_float=Decimal) + + transport = HTTPXAsyncTransport( + url=url, + timeout=10, + json_unserialize=json_loads, + ) + + async with Client(transport=transport) as session: + + query = gql(query_float_str) + + # Execute query asynchronously + result = await session.execute(query) + + pi = result["pi"] + + assert pi == Decimal("3.141592653589793238462643383279502884197")