Skip to content

Commit ea55fc0

Browse files
mdesmethashhar
authored andcommitted
Enhance Cursor.description
1 parent fd78e41 commit ea55fc0

File tree

3 files changed

+80
-17
lines changed

3 files changed

+80
-17
lines changed

tests/integration/test_dbapi_integration.py

+48-14
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def test_none_query_param(trino_connection):
119119
rows = cur.fetchall()
120120

121121
assert rows[0][0] is None
122+
assert_cursor_description(cur, trino_type="unknown")
122123

123124

124125
def test_string_query_param(trino_connection):
@@ -128,6 +129,7 @@ def test_string_query_param(trino_connection):
128129
rows = cur.fetchall()
129130

130131
assert rows[0][0] == "six'"
132+
assert_cursor_description(cur, trino_type="varchar(4)", size=4)
131133

132134

133135
def test_execute_many(trino_connection):
@@ -241,10 +243,11 @@ def test_legacy_primitive_types_with_connection_and_cursor(
241243
def test_decimal_query_param(trino_connection):
242244
cur = trino_connection.cursor()
243245

244-
cur.execute("SELECT ?", params=(Decimal('0.142857'),))
246+
cur.execute("SELECT ?", params=(Decimal('1112.142857'),))
245247
rows = cur.fetchall()
246248

247-
assert rows[0][0] == Decimal('0.142857')
249+
assert rows[0][0] == Decimal('1112.142857')
250+
assert_cursor_description(cur, trino_type="decimal(10, 6)", precision=10, scale=6)
248251

249252

250253
def test_null_decimal(trino_connection):
@@ -254,6 +257,7 @@ def test_null_decimal(trino_connection):
254257
rows = cur.fetchall()
255258

256259
assert rows[0][0] is None
260+
assert_cursor_description(cur, trino_type="decimal(38, 0)", precision=38, scale=0)
257261

258262

259263
def test_biggest_decimal(trino_connection):
@@ -264,6 +268,7 @@ def test_biggest_decimal(trino_connection):
264268
rows = cur.fetchall()
265269

266270
assert rows[0][0] == params
271+
assert_cursor_description(cur, trino_type="decimal(38, 0)", precision=38, scale=0)
267272

268273

269274
def test_smallest_decimal(trino_connection):
@@ -274,6 +279,7 @@ def test_smallest_decimal(trino_connection):
274279
rows = cur.fetchall()
275280

276281
assert rows[0][0] == params
282+
assert_cursor_description(cur, trino_type="decimal(38, 0)", precision=38, scale=0)
277283

278284

279285
def test_highest_precision_decimal(trino_connection):
@@ -284,6 +290,7 @@ def test_highest_precision_decimal(trino_connection):
284290
rows = cur.fetchall()
285291

286292
assert rows[0][0] == params
293+
assert_cursor_description(cur, trino_type="decimal(38, 38)", precision=38, scale=38)
287294

288295

289296
def test_datetime_query_param(trino_connection):
@@ -295,7 +302,7 @@ def test_datetime_query_param(trino_connection):
295302
rows = cur.fetchall()
296303

297304
assert rows[0][0] == params
298-
assert cur.description[0][1] == "timestamp(6)"
305+
assert_cursor_description(cur, trino_type="timestamp(6)", precision=6)
299306

300307

301308
def test_datetime_with_utc_time_zone_query_param(trino_connection):
@@ -307,7 +314,7 @@ def test_datetime_with_utc_time_zone_query_param(trino_connection):
307314
rows = cur.fetchall()
308315

309316
assert rows[0][0] == params
310-
assert cur.description[0][1] == "timestamp(6) with time zone"
317+
assert_cursor_description(cur, trino_type="timestamp(6) with time zone", precision=6)
311318

312319

313320
def test_datetime_with_numeric_offset_time_zone_query_param(trino_connection):
@@ -321,7 +328,7 @@ def test_datetime_with_numeric_offset_time_zone_query_param(trino_connection):
321328
rows = cur.fetchall()
322329

323330
assert rows[0][0] == params
324-
assert cur.description[0][1] == "timestamp(6) with time zone"
331+
assert_cursor_description(cur, trino_type="timestamp(6) with time zone", precision=6)
325332

326333

327334
def test_datetime_with_named_time_zone_query_param(trino_connection):
@@ -333,7 +340,7 @@ def test_datetime_with_named_time_zone_query_param(trino_connection):
333340
rows = cur.fetchall()
334341

335342
assert rows[0][0] == params
336-
assert cur.description[0][1] == "timestamp(6) with time zone"
343+
assert_cursor_description(cur, trino_type="timestamp(6) with time zone", precision=6)
337344

338345

339346
def test_datetime_with_trailing_zeros(trino_connection):
@@ -343,6 +350,7 @@ def test_datetime_with_trailing_zeros(trino_connection):
343350
rows = cur.fetchall()
344351

345352
assert rows[0][0] == datetime.strptime("2001-08-22 03:04:05.321000", "%Y-%m-%d %H:%M:%S.%f")
353+
assert_cursor_description(cur, trino_type="timestamp(6)", precision=6)
346354

347355

348356
def test_null_datetime_with_time_zone(trino_connection):
@@ -352,6 +360,7 @@ def test_null_datetime_with_time_zone(trino_connection):
352360
rows = cur.fetchall()
353361

354362
assert rows[0][0] is None
363+
assert_cursor_description(cur, trino_type="timestamp(3) with time zone", precision=3)
355364

356365

357366
def test_datetime_with_time_zone_numeric_offset(trino_connection):
@@ -361,6 +370,7 @@ def test_datetime_with_time_zone_numeric_offset(trino_connection):
361370
rows = cur.fetchall()
362371

363372
assert rows[0][0] == datetime.strptime("2001-08-22 03:04:05.321 -08:00", "%Y-%m-%d %H:%M:%S.%f %z")
373+
assert_cursor_description(cur, trino_type="timestamp(3) with time zone", precision=3)
364374

365375

366376
def test_datetimes_with_time_zone_in_dst_gap_query_param(trino_connection):
@@ -404,6 +414,7 @@ def test_date_query_param(trino_connection):
404414
rows = cur.fetchall()
405415

406416
assert rows[0][0] == params
417+
assert_cursor_description(cur, trino_type="date")
407418

408419

409420
def test_null_date(trino_connection):
@@ -413,6 +424,7 @@ def test_null_date(trino_connection):
413424
rows = cur.fetchall()
414425

415426
assert rows[0][0] is None
427+
assert_cursor_description(cur, trino_type="date")
416428

417429

418430
def test_unsupported_python_dates(trino_connection):
@@ -462,6 +474,16 @@ def test_supported_special_dates_query_param(trino_connection):
462474
assert rows[0][0] == params
463475

464476

477+
def test_char(trino_connection):
478+
cur = trino_connection.cursor()
479+
480+
cur.execute("SELECT CHAR 'trino'")
481+
rows = cur.fetchall()
482+
483+
assert rows[0][0] == 'trino'
484+
assert_cursor_description(cur, trino_type="char(5)", size=5)
485+
486+
465487
def test_time_query_param(trino_connection):
466488
cur = trino_connection.cursor()
467489

@@ -471,7 +493,7 @@ def test_time_query_param(trino_connection):
471493
rows = cur.fetchall()
472494

473495
assert rows[0][0] == params
474-
assert cur.description[0][1] == "time(6)"
496+
assert_cursor_description(cur, trino_type="time(6)", precision=6)
475497

476498

477499
def test_time_with_named_time_zone_query_param(trino_connection):
@@ -501,7 +523,7 @@ def test_time(trino_connection):
501523
rows = cur.fetchall()
502524

503525
assert rows[0][0] == time(1, 2, 3, 456000)
504-
assert cur.description[0][1] == "time(3)"
526+
assert_cursor_description(cur, trino_type="time(3)", precision=3)
505527

506528

507529
def test_null_time(trino_connection):
@@ -511,6 +533,7 @@ def test_null_time(trino_connection):
511533
rows = cur.fetchall()
512534

513535
assert rows[0][0] is None
536+
assert_cursor_description(cur, trino_type="time(3)", precision=3)
514537

515538

516539
def test_time_with_time_zone_negative_offset(trino_connection):
@@ -522,7 +545,7 @@ def test_time_with_time_zone_negative_offset(trino_connection):
522545
tz = timezone(-timedelta(hours=8, minutes=0))
523546

524547
assert rows[0][0] == time(1, 2, 3, 456000, tzinfo=tz)
525-
assert cur.description[0][1] == "time(3) with time zone"
548+
assert_cursor_description(cur, trino_type="time(3) with time zone", precision=3)
526549

527550

528551
def test_time_with_time_zone_positive_offset(trino_connection):
@@ -534,7 +557,7 @@ def test_time_with_time_zone_positive_offset(trino_connection):
534557
tz = timezone(timedelta(hours=8, minutes=0))
535558

536559
assert rows[0][0] == time(1, 2, 3, 456000, tzinfo=tz)
537-
assert cur.description[0][1] == "time(3) with time zone"
560+
assert_cursor_description(cur, trino_type="time(3) with time zone", precision=3)
538561

539562

540563
def test_null_date_with_time_zone(trino_connection):
@@ -544,6 +567,7 @@ def test_null_date_with_time_zone(trino_connection):
544567
rows = cur.fetchall()
545568

546569
assert rows[0][0] is None
570+
assert_cursor_description(cur, trino_type="time(3) with time zone", precision=3)
547571

548572

549573
@pytest.mark.parametrize(
@@ -717,7 +741,7 @@ def test_float_query_param(trino_connection):
717741
cur.execute("SELECT ?", params=(1.1,))
718742
rows = cur.fetchall()
719743

720-
assert cur.description[0][1] == "double"
744+
assert_cursor_description(cur, trino_type="double")
721745
assert rows[0][0] == 1.1
722746

723747

@@ -726,7 +750,7 @@ def test_float_nan_query_param(trino_connection):
726750
cur.execute("SELECT ?", params=(float("nan"),))
727751
rows = cur.fetchall()
728752

729-
assert cur.description[0][1] == "double"
753+
assert_cursor_description(cur, trino_type="double")
730754
assert isinstance(rows[0][0], float)
731755
assert math.isnan(rows[0][0])
732756

@@ -736,6 +760,7 @@ def test_float_inf_query_param(trino_connection):
736760
cur.execute("SELECT ?", params=(float("inf"),))
737761
rows = cur.fetchall()
738762

763+
assert_cursor_description(cur, trino_type="double")
739764
assert rows[0][0] == float("inf")
740765

741766
cur.execute("SELECT ?", params=(float("-inf"),))
@@ -750,13 +775,13 @@ def test_int_query_param(trino_connection):
750775
rows = cur.fetchall()
751776

752777
assert rows[0][0] == 3
753-
assert cur.description[0][1] == "integer"
778+
assert_cursor_description(cur, trino_type="integer")
754779

755780
cur.execute("SELECT ?", params=(9223372036854775807,))
756781
rows = cur.fetchall()
757782

758783
assert rows[0][0] == 9223372036854775807
759-
assert cur.description[0][1] == "bigint"
784+
assert_cursor_description(cur, trino_type="bigint")
760785

761786

762787
@pytest.mark.parametrize('params', [
@@ -1239,3 +1264,12 @@ def test_describe_table_query(run_trino):
12391264
aliased=False,
12401265
)
12411266
]
1267+
1268+
1269+
def assert_cursor_description(cur, trino_type, size=None, precision=None, scale=None):
1270+
assert cur.description[0][1] == trino_type
1271+
assert cur.description[0][2] is None
1272+
assert cur.description[0][3] is size
1273+
assert cur.description[0][4] is precision
1274+
assert cur.description[0][5] is scale
1275+
assert cur.description[0][6] is None

trino/constants.py

+4
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,7 @@
5353
HEADER_SET_CATALOG = "X-Trino-Set-Catalog"
5454

5555
HEADER_CLIENT_CAPABILITIES = "X-Trino-Client-Capabilities"
56+
57+
LENGTH_TYPES = ["char", "varchar"]
58+
PRECISION_TYPES = ["time", "time with time zone", "timestamp", "timestamp with time zone", "decimal"]
59+
SCALE_TYPES = ["decimal"]

trino/dbapi.py

+28-3
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import trino.exceptions
2929
import trino.logging
3030
from trino import constants
31+
from trino.constants import LENGTH_TYPES, PRECISION_TYPES, SCALE_TYPES
3132
from trino.exceptions import (
3233
DatabaseError,
3334
DataError,
@@ -237,6 +238,31 @@ def from_row(cls, row: List[Any]):
237238
return cls(*row)
238239

239240

241+
class ColumnDescription(NamedTuple):
242+
name: str
243+
type_code: int
244+
display_size: int
245+
internal_size: int
246+
precision: int
247+
scale: int
248+
null_ok: bool
249+
250+
@classmethod
251+
def from_column(cls, column: Dict[str, Any]):
252+
type_signature = column["typeSignature"]
253+
raw_type = type_signature["rawType"]
254+
arguments = type_signature["arguments"]
255+
return cls(
256+
column["name"], # name
257+
column["type"], # type_code
258+
None, # display_size
259+
arguments[0]["value"] if raw_type in LENGTH_TYPES else None, # internal_size
260+
arguments[0]["value"] if raw_type in PRECISION_TYPES else None, # precision
261+
arguments[1]["value"] if raw_type in SCALE_TYPES else None, # scale
262+
None # null_ok
263+
)
264+
265+
240266
class Cursor(object):
241267
"""Database cursor.
242268
@@ -278,14 +304,13 @@ def update_type(self):
278304
return None
279305

280306
@property
281-
def description(self):
307+
def description(self) -> List[ColumnDescription]:
282308
if self._query.columns is None:
283309
return None
284310

285311
# [ (name, type_code, display_size, internal_size, precision, scale, null_ok) ]
286312
return [
287-
(col["name"], col["type"], None, None, None, None, None)
288-
for col in self._query.columns
313+
ColumnDescription.from_column(col) for col in self._query.columns
289314
]
290315

291316
@property

0 commit comments

Comments
 (0)