Skip to content

Commit

Permalink
Improve datachain ds ls output (#897)
Browse files Browse the repository at this point in the history
This changes the way the datasets are listed in a CLI.

Closes #11280
  • Loading branch information
amritghimire authored Feb 5, 2025
1 parent 2b51c22 commit d0c5f94
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 20 deletions.
1 change: 1 addition & 0 deletions src/datachain/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def handle_dataset_command(args, catalog):
local=args.local,
all=args.all,
team=args.team,
latest_only=not args.versions,
),
"rm": lambda: rm_dataset(
catalog,
Expand Down
67 changes: 59 additions & 8 deletions src/datachain/cli/commands/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,27 @@
from datachain.studio import list_datasets as list_datasets_studio


def group_dataset_versions(datasets, latest_only=True):
grouped = {}
# Sort to ensure groupby works as expected
# (groupby expects consecutive items with the same key)
for name, version in sorted(datasets):
grouped.setdefault(name, []).append(version)

if latest_only:
# For each dataset name, pick the highest version.
return {name: max(versions) for name, versions in grouped.items()}
# For each dataset name, return a sorted list of unique versions.
return {name: sorted(set(versions)) for name, versions in grouped.items()}


def list_datasets(
catalog: "Catalog",
studio: bool = False,
local: bool = False,
all: bool = True,
team: Optional[str] = None,
latest_only: bool = True,
):
token = Config().read().get("studio", {}).get("token")
all, local, studio = determine_flavors(studio, local, all, token)
Expand All @@ -27,15 +42,48 @@ def list_datasets(
set(list_datasets_studio(team=team)) if (all or studio) and token else set()
)

# Group the datasets for both local and studio sources.
local_grouped = group_dataset_versions(local_datasets, latest_only)
studio_grouped = group_dataset_versions(studio_datasets, latest_only)

# Merge all dataset names from both sources.
all_dataset_names = sorted(set(local_grouped.keys()) | set(studio_grouped.keys()))

datasets = []
if latest_only:
# For each dataset name, get the latest version from each source (if available).
for name in all_dataset_names:
datasets.append((name, (local_grouped.get(name), studio_grouped.get(name))))
else:
# For each dataset name, merge all versions from both sources.
for name in all_dataset_names:
local_versions = local_grouped.get(name, [])
studio_versions = studio_grouped.get(name, [])

# If neither source has any versions, record it as (None, None).
if not local_versions and not studio_versions:
datasets.append((name, (None, None)))
else:
# For each unique version from either source, record its presence.
for version in sorted(set(local_versions) | set(studio_versions)):
datasets.append(
(
name,
(
version if version in local_versions else None,
version if version in studio_versions else None,
),
)
)

rows = [
_datasets_tabulate_row(
name=name,
version=version,
both=(all or (local and studio)) and token,
local=(name, version) in local_datasets,
studio=(name, version) in studio_datasets,
local_version=local_version,
studio_version=studio_version,
)
for name, version in local_datasets.union(studio_datasets)
for name, (local_version, studio_version) in datasets
]

print(tabulate(rows, headers="keys"))
Expand All @@ -47,14 +95,17 @@ def list_datasets_local(catalog: "Catalog"):
yield (d.name, v.version)


def _datasets_tabulate_row(name, version, both, local, studio):
def _datasets_tabulate_row(name, both, local_version, studio_version):
row = {
"Name": name,
"Version": version,
}
if both:
row["Studio"] = "\u2714" if studio else "\u2716"
row["Local"] = "\u2714" if local else "\u2716"
row["Studio"] = f"v{studio_version}" if studio_version else "\u2716"
row["Local"] = f"v{local_version}" if local_version else "\u2716"
else:
latest_version = local_version or studio_version
row["Latest Version"] = f"v{latest_version}" if latest_version else "\u2716"

return row


Expand Down
6 changes: 6 additions & 0 deletions src/datachain/cli/parser/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,12 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
description="List datasets.",
formatter_class=CustomHelpFormatter,
)
datasets_ls_parser.add_argument(
"--versions",
action="store_true",
default=False,
help="List all the versions of each dataset",
)
datasets_ls_parser.add_argument(
"--studio",
action="store_true",
Expand Down
3 changes: 3 additions & 0 deletions src/datachain/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,9 @@ def merge_versions(self, other: "DatasetListRecord") -> "DatasetListRecord":
self.versions.sort(key=lambda v: v.version)
return self

def latest_version(self) -> DatasetListVersion:
return max(self.versions, key=lambda v: v.version)

@property
def is_bucket_listing(self) -> bool:
"""
Expand Down
2 changes: 1 addition & 1 deletion tests/test_cli_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

def _tabulated_datasets(name, version):
row = [
{"Name": name, "Version": version},
{"Name": name, "Latest Version": f"v{version}"},
]
return tabulate.tabulate(row, headers="keys")

Expand Down
33 changes: 22 additions & 11 deletions tests/test_cli_studio.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,31 +129,38 @@ def list_datasets_local(_):
side_effect=list_datasets_local,
)
local_rows = [
{"Name": "both", "Version": "1"},
{"Name": "local", "Version": "1"},
{"Name": "both", "Latest Version": "v1"},
{"Name": "local", "Latest Version": "v1"},
]
local_output = tabulate(local_rows, headers="keys")

studio_rows = [
{"Name": "both", "Version": "1"},
{"Name": "both", "Latest Version": "v1"},
{
"Name": "cats",
"Version": "1",
"Latest Version": "v1",
},
{"Name": "dogs", "Version": "1"},
{"Name": "dogs", "Version": "2"},
{"Name": "dogs", "Latest Version": "v2"},
]
studio_output = tabulate(studio_rows, headers="keys")

both_rows = [
{"Name": "both", "Version": "1", "Studio": "\u2714", "Local": "\u2714"},
{"Name": "cats", "Version": "1", "Studio": "\u2714", "Local": "\u2716"},
{"Name": "dogs", "Version": "1", "Studio": "\u2714", "Local": "\u2716"},
{"Name": "dogs", "Version": "2", "Studio": "\u2714", "Local": "\u2716"},
{"Name": "local", "Version": "1", "Studio": "\u2716", "Local": "\u2714"},
{"Name": "both", "Studio": "v1", "Local": "v1"},
{"Name": "cats", "Studio": "v1", "Local": "\u2716"},
{"Name": "dogs", "Studio": "v2", "Local": "\u2716"},
{"Name": "local", "Studio": "\u2716", "Local": "v1"},
]
both_output = tabulate(both_rows, headers="keys")

both_rows_versions = [
{"Name": "both", "Studio": "v1", "Local": "v1"},
{"Name": "cats", "Studio": "v1", "Local": "\u2716"},
{"Name": "dogs", "Studio": "v1", "Local": "\u2716"},
{"Name": "dogs", "Studio": "v2", "Local": "\u2716"},
{"Name": "local", "Studio": "\u2716", "Local": "v1"},
]
both_output_versions = tabulate(both_rows_versions, headers="keys")

assert main(["dataset", "ls", "--local"]) == 0
out = capsys.readouterr().out
assert sorted(out.splitlines()) == sorted(local_output.splitlines())
Expand All @@ -174,6 +181,10 @@ def list_datasets_local(_):
out = capsys.readouterr().out
assert sorted(out.splitlines()) == sorted(both_output.splitlines())

assert main(["dataset", "ls", "--versions"]) == 0
out = capsys.readouterr().out
assert sorted(out.splitlines()) == sorted(both_output_versions.splitlines())


def test_studio_edit_dataset(capsys, mocker):
with requests_mock.mock() as m:
Expand Down

0 comments on commit d0c5f94

Please sign in to comment.