Skip to content

Commit fc73ba1

Browse files
committed
fix some tests
1 parent 6584f54 commit fc73ba1

File tree

2 files changed

+29
-10
lines changed

2 files changed

+29
-10
lines changed

gql/transport/aiohttp_websockets.py

+28-9
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@
1616
)
1717

1818
import aiohttp
19-
from aiohttp.client_reqrep import Fingerprint
20-
from aiohttp.helpers import BasicAuth, hdrs
19+
from aiohttp import hdrs, BasicAuth, Fingerprint, WSMsgType
2120
from aiohttp.typedefs import LooseHeaders, StrOrURL
2221
from graphql import DocumentNode, ExecutionResult, print_ast
2322
from multidict import CIMultiDict, CIMultiDictProxy
@@ -32,6 +31,11 @@
3231
)
3332
from gql.transport.websockets_base import ListenerQueue
3433

34+
try:
35+
from json.decoder import JSONDecodeError
36+
except ImportError:
37+
from simplejson import JSONDecodeError
38+
3539
log = logging.getLogger("gql.transport.aiohttp_websockets")
3640

3741

@@ -149,7 +153,7 @@ def __init__(
149153
self.close_exception: Optional[Exception] = None
150154

151155
def _parse_answer_graphqlws(
152-
self, json_answer: Dict[str, Any]
156+
self, answer: Dict[str, Any]
153157
) -> Tuple[str, Optional[int], Optional[ExecutionResult]]:
154158
"""Parse the answer received from the server if the server supports the
155159
graphql-ws protocol.
@@ -175,14 +179,14 @@ def _parse_answer_graphqlws(
175179
execution_result: Optional[ExecutionResult] = None
176180

177181
try:
178-
answer_type = str(json_answer.get("type"))
182+
answer_type = str(answer.get("type"))
179183

180184
if answer_type in ["next", "error", "complete"]:
181-
answer_id = int(str(json_answer.get("id")))
185+
answer_id = int(str(answer.get("id")))
182186

183187
if answer_type == "next" or answer_type == "error":
184188

185-
payload = json_answer.get("payload")
189+
payload = answer.get("payload")
186190

187191
if answer_type == "next":
188192

@@ -213,7 +217,7 @@ def _parse_answer_graphqlws(
213217
)
214218

215219
elif answer_type in ["ping", "pong", "connection_ack"]:
216-
self.payloads[answer_type] = json_answer.get("payload", None)
220+
self.payloads[answer_type] = answer.get("payload", None)
217221

218222
else:
219223
raise ValueError
@@ -223,7 +227,7 @@ def _parse_answer_graphqlws(
223227

224228
except ValueError as e:
225229
raise TransportProtocolError(
226-
f"Server did not return a GraphQL result: {json_answer}"
230+
f"Server did not return a GraphQL result: {answer}"
227231
) from e
228232

229233
return answer_type, answer_id, execution_result
@@ -471,14 +475,27 @@ async def _send(self, message: Dict[str, Any]) -> None:
471475
raise e
472476

473477
async def _receive(self) -> Dict[str, Any]:
478+
log.debug("Entering _receive()")
474479

475480
if self.websocket is None:
476481
raise TransportClosed("WebSocket connection is closed")
477482

478-
answer = await self.websocket.receive_json()
483+
try:
484+
answer = await self.websocket.receive_json()
485+
except TypeError as e:
486+
answer = await self.websocket.receive()
487+
if answer.type in (WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.CLOSING):
488+
self._fail(e, clean_close=True)
489+
raise ConnectionResetError
490+
else:
491+
self._fail(e, clean_close=False)
492+
except JSONDecodeError as e:
493+
self._fail(e)
479494

480495
log.info("<<< %s", answer)
481496

497+
log.debug("Exiting _receive()")
498+
482499
return answer
483500

484501
def _remove_listener(self, query_id) -> None:
@@ -546,6 +563,8 @@ async def _handle_answer(
546563
async def _receive_data_loop(self) -> None:
547564
"""Main asyncio task which will listen to the incoming messages and will
548565
call the parse_answer and handle_answer methods of the subclass."""
566+
log.debug("Entering _receive_data_loop()")
567+
549568
try:
550569
while True:
551570

tests/test_aiohttp_websocket_exceptions.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ async def test_aiohttp_websocket_transport_protocol_errors(
250250

251251
query = gql("query { hello }")
252252

253-
with pytest.raises(TransportProtocolError):
253+
with pytest.raises((TransportProtocolError, TransportQueryError)):
254254
await session.execute(query)
255255

256256

0 commit comments

Comments
 (0)