Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sync batching to requests sync transport #431

Merged
merged 15 commits into from
Sep 5, 2023
Merged
1 change: 0 additions & 1 deletion docs/code_examples/console_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import logging

from aioconsole import ainput

from gql import Client, gql
from gql.transport.aiohttp import AIOHTTPTransport

Expand Down
1 change: 0 additions & 1 deletion docs/code_examples/fastapi_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

from fastapi import FastAPI, HTTPException
from fastapi.responses import HTMLResponse

from gql import Client, gql
from gql.transport.aiohttp import AIOHTTPTransport

Expand Down
178 changes: 169 additions & 9 deletions gql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
Callable,
Dict,
Generator,
List,
Optional,
TypeVar,
Union,
Expand All @@ -27,6 +28,8 @@
validate,
)

from gql.transport.data_structures.graphql_request import GraphQLRequest

from .transport.async_transport import AsyncTransport
from .transport.exceptions import TransportClosed, TransportQueryError
from .transport.local_schema import LocalSchemaTransport
Expand Down Expand Up @@ -236,6 +239,24 @@ def execute_sync(
**kwargs,
)

def execute_batch_sync(
self,
reqs: List[GraphQLRequest],
serialize_variables: Optional[bool] = None,
parse_result: Optional[bool] = None,
get_execution_result: bool = False,
**kwargs,
) -> Union[List[Dict[str, Any]], List[ExecutionResult]]:
""":meta private:"""
with self as session:
return session.execute_batch(
reqs,
serialize_variables=serialize_variables,
parse_result=parse_result,
get_execution_result=get_execution_result,
**kwargs,
)

