16
16
)
17
17
18
18
import aiohttp
19
- from aiohttp .client_reqrep import Fingerprint
20
- from aiohttp .helpers import BasicAuth , hdrs
19
+ from aiohttp import hdrs , BasicAuth , Fingerprint , WSMsgType
21
20
from aiohttp .typedefs import LooseHeaders , StrOrURL
22
21
from graphql import DocumentNode , ExecutionResult , print_ast
23
22
from multidict import CIMultiDict , CIMultiDictProxy
32
31
)
33
32
from gql .transport .websockets_base import ListenerQueue
34
33
34
+ try :
35
+ from json .decoder import JSONDecodeError
36
+ except ImportError :
37
+ from simplejson import JSONDecodeError
38
+
35
39
log = logging .getLogger ("gql.transport.aiohttp_websockets" )
36
40
37
41
@@ -149,7 +153,7 @@ def __init__(
149
153
self .close_exception : Optional [Exception ] = None
150
154
151
155
def _parse_answer_graphqlws (
152
- self , json_answer : Dict [str , Any ]
156
+ self , answer : Dict [str , Any ]
153
157
) -> Tuple [str , Optional [int ], Optional [ExecutionResult ]]:
154
158
"""Parse the answer received from the server if the server supports the
155
159
graphql-ws protocol.
@@ -175,14 +179,14 @@ def _parse_answer_graphqlws(
175
179
execution_result : Optional [ExecutionResult ] = None
176
180
177
181
try :
178
- answer_type = str (json_answer .get ("type" ))
182
+ answer_type = str (answer .get ("type" ))
179
183
180
184
if answer_type in ["next" , "error" , "complete" ]:
181
- answer_id = int (str (json_answer .get ("id" )))
185
+ answer_id = int (str (answer .get ("id" )))
182
186
183
187
if answer_type == "next" or answer_type == "error" :
184
188
185
- payload = json_answer .get ("payload" )
189
+ payload = answer .get ("payload" )
186
190
187
191
if answer_type == "next" :
188
192
@@ -213,7 +217,7 @@ def _parse_answer_graphqlws(
213
217
)
214
218
215
219
elif answer_type in ["ping" , "pong" , "connection_ack" ]:
216
- self .payloads [answer_type ] = json_answer .get ("payload" , None )
220
+ self .payloads [answer_type ] = answer .get ("payload" , None )
217
221
218
222
else :
219
223
raise ValueError
@@ -223,7 +227,7 @@ def _parse_answer_graphqlws(
223
227
224
228
except ValueError as e :
225
229
raise TransportProtocolError (
226
- f"Server did not return a GraphQL result: { json_answer } "
230
+ f"Server did not return a GraphQL result: { answer } "
227
231
) from e
228
232
229
233
return answer_type , answer_id , execution_result
@@ -471,14 +475,27 @@ async def _send(self, message: Dict[str, Any]) -> None:
471
475
raise e
472
476
473
477
async def _receive (self ) -> Dict [str , Any ]:
478
+ log .debug ("Entering _receive()" )
474
479
475
480
if self .websocket is None :
476
481
raise TransportClosed ("WebSocket connection is closed" )
477
482
478
- answer = await self .websocket .receive_json ()
483
+ try :
484
+ answer = await self .websocket .receive_json ()
485
+ except TypeError as e :
486
+ answer = await self .websocket .receive ()
487
+ if answer .type in (WSMsgType .CLOSE , WSMsgType .CLOSED , WSMsgType .CLOSING ):
488
+ self ._fail (e , clean_close = True )
489
+ raise ConnectionResetError
490
+ else :
491
+ self ._fail (e , clean_close = False )
492
+ except JSONDecodeError as e :
493
+ self ._fail (e )
479
494
480
495
log .info ("<<< %s" , answer )
481
496
497
+ log .debug ("Exiting _receive()" )
498
+
482
499
return answer
483
500
484
501
def _remove_listener (self , query_id ) -> None :
@@ -546,6 +563,8 @@ async def _handle_answer(
546
563
async def _receive_data_loop (self ) -> None :
547
564
"""Main asyncio task which will listen to the incoming messages and will
548
565
call the parse_answer and handle_answer methods of the subclass."""
566
+ log .debug ("Entering _receive_data_loop()" )
567
+
549
568
try :
550
569
while True :
551
570
0 commit comments