-
Notifications
You must be signed in to change notification settings - Fork 149
/
Copy pathtest_generic_database.py
191 lines (154 loc) · 5.4 KB
/
test_generic_database.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
#
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License 2.0;
# you may not use this file except in compliance with the Elastic License 2.0.
#
"""Tests the Generic Database source class methods"""
from functools import partial
import pytest
from connectors.sources.generic_database import (
configured_tables,
fetch,
is_wildcard,
map_column_names,
)
from connectors.sources.mssql import MSSQLQueries
SCHEMA = "dbo"
TABLE = "emp_table"
USER = "admin"
CUSTOMER_TABLE = "customer"
class ConnectionSync:
"""This Class create dummy connection with database and return dummy cursor"""
def __init__(self, query_object):
"""Setup dummy connection"""
self.query_object = query_object
def __enter__(self):
"""Make a dummy database connection and return it"""
return self
def __exit__(self, exception_type, exception_value, exception_traceback):
"""Make sure the dummy database connection gets closed"""
pass
def execute(self, statement):
"""This method returns dummy cursor"""
return CursorSync(query_object=self.query_object, statement=statement)
def close(self):
pass
class CursorSync:
"""This class contains methods which returns dummy response"""
def __enter__(self):
"""Make a dummy database connection and return it"""
return self
def __init__(self, query_object, *args, **kwargs):
"""Setup dummy cursor"""
self.first_call = True
self.query = kwargs["statement"]
self.query_object = query_object
def keys(self):
"""Return Columns of table
Returns:
list: List of columns
"""
return ["ids", "names"]
def fetchmany(self, size):
"""This method returns response of fetchmany
Args:
size (int): Number of rows
Returns:
list: List of rows
"""
if self.first_call:
self.first_call = False
self.query = str(self.query)
if self.query == self.query_object.all_schemas():
return [(SCHEMA,)]
elif self.query == self.query_object.all_tables(schema=SCHEMA, user=USER):
return [(TABLE,)]
elif self.query == self.query_object.table_data_count(
schema=SCHEMA, table=TABLE
):
return [(10,)]
elif self.query == self.query_object.table_primary_key(
schema=SCHEMA, table=TABLE, user=USER
):
return [("ids",)]
elif self.query == self.query_object.table_primary_key(
schema=SCHEMA, table=CUSTOMER_TABLE, user=USER
):
return [("ids",)]
elif self.query == self.query_object.table_last_update_time(
schema=SCHEMA, table=TABLE
):
return [("2023-02-21T08:37:15+00:00",)]
elif self.query == self.query_object.table_last_update_time(
schema=SCHEMA, table=CUSTOMER_TABLE
):
return [("2023-02-21T08:37:15+00:00",)]
elif self.query.lower() == "select * from customer":
return [(1, "customer_1"), (2, "customer_2")]
else:
return [
(
1,
"abcd",
),
(
2,
"xyz",
),
]
return []
@pytest.mark.parametrize(
"tables, expected_tables",
[
("", []),
("table", ["table"]),
("table_1, table_2", ["table_1", "table_2"]),
(["table_1", "table_2"], ["table_1", "table_2"]),
(["table_1", "table_2", ""], ["table_1", "table_2"]),
],
)
def test_configured_tables(tables, expected_tables):
actual_tables = configured_tables(tables)
assert actual_tables == expected_tables
@pytest.mark.parametrize("tables", ["*", ["*"]])
def test_is_wildcard(tables):
assert is_wildcard(tables)
COLUMN_NAMES = ["Column_1", "Column_2"]
@pytest.mark.parametrize(
"schema, tables, prefix",
[
(None, None, ""),
("Schema", None, "schema_"),
("Schema", [], "schema_"),
(None, ["Table"], "table_"),
(" ", ["Table"], "table_"),
("Schema", ["Table"], "schema_table_"),
("Schema", ["Table1", "Table2"], "schema_table1_table2_"),
],
)
def test_map_column_names(schema, tables, prefix):
mapped_column_names = map_column_names(COLUMN_NAMES, schema, tables)
for column_name, mapped_column_name in zip(
COLUMN_NAMES, mapped_column_names, strict=True
):
assert f"{prefix}{column_name}".lower() == mapped_column_name
async def get_cursor(query_object, query):
return CursorSync(query_object=query_object, statement=query)
@pytest.mark.asyncio
async def test_fetch():
query_object = MSSQLQueries()
rows = []
async for row in fetch(
cursor_func=partial(get_cursor, query_object, None),
fetch_columns=True,
fetch_size=10,
retry_count=3,
):
rows.append(row)
assert len(rows) == 3
assert rows[0][0] == "ids"
assert rows[0][1] == "names"
assert rows[1][0] == 1
assert rows[1][1] == "abcd"
assert rows[2][0] == 2
assert rows[2][1] == "xyz"