diff --git a/src/datachain/catalog/catalog.py b/src/datachain/catalog/catalog.py index 80b5824ce..a762dd948 100644 --- a/src/datachain/catalog/catalog.py +++ b/src/datachain/catalog/catalog.py @@ -89,10 +89,6 @@ PULL_DATASET_CHECK_STATUS_INTERVAL = 20 # interval to check export status in Studio -def raise_remote_error(error_message: str) -> NoReturn: - raise DataChainError(f"Error from server: {error_message}") - - def noop(_: str): pass @@ -211,14 +207,14 @@ def check_for_status(self) -> None: self.remote_ds_name, self.remote_ds_version ) if not export_status_response.ok: - raise_remote_error(export_status_response.message) + raise DataChainError(export_status_response.message) export_status = export_status_response.data["status"] # type: ignore [index] if export_status == "failed": - raise_remote_error("Dataset export failed in Studio") + raise DataChainError("Dataset export failed in Studio") if export_status == "removed": - raise_remote_error("Dataset export removed in Studio") + raise DataChainError("Dataset export removed in Studio") self.last_status_check = time.time() @@ -1113,7 +1109,7 @@ def get_remote_dataset(self, name: str) -> DatasetRecord: info_response = studio_client.dataset_info(name) if not info_response.ok: - raise_remote_error(info_response.message) + raise DataChainError(info_response.message) dataset_info = info_response.data assert isinstance(dataset_info, dict) @@ -1409,7 +1405,7 @@ def _instantiate(ds_uri: str) -> None: remote_ds_name, remote_ds_version.version ) if not export_response.ok: - raise_remote_error(export_response.message) + raise DataChainError(export_response.message) signed_urls = export_response.data diff --git a/src/datachain/studio.py b/src/datachain/studio.py index caa5ed1e8..aaa3b28d6 100644 --- a/src/datachain/studio.py +++ b/src/datachain/studio.py @@ -3,7 +3,6 @@ import sys from typing import TYPE_CHECKING, Optional -from datachain.catalog.catalog import raise_remote_error from datachain.config import Config, ConfigLevel from datachain.dataset import QUERY_DATASET_PREFIX from datachain.error import DataChainError @@ -150,7 +149,7 @@ def list_datasets(team: Optional[str] = None, name: Optional[str] = None): response = client.ls_datasets() if not response.ok: - raise_remote_error(response.message) + raise DataChainError(response.message) if not response.data: return @@ -171,7 +170,7 @@ def list_dataset_versions(team: Optional[str] = None, name: str = ""): response = client.dataset_info(name) if not response.ok: - raise_remote_error(response.message) + raise DataChainError(response.message) if not response.data: return @@ -191,7 +190,7 @@ def edit_studio_dataset( client = StudioClient(team=team_name) response = client.edit_dataset(name, new_name, description, labels) if not response.ok: - raise_remote_error(response.message) + raise DataChainError(response.message) print(f"Dataset '{name}' updated in Studio") @@ -205,7 +204,7 @@ def remove_studio_dataset( client = StudioClient(team=team_name) response = client.rm_dataset(name, version, force) if not response.ok: - raise_remote_error(response.message) + raise DataChainError(response.message) print(f"Dataset '{name}' removed from Studio") @@ -235,7 +234,7 @@ async def _run(): response = client.dataset_job_versions(job_id) if not response.ok: - raise_remote_error(response.message) + raise DataChainError(response.message) response_data = response.data if response_data: @@ -286,7 +285,7 @@ def create_job( requirements=requirements, ) if not response.ok: - raise_remote_error(response.message) + raise DataChainError(response.message) if not response.data: raise DataChainError("Failed to create job") @@ -307,7 +306,7 @@ def upload_files(client: StudioClient, files: list[str]) -> list[str]: file_content = f.read() response = client.upload_file(file_content, file_name) if not response.ok: - raise_remote_error(response.message) + raise DataChainError(response.message) if not response.data: raise DataChainError(f"Failed to upload file {file_name}") @@ -328,7 +327,7 @@ def cancel_job(job_id: str, team_name: Optional[str]): client = StudioClient(team=team_name) response = client.cancel_job(job_id) if not response.ok: - raise_remote_error(response.message) + raise DataChainError(response.message) print(f"Job {job_id} canceled") diff --git a/tests/func/test_pull.py b/tests/func/test_pull.py index df4e55dc2..ab398683d 100644 --- a/tests/func/test_pull.py +++ b/tests/func/test_pull.py @@ -310,7 +310,7 @@ def test_pull_dataset_not_found_in_remote( with pytest.raises(DataChainError) as exc_info: catalog.pull_dataset("ds://dogs@v1") - assert str(exc_info.value) == "Error from server: Dataset not found" + assert str(exc_info.value) == "Dataset not found" @pytest.mark.parametrize("cloud_type, version_aware", [("s3", False)], indirect=True) @@ -332,9 +332,7 @@ def test_pull_dataset_exporting_dataset_failed_in_remote( with pytest.raises(DataChainError) as exc_info: catalog.pull_dataset("ds://dogs@v1") - assert str(exc_info.value) == ( - f"Error from server: Dataset export {export_status} in Studio" - ) + assert str(exc_info.value) == f"Dataset export {export_status} in Studio" @pytest.mark.parametrize("cloud_type, version_aware", [("s3", False)], indirect=True)