-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathintegration.py
115 lines (89 loc) · 3.21 KB
/
integration.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from time import sleep
import unittest
import os
import uuid
import tempfile
from pathlib import Path
from railib import api, config
# TODO: create_engine_wait should be added to API
# with exponential backoff
def create_engine_wait(ctx: api.Context, engine: str):
state = api.create_engine(ctx, engine)["compute"]["state"]
count = 0
while not ("PROVISIONED" == state):
if count > 12:
return
count += 1
sleep(30)
state = api.get_engine(ctx, engine)["state"]
# Get creds from env vars if exists
client_id = os.getenv("CLIENT_ID")
client_secret = os.getenv("CLIENT_SECRET")
client_credentials_url = os.getenv("CLIENT_CREDENTIALS_URL")
if client_id is None:
print("not using config from path")
cfg = config.read()
else:
file = tempfile.NamedTemporaryFile(mode="w")
file.writelines(f"""
[default]
client_id={client_id}
client_secret={client_secret}
client_credentials_url={client_credentials_url}
region=us-east
port=443
host=azure.relationalai.com
""")
file.seek(0)
cfg = config.read(fname=file.name)
file.close()
ctx = api.Context(**cfg)
class TestTransactionAsync(unittest.TestCase):
def setUp(self):
self.suffix = uuid.uuid4()
self.engine = f"python-sdk-{self.suffix}"
self.dbname = f"python-sdk-{self.suffix}"
create_engine_wait(ctx, self.engine)
api.create_database(ctx, self.dbname)
def test_v2_exec(self):
cmd = "x, x^2, x^3, x^4 from x in {1; 2; 3; 4; 5}"
rsp = api.exec(ctx, self.dbname, self.engine, cmd)
# transaction
self.assertEqual("COMPLETED", rsp.transaction["state"])
# metadata
with open(os.path.join(Path(__file__).parent, "metadata.pb"), "rb") as f:
data = f.read()
self.assertEqual(
rsp.metadata,
api._parse_metadata_proto(data)
)
# problems
self.assertEqual(0, len(rsp.problems))
# results
self.assertEqual(
{
'v1': [1, 2, 3, 4, 5],
'v2': [1, 4, 9, 16, 25],
'v3': [1, 8, 27, 64, 125],
'v4': [1, 16, 81, 256, 625]
},
rsp.results[0]["table"].to_pydict())
def test_models(self):
models = api.list_models(ctx, self.dbname, self.engine)
self.assertTrue(len(models) > 0)
models = {'test_model': 'def foo=:bar'}
resp = api.install_models(ctx, self.dbname, self.engine, models)
self.assertEqual(resp.transaction['state'], 'COMPLETED')
value = api.get_model(ctx, self.dbname, self.engine, 'test_model')
self.assertEqual(models['test_model'], value)
models = api.list_models(ctx, self.dbname, self.engine)
self.assertTrue('test_model' in models)
resp = api.delete_models(ctx, self.dbname, self.engine, ['test_model'])
self.assertEqual(resp.transaction['state'], 'COMPLETED')
models = api.list_models(ctx, self.dbname, self.engine)
self.assertFalse('test_model' in models)
def tearDown(self):
api.delete_engine(ctx, self.engine)
api.delete_database(ctx, self.dbname)
if __name__ == '__main__':
unittest.main()