@overload
async def execute_async(
self,
Expand Down Expand Up @@ -375,7 +396,6 @@ def execute(
"""

if isinstance(self.transport, AsyncTransport):

# Get the current asyncio event loop
# Or create a new event loop if there isn't one (in a new Thread)
try:
Expand Down Expand Up @@ -418,6 +438,48 @@ def execute(
**kwargs,
)

def execute_batch(
self,
reqs: List[GraphQLRequest],
serialize_variables: Optional[bool] = None,
parse_result: Optional[bool] = None,
get_execution_result: bool = False,
**kwargs,
) -> Union[List[Dict[str, Any]], List[ExecutionResult]]:
"""Execute the provided document AST against the remote server using
the transport provided during init.

This function **WILL BLOCK** until the result is received from the server.

Either the transport is sync and we execute the query synchronously directly
OR the transport is async and we execute the query in the asyncio loop
(blocking here until answer).

This method will:

- connect using the transport to get a session
- execute the GraphQL request on the transport session
- close the session and close the connection to the server

If you have multiple requests to send, it is better to get your own session
and execute the requests in your session.

The extra arguments passed in the method will be passed to the transport
execute method.
"""

if isinstance(self.transport, AsyncTransport):
raise NotImplementedError("Batching is not implemented for async yet.")

else: # Sync transports
return self.execute_batch_sync(
reqs,
serialize_variables=serialize_variables,
parse_result=parse_result,
get_execution_result=get_execution_result,
**kwargs,
)

@overload
def subscribe_async(
self,
Expand Down Expand Up @@ -476,7 +538,6 @@ async def subscribe_async(
]:
""":meta private:"""
async with self as session:

generator = session.subscribe(
document,
variable_values=variable_values,
Expand Down Expand Up @@ -600,7 +661,6 @@ def subscribe(
pass

except (KeyboardInterrupt, Exception, GeneratorExit):

# Graceful shutdown
asyncio.ensure_future(async_generator.aclose(), loop=loop)

Expand Down Expand Up @@ -661,11 +721,9 @@ async def close_async(self):
await self.transport.close()

async def __aenter__(self):

return await self.connect_async()

async def __aexit__(self, exc_type, exc, tb):

await self.close_async()

def connect_sync(self):
Expand Down Expand Up @@ -705,7 +763,6 @@ def close_sync(self):
self.transport.close()

def __enter__(self):

return self.connect_sync()

def __exit__(self, *args):
Expand Down Expand Up @@ -880,6 +937,112 @@ def execute(

return result.data

def _execute_batch(
self,
reqs: List[GraphQLRequest],
serialize_variables: Optional[bool] = None,
parse_result: Optional[bool] = None,
**kwargs,
) -> List[ExecutionResult]:
"""Execute the provided document AST synchronously using
the sync transport, returning an ExecutionResult object.

:param document: GraphQL query as AST Node object.
:param variable_values: Dictionary of input parameters.
:param operation_name: Name of the operation that shall be executed.
:param serialize_variables: whether the variable values should be
serialized. Used for custom scalars and/or enums.
By default use the serialize_variables argument of the client.
:param parse_result: Whether gql will unserialize the result.
By default use the parse_results argument of the client.

The extra arguments are passed to the transport execute method."""

# Validate document
if self.client.schema:
for req in reqs:
self.client.validate(req.document)

# Parse variable values for custom scalars if requested
if serialize_variables or (
serialize_variables is None and self.client.serialize_variables
):
reqs = [
req.serialize_variable_values(self.client.schema)
if req.variable_values is not None
else req
for req in reqs
]

results = self.transport.execute_batch(reqs, **kwargs)

# Unserialize the result if requested
if self.client.schema:
if parse_result or (parse_result is None and self.client.parse_results):
for result in results:
result.data = parse_result_fn(
self.client.schema,
req.document,
result.data,
operation_name=req.operation_name,
)

return results

def execute_batch(
self,
reqs: List[GraphQLRequest],
serialize_variables: Optional[bool] = None,
parse_result: Optional[bool] = None,
get_execution_result: bool = False,
**kwargs,
) -> Union[List[Dict[str, Any]], List[ExecutionResult]]:
"""Execute the provided document AST synchronously using
the sync transport.

Raises a TransportQueryError if an error has been returned in
the ExecutionResult.

:param document: GraphQL query as AST Node object.
:param variable_values: Dictionary of input parameters.
:param operation_name: Name of the operation that shall be executed.
:param serialize_variables: whether the variable values should be
serialized. Used for custom scalars and/or enums.
By default use the serialize_variables argument of the client.
:param parse_result: Whether gql will unserialize the result.
By default use the parse_results argument of the client.
:param get_execution_result: return the full ExecutionResult instance instead of
only the "data" field. Necessary if you want to get the "extensions" field.

The extra arguments are passed to the transport execute method."""

# Validate and execute on the transport
results = self._execute_batch(
reqs,
serialize_variables=serialize_variables,
parse_result=parse_result,
**kwargs,
)

for result in results:
# Raise an error if an error is returned in the ExecutionResult object
if result.errors:
raise TransportQueryError(
str_first_element(result.errors),
errors=result.errors,
data=result.data,
extensions=result.extensions,
)

assert (
result.data is not None
), "Transport returned an ExecutionResult without data or errors"

if get_execution_result:
return results

return [result.data for result in results] # type: ignore

def fetch_schema(self) -> None:
"""Fetch the GraphQL schema explicitly using introspection.

Expand Down Expand Up @@ -966,7 +1129,6 @@ async def _subscribe(

try:
async for result in inner_generator:

if self.client.schema:
if parse_result or (
parse_result is None and self.client.parse_results
Expand Down Expand Up @@ -1070,7 +1232,6 @@ async def subscribe(
try:
# Validate and subscribe on the transport
async for result in inner_generator:

# Raise an error if an error is returned in the ExecutionResult object
if result.errors:
raise TransportQueryError(
Expand Down Expand Up @@ -1343,7 +1504,6 @@ async def _connection_loop(self):
"""

while True:

# Connect to the transport with the retry decorator
# By default it should keep retrying until it connect
await self._connect_with_retries()
Expand Down
3 changes: 3 additions & 0 deletions gql/transport/data_structures/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .graphql_request import GraphQLRequest

__all__ = ["GraphQLRequest"]
37 changes: 37 additions & 0 deletions gql/transport/data_structures/graphql_request.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from typing import Any, Dict, Optional

from attr import dataclass
from graphql import DocumentNode, GraphQLSchema

from gql.utilities import serialize_variable_values


@dataclass(frozen=True)
class GraphQLRequest:
"""GraphQL Request to be executed."""

document: DocumentNode
"""GraphQL query as AST Node object."""

variable_values: Optional[Dict[str, Any]] = None
"""Dictionary of input parameters (Default: None)."""

operation_name: Optional[str] = None
"""
Name of the operation that shall be executed.
Only required in multi-operation documents (Default: None).
"""

def serialize_variable_values(self, schema: GraphQLSchema) -> "GraphQLRequest":
assert self.variable_values

return GraphQLRequest(
document=self.document,
variable_values=serialize_variable_values(
schema=schema,
document=self.document,
variable_values=self.variable_values,
operation_name=self.operation_name,
),
operation_name=self.operation_name,
)
10 changes: 10 additions & 0 deletions gql/transport/httpx.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import httpx
from graphql import DocumentNode, ExecutionResult, print_ast

from gql.transport.data_structures import GraphQLRequest

from ..utils import extract_files
from . import AsyncTransport, Transport
from .exceptions import (
Expand Down Expand Up @@ -229,6 +231,14 @@ def execute( # type: ignore

return self._prepare_result(response)

def execute_batch(
self,
reqs: List[GraphQLRequest],
*args,
**kwargs,
) -> List[ExecutionResult]:
return super().execute_batch(reqs, *args, **kwargs)

def close(self):
"""Closing the transport by closing the inner session"""
if self.client:
Expand Down
Loading