Skip to content

Commit

Permalink
Add custom error and CustomHelpFormatter for cli (#888)
Browse files Browse the repository at this point in the history
  • Loading branch information
amritghimire authored Feb 3, 2025
1 parent 7f757b3 commit cc91ed9
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 21 deletions.
65 changes: 53 additions & 12 deletions src/datachain/cli/parser/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import argparse
from argparse import ArgumentParser
from importlib.metadata import PackageNotFoundError, version

import shtab
Expand All @@ -10,12 +9,16 @@
from .studio import add_auth_parser
from .utils import (
FIND_COLUMNS,
CustomHelpFormatter,
add_anon_arg,
add_show_args,
add_sources_arg,
add_update_arg,
find_columns_type,
)
from .utils import (
CustomArgumentParser as ArgumentParser,
)


def get_parser() -> ArgumentParser: # noqa: PLR0915
Expand All @@ -28,10 +31,11 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
parser = ArgumentParser(
description="DataChain: Wrangle unstructured AI data at scale.",
prog="datachain",
formatter_class=CustomHelpFormatter,
)
parser.add_argument("-V", "--version", action="version", version=__version__)

parent_parser = ArgumentParser(add_help=False)
parent_parser = ArgumentParser(add_help=False, formatter_class=CustomHelpFormatter)
parent_parser.add_argument(
"-v", "--verbose", action="count", default=0, help="Be verbose"
)
Expand Down Expand Up @@ -59,7 +63,10 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
help=f"Use `{parser.prog} command --help` for command-specific help",
)
parse_cp = subp.add_parser(
"cp", parents=[parent_parser], description="Copy data files from the cloud."
"cp",
parents=[parent_parser],
description="Copy data files from the cloud.",
formatter_class=CustomHelpFormatter,
)
add_sources_arg(parse_cp).complete = shtab.DIR # type: ignore[attr-defined]
parse_cp.add_argument(
Expand Down Expand Up @@ -90,7 +97,10 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
add_update_arg(parse_cp)

parse_clone = subp.add_parser(
"clone", parents=[parent_parser], description="Copy data files from the cloud."
"clone",
parents=[parent_parser],
description="Copy data files from the cloud.",
formatter_class=CustomHelpFormatter,
)
add_sources_arg(parse_clone).complete = shtab.DIR # type: ignore[attr-defined]
parse_clone.add_argument(
Expand Down Expand Up @@ -134,6 +144,7 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
aliases=["ds"],
parents=[parent_parser],
description="Commands for managing datasets.",
formatter_class=CustomHelpFormatter,
)
add_anon_arg(datasets_parser)
datasets_subparser = datasets_parser.add_subparsers(
Expand All @@ -145,6 +156,7 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
"pull",
parents=[parent_parser],
description="Pull specific dataset version from Studio.",
formatter_class=CustomHelpFormatter,
)
parse_pull.add_argument(
"dataset",
Expand Down Expand Up @@ -188,7 +200,10 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
)

parse_edit_dataset = datasets_subparser.add_parser(
"edit", parents=[parent_parser], description="Edit dataset metadata."
"edit",
parents=[parent_parser],
description="Edit dataset metadata.",
formatter_class=CustomHelpFormatter,
)
parse_edit_dataset.add_argument("name", type=str, help="Dataset name")
parse_edit_dataset.add_argument(
Expand Down Expand Up @@ -234,7 +249,10 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
)

datasets_ls_parser = datasets_subparser.add_parser(
"ls", parents=[parent_parser], description="List datasets."
"ls",
parents=[parent_parser],
description="List datasets.",
formatter_class=CustomHelpFormatter,
)
datasets_ls_parser.add_argument(
"--studio",
Expand Down Expand Up @@ -264,7 +282,11 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
)

rm_dataset_parser = datasets_subparser.add_parser(
"rm", parents=[parent_parser], description="Remove dataset.", aliases=["remove"]
"rm",
parents=[parent_parser],
description="Remove dataset.",
aliases=["remove"],
formatter_class=CustomHelpFormatter,
)
rm_dataset_parser.add_argument("name", type=str, help="Dataset name")
rm_dataset_parser.add_argument(
Expand Down Expand Up @@ -308,7 +330,10 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
)

parse_ls = subp.add_parser(
"ls", parents=[parent_parser], description="List storage contents."
"ls",
parents=[parent_parser],
description="List storage contents.",
formatter_class=CustomHelpFormatter,
)
add_anon_arg(parse_ls)
add_update_arg(parse_ls)
Expand Down Expand Up @@ -348,7 +373,10 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
)

