diff --git a/railib/api.py b/railib/api.py index aa8990b..80dc440 100644 --- a/railib/api.py +++ b/railib/api.py @@ -101,6 +101,19 @@ class Permission(str, Enum): LIST_ACCESS_KEYS = "list:accesskey" +@unique +class EngineState(str, Enum): + REQUESTED = "REQUESTED" + UPDATING = "UPDATING" + PROVISIONING = "PROVISIONING" + PROVISIONED = "PROVISIONED" + PROVISION_FAILED = "PROVISION_FAILED" + DELETING = "DELETING" + SUSPENDED = "SUSPENDED" + DEPROVISIONING = "DEPROVISIONING" + UNKNOWN = "UNKNOWN" + + __all__ = [ "Context", "Mode", @@ -112,6 +125,7 @@ class Permission(str, Enum): "create_oauth_client", "delete_database", "delete_engine", + "delete_engine_wait", "delete_model", "disable_user", "enable_user", @@ -134,8 +148,13 @@ class Permission(str, Enum): "list_oauth_clients", "load_csv", "update_user", + "ResourceNotFoundError", ] +class ResourceNotFoundError(Exception): + """An error response, typically triggered by a 412 response (for update) or 404 (for get/post)""" + pass + # Context contains the state required to make rAI API calls. class Context(rest.Context): @@ -221,11 +240,15 @@ def _get_resource(ctx: Context, path: str, key=None, **kwargs) -> Dict: url = _mkurl(ctx, path) rsp = rest.get(ctx, url, **kwargs) rsp = json.loads(rsp.read()) + if key: rsp = rsp[key] - if rsp and isinstance(rsp, list): - assert len(rsp) == 1 + + if isinstance(rsp, list): + if len(rsp) == 0: + raise ResourceNotFoundError(f"Resource not found at {url}") return rsp[0] + return rsp @@ -356,8 +379,8 @@ def poll_with_specified_overhead( time.sleep(duration) -def is_engine_term_state(state: str) -> bool: - return state == "PROVISIONED" or ("FAILED" in state) +def is_engine_provisioning_term_state(state: str) -> bool: + return state in [EngineState.PROVISIONED, EngineState.PROVISION_FAILED] def create_engine(ctx: Context, engine: str, size: str = "XS", **kwargs): @@ -370,7 +393,7 @@ def create_engine(ctx: Context, engine: str, size: str = "XS", **kwargs): def create_engine_wait(ctx: Context, engine: str, size: str = "XS", **kwargs): create_engine(ctx, engine, size, **kwargs) poll_with_specified_overhead( - lambda: is_engine_term_state(get_engine(ctx, engine)["state"]), + lambda: is_engine_provisioning_term_state(get_engine(ctx, engine)["state"]), overhead_rate=0.2, timeout=30 * 60, ) @@ -416,6 +439,18 @@ def delete_engine(ctx: Context, engine: str, **kwargs) -> Dict: return json.loads(rsp.read()) +def delete_engine_wait(ctx: Context, engine: str, **kwargs): + rsp = delete_engine(ctx, engine, **kwargs) + rsp = rsp["status"] + + while rsp["state"] in [EngineState.DEPROVISIONING, EngineState.DELETING]: + try: + rsp = get_engine(ctx, engine) + except ResourceNotFoundError: + break + time.sleep(3) + + def delete_user(ctx: Context, id: str, **kwargs) -> Dict: url = _mkurl(ctx, f"{PATH_USER}/{id}") rsp = rest.delete(ctx, url, None, **kwargs) diff --git a/test/test_unit.py b/test/test_unit.py index cf885d2..ee8c650 100644 --- a/test/test_unit.py +++ b/test/test_unit.py @@ -1,8 +1,12 @@ +import json import unittest +from unittest.mock import patch, MagicMock from railib import api +ctx = MagicMock() + class TestPolling(unittest.TestCase): def test_timeout_exception(self): try: @@ -23,5 +27,85 @@ def test_validation(self): api.poll_with_specified_overhead(lambda: True, overhead_rate=0.1, timeout=1, max_tries=1) +class TestEngineAPI(unittest.TestCase): + @patch('railib.rest.get') + def test_get_engine(self, mock_get): + response_json = { + "computes": [ + { + "name": "test-engine" + } + ] + } + mock_response = MagicMock() + mock_response.read.return_value = json.dumps(response_json).encode() + mock_get.return_value = mock_response + + engine = api.get_engine(ctx, "test-engine") + + self.assertEqual(engine, response_json['computes'][0]) + mock_get.assert_called_once() + + @patch('railib.rest.delete') + def test_delete_engine(self, mock_delete): + response_json = { + "status": { + "name": "test-engine", + "state": api.EngineState.DELETING.value, + "message": "engine \"test-engine\" deleted successfully" + } + } + + mock_response = MagicMock() + mock_response.read.return_value = json.dumps(response_json).encode() + mock_delete.return_value = mock_response + + res = api.delete_engine(ctx, "test-engine") + + self.assertEqual(res, response_json) + mock_delete.assert_called_once() + + @patch('railib.rest.delete') + @patch('railib.rest.get') + def test_delete_engine_wait(self, mock_get, mock_delete): + # mock response for engine delete + response_delete_json = { + "status": { + "name": "test-engine", + "state": "DELETING", + "message": "engine \"test-engine\" deleted successfully" + } + } + mock_response_delete = MagicMock() + mock_response_delete.read.return_value = json.dumps(response_delete_json).encode() + mock_delete.return_value = mock_response_delete + + # mock response for engine get and return an engine in DEPROVISIONING state + response_get_deprovisioning_engine_json = { + "computes": [ + { + "name": "test-engine", + "state": api.EngineState.DEPROVISIONING.value + } + ] + } + mock_response_get_deprovisioning_engine = MagicMock() + mock_response_get_deprovisioning_engine.read.return_value = json.dumps(response_get_deprovisioning_engine_json).encode() + + # mock response for engine get and return empty list as if the engine has been completely deleted + response_get_no_engine_json = { + "computes": [] + } + mock_response_get_no_engine = MagicMock() + mock_response_get_no_engine.read.return_value = json.dumps(response_get_no_engine_json).encode() + + mock_get.side_effect = [mock_response_get_deprovisioning_engine, mock_response_get_no_engine] + + res = api.delete_engine_wait(ctx, "test-engine") + self.assertEqual(res, None) + self.assertEqual(mock_delete.call_count, 1) + self.assertEqual(mock_get.call_count, 2) + + if __name__ == '__main__': unittest.main()