From 596b517e99eaad992b9d7bbb7582bfc1d823c887 Mon Sep 17 00:00:00 2001 From: Johannes Kasimir Date: Mon, 19 Feb 2024 14:31:56 +0100 Subject: [PATCH 01/12] refactor: remove subprovider concept --- src/sciline/pipeline.py | 117 +++++++++++++++++----------------------- 1 file changed, 48 insertions(+), 69 deletions(-) diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index 98fff4bc..423157c4 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -368,7 +368,6 @@ def __init__( Dictionary of concrete values to provide for types. """ self._providers: Dict[Key, Provider] = {} - self._subproviders: Dict[type, Dict[Tuple[Key | TypeVar, ...], Provider]] = {} self._param_tables: Dict[Key, ParamTable] = {} self._param_name_to_table_key: Dict[Key, Key] = {} for provider in providers or []: @@ -532,72 +531,59 @@ def _set_provider( 'Series is a special container reserved for use in conjunction with ' 'sciline.ParamTable and must not be provided directly.' ) - if (origin := get_origin(key)) is not None: - subproviders = self._subproviders.setdefault(origin, {}) - args = get_args(key) - subproviders[args] = provider - else: - self._providers[key] = provider + self._providers[key] = provider def _get_provider( self, tp: Union[Type[T], Item[T]], handler: Optional[ErrorHandler] = None ) -> Tuple[Provider, Dict[TypeVar, Key]]: handler = handler or HandleAsBuildTimeException() explanation: List[str] = [] - if (provider := self._providers.get(tp)) is not None: - return provider, {} - elif (origin := get_origin(tp)) is not None and ( - subproviders := self._subproviders.get(origin) - ) is not None: - requested = get_args(tp) - matches = [ - (subprovider, bound) - for args, subprovider in subproviders.items() - if ( - bound := _find_bounds_to_make_compatible_type_tuple(requested, args) - ) - is not None - ] - typevar_counts = [len(bound) for _, bound in matches] - min_typevar_count = min(typevar_counts, default=0) - matches = [ - m - for count, m in zip(typevar_counts, matches) - if count == min_typevar_count - ] - - if len(matches) == 1: - provider, bound = matches[0] - return provider, bound - elif len(matches) > 1: - matching_providers = [provider.location.name for provider, _ in matches] - raise AmbiguousProvider( - f"Multiple providers found for type {tp}." - f" Matching providers are: {matching_providers}." - ) - else: - typevars_in_expression = _extract_typevars_from_generic_type(origin) - if typevars_in_expression: - explanation = [ - ''.join( - map( - str, + + matches = [ + (provider, bound) + for ptype, provider in self._providers.items() + if (bound := _find_bounds_to_make_compatible_type_tuple((tp,), (ptype,))) + is not None + ] + typevar_counts = [len(bound) for _, bound in matches] + min_typevar_count = min(typevar_counts, default=0) + matches = [ + m for count, m in zip(typevar_counts, matches) if count == min_typevar_count + ] + + if len(matches) == 1: + provider, bound = matches[0] + return provider, bound + elif len(matches) > 1: + matching_providers = [provider.location.name for provider, _ in matches] + raise AmbiguousProvider( + f"Multiple providers found for type {tp}." + f" Matching providers are: {matching_providers}." + ) + else: + origin = get_origin(tp) + typevars_of_generic = _extract_typevars_from_generic_type(origin) + if typevars_of_generic: + explanation = [ + ''.join( + map( + str, + ( + 'Note that ', + keyname(origin[typevars_of_generic]), + ' has constraints ', ( - 'Note that ', - keyname(origin[typevars_in_expression]), - ' has constraints ', - ( - { - keyname(tv): tuple( - map(keyname, tv.__constraints__) - ) - for tv in typevars_in_expression - } - ), + { + keyname(tv): tuple( + map(keyname, tv.__constraints__) + ) + for tv in typevars_of_generic + } ), - ) + ), ) - ] + ) + ] return handler.handle_unsatisfied_requirement(tp, *explanation), {} def _get_unique_provider( @@ -923,7 +909,6 @@ def copy(self) -> Pipeline: """ out = Pipeline() out._providers = self._providers.copy() - out._subproviders = {k: v.copy() for k, v in self._subproviders.items()} out._param_tables = self._param_tables.copy() out._param_name_to_table_key = self._param_name_to_table_key.copy() return out @@ -932,14 +917,8 @@ def __copy__(self) -> Pipeline: return self.copy() def _repr_html_(self) -> str: - providers_without_parameters = ( - (origin, tuple(), value) for origin, value in self._providers.items() + providers = ( + (get_origin(tp), get_args(tp), value) + for tp, value in self._providers.items() ) # type: ignore[var-annotated] - providers_with_parameters = ( - (origin, args, value) - for origin in self._subproviders - for args, value in self._subproviders[origin].items() - ) - return pipeline_html_repr( - chain(providers_without_parameters, providers_with_parameters) - ) + return pipeline_html_repr(providers) From f86bae4b2e4a6f142f1313f69ec065ea0415c2bb Mon Sep 17 00:00:00 2001 From: Johannes Kasimir Date: Mon, 19 Feb 2024 15:32:49 +0100 Subject: [PATCH 02/12] optimization for nongeneric providers --- src/sciline/pipeline.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index 423157c4..33c21100 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -539,12 +539,20 @@ def _get_provider( handler = handler or HandleAsBuildTimeException() explanation: List[str] = [] - matches = [ - (provider, bound) - for ptype, provider in self._providers.items() - if (bound := _find_bounds_to_make_compatible_type_tuple((tp,), (ptype,))) - is not None - ] + if tp in self._providers: + # Optimization to quickly find non-generic providers + matches = [(self._providers[tp], {})] + else: + matches = [ + (provider, bound) + for return_type, provider in self._providers.items() + if ( + bound := _find_bounds_to_make_compatible_type_tuple( + (tp,), (return_type,) + ) + ) + is not None + ] typevar_counts = [len(bound) for _, bound in matches] min_typevar_count = min(typevar_counts, default=0) matches = [ From 5a3a900c763656c8c423b51cc4762230cc6bafad Mon Sep 17 00:00:00 2001 From: Johannes Kasimir Date: Mon, 19 Feb 2024 15:40:55 +0100 Subject: [PATCH 03/12] fix: use more suitable method --- src/sciline/pipeline.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index 33c21100..4b9618e8 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -546,11 +546,7 @@ def _get_provider( matches = [ (provider, bound) for return_type, provider in self._providers.items() - if ( - bound := _find_bounds_to_make_compatible_type_tuple( - (tp,), (return_type,) - ) - ) + if (bound := _find_bounds_to_make_compatible_type(tp, return_type)) is not None ] typevar_counts = [len(bound) for _, bound in matches] From 2d488170cbda74a92d763a2f252284f8f789b513 Mon Sep 17 00:00:00 2001 From: Johannes Kasimir Date: Mon, 19 Feb 2024 15:56:13 +0100 Subject: [PATCH 04/12] feat: add __getitem__ and __contains__ --- src/sciline/pipeline.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index 4b9618e8..b5f5aef7 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -433,6 +433,18 @@ def __setitem__(self, key: Type[T], param: T) -> None: ) self._set_provider(key, Provider.parameter(param)) + def __getitem__(self, tp: Union[Type[T], Item[T]]) -> Provider: + return self.get(tp).graph[tp] + + def __contains__(self, tp: Union[Type[T], Item[T]]): + try: + self.get(tp) + except AmbiguousProvider: + return True + except UnsatisfiedRequirement: + return False + return True + def set_param_table(self, params: ParamTable) -> None: """ Set a parameter table for a row dimension. From ded4f8aed2c6ec0e71ff56a4605f6d7b9fcf4f7e Mon Sep 17 00:00:00 2001 From: Johannes Kasimir Date: Mon, 19 Feb 2024 15:58:43 +0100 Subject: [PATCH 05/12] fix: do only one lookup --- src/sciline/pipeline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index b5f5aef7..2e44fc39 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -551,9 +551,9 @@ def _get_provider( handler = handler or HandleAsBuildTimeException() explanation: List[str] = [] - if tp in self._providers: + if provider := self._providers.get(tp): # Optimization to quickly find non-generic providers - matches = [(self._providers[tp], {})] + matches = [(provider, {})] else: matches = [ (provider, bound) From 7b4149bc3bc0070c6551f3a2ff6f3d78d28c5e40 Mon Sep 17 00:00:00 2001 From: Johannes Kasimir Date: Mon, 19 Feb 2024 16:06:09 +0100 Subject: [PATCH 06/12] mypy --- src/sciline/pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index 2e44fc39..f8f32688 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -436,7 +436,7 @@ def __setitem__(self, key: Type[T], param: T) -> None: def __getitem__(self, tp: Union[Type[T], Item[T]]) -> Provider: return self.get(tp).graph[tp] - def __contains__(self, tp: Union[Type[T], Item[T]]): + def __contains__(self, tp: Union[Type[T], Item[T]]) -> bool: try: self.get(tp) except AmbiguousProvider: From 86da6fc5a6f207746f6da75ec7f4e770c0b8b18f Mon Sep 17 00:00:00 2001 From: Johannes Kasimir Date: Mon, 19 Feb 2024 16:08:49 +0100 Subject: [PATCH 07/12] fix --- src/sciline/pipeline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index f8f32688..5a296c83 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -434,11 +434,11 @@ def __setitem__(self, key: Type[T], param: T) -> None: self._set_provider(key, Provider.parameter(param)) def __getitem__(self, tp: Union[Type[T], Item[T]]) -> Provider: - return self.get(tp).graph[tp] + return self._get_unique_provider(tp, HandleAsBuildTimeException()) def __contains__(self, tp: Union[Type[T], Item[T]]) -> bool: try: - self.get(tp) + self[tp] except AmbiguousProvider: return True except UnsatisfiedRequirement: From 81fba8238f95fd45fe9274ce2a60e55eed1c5a12 Mon Sep 17 00:00:00 2001 From: Johannes Kasimir Date: Mon, 19 Feb 2024 16:18:05 +0100 Subject: [PATCH 08/12] mypy --- src/sciline/pipeline.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index 5a296c83..ba5cad39 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -434,7 +434,7 @@ def __setitem__(self, key: Type[T], param: T) -> None: self._set_provider(key, Provider.parameter(param)) def __getitem__(self, tp: Union[Type[T], Item[T]]) -> Provider: - return self._get_unique_provider(tp, HandleAsBuildTimeException()) + return self._get_unique_provider(tp, HandleAsBuildTimeException())[0] def __contains__(self, tp: Union[Type[T], Item[T]]) -> bool: try: @@ -933,8 +933,8 @@ def __copy__(self) -> Pipeline: return self.copy() def _repr_html_(self) -> str: - providers = ( + providers = [ (get_origin(tp), get_args(tp), value) for tp, value in self._providers.items() - ) # type: ignore[var-annotated] + ] return pipeline_html_repr(providers) From 33c015b876c8d36bb346cb8da7dfbe9b800b1db2 Mon Sep 17 00:00:00 2001 From: Johannes Kasimir Date: Tue, 20 Feb 2024 09:19:29 +0100 Subject: [PATCH 09/12] test: getitem --- tests/pipeline_test.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/tests/pipeline_test.py b/tests/pipeline_test.py index 402d93b7..28e2bf77 100644 --- a/tests/pipeline_test.py +++ b/tests/pipeline_test.py @@ -1543,3 +1543,32 @@ def __new__(cls, x: int) -> str: # type: ignore[misc] with pytest.raises(TypeError): sl.Pipeline([C], params={int: 3}) + + +def test_getitem() -> None: + def p(c: int) -> float: + return float(c) + + pipeline = sl.Pipeline([p], params={int: 3}) + assert isinstance(pipeline[int], sl.typing.Provider) + assert isinstance(pipeline[float], sl.typing.Provider) + with pytest.raises(sl.UnsatisfiedRequirement): + pipeline[str] + + +def test_getitem_generic() -> None: + Number = TypeVar('Number', int, float) + + @dataclass + class Double(Generic[Number]): + number: Number + + def p(n: Number) -> Double[Number]: + return 2 * n + + pipeline = sl.Pipeline([p]) + assert isinstance(pipeline[Double[int]], sl.typing.Provider) + assert isinstance(pipeline[Double[float]], sl.typing.Provider) + + with pytest.raises(sl.UnsatisfiedRequirement): + pipeline[Double[str]] From 49b088a4b06009ba58a87fa0c1ac088b99b83a43 Mon Sep 17 00:00:00 2001 From: Johannes Kasimir Date: Tue, 20 Feb 2024 10:03:03 +0100 Subject: [PATCH 10/12] test: contains --- tests/pipeline_test.py | 63 ++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 61 insertions(+), 2 deletions(-) diff --git a/tests/pipeline_test.py b/tests/pipeline_test.py index 28e2bf77..3d7022b6 100644 --- a/tests/pipeline_test.py +++ b/tests/pipeline_test.py @@ -1545,7 +1545,7 @@ def __new__(cls, x: int) -> str: # type: ignore[misc] sl.Pipeline([C], params={int: 3}) -def test_getitem() -> None: +def test_pipeline_getitem() -> None: def p(c: int) -> float: return float(c) @@ -1556,7 +1556,7 @@ def p(c: int) -> float: pipeline[str] -def test_getitem_generic() -> None: +def test_pipeline_getitem_generic() -> None: Number = TypeVar('Number', int, float) @dataclass @@ -1572,3 +1572,62 @@ def p(n: Number) -> Double[Number]: with pytest.raises(sl.UnsatisfiedRequirement): pipeline[Double[str]] + + +def test_pipeline_getitem_ambiguous() -> None: + N1 = TypeVar('N1', int, float) + N2 = TypeVar('N2', int, float) + + @dataclass + class Two(Generic[N1, N2]): + a: N1 + b: N2 + + def p1(n: N1) -> Two[N1, float]: + return Two[N1, N1](n, 1.0) + + def p2(n: N2) -> Two[int, N2]: + return Two[N1, N2](1, n) + + pipeline = sl.Pipeline([p1, p2]) + assert isinstance(pipeline[Two[float, float]], sl.typing.Provider) + assert isinstance(pipeline[Two[int, int]], sl.typing.Provider) + with pytest.raises(sl.AmbiguousProvider): + pipeline[Two[int, float]] + + +def test_pipeline_contains() -> None: + N1 = TypeVar('N1', int, float) + N2 = TypeVar('N2', int, float) + + @dataclass + class One(Generic[N1]): + a: N1 + + @dataclass + class Two(Generic[N1, N2]): + a: N1 + b: N2 + + def p1(c: int) -> float: + return float(c) + + def p2(n: N1) -> One[N1]: + return 2 * n + + def p3(n: N1) -> Two[N1, float]: + return Two[N1, N1](n, 1.0) + + def p4(n: N2) -> Two[int, N2]: + return Two[N1, N2](1, n) + + pipeline = sl.Pipeline([p1, p2, p3, p4], params={int: 3}) + assert float in pipeline + assert int in pipeline + assert One[int] in pipeline + assert One[float] in pipeline + assert Two[int, float] in pipeline + + assert str not in pipeline + assert One[str] not in pipeline + assert Two[str, float] not in pipeline From 11002b71a07fa1a4638ec6134fc32ae8258995d0 Mon Sep 17 00:00:00 2001 From: Johannes Kasimir Date: Fri, 23 Feb 2024 10:19:50 +0100 Subject: [PATCH 11/12] better names + don't expose provider class --- src/sciline/pipeline.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index ba5cad39..4eb4b5cc 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -433,17 +433,25 @@ def __setitem__(self, key: Type[T], param: T) -> None: ) self._set_provider(key, Provider.parameter(param)) - def __getitem__(self, tp: Union[Type[T], Item[T]]) -> Provider: - return self._get_unique_provider(tp, HandleAsBuildTimeException())[0] - - def __contains__(self, tp: Union[Type[T], Item[T]]) -> bool: + def get_provider( + self, tp: Union[Type[T], Item[T]] + ) -> Union[Callable[[Any, ...], T], T]: + '''Get the provider that produces the type. If the type is provided by a + parameter the method returns the value of the parameter.''' + provider = self._get_unique_provider(tp, HandleAsBuildTimeException())[0] + if provider.kind == 'function': + return provider.func + return provider.func() + + def has_provider(self, tp: Union[Type[T], Item[T]]) -> bool: + '''Determines if the pipeline has a provider that produces the type.''' try: - self[tp] + self.get_provider(tp) + return True except AmbiguousProvider: return True except UnsatisfiedRequirement: return False - return True def set_param_table(self, params: ParamTable) -> None: """ From bc536260bf7adce1491d660f3d9f51e61b448679 Mon Sep 17 00:00:00 2001 From: Johannes Kasimir Date: Fri, 23 Feb 2024 10:43:55 +0100 Subject: [PATCH 12/12] fix tests --- tests/pipeline_test.py | 46 +++++++++++++++++++++--------------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/tests/pipeline_test.py b/tests/pipeline_test.py index 3d7022b6..34d4f64b 100644 --- a/tests/pipeline_test.py +++ b/tests/pipeline_test.py @@ -1545,18 +1545,18 @@ def __new__(cls, x: int) -> str: # type: ignore[misc] sl.Pipeline([C], params={int: 3}) -def test_pipeline_getitem() -> None: +def test_pipeline_get_provider() -> None: def p(c: int) -> float: - return float(c) + return float(c + 1) pipeline = sl.Pipeline([p], params={int: 3}) - assert isinstance(pipeline[int], sl.typing.Provider) - assert isinstance(pipeline[float], sl.typing.Provider) + assert pipeline.get_provider(int) == 3 + assert pipeline.get_provider(float) is p with pytest.raises(sl.UnsatisfiedRequirement): - pipeline[str] + pipeline.get_provider(str) -def test_pipeline_getitem_generic() -> None: +def test_pipeline_get_provider_generic() -> None: Number = TypeVar('Number', int, float) @dataclass @@ -1567,14 +1567,14 @@ def p(n: Number) -> Double[Number]: return 2 * n pipeline = sl.Pipeline([p]) - assert isinstance(pipeline[Double[int]], sl.typing.Provider) - assert isinstance(pipeline[Double[float]], sl.typing.Provider) + assert pipeline.get_provider(Double[int]) is p + assert pipeline.get_provider(Double[float]) is p with pytest.raises(sl.UnsatisfiedRequirement): - pipeline[Double[str]] + pipeline.get_provider(Double[str]) -def test_pipeline_getitem_ambiguous() -> None: +def test_pipeline_get_provider_ambiguous() -> None: N1 = TypeVar('N1', int, float) N2 = TypeVar('N2', int, float) @@ -1590,13 +1590,13 @@ def p2(n: N2) -> Two[int, N2]: return Two[N1, N2](1, n) pipeline = sl.Pipeline([p1, p2]) - assert isinstance(pipeline[Two[float, float]], sl.typing.Provider) - assert isinstance(pipeline[Two[int, int]], sl.typing.Provider) + assert pipeline.get_provider(Two[float, float]) is p1 + assert pipeline.get_provider(Two[int, int]) is p2 with pytest.raises(sl.AmbiguousProvider): - pipeline[Two[int, float]] + pipeline.get_provider(Two[int, float]) -def test_pipeline_contains() -> None: +def test_pipeline_has_provider() -> None: N1 = TypeVar('N1', int, float) N2 = TypeVar('N2', int, float) @@ -1622,12 +1622,12 @@ def p4(n: N2) -> Two[int, N2]: return Two[N1, N2](1, n) pipeline = sl.Pipeline([p1, p2, p3, p4], params={int: 3}) - assert float in pipeline - assert int in pipeline - assert One[int] in pipeline - assert One[float] in pipeline - assert Two[int, float] in pipeline - - assert str not in pipeline - assert One[str] not in pipeline - assert Two[str, float] not in pipeline + assert pipeline.has_provider(float) + assert pipeline.has_provider(int) + assert pipeline.has_provider(One[int]) + assert pipeline.has_provider(One[float]) + assert pipeline.has_provider(Two[int, float]) + + assert not pipeline.has_provider(str) + assert not pipeline.has_provider(One[str]) + assert not pipeline.has_provider(Two[str, float])