parse_du = subp.add_parser(
"du", parents=[parent_parser], description="Display space usage."
"du",
parents=[parent_parser],
description="Display space usage.",
formatter_class=CustomHelpFormatter,
)
add_sources_arg(parse_du)
add_anon_arg(parse_du)
Expand Down Expand Up @@ -380,7 +408,10 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
)

parse_find = subp.add_parser(
"find", parents=[parent_parser], description="Search in a directory hierarchy."
"find",
parents=[parent_parser],
description="Search in a directory hierarchy.",
formatter_class=CustomHelpFormatter,
)
add_anon_arg(parse_find)
add_update_arg(parse_find)
Expand Down Expand Up @@ -435,7 +466,10 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
)

parse_index = subp.add_parser(
"index", parents=[parent_parser], description="Index storage location."
"index",
parents=[parent_parser],
description="Index storage location.",
formatter_class=CustomHelpFormatter,
)
add_anon_arg(parse_index)
add_update_arg(parse_index)
Expand All @@ -445,6 +479,7 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
"show",
parents=[parent_parser],
description="Create a new dataset with a query script.",
formatter_class=CustomHelpFormatter,
)
show_parser.add_argument("name", type=str, help="Dataset name")
show_parser.add_argument(
Expand All @@ -461,6 +496,7 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
"query",
parents=[parent_parser],
description="Create a new dataset with a query script.",
formatter_class=CustomHelpFormatter,
)
add_anon_arg(query_parser)
query_parser.add_argument(
Expand Down Expand Up @@ -491,11 +527,15 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
"clear-cache",
parents=[parent_parser],
description="Clear the local file cache.",
formatter_class=CustomHelpFormatter,
)
add_anon_arg(parse_clear_cache)

parse_gc = subp.add_parser(
"gc", parents=[parent_parser], description="Garbage collect temporary tables."
"gc",
parents=[parent_parser],
description="Garbage collect temporary tables.",
formatter_class=CustomHelpFormatter,
)
add_anon_arg(parse_gc)

Expand All @@ -510,6 +550,7 @@ def add_completion_parser(subparsers, parents):
"completion",
parents=parents,
description="Output shell completion script.",
formatter_class=CustomHelpFormatter,
)
parser.add_argument(
"-s",
Expand Down
18 changes: 14 additions & 4 deletions src/datachain/cli/parser/job.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
from datachain.cli.parser.utils import CustomHelpFormatter


def add_jobs_parser(subparsers, parent_parser) -> None:
jobs_help = "Manage jobs in Studio"
jobs_description = "Commands to manage job execution in Studio."
jobs_parser = subparsers.add_parser(
"job", parents=[parent_parser], description=jobs_description, help=jobs_help
"job",
parents=[parent_parser],
description=jobs_description,
help=jobs_help,
formatter_class=CustomHelpFormatter,
)
jobs_subparser = jobs_parser.add_subparsers(
dest="cmd",
Expand All @@ -17,10 +24,11 @@ def add_jobs_parser(subparsers, parent_parser) -> None:
parents=[parent_parser],
description=studio_run_description,
help=studio_run_help,
formatter_class=CustomHelpFormatter,
)

studio_run_parser.add_argument(
"query_file",
"file",
action="store",
help="Query file to run",
)
Expand Down Expand Up @@ -78,10 +86,11 @@ def add_jobs_parser(subparsers, parent_parser) -> None:
parents=[parent_parser],
description=studio_cancel_description,
help=studio_cancel_help,
formatter_class=CustomHelpFormatter,
)

studio_cancel_parser.add_argument(
"job_id",
"id",
action="store",
help="Job ID to cancel",
)
Expand All @@ -100,10 +109,11 @@ def add_jobs_parser(subparsers, parent_parser) -> None:
parents=[parent_parser],
description=studio_log_description,
help=studio_log_help,
formatter_class=CustomHelpFormatter,
)

