Skip to content

Commit ea9a904

Browse files
authored
Pull dataset from studio if not available locally (#901)
* Pull dataset from studio if not available locally If the following case are met, this will pull dataset from Studio. - User should be logged in to Studio. - The dataset or version doesn't exist in local - User has not pass studio=False to from_dataset. In such case, this will pull the dataset from studio before continuing further. The test is added to check for such behavior. Closes #874 * Move token check to util * Move to catalog
1 parent 0ff6d54 commit ea9a904

File tree

5 files changed

+102
-3
lines changed

5 files changed

+102
-3
lines changed

src/datachain/catalog/catalog.py

+25
Original file line numberDiff line numberDiff line change
@@ -1097,6 +1097,31 @@ def register_dataset(
10971097
def get_dataset(self, name: str) -> DatasetRecord:
10981098
return self.metastore.get_dataset(name)
10991099

1100+
def get_dataset_with_remote_fallback(
1101+
self, name: str, version: Optional[int] = None
1102+
) -> DatasetRecord:
1103+
try:
1104+
ds = self.get_dataset(name)
1105+
if version and not ds.has_version(version):
1106+
raise DatasetVersionNotFoundError(
1107+
f"Dataset {name} does not have version {version}"
1108+
)
1109+
return ds
1110+
1111+
except (DatasetNotFoundError, DatasetVersionNotFoundError):
1112+
print("Dataset not found in local catalog, trying to get from studio")
1113+
1114+
remote_ds_uri = f"{DATASET_PREFIX}{name}"
1115+
if version:
1116+
remote_ds_uri += f"@v{version}"
1117+
1118+
self.pull_dataset(
1119+
remote_ds_uri=remote_ds_uri,
1120+
local_ds_name=name,
1121+
local_ds_version=version,
1122+
)
1123+
return self.get_dataset(name)
1124+
11001125
def get_dataset_with_version_uuid(self, uuid: str) -> DatasetRecord:
11011126
"""Returns dataset that contains version with specific uuid"""
11021127
for dataset in self.ls_datasets():

src/datachain/lib/dc.py

+2
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,7 @@ def from_dataset(
481481
version: Optional[int] = None,
482482
session: Optional[Session] = None,
483483
settings: Optional[dict] = None,
484+
fallback_to_remote: bool = True,
484485
) -> "Self":
485486
"""Get data from a saved Dataset. It returns the chain itself.
486487
@@ -498,6 +499,7 @@ def from_dataset(
498499
version=version,
499500
session=session,
500501
indexing_column_types=File._datachain_column_types,
502+
fallback_to_remote=fallback_to_remote,
501503
)
502504
telemetry.send_event_once("class", "datachain_init", name=name, version=version)
503505
if settings:

src/datachain/query/dataset.py

+28-3
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,17 @@
4242
partition_col_names,
4343
partition_columns,
4444
)
45-
from datachain.dataset import DatasetStatus, RowDict
46-
from datachain.error import DatasetNotFoundError, QueryScriptCancelError
45+
from datachain.dataset import DATASET_PREFIX, DatasetStatus, RowDict
46+
from datachain.error import (
47+
DatasetNotFoundError,
48+
QueryScriptCancelError,
49+
)
4750
from datachain.func.base import Function
4851
from datachain.lib.udf import UDFAdapter, _get_cache
4952
from datachain.progress import CombinedDownloadCallback, TqdmCombinedDownloadCallback
5053
from datachain.query.schema import C, UDFParamSpec, normalize_param
5154
from datachain.query.session import Session
55+
from datachain.remote.studio import is_token_set
5256
from datachain.sql.functions.random import rand
5357
from datachain.utils import (
5458
batched,
@@ -1081,6 +1085,7 @@ def __init__(
10811085
session: Optional[Session] = None,
10821086
indexing_column_types: Optional[dict[str, Any]] = None,
10831087
in_memory: bool = False,
1088+
fallback_to_remote: bool = True,
10841089
) -> None:
10851090
self.session = Session.get(session, catalog=catalog, in_memory=in_memory)
10861091
self.catalog = catalog or self.session.catalog
@@ -1097,7 +1102,12 @@ def __init__(
10971102
self.column_types: Optional[dict[str, Any]] = None
10981103

10991104
self.name = name
1100-
ds = self.catalog.get_dataset(name)
1105+
1106+
if fallback_to_remote and is_token_set():
1107+
ds = self.catalog.get_dataset_with_remote_fallback(name, version)
1108+
else:
1109+
ds = self.catalog.get_dataset(name)
1110+
11011111
self.version = version or ds.latest_version
11021112
self.feature_schema = ds.get_version(self.version).feature_schema
11031113
self.column_types = copy(ds.schema)
@@ -1112,6 +1122,21 @@ def __iter__(self):
11121122
def __or__(self, other):
11131123
return self.union(other)
11141124

1125+
def pull_dataset(self, name: str, version: Optional[int] = None) -> "DatasetRecord":
1126+
print("Dataset not found in local catalog, trying to get from studio")
1127+
1128+
remote_ds_uri = f"{DATASET_PREFIX}{name}"
1129+
if version:
1130+
remote_ds_uri += f"@v{version}"
1131+
1132+
self.catalog.pull_dataset(
1133+
remote_ds_uri=remote_ds_uri,
1134+
local_ds_name=name,
1135+
local_ds_version=version,
1136+
)
1137+
1138+
return self.catalog.get_dataset(name)
1139+
11151140
@staticmethod
11161141
def get_table() -> "TableClause":
11171142
table_name = "".join(

src/datachain/remote/studio.py

+7
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,13 @@ def _is_server_error(status_code: int) -> bool:
3939
return str(status_code).startswith("5")
4040

4141

42+
def is_token_set() -> bool:
43+
return (
44+
bool(os.environ.get("DVC_STUDIO_TOKEN"))
45+
or Config().read().get("studio", {}).get("token") is not None
46+
)
47+
48+
4249
def _parse_dates(obj: dict, date_fields: list[str]):
4350
"""
4451
Function that converts string ISO dates to datetime.datetime instances in object

tests/func/test_pull.py

+40
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from datachain.config import Config, ConfigLevel
1111
from datachain.dataset import DatasetStatus
1212
from datachain.error import DataChainError, DatasetNotFoundError
13+
from datachain.lib.dc import DataChain
14+
from datachain.query.session import Session
1315
from datachain.utils import STUDIO_URL, JSONSerialize
1416
from tests.data import ENTRIES
1517
from tests.utils import assert_row_names, skip_if_not_sqlite, tree_from_path
@@ -267,6 +269,44 @@ def test_pull_dataset_success(
267269
}
268270

269271

272+
@pytest.mark.parametrize("cloud_type, version_aware", [("s3", False)], indirect=True)
273+
@skip_if_not_sqlite
274+
def test_datachain_from_dataset_pull(
275+
mocker,
276+
cloud_test_catalog,
277+
remote_dataset_info,
278+
dataset_export,
279+
dataset_export_status,
280+
dataset_export_data_chunk,
281+
):
282+
# Check if the datachain pull from studio if datachain is not available.
283+
mocker.patch(
284+
"datachain.catalog.catalog.DatasetRowsFetcher.should_check_for_status",
285+
return_value=True,
286+
)
287+
288+
catalog = cloud_test_catalog.catalog
289+
290+
# Makes sure dataset is not available locally at first
291+
with pytest.raises(DatasetNotFoundError):
292+
catalog.get_dataset("dogs")
293+
294+
with Session("testSession", catalog=catalog):
295+
ds = DataChain.from_dataset(
296+
name="dogs",
297+
version=1,
298+
fallback_to_remote=True,
299+
)
300+
301+
assert ds.dataset.name == "dogs"
302+
assert ds.dataset.latest_version == 1
303+
assert ds.dataset.status == DatasetStatus.COMPLETE
304+
305+
# Check that dataset is available locally after pulling
306+
dataset = catalog.get_dataset("dogs")
307+
assert dataset.name == "dogs"
308+
309+
270310
@pytest.mark.parametrize("cloud_type, version_aware", [("s3", False)], indirect=True)
271311
@skip_if_not_sqlite
272312
def test_pull_dataset_wrong_dataset_uri_format(

0 commit comments

Comments
 (0)