Skip to content

Commit 5078634

Browse files
committed
Deduplicate choices
Ensure choices are not duplicated, as that makes StrEnum impossible to resolve. Each element should only exist once. The deduplicate must be done at the schema level and therefore a decorator is utilized to mutate duplicated choices passed to Input().
1 parent 09fbcbc commit 5078634

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed

python/cog/types.py

+23
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import inspect
12
import io
23
import mimetypes
34
import os
@@ -53,6 +54,28 @@ class CogBuildConfig(TypedDict, total=False): # pylint: disable=too-many-ancest
5354
run: Optional[Union[List[str], List[Dict[str, Any]]]]
5455

5556

57+
# The following decorator is used to mutate the definition of choices in the
58+
# case that the value(s) are duplicated. This results in the inability to
59+
# create the enum as the values are not unique. This is generally a hack to
60+
# work around previously created invalid schemas.
61+
def _deduplicate_choices(func: Any) -> Any:
62+
def wrapper(*args: Any, **kwargs: Any) -> Any:
63+
sig = inspect.signature(func)
64+
bound_args = sig.bind(*args, **kwargs)
65+
bound_args.apply_defaults()
66+
67+
if (
68+
"choices" in bound_args.arguments
69+
and bound_args.arguments["choices"] is not None
70+
):
71+
bound_args.arguments["choices"] = list(set(bound_args.arguments["choices"]))
72+
73+
return func(*bound_args.args, **bound_args.kwargs)
74+
75+
return wrapper
76+
77+
78+
@_deduplicate_choices
5679
def Input( # pylint: disable=invalid-name, too-many-arguments
5780
default: Any = ...,
5881
description: str = None,

python/tests/server/fixtures/input_choices.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,6 @@
22

33

44
class Predictor(BasePredictor):
5-
def predict(self, text: str = Input(choices=["foo", "bar"])) -> str:
5+
def predict(self, text: str = Input(choices=["foo", "bar", "foo"])) -> str:
66
assert type(text) == str
77
return text

0 commit comments

Comments
 (0)