diff --git a/.travis.yml b/.travis.yml index 5a12f0a7..dc65c268 100644 --- a/.travis.yml +++ b/.travis.yml @@ -24,7 +24,7 @@ python: 3.7 before_install: - pip install pycodestyle - - pycodestyle --exclude=".git,.tox,__pycache,venv,.eggs,build,jupyterhub_config.py" + - pycodestyle --exclude=".git,.tox,__pycache,venv,.eggs,build,jupyterhub_config.py,cylc/uiserver/tornado_ws.py" install: - pip install .[all] script: diff --git a/cylc/uiserver/tornado_ws.py b/cylc/uiserver/tornado_ws.py new file mode 100644 index 00000000..1ecc682c --- /dev/null +++ b/cylc/uiserver/tornado_ws.py @@ -0,0 +1,120 @@ +# This file is a temporary solution for subscriptions with graphql_ws and +# Tornado, from the following pending PR to graphql-ws: +# https://github.com/graphql-python/graphql-ws/pull/25/files +# The file was copied from this revision: +# https://github.com/graphql-python/graphql-ws/blob/cf560b9a5d18d4a3908dc2cfe2199766cc988fef/graphql_ws/tornado.py + +from inspect import isawaitable + +from asyncio import ensure_future, wait, shield +from tornado.websocket import WebSocketClosedError +from graphql.execution.executors.asyncio import AsyncioExecutor + +from graphql_ws.base import ConnectionClosedException, BaseConnectionContext, BaseSubscriptionServer +from graphql_ws.observable_aiter import setup_observable_extension + +from graphql_ws.constants import ( + GQL_CONNECTION_ACK, + GQL_CONNECTION_ERROR, + GQL_COMPLETE +) + +setup_observable_extension() + + +class TornadoConnectionContext(BaseConnectionContext): + async def receive(self): + try: + msg = await self.ws.recv() + return msg + except WebSocketClosedError: + raise ConnectionClosedException() + + async def send(self, data): + if self.closed: + return + await self.ws.write_message(data) + + @property + def closed(self): + return self.ws.close_code is not None + + async def close(self, code): + await self.ws.close(code) + + +class TornadoSubscriptionServer(BaseSubscriptionServer): + def __init__(self, schema, keep_alive=True, loop=None): + self.loop = loop + super().__init__(schema, keep_alive) + + def get_graphql_params(self, *args, **kwargs): + params = super(TornadoSubscriptionServer, + self).get_graphql_params(*args, **kwargs) + return dict(params, return_promise=True, executor=AsyncioExecutor(loop=self.loop)) + + async def _handle(self, ws, request_context): + connection_context = TornadoConnectionContext(ws, request_context) + await self.on_open(connection_context) + pending = set() + while True: + try: + if connection_context.closed: + raise ConnectionClosedException() + message = await connection_context.receive() + except ConnectionClosedException: + break + finally: + if pending: + (_, pending) = await wait(pending, timeout=0, loop=self.loop) + + task = ensure_future( + self.on_message(connection_context, message), loop=self.loop) + pending.add(task) + + self.on_close(connection_context) + for task in pending: + task.cancel() + + async def handle(self, ws, request_context=None): + await shield(self._handle(ws, request_context), loop=self.loop) + + async def on_open(self, connection_context): + pass + + def on_close(self, connection_context): + remove_operations = list(connection_context.operations.keys()) + for op_id in remove_operations: + self.unsubscribe(connection_context, op_id) + + async def on_connect(self, connection_context, payload): + pass + + async def on_connection_init(self, connection_context, op_id, payload): + try: + await self.on_connect(connection_context, payload) + await self.send_message(connection_context, op_type=GQL_CONNECTION_ACK) + except Exception as e: + await self.send_error(connection_context, op_id, e, GQL_CONNECTION_ERROR) + await connection_context.close(1011) + + async def on_start(self, connection_context, op_id, params): + execution_result = self.execute( + connection_context.request_context, params) + + if isawaitable(execution_result): + execution_result = await execution_result + + if not hasattr(execution_result, '__aiter__'): + await self.send_execution_result(connection_context, op_id, execution_result) + else: + iterator = await execution_result.__aiter__() + connection_context.register_operation(op_id, iterator) + async for single_result in iterator: + if not connection_context.has_operation(op_id): + break + await self.send_execution_result(connection_context, op_id, single_result) + await self.send_message(connection_context, op_id, GQL_COMPLETE) + + async def on_stop(self, connection_context, op_id): + self.unsubscribe(connection_context, op_id) diff --git a/setup.py b/setup.py index a80e8b38..f59dd1b1 100644 --- a/setup.py +++ b/setup.py @@ -41,7 +41,8 @@ def find_version(*file_paths): 'jupyterhub==1.0.*', 'tornado==6.0.*', 'graphene-tornado==2.1.*', - 'cylc-flow==8.0a1' + 'cylc-flow==8.0a1', + 'graphql-ws==0.3.*' ] setup_requires = [