From b5ceeab13ebc3f75c994ff8f500a407849d3a7db Mon Sep 17 00:00:00 2001
From: Leszek Hanusz <leszek.hanusz@gmail.com>
Date: Thu, 8 Feb 2024 14:43:59 +0100
Subject: [PATCH] Adding json_unserialize attribute the httpx transport

---
 gql/transport/httpx.py    | 10 +++++++-
 tests/test_httpx_async.py | 50 +++++++++++++++++++++++++++++++++++++++
 2 files changed, 59 insertions(+), 1 deletion(-)

diff --git a/gql/transport/httpx.py b/gql/transport/httpx.py
index cfc25dc9..a328127f 100644
--- a/gql/transport/httpx.py
+++ b/gql/transport/httpx.py
@@ -38,6 +38,7 @@ def __init__(
         self,
         url: Union[str, httpx.URL],
         json_serialize: Callable = json.dumps,
+        json_unserialize: Callable = json.loads,
         **kwargs,
     ):
         """Initialize the transport with the given httpx parameters.
@@ -45,10 +46,13 @@ def __init__(
         :param url: The GraphQL server URL. Example: 'https://server.com:PORT/path'.
         :param json_serialize: Json serializer callable.
                 By default json.dumps() function.
+        :param json_unserialize: Json unserializer callable.
+                By default json.loads() function.
         :param kwargs: Extra args passed to the `httpx` client.
         """
         self.url = url
         self.json_serialize = json_serialize
+        self.json_unserialize = json_unserialize
         self.kwargs = kwargs
 
     def _prepare_request(
@@ -145,7 +149,11 @@ def _prepare_result(self, response: httpx.Response) -> ExecutionResult:
             log.debug("<<< %s", response.text)
 
         try:
-            result: Dict[str, Any] = response.json()
+            result: Dict[str, Any]
+            if self.json_unserialize == json.loads:
+                result = response.json()
+            else:
+                result = self.json_unserialize(response.content)
 
         except Exception:
             self._raise_response_error(response, "Not a JSON answer")
diff --git a/tests/test_httpx_async.py b/tests/test_httpx_async.py
index e5be73ec..e8350ad1 100644
--- a/tests/test_httpx_async.py
+++ b/tests/test_httpx_async.py
@@ -1389,3 +1389,53 @@ async def handler(request):
     # Checking that there is no space after the colon in the log
     expected_log = '"query":"query getContinents'
     assert expected_log in caplog.text
+
+
+query_float_str = """
+    query getPi {
+      pi
+    }
+"""
+
+query_float_server_answer_data = '{"pi": 3.141592653589793238462643383279502884197}'
+
+query_float_server_answer = f'{{"data":{query_float_server_answer_data}}}'
+
+
+@pytest.mark.asyncio
+async def test_httpx_json_unserializer(event_loop, aiohttp_server):
+    from aiohttp import web
+    from decimal import Decimal
+    from functools import partial
+    from gql.transport.httpx import HTTPXAsyncTransport
+
+    async def handler(request):
+        return web.Response(
+            text=query_float_server_answer,
+            content_type="application/json",
+        )
+
+    app = web.Application()
+    app.router.add_route("POST", "/", handler)
+    server = await aiohttp_server(app)
+
+    url = str(server.make_url("/"))
+
+    json_loads = partial(json.loads, parse_float=Decimal)
+
+    transport = HTTPXAsyncTransport(
+        url=url,
+        timeout=10,
+        json_unserialize=json_loads,
+    )
+
+    async with Client(transport=transport) as session:
+
+        query = gql(query_float_str)
+
+        # Execute query asynchronously
+        result = await session.execute(query)
+
+        pi = result["pi"]
+
+        assert pi == Decimal("3.141592653589793238462643383279502884197")