From b86791fee2322ae3d5ff898b9a75ea478d71fcca Mon Sep 17 00:00:00 2001 From: Matt Seddon Date: Wed, 27 Nov 2024 16:49:56 +1100 Subject: [PATCH] add DatasetList dataclass to extract only needed fields from metastore --- src/datachain/catalog/catalog.py | 7 +- src/datachain/data_storage/metastore.py | 74 +++++++++-- src/datachain/dataset.py | 156 +++++++++++++++++++++--- src/datachain/lib/dataset_info.py | 10 +- tests/func/test_catalog.py | 21 ++++ 5 files changed, 237 insertions(+), 31 deletions(-) diff --git a/src/datachain/catalog/catalog.py b/src/datachain/catalog/catalog.py index 010caaf4f..7a46e6aa7 100644 --- a/src/datachain/catalog/catalog.py +++ b/src/datachain/catalog/catalog.py @@ -38,6 +38,7 @@ DATASET_PREFIX, QUERY_DATASET_PREFIX, DatasetDependency, + DatasetListRecord, DatasetRecord, DatasetStats, DatasetStatus, @@ -72,7 +73,7 @@ AbstractMetastore, AbstractWarehouse, ) - from datachain.dataset import DatasetVersion + from datachain.dataset import DatasetListVersion from datachain.job import Job from datachain.lib.file import File from datachain.listing import Listing @@ -1135,7 +1136,7 @@ def get_dataset_dependencies( return direct_dependencies - def ls_datasets(self, include_listing: bool = False) -> Iterator[DatasetRecord]: + def ls_datasets(self, include_listing: bool = False) -> Iterator[DatasetListRecord]: datasets = self.metastore.list_datasets() for d in datasets: if not d.is_bucket_listing or include_listing: @@ -1144,7 +1145,7 @@ def ls_datasets(self, include_listing: bool = False) -> Iterator[DatasetRecord]: def list_datasets_versions( self, include_listing: bool = False, - ) -> Iterator[tuple[DatasetRecord, "DatasetVersion", Optional["Job"]]]: + ) -> Iterator[tuple[DatasetListRecord, "DatasetListVersion", Optional["Job"]]]: """Iterate over all dataset versions with related jobs.""" datasets = list(self.ls_datasets(include_listing=include_listing)) diff --git a/src/datachain/data_storage/metastore.py b/src/datachain/data_storage/metastore.py index 7afbb3e72..ef81f56a5 100644 --- a/src/datachain/data_storage/metastore.py +++ b/src/datachain/data_storage/metastore.py @@ -27,6 +27,8 @@ from datachain.data_storage.serializer import Serializable from datachain.dataset import ( DatasetDependency, + DatasetListRecord, + DatasetListVersion, DatasetRecord, DatasetStatus, DatasetVersion, @@ -59,6 +61,8 @@ class AbstractMetastore(ABC, Serializable): schema: "schema.Schema" dataset_class: type[DatasetRecord] = DatasetRecord + dataset_list_class: type[DatasetListRecord] = DatasetListRecord + dataset_list_version_class: type[DatasetListVersion] = DatasetListVersion dependency_class: type[DatasetDependency] = DatasetDependency job_class: type[Job] = Job @@ -166,11 +170,11 @@ def remove_dataset_version( """ @abstractmethod - def list_datasets(self) -> Iterator[DatasetRecord]: + def list_datasets(self) -> Iterator[DatasetListRecord]: """Lists all datasets.""" @abstractmethod - def list_datasets_by_prefix(self, prefix: str) -> Iterator["DatasetRecord"]: + def list_datasets_by_prefix(self, prefix: str) -> Iterator["DatasetListRecord"]: """Lists all datasets which names start with prefix.""" @abstractmethod @@ -348,6 +352,14 @@ def _dataset_fields(self) -> list[str]: if c.name # type: ignore [attr-defined] ] + @cached_property + def _dataset_list_fields(self) -> list[str]: + return [ + c.name # type: ignore [attr-defined] + for c in self._datasets_columns() + if c.name in self.dataset_list_class.__dataclass_fields__ # type: ignore [attr-defined] + ] + @classmethod def _datasets_versions_columns(cls) -> list["SchemaItem"]: """Datasets versions table columns.""" @@ -390,6 +402,15 @@ def _dataset_version_fields(self) -> list[str]: if c.name # type: ignore [attr-defined] ] + @cached_property + def _dataset_list_version_fields(self) -> list[str]: + return [ + c.name # type: ignore [attr-defined] + for c in self._datasets_versions_columns() + if c.name # type: ignore [attr-defined] + in self.dataset_list_version_class.__dataclass_fields__ + ] + @classmethod def _datasets_dependencies_columns(cls) -> list["SchemaItem"]: """Datasets dependencies table columns.""" @@ -671,7 +692,25 @@ def _parse_datasets(self, rows) -> Iterator["DatasetRecord"]: if dataset: yield dataset - def _base_dataset_query(self): + def _parse_list_dataset(self, rows) -> Optional[DatasetListRecord]: + versions = [self.dataset_list_class.parse(*r) for r in rows] + if not versions: + return None + return reduce(lambda ds, version: ds.merge_versions(version), versions) + + def _parse_dataset_list(self, rows) -> Iterator["DatasetListRecord"]: + # grouping rows by dataset id + for _, g in groupby(rows, lambda r: r[0]): + dataset = self._parse_list_dataset(list(g)) + if dataset: + yield dataset + + def _get_dataset_query( + self, + dataset_fields: list[str], + dataset_version_fields: list[str], + isouter: bool = True, + ): if not ( self.db.has_table(self._datasets.name) and self.db.has_table(self._datasets_versions.name) @@ -680,23 +719,36 @@ def _base_dataset_query(self): d = self._datasets dv = self._datasets_versions + query = self._datasets_select( - *(getattr(d.c, f) for f in self._dataset_fields), - *(getattr(dv.c, f) for f in self._dataset_version_fields), + *(getattr(d.c, f) for f in dataset_fields), + *(getattr(dv.c, f) for f in dataset_version_fields), ) - j = d.join(dv, d.c.id == dv.c.dataset_id, isouter=True) + j = d.join(dv, d.c.id == dv.c.dataset_id, isouter=isouter) return query.select_from(j) - def list_datasets(self) -> Iterator["DatasetRecord"]: + def _base_dataset_query(self): + return self._get_dataset_query( + self._dataset_fields, self._dataset_version_fields + ) + + def _base_list_datasets_query(self): + return self._get_dataset_query( + self._dataset_list_fields, self._dataset_list_version_fields, isouter=False + ) + + def list_datasets(self) -> Iterator["DatasetListRecord"]: """Lists all datasets.""" - yield from self._parse_datasets(self.db.execute(self._base_dataset_query())) + yield from self._parse_dataset_list( + self.db.execute(self._base_list_datasets_query()) + ) def list_datasets_by_prefix( self, prefix: str, conn=None - ) -> Iterator["DatasetRecord"]: - query = self._base_dataset_query() + ) -> Iterator["DatasetListRecord"]: + query = self._base_list_datasets_query() query = query.where(self._datasets.c.name.startswith(prefix)) - yield from self._parse_datasets(self.db.execute(query)) + yield from self._parse_dataset_list(self.db.execute(query)) def get_dataset(self, name: str, conn=None) -> DatasetRecord: """Gets a single dataset by name""" diff --git a/src/datachain/dataset.py b/src/datachain/dataset.py index a7d28be1f..d255a2d6d 100644 --- a/src/datachain/dataset.py +++ b/src/datachain/dataset.py @@ -15,7 +15,9 @@ from datachain.sql.types import NAME_TYPES_MAPPING, SQLType T = TypeVar("T", bound="DatasetRecord") +LT = TypeVar("LT", bound="DatasetListRecord") V = TypeVar("V", bound="DatasetVersion") +LV = TypeVar("LV", bound="DatasetListVersion") DD = TypeVar("DD", bound="DatasetDependency") DATASET_PREFIX = "ds://" @@ -264,6 +266,59 @@ def from_dict(cls, d: dict[str, Any]) -> "DatasetVersion": return cls(**kwargs) +@dataclass +class DatasetListVersion: + id: int + uuid: str + dataset_id: int + version: int + status: int + created_at: datetime + finished_at: Optional[datetime] + error_message: str + error_stack: str + num_objects: Optional[int] + size: Optional[int] + query_script: str = "" + job_id: Optional[str] = None + + @classmethod + def parse( + cls: type[LV], + id: int, + uuid: str, + dataset_id: int, + version: int, + status: int, + created_at: datetime, + finished_at: Optional[datetime], + error_message: str, + error_stack: str, + num_objects: Optional[int], + size: Optional[int], + query_script: str = "", + job_id: Optional[str] = None, + ): + return cls( + id, + uuid, + dataset_id, + version, + status, + created_at, + finished_at, + error_message, + error_stack, + num_objects, + size, + query_script, + job_id, + ) + + def __hash__(self): + return hash(f"{self.dataset_id}_{self.version}") + + @dataclass class DatasetRecord: id: int @@ -447,20 +502,6 @@ def uri(self, version: int) -> str: identifier = self.identifier(version) return f"{DATASET_PREFIX}{identifier}" - @property - def is_bucket_listing(self) -> bool: - """ - For bucket listing we implicitly create underlying dataset to hold data. This - method is checking if this is one of those datasets. - """ - from datachain.client import Client - - # TODO refactor and maybe remove method in - # https://github.com/iterative/datachain/issues/318 - return Client.is_data_source_uri(self.name) or self.name.startswith( - LISTING_PREFIX - ) - @property def versions_values(self) -> list[int]: """ @@ -499,5 +540,92 @@ def from_dict(cls, d: dict[str, Any]) -> "DatasetRecord": return cls(**kwargs, versions=versions) +@dataclass +class DatasetListRecord: + id: int + name: str + description: Optional[str] + labels: list[str] + versions: list[DatasetListVersion] + created_at: Optional[datetime] = None + + @classmethod + def parse( # noqa: PLR0913 + cls: type[LT], + id: int, + name: str, + description: Optional[str], + labels: str, + created_at: datetime, + version_id: int, + version_uuid: str, + version_dataset_id: int, + version: int, + version_status: int, + version_created_at: datetime, + version_finished_at: Optional[datetime], + version_error_message: str, + version_error_stack: str, + version_num_objects: Optional[int], + version_size: Optional[int], + version_query_script: Optional[str], + version_job_id: Optional[str] = None, + ) -> "DatasetListRecord": + labels_lst: list[str] = json.loads(labels) if labels else [] + + dataset_version = DatasetListVersion.parse( + version_id, + version_uuid, + version_dataset_id, + version, + version_status, + version_created_at, + version_finished_at, + version_error_message, + version_error_stack, + version_num_objects, + version_size, + version_query_script, # type: ignore[arg-type] + version_job_id, + ) + + return cls( + id, + name, + description, + labels_lst, + [dataset_version], + created_at, + ) + + def merge_versions(self, other: "DatasetListRecord") -> "DatasetListRecord": + """Merge versions from another dataset""" + if other.id != self.id: + raise RuntimeError("Cannot merge versions of datasets with different ids") + if not other.versions: + # nothing to merge + return self + if not self.versions: + self.versions = [] + + self.versions = list(set(self.versions + other.versions)) + self.versions.sort(key=lambda v: v.version) + return self + + @property + def is_bucket_listing(self) -> bool: + """ + For bucket listing we implicitly create underlying dataset to hold data. This + method is checking if this is one of those datasets. + """ + from datachain.client import Client + + # TODO refactor and maybe remove method in + # https://github.com/iterative/datachain/issues/318 + return Client.is_data_source_uri(self.name) or self.name.startswith( + LISTING_PREFIX + ) + + class RowDict(dict): pass diff --git a/src/datachain/lib/dataset_info.py b/src/datachain/lib/dataset_info.py index a8f8ef949..bf6d46312 100644 --- a/src/datachain/lib/dataset_info.py +++ b/src/datachain/lib/dataset_info.py @@ -5,7 +5,11 @@ from pydantic import Field, field_validator -from datachain.dataset import DatasetRecord, DatasetStatus, DatasetVersion +from datachain.dataset import ( + DatasetListRecord, + DatasetListVersion, + DatasetStatus, +) from datachain.job import Job from datachain.lib.data_model import DataModel from datachain.utils import TIME_ZERO @@ -57,8 +61,8 @@ def validate_metrics(cls, v): @classmethod def from_models( cls, - dataset: DatasetRecord, - version: DatasetVersion, + dataset: DatasetListRecord, + version: DatasetListVersion, job: Optional[Job], ) -> "Self": return cls( diff --git a/tests/func/test_catalog.py b/tests/func/test_catalog.py index f23bf0b26..fdc9619d8 100644 --- a/tests/func/test_catalog.py +++ b/tests/func/test_catalog.py @@ -772,6 +772,27 @@ def test_dataset_stats(test_session): assert dataset_version2.size == 18 +def test_ls_datasets_no_json(test_session): + ids = [1, 2, 3] + values = tuple(zip(["a", "b", "c"], [1, 2, 3])) + + DataChain.from_values( + ids=ids, + file=[File(path=name, size=size) for name, size in values], + session=test_session, + ).save() + datasets = test_session.catalog.ls_datasets() + assert datasets + for d in datasets: + assert hasattr(d, "id") + assert not hasattr(d, "feature_schema") + assert d.versions + for v in d.versions: + assert hasattr(v, "id") + assert not hasattr(v, "preview") + assert not hasattr(v, "feature_schema") + + @pytest.mark.parametrize("cloud_type", ["s3", "azure", "gs"], indirect=True) def test_listing_stats(cloud_test_catalog): catalog = cloud_test_catalog.catalog