studio_log_parser.add_argument(
"job_id",
"id",
action="store",
help="Job ID to show logs for",
)
Expand Down
8 changes: 8 additions & 0 deletions src/datachain/cli/parser/studio.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from datachain.cli.parser.utils import CustomHelpFormatter


def add_auth_parser(subparsers, parent_parser) -> None:
from dvc_studio_client.auth import AVAILABLE_SCOPES

Expand All @@ -9,6 +12,7 @@ def add_auth_parser(subparsers, parent_parser) -> None:
parents=[parent_parser],
description=auth_description,
help=auth_help,
formatter_class=CustomHelpFormatter,
)
auth_subparser = auth_parser.add_subparsers(
dest="cmd",
Expand All @@ -27,6 +31,7 @@ def add_auth_parser(subparsers, parent_parser) -> None:
parents=[parent_parser],
description=auth_login_description,
help=auth_login_help,
formatter_class=CustomHelpFormatter,
)

login_parser.add_argument(
Expand Down Expand Up @@ -69,6 +74,7 @@ def add_auth_parser(subparsers, parent_parser) -> None:
parents=[parent_parser],
description=auth_logout_description,
help=auth_logout_help,
formatter_class=CustomHelpFormatter,
)

auth_team_help = "Set default team for Studio operations"
Expand All @@ -79,6 +85,7 @@ def add_auth_parser(subparsers, parent_parser) -> None:
parents=[parent_parser],
description=auth_team_description,
help=auth_team_help,
formatter_class=CustomHelpFormatter,
)
team_parser.add_argument(
"team_name",
Expand All @@ -100,4 +107,5 @@ def add_auth_parser(subparsers, parent_parser) -> None:
parents=[parent_parser],
description=auth_token_description,
help=auth_token_help,
formatter_class=CustomHelpFormatter,
)
21 changes: 20 additions & 1 deletion src/datachain/cli/parser/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,30 @@
from argparse import Action, ArgumentParser, ArgumentTypeError
from argparse import Action, ArgumentParser, ArgumentTypeError, HelpFormatter
from typing import Union

from datachain.cli.utils import CommaSeparatedArgs

FIND_COLUMNS = ["du", "name", "path", "size", "type"]


class CustomHelpFormatter(HelpFormatter):
def add_arguments(self, actions):
# Sort arguments to move --help and --version to the end
normal_actions = [
a for a in actions if a.dest not in ("help", "verbose", "quiet")
]
special_actions = [a for a in actions if a.dest in ("help", "verbose", "quiet")]
super().add_arguments(normal_actions + special_actions)


class CustomArgumentParser(ArgumentParser):
def error(self, message):
internal_commands = ["internal-run-udf", "internal-run-udf-worker"]

hidden_portion = "".join(f"'{cmd}', " for cmd in internal_commands)
message = message.replace(hidden_portion, "")
super().error(message)


def find_columns_type(
columns_str: str,
default_colums_str: str = "path",
Expand Down
6 changes: 3 additions & 3 deletions src/datachain/studio.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def process_jobs_args(args: "Namespace"):

if args.cmd == "run":
return create_job(
args.query_file,
args.file,
args.team,
args.env_file,
args.env,
Expand All @@ -41,9 +41,9 @@ def process_jobs_args(args: "Namespace"):
)

if args.cmd == "cancel":
return cancel_job(args.job_id, args.team)
return cancel_job(args.id, args.team)
if args.cmd == "logs":
return show_job_logs(args.job_id, args.team)
return show_job_logs(args.id, args.team)
raise DataChainError(f"Unknown command '{args.cmd}'.")


Expand Down
3 changes: 2 additions & 1 deletion tests/unit/test_cli_parsing.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import logging
from argparse import ArgumentParser, ArgumentTypeError
from argparse import ArgumentTypeError

import pytest

from datachain.cli import (
get_logging_level,
get_parser,
)
from datachain.cli.parser.utils import CustomArgumentParser as ArgumentParser
from datachain.cli.parser.utils import find_columns_type
from datachain.cli.utils import CommaSeparatedArgs, KeyValueArgs

Expand Down

0 comments on commit cc91ed9

Please sign in to comment.