Skip to content

Commit 1484fb1

Browse files
authored
Merge branch 'main' into amrit/from_dataset
2 parents c4cdeed + 1e8bec1 commit 1484fb1

File tree

5 files changed

+33
-30
lines changed

5 files changed

+33
-30
lines changed

src/datachain/catalog/catalog.py

+5-9
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,6 @@
8989
PULL_DATASET_CHECK_STATUS_INTERVAL = 20 # interval to check export status in Studio
9090

9191

92-
def raise_remote_error(error_message: str) -> NoReturn:
93-
raise DataChainError(f"Error from server: {error_message}")
94-
95-
9692
def noop(_: str):
9793
pass
9894

@@ -211,14 +207,14 @@ def check_for_status(self) -> None:
211207
self.remote_ds_name, self.remote_ds_version
212208
)
213209
if not export_status_response.ok:
214-
raise_remote_error(export_status_response.message)
210+
raise DataChainError(export_status_response.message)
215211

216212
export_status = export_status_response.data["status"] # type: ignore [index]
217213

218214
if export_status == "failed":
219-
raise_remote_error("Dataset export failed in Studio")
215+
raise DataChainError("Dataset export failed in Studio")
220216
if export_status == "removed":
221-
raise_remote_error("Dataset export removed in Studio")
217+
raise DataChainError("Dataset export removed in Studio")
222218

223219
self.last_status_check = time.time()
224220

@@ -1113,7 +1109,7 @@ def get_remote_dataset(self, name: str) -> DatasetRecord:
11131109

11141110
info_response = studio_client.dataset_info(name)
11151111
if not info_response.ok:
1116-
raise_remote_error(info_response.message)
1112+
raise DataChainError(info_response.message)
11171113

11181114
dataset_info = info_response.data
11191115
assert isinstance(dataset_info, dict)
@@ -1409,7 +1405,7 @@ def _instantiate(ds_uri: str) -> None:
14091405
remote_ds_name, remote_ds_version.version
14101406
)
14111407
if not export_response.ok:
1412-
raise_remote_error(export_response.message)
1408+
raise DataChainError(export_response.message)
14131409

14141410
signed_urls = export_response.data
14151411

src/datachain/remote/studio.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from urllib.parse import urlparse, urlunparse
1515

1616
import websockets
17+
from requests.exceptions import HTTPError, Timeout
1718

1819
from datachain.config import Config
1920
from datachain.error import DataChainError
@@ -111,8 +112,8 @@ def _get_team(self) -> str:
111112
raise DataChainError(
112113
"Studio team is not set. "
113114
"Use `datachain auth team <team_name>` "
114-
"or environment variable `DVC_STUDIO_TEAM` to set it."
115-
"You can also set it in the config file as team under studio."
115+
"or environment variable `DVC_STUDIO_TEAM` to set it. "
116+
"You can also set `studio.team` in the config file."
116117
)
117118

