Skip to content

Commit a8b276d

Browse files
committed
Modify transport init parameters
Adding connect_args parameter to be able to provide any argument to the ws_connect method Removing the following parameters (they can now be provided in the connect_args dict): - autoclose - autoping - compress - max_msg_size - verify_ssl - method Renaming protocols to subprotocols to be more similar to the websockets transport
1 parent 277fd5d commit a8b276d

File tree

3 files changed

+29
-30
lines changed

3 files changed

+29
-30
lines changed

gql/transport/aiohttp_websockets.py

+24-27
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
)
1818

1919
import aiohttp
20-
from aiohttp import BasicAuth, Fingerprint, WSMsgType, hdrs
20+
from aiohttp import BasicAuth, Fingerprint, WSMsgType
2121
from aiohttp.typedefs import LooseHeaders, StrOrURL
2222
from graphql import DocumentNode, ExecutionResult, print_ast
2323
from multidict import CIMultiDictProxy
@@ -110,23 +110,17 @@ def __init__(
110110
self,
111111
url: StrOrURL,
112112
*,
113-
method: str = hdrs.METH_GET,
114-
protocols: Collection[str] = (),
115-
autoclose: bool = True,
116-
autoping: bool = True,
113+
subprotocols: Optional[Collection[str]] = None,
117114
heartbeat: Optional[float] = None,
118115
auth: Optional[BasicAuth] = None,
119116
origin: Optional[str] = None,
120117
params: Optional[Mapping[str, str]] = None,
121118
headers: Optional[LooseHeaders] = None,
122119
proxy: Optional[StrOrURL] = None,
123120
proxy_auth: Optional[BasicAuth] = None,
121+
proxy_headers: Optional[LooseHeaders] = None,
124122
ssl: Optional[Union[SSLContext, Literal[False], Fingerprint]] = None,
125123
ssl_context: Optional[SSLContext] = None,
126-
verify_ssl: Optional[bool] = True,
127-
proxy_headers: Optional[LooseHeaders] = None,
128-
compress: int = 0,
129-
max_msg_size: int = 4 * 1024 * 1024,
130124
websocket_close_timeout: float = 10.0,
131125
receive_timeout: Optional[float] = None,
132126
ssl_close_timeout: Optional[Union[int, float]] = 10,
@@ -139,32 +133,31 @@ def __init__(
139133
pong_timeout: Optional[Union[int, float]] = None,
140134
answer_pings: bool = True,
141135
client_session_args: Optional[Dict[str, Any]] = None,
136+
connect_args: Dict[str, Any] = {},
142137
) -> None:
143138
self.url: StrOrURL = url
144-
self.headers: Optional[LooseHeaders] = headers
145-
self.auth: Optional[BasicAuth] = auth
146-
self.autoclose: bool = autoclose
147-
self.autoping: bool = autoping
148-
self.compress: int = compress
149139
self.heartbeat: Optional[float] = heartbeat
150-
self.max_msg_size: int = max_msg_size
151-
self.method: str = method
140+
self.auth: Optional[BasicAuth] = auth
152141
self.origin: Optional[str] = origin
153142
self.params: Optional[Mapping[str, str]] = params
154-
self.protocols: Collection[str] = protocols
143+
self.headers: Optional[LooseHeaders] = headers
144+
155145
self.proxy: Optional[StrOrURL] = proxy
156146
self.proxy_auth: Optional[BasicAuth] = proxy_auth
157147
self.proxy_headers: Optional[LooseHeaders] = proxy_headers
158-
self.ssl_close_timeout: Optional[Union[int, float]] = ssl_close_timeout
148+
159149
self.ssl: Optional[Union[SSLContext, Literal[False], Fingerprint]] = ssl
160150
self.ssl_context: Optional[SSLContext] = ssl_context
151+
161152
self.websocket_close_timeout: float = websocket_close_timeout
162153
self.receive_timeout: Optional[float] = receive_timeout
154+
155+
self.ssl_close_timeout: Optional[Union[int, float]] = ssl_close_timeout
163156
self.connect_timeout: Optional[Union[int, float]] = connect_timeout
164157
self.close_timeout: Optional[Union[int, float]] = close_timeout
165158
self.ack_timeout: Optional[Union[int, float]] = ack_timeout
166159
self.keep_alive_timeout: Optional[Union[int, float]] = keep_alive_timeout
167-
self.verify_ssl: Optional[bool] = verify_ssl
160+
168161
self.init_payload: Dict[str, Any] = init_payload
169162

170163
# We need to set an event loop here if there is none
@@ -221,12 +214,15 @@ def __init__(
221214
"""pong_received is an asyncio Event which will fire each time
222215
a pong is received with the graphql-ws protocol"""
223216

224-
self.supported_subprotocols: Collection[str] = protocols or (
217+
self.supported_subprotocols: Collection[str] = subprotocols or (
225218
self.APOLLO_SUBPROTOCOL,
226219
self.GRAPHQLWS_SUBPROTOCOL,
227220
)
221+
228222
self.close_exception: Optional[Exception] = None
223+
229224
self.client_session_args = client_session_args
225+
self.connect_args = connect_args
230226

231227
def _parse_answer_graphqlws(
232228
self, answer: Dict[str, Any]
@@ -782,28 +778,29 @@ async def connect(self) -> None:
782778
if self.websocket is None and not self._connecting:
783779
self._connecting = True
784780

781+
connect_args: Dict[str, Any] = {}
782+
783+
# Adding custom parameters passed from init
784+
if self.connect_args:
785+
connect_args.update(self.connect_args)
786+
785787
try:
786788
self.websocket = await self.session.ws_connect(
787-
method=self.method,
788789
url=self.url,
789790
headers=self.headers,
790791
auth=self.auth,
791-
autoclose=self.autoclose,
792-
autoping=self.autoping,
793-
compress=self.compress,
794792
heartbeat=self.heartbeat,
795-
max_msg_size=self.max_msg_size,
796793
origin=self.origin,
797794
params=self.params,
798795
protocols=self.supported_subprotocols,
799796
proxy=self.proxy,
800797
proxy_auth=self.proxy_auth,
801798
proxy_headers=self.proxy_headers,
799+
timeout=self.websocket_close_timeout,
802800
receive_timeout=self.receive_timeout,
803801
ssl=self.ssl,
804802
ssl_context=None,
805-
timeout=self.websocket_close_timeout,
806-
verify_ssl=self.verify_ssl,
803+
**connect_args,
807804
)
808805
finally:
809806
self._connecting = False

tests/conftest.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -516,7 +516,7 @@ async def client_and_aiohttp_websocket_graphql_server(graphqlws_server):
516516
url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}"
517517
sample_transport = AIOHTTPWebsocketsTransport(
518518
url=url,
519-
protocols=[AIOHTTPWebsocketsTransport.GRAPHQLWS_SUBPROTOCOL],
519+
subprotocols=[AIOHTTPWebsocketsTransport.GRAPHQLWS_SUBPROTOCOL],
520520
)
521521

522522
async with Client(transport=sample_transport) as session:

tests/test_aiohttp_websocket_query.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -499,10 +499,12 @@ async def test_aiohttp_websocket_add_extra_parameters_to_connect(event_loop, ser
499499

500500
url = f"ws://{server.hostname}:{server.port}/graphql"
501501

502-
# Increase max payload size to avoid websockets.exceptions.PayloadTooBig exceptions
502+
# Increase max payload size
503503
transport = AIOHTTPWebsocketsTransport(
504504
url=url,
505-
max_msg_size=(2**21),
505+
connect_args={
506+
"max_msg_size": 2**21,
507+
},
506508
)
507509

508510
query = gql(query1_str)

0 commit comments

Comments
 (0)