diff --git a/python/cog/types.py b/python/cog/types.py index f4110e68ca..732bf9d818 100644 --- a/python/cog/types.py +++ b/python/cog/types.py @@ -1,3 +1,4 @@ +import inspect import io import mimetypes import os @@ -53,6 +54,28 @@ class CogBuildConfig(TypedDict, total=False): # pylint: disable=too-many-ancest run: Optional[Union[List[str], List[Dict[str, Any]]]] +# The following decorator is used to mutate the definition of choices in the +# case that the value(s) are duplicated. This results in the inability to +# create the enum as the values are not unique. This is generally a hack to +# work around previously created invalid schemas. +def _deduplicate_choices(func: Any) -> Any: + def wrapper(*args: Any, **kwargs: Any) -> Any: + sig = inspect.signature(func) + bound_args = sig.bind(*args, **kwargs) + bound_args.apply_defaults() + + if ( + "choices" in bound_args.arguments + and bound_args.arguments["choices"] is not None + ): + bound_args.arguments["choices"] = list(set(bound_args.arguments["choices"])) + + return func(*bound_args.args, **bound_args.kwargs) + + return wrapper + + +@_deduplicate_choices def Input( # pylint: disable=invalid-name, too-many-arguments default: Any = ..., description: str = None, diff --git a/python/tests/server/fixtures/input_choices.py b/python/tests/server/fixtures/input_choices.py index 659ee20e3f..28f0a8d99f 100644 --- a/python/tests/server/fixtures/input_choices.py +++ b/python/tests/server/fixtures/input_choices.py @@ -2,6 +2,6 @@ class Predictor(BasePredictor): - def predict(self, text: str = Input(choices=["foo", "bar"])) -> str: + def predict(self, text: str = Input(choices=["foo", "bar", "foo"])) -> str: assert type(text) == str return text