118119
return team
@@ -165,15 +166,14 @@ def _send_request_msgpack(
165166
message = content.get("message", "")
166167
return Response(response_data, ok, message)
167168

168-
@retry_with_backoff(retries=5)
169+
@retry_with_backoff(retries=3, errors=(HTTPError, Timeout))
169170
def _send_request(
170171
self, route: str, data: dict[str, Any], method: Optional[str] = "POST"
171172
) -> Response[Any]:
172173
"""
173174
Function that communicate Studio API.
174175
It will raise an exception, and try to retry, if 5xx status code is
175-
returned, or if ConnectionError or Timeout exceptions are thrown from
176-
requests lib
176+
returned, or if Timeout exceptions is thrown from the requests lib
177177
"""
178178
import requests
179179

@@ -195,7 +195,7 @@ def _send_request(
195195
)
196196
try:
197197
response.raise_for_status()
198-
except requests.exceptions.HTTPError:
198+
except HTTPError:
199199
if _is_server_error(response.status_code):
200200
# going to retry
201201
raise

src/datachain/studio.py

+8-9
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import sys
44
from typing import TYPE_CHECKING, Optional
55

6-
from datachain.catalog.catalog import raise_remote_error
76
from datachain.config import Config, ConfigLevel
87
from datachain.dataset import QUERY_DATASET_PREFIX
98
from datachain.error import DataChainError
@@ -150,7 +149,7 @@ def list_datasets(team: Optional[str] = None, name: Optional[str] = None):
150149
response = client.ls_datasets()
151150

152151
if not response.ok:
153-
raise_remote_error(response.message)
152+
raise DataChainError(response.message)
154153

155154
if not response.data:
156155
return
@@ -171,7 +170,7 @@ def list_dataset_versions(team: Optional[str] = None, name: str = ""):
171170
response = client.dataset_info(name)
172171

173172
if not response.ok:
174-
raise_remote_error(response.message)
173+
raise DataChainError(response.message)
175174

176175
if not response.data:
177176
return
@@ -191,7 +190,7 @@ def edit_studio_dataset(
191190
client = StudioClient(team=team_name)
192191
response = client.edit_dataset(name, new_name, description, labels)
193192
if not response.ok:
194-
raise_remote_error(response.message)
193+
raise DataChainError(response.message)
195194

196195
print(f"Dataset '{name}' updated in Studio")
197196

@@ -205,7 +204,7 @@ def remove_studio_dataset(
205204
client = StudioClient(team=team_name)
206205
response = client.rm_dataset(name, version, force)
207206
if not response.ok:
208-
raise_remote_error(response.message)
207+
raise DataChainError(response.message)
209208

210209
print(f"Dataset '{name}' removed from Studio")
211210

@@ -235,7 +234,7 @@ async def _run():
235234

236235
response = client.dataset_job_versions(job_id)
237236
if not response.ok:
238-
raise_remote_error(response.message)
237+
raise DataChainError(response.message)
239238

240239
response_data = response.data
241240
if response_data:
@@ -286,7 +285,7 @@ def create_job(
286285
requirements=requirements,
287286
)
288287
if not response.ok:
289-
raise_remote_error(response.message)
288+
raise DataChainError(response.message)
290289

291290
if not response.data:
292291
raise DataChainError("Failed to create job")
@@ -307,7 +306,7 @@ def upload_files(client: StudioClient, files: list[str]) -> list[str]:
307306
file_content = f.read()
308307
response = client.upload_file(file_content, file_name)
309308
if not response.ok:
310-
raise_remote_error(response.message)
309+
raise DataChainError(response.message)
311310

312311
if not response.data:
313312
raise DataChainError(f"Failed to upload file {file_name}")
@@ -328,7 +327,7 @@ def cancel_job(job_id: str, team_name: Optional[str]):
328327
client = StudioClient(team=team_name)
329328
response = client.cancel_job(job_id)
330329
if not response.ok:
331-
raise_remote_error(response.message)
330+
raise DataChainError(response.message)
332331

333332
print(f"Job {job_id} canceled")
334333

src/datachain/utils.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import glob
22
import io
33
import json
4+
import logging
45
import os
56
import os.path as osp
67
import random
@@ -25,6 +26,9 @@
2526
import pandas as pd
2627
from typing_extensions import Self
2728

29+
30+
logger = logging.getLogger("datachain")
31+
2832
NUL = b"\0"
2933
TIME_ZERO = datetime.fromtimestamp(0, tz=timezone.utc)
3034

@@ -271,19 +275,25 @@ def flatten(items):
271275
yield item
272276

273277

274-
def retry_with_backoff(retries=5, backoff_sec=1):
278+
def retry_with_backoff(retries=5, backoff_sec=1, errors=(Exception,)):
275279
def retry(f):
276280
def wrapper(*args, **kwargs):
277281
num_tried = 0
278282
while True:
279283
try:
280284
return f(*args, **kwargs)
281-
except Exception:
285+
except errors:
282286
if num_tried == retries:
283287
raise
284288
sleep = (
285289
backoff_sec * 2** num_tried + random.uniform(0, 1) # noqa: S311
286290
)
291+
logger.exception(
292+
"Error in %s, retrying in %ds, attempt %d",
293+
f.__name__,
294+
sleep,
295+
num_tried,
296+
)
287297
time.sleep(sleep)
288298
num_tried += 1
289299

tests/func/test_pull.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,7 @@ def test_pull_dataset_not_found_in_remote(
350350

351351
with pytest.raises(DataChainError) as exc_info:
352352
catalog.pull_dataset("ds://dogs@v1")
353-
assert str(exc_info.value) == "Error from server: Dataset not found"
353+
assert str(exc_info.value) == "Dataset not found"
354354

355355

356356
@pytest.mark.parametrize("cloud_type, version_aware", [("s3", False)], indirect=True)
@@ -372,9 +372,7 @@ def test_pull_dataset_exporting_dataset_failed_in_remote(
372372

373373
with pytest.raises(DataChainError) as exc_info:
374374
catalog.pull_dataset("ds://dogs@v1")
375-
assert str(exc_info.value) == (
376-
f"Error from server: Dataset export {export_status} in Studio"
377-
)
375+
assert str(exc_info.value) == f"Dataset export {export_status} in Studio"
378376

379377

380378
@pytest.mark.parametrize("cloud_type, version_aware", [("s3", False)], indirect=True)

0 commit comments

Comments
 (0)