Skip to content

Commit 4e32a11

Browse files
committed
add an option deferred_fetch to Cursor.execute()
1 parent cbc0e12 commit 4e32a11

File tree

4 files changed

+87
-15
lines changed

4 files changed

+87
-15
lines changed

Diff for: tests/unit/test_client.py

+53
Original file line numberDiff line numberDiff line change
@@ -1016,6 +1016,59 @@ def json(self):
10161016
assert isinstance(result, TrinoResult)
10171017

10181018

1019+
def test_trino_query_deferred_fetch(sample_get_response_data):
1020+
"""
1021+
Validates that the `TrinoQuery.execute` function deferred_fetch and non-block execution
1022+
"""
1023+
1024+
class MockResponse(mock.Mock):
1025+
# Fake response class
1026+
@property
1027+
def headers(self):
1028+
return {
1029+
'X-Trino-Fake-1': 'one',
1030+
'X-Trino-Fake-2': 'two',
1031+
}
1032+
1033+
def json(self):
1034+
return sample_get_response_data
1035+
1036+
rows = sample_get_response_data['data']
1037+
sample_get_response_data['data'] = []
1038+
sql = 'execute my_stament using 1, 2, 3'
1039+
request = TrinoRequest(
1040+
host="coordinator",
1041+
port=8080,
1042+
client_session=ClientSession(
1043+
user="test",
1044+
source="test",
1045+
catalog="test",
1046+
schema="test",
1047+
properties={},
1048+
),
1049+
http_scheme="http",
1050+
)
1051+
query = TrinoQuery(
1052+
request=request,
1053+
query=sql
1054+
)
1055+
1056+
with (
1057+
mock.patch.object(request, 'post', return_value=MockResponse()) as mock_post,
1058+
mock.patch.object(query, 'fetch', return_value=rows) as mock_fetch,
1059+
):
1060+
result = query.execute()
1061+
mock_fetch.assert_called_once()
1062+
assert result.rows == rows
1063+
1064+
with (
1065+
mock.patch.object(request, 'post', return_value=MockResponse()) as mock_post,
1066+
mock.patch.object(query, 'fetch', return_value=rows) as mock_fetch,
1067+
):
1068+
result = query.execute(deferred_fetch=True)
1069+
mock_fetch.assert_not_called()
1070+
1071+
10191072
def test_delay_exponential_without_jitter():
10201073
max_delay = 1200.0
10211074
get_delay = _DelayExponential(base=5, jitter=False, max_delay=max_delay)

Diff for: trino/client.py

+17-10
Original file line numberDiff line numberDiff line change
@@ -775,13 +775,18 @@ def result(self):
775775
def info_uri(self):
776776
return self._info_uri
777777

778-
def execute(self, additional_http_headers=None) -> TrinoResult:
779-
"""Initiate a Trino query by sending the SQL statement
780-
781-
This is the first HTTP request sent to the coordinator.
782-
It sets the query_id and returns a Result object used to
783-
track the rows returned by the query. To fetch all rows,
784-
call fetch() until finished is true.
778+
def execute(
779+
self,
780+
additional_http_headers: Optional[Dict[str, Any]] = None,
781+
deferred_fetch: bool = False,
782+
) -> TrinoResult:
783+
"""Initiate a Trino query by sending the SQL statement to the coordinator.
784+
To fetch all rows, call fetch() until finished is true.
785+
786+
Parameters:
787+
additional_http_headers: extra headers send to the Trino server.
788+
deferred_fetch: By default, the execution is blocked until at least one row is received
789+
or query is finished or cancelled. To continue without waiting the result, set deferred_fetch=True.
785790
"""
786791
if self.cancelled:
787792
raise exceptions.TrinoUserError("Query has been cancelled", self.query_id)
@@ -799,9 +804,11 @@ def execute(self, additional_http_headers=None) -> TrinoResult:
799804
rows = self._row_mapper.map(status.rows) if self._row_mapper else status.rows
800805
self._result = TrinoResult(self, rows)
801806

802-
# Execute should block until at least one row is received or query is finished or cancelled
803-
while not self.finished and not self.cancelled and len(self._result.rows) == 0:
804-
self._result.rows += self.fetch()
807+
if not deferred_fetch:
808+
# Execute should block until at least one row is received or query is finished or cancelled
809+
while not self.finished and not self.cancelled and len(self._result.rows) == 0:
810+
self._result.rows += self.fetch()
811+
805812
return self._result
806813

807814
def _update_state(self, status):

Diff for: trino/dbapi.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -558,7 +558,10 @@ def _deallocate_prepared_statement(self, statement_name: str) -> None:
558558
def _generate_unique_statement_name(self):
559559
return 'st_' + uuid.uuid4().hex.replace('-', '')
560560

561-
def execute(self, operation, params=None):
561+
def execute(self, operation, params=None, **kwargs: Any):
562+
additional_http_headers = kwargs.get("additional_http_headers", None)
563+
deferred_fetch = kwargs.get("deferred_fetch", False)
564+
562565
if params:
563566
assert isinstance(params, (list, tuple)), (
564567
'params must be a list or tuple containing the query '
@@ -575,7 +578,10 @@ def execute(self, operation, params=None):
575578
self._query = self._execute_prepared_statement(
576579
statement_name, params
577580
)
578-
self._iterator = iter(self._query.execute())
581+
self._iterator = iter(self._query.execute(
582+
additional_http_headers=additional_http_headers,
583+
deferred_fetch=deferred_fetch,
584+
))
579585
finally:
580586
# Send deallocate statement
581587
# At this point the query can be deallocated since it has already
@@ -584,12 +590,18 @@ def execute(self, operation, params=None):
584590
self._deallocate_prepared_statement(statement_name)
585591
else:
586592
self._query = self._execute_immediate_statement(operation, params)
587-
self._iterator = iter(self._query.execute())
593+
self._iterator = iter(self._query.execute(
594+
additional_http_headers=additional_http_headers,
595+
deferred_fetch=deferred_fetch,
596+
))
588597

589598
else:
590599
self._query = trino.client.TrinoQuery(self._request, query=operation,
591600
legacy_primitive_types=self._legacy_primitive_types)
592-
self._iterator = iter(self._query.execute())
601+
self._iterator = iter(self._query.execute(
602+
additional_http_headers=additional_http_headers,
603+
deferred_fetch=deferred_fetch,
604+
))
593605
return self
594606

595607
def executemany(self, operation, seq_of_params):

Diff for: trino/sqlalchemy/dialect.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ def _get_default_schema_name(self, connection: Connection) -> Optional[str]:
377377
def do_execute(
378378
self, cursor: Cursor, statement: str, parameters: Tuple[Any, ...], context: DefaultExecutionContext = None
379379
):
380-
cursor.execute(statement, parameters)
380+
cursor.execute(statement, parameters, **context.execution_options)
381381

382382
def do_rollback(self, dbapi_connection: trino_dbapi.Connection):
383383
if dbapi_connection.transaction is not None:

0 commit comments

Comments
 (0)