Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add method to get provider from type #135

Closed
wants to merge 12 commits into from
137 changes: 70 additions & 67 deletions src/sciline/pipeline.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

About the conceptual issues, I was wrong stating that this makes Pipeline a Mapping. It does not because we would have to also implement __iter__ and __len__. And actually, since there is a __setitem__, if anything, it should be MutableMapping, thus we would also need __delitem__.
However, Pipeline implements part of the interface (and also the interface of Sequence). I'm not entirely sure how it fits in.

More concretely, implementing __getitem__ makes Python think that pipelines are iterable with integer indices:

import sciline as sl

def to_string(i: int) -> str:
    return str(i)

pl = sl.Pipeline([to_string], params={int: 3})
list(pl)

raises

sciline.handler.UnsatisfiedRequirement: ('No provider found for type', 0)

This not great. So you would also need to implement __iter__. And since you would be going for a dict-like interface, you would also need keys, values, and items.

And that leaves us with the asymmetry between __getitem__ and __setitem__. For generic providers, the former currently returns different values and accepts different keys than the latter. Even if __setitem__ were extended to support providers on top of parameters.
Also, what about parameter tables or sentinels? Would __getitem__ return those special, internal providers? Or how would it represent them?

With the proposed semantics, the new functions allow checking what a pipeline can produce and how. So it does not allow treating Pipeline like a container. So named methods seem more appropriate to me.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this makes a lot of sense.
How about adding a get_provider method instead of __getitem__?

Also, what about parameter tables or sentinels? Would getitem return those special, internal providers? Or how would it represent them?

I don't know, what do you think? I'm thinking that it's likely that a user wants to access the parameter tables that a pipeline uses, but they should preferably be returned in the same form that they were inserted, i.e. as sl.ParamTable.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know, what do you think? I'm thinking that it's likely that a user wants to access the parameter tables that a pipeline uses, but they should preferably be returned in the same form that they were inserted, i.e. as sl.ParamTable.

Just like the discussion about using a graph data structure inside Pipeline, this shows that the implicit way in which we handle param-tables is problematic. We may need to re-think that (come up with a more explicit interface, or way the express the nesting?).

Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,6 @@ def __init__(
Dictionary of concrete values to provide for types.
"""
self._providers: Dict[Key, Provider] = {}
self._subproviders: Dict[type, Dict[Tuple[Key | TypeVar, ...], Provider]] = {}
self._param_tables: Dict[Key, ParamTable] = {}
self._param_name_to_table_key: Dict[Key, Key] = {}
for provider in providers or []:
Expand Down Expand Up @@ -434,6 +433,26 @@ def __setitem__(self, key: Type[T], param: T) -> None:
)
self._set_provider(key, Provider.parameter(param))

def get_provider(
self, tp: Union[Type[T], Item[T]]
) -> Union[Callable[[Any, ...], T], T]:
'''Get the provider that produces the type. If the type is provided by a
parameter the method returns the value of the parameter.'''
provider = self._get_unique_provider(tp, HandleAsBuildTimeException())[0]
if provider.kind == 'function':
return provider.func
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does it return the function and not the provider object?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because I agreed with the concern that the Provider class isn't currently exposed, and returning a Provider here would expose it. I think just returning the same type that was provided by the user keeps things simpler.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not concerned about exposing it. This can be useful in its own right.

Your solution makes it difficult to distinguish between providers and parameters. You can't check the type of the return value of get_provider because parameters can be callable and providers can now be any callable object, not just functions. So you would have to compare the returned type to tp.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My only concern about expose Provider is that it would increase the api surface area. I agree that it's a problem that you can't generally determine if the returned provider is a parameter or a function.

return provider.func()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about other kinds?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that's something we need to talk about. What other cases are there, I looked at the ProviderKinds type but didn't recognize all of them. However, I think it's unlikely that a user will request other types, so for now I'm open to just raising NotImplemented in the other cases.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, users might request Series. But that, along with the other kinds, may be subject to change if param tables are removed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's what I thought as well, so I don't think it makes sense to invest time to solve that now.


def has_provider(self, tp: Union[Type[T], Item[T]]) -> bool:
'''Determines if the pipeline has a provider that produces the type.'''
try:
self.get_provider(tp)
return True
except AmbiguousProvider:
return True
except UnsatisfiedRequirement:
return False

def set_param_table(self, params: ParamTable) -> None:
"""
Set a parameter table for a row dimension.
Expand Down Expand Up @@ -532,72 +551,63 @@ def _set_provider(
'Series is a special container reserved for use in conjunction with '
'sciline.ParamTable and must not be provided directly.'
)
if (origin := get_origin(key)) is not None:
subproviders = self._subproviders.setdefault(origin, {})
args = get_args(key)
subproviders[args] = provider
else:
self._providers[key] = provider
self._providers[key] = provider

def _get_provider(
self, tp: Union[Type[T], Item[T]], handler: Optional[ErrorHandler] = None
) -> Tuple[Provider, Dict[TypeVar, Key]]:
handler = handler or HandleAsBuildTimeException()
explanation: List[str] = []
if (provider := self._providers.get(tp)) is not None:
return provider, {}
elif (origin := get_origin(tp)) is not None and (
subproviders := self._subproviders.get(origin)
) is not None:
requested = get_args(tp)

if provider := self._providers.get(tp):
# Optimization to quickly find non-generic providers
matches = [(provider, {})]
else:
matches = [
(subprovider, bound)
for args, subprovider in subproviders.items()
if (
bound := _find_bounds_to_make_compatible_type_tuple(requested, args)
)
(provider, bound)
for return_type, provider in self._providers.items()
if (bound := _find_bounds_to_make_compatible_type(tp, return_type))
is not None
]
typevar_counts = [len(bound) for _, bound in matches]
min_typevar_count = min(typevar_counts, default=0)
matches = [
m
for count, m in zip(typevar_counts, matches)
if count == min_typevar_count
]

if len(matches) == 1:
provider, bound = matches[0]
return provider, bound
elif len(matches) > 1:
matching_providers = [provider.location.name for provider, _ in matches]
raise AmbiguousProvider(
f"Multiple providers found for type {tp}."
f" Matching providers are: {matching_providers}."
)
else:
typevars_in_expression = _extract_typevars_from_generic_type(origin)
if typevars_in_expression:
explanation = [
''.join(
map(
str,
typevar_counts = [len(bound) for _, bound in matches]
min_typevar_count = min(typevar_counts, default=0)
matches = [
m for count, m in zip(typevar_counts, matches) if count == min_typevar_count
]

if len(matches) == 1:
provider, bound = matches[0]
return provider, bound
elif len(matches) > 1:
matching_providers = [provider.location.name for provider, _ in matches]
raise AmbiguousProvider(
f"Multiple providers found for type {tp}."
f" Matching providers are: {matching_providers}."
)
else:
origin = get_origin(tp)
typevars_of_generic = _extract_typevars_from_generic_type(origin)
if typevars_of_generic:
explanation = [
''.join(
map(
str,
(
'Note that ',
keyname(origin[typevars_of_generic]),
' has constraints ',
(
'Note that ',
keyname(origin[typevars_in_expression]),
' has constraints ',
(
{
keyname(tv): tuple(
map(keyname, tv.__constraints__)
)
for tv in typevars_in_expression
}
),
{
keyname(tv): tuple(
map(keyname, tv.__constraints__)
)
for tv in typevars_of_generic
}
),
)
),
)
]
)
]
return handler.handle_unsatisfied_requirement(tp, *explanation), {}

def _get_unique_provider(
Expand Down Expand Up @@ -923,7 +933,6 @@ def copy(self) -> Pipeline:
"""
out = Pipeline()
out._providers = self._providers.copy()
out._subproviders = {k: v.copy() for k, v in self._subproviders.items()}
out._param_tables = self._param_tables.copy()
out._param_name_to_table_key = self._param_name_to_table_key.copy()
return out
Expand All @@ -932,14 +941,8 @@ def __copy__(self) -> Pipeline:
return self.copy()

def _repr_html_(self) -> str:
providers_without_parameters = (
(origin, tuple(), value) for origin, value in self._providers.items()
) # type: ignore[var-annotated]
providers_with_parameters = (
(origin, args, value)
for origin in self._subproviders
for args, value in self._subproviders[origin].items()
)
return pipeline_html_repr(
chain(providers_without_parameters, providers_with_parameters)
)
providers = [
(get_origin(tp), get_args(tp), value)
for tp, value in self._providers.items()
]
return pipeline_html_repr(providers)
88 changes: 88 additions & 0 deletions tests/pipeline_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1543,3 +1543,91 @@ def __new__(cls, x: int) -> str: # type: ignore[misc]

with pytest.raises(TypeError):
sl.Pipeline([C], params={int: 3})


def test_pipeline_get_provider() -> None:
def p(c: int) -> float:
return float(c + 1)

pipeline = sl.Pipeline([p], params={int: 3})
assert pipeline.get_provider(int) == 3
assert pipeline.get_provider(float) is p
with pytest.raises(sl.UnsatisfiedRequirement):
pipeline.get_provider(str)


def test_pipeline_get_provider_generic() -> None:
Number = TypeVar('Number', int, float)

@dataclass
class Double(Generic[Number]):
number: Number

def p(n: Number) -> Double[Number]:
return 2 * n

pipeline = sl.Pipeline([p])
assert pipeline.get_provider(Double[int]) is p
assert pipeline.get_provider(Double[float]) is p

with pytest.raises(sl.UnsatisfiedRequirement):
pipeline.get_provider(Double[str])


def test_pipeline_get_provider_ambiguous() -> None:
N1 = TypeVar('N1', int, float)
N2 = TypeVar('N2', int, float)

@dataclass
class Two(Generic[N1, N2]):
a: N1
b: N2

def p1(n: N1) -> Two[N1, float]:
return Two[N1, N1](n, 1.0)

def p2(n: N2) -> Two[int, N2]:
return Two[N1, N2](1, n)

pipeline = sl.Pipeline([p1, p2])
assert pipeline.get_provider(Two[float, float]) is p1
assert pipeline.get_provider(Two[int, int]) is p2
with pytest.raises(sl.AmbiguousProvider):
pipeline.get_provider(Two[int, float])


def test_pipeline_has_provider() -> None:
N1 = TypeVar('N1', int, float)
N2 = TypeVar('N2', int, float)

@dataclass
class One(Generic[N1]):
a: N1

@dataclass
class Two(Generic[N1, N2]):
a: N1
b: N2

def p1(c: int) -> float:
return float(c)

def p2(n: N1) -> One[N1]:
return 2 * n

def p3(n: N1) -> Two[N1, float]:
return Two[N1, N1](n, 1.0)

def p4(n: N2) -> Two[int, N2]:
return Two[N1, N2](1, n)

pipeline = sl.Pipeline([p1, p2, p3, p4], params={int: 3})
assert pipeline.has_provider(float)
assert pipeline.has_provider(int)
assert pipeline.has_provider(One[int])
assert pipeline.has_provider(One[float])
assert pipeline.has_provider(Two[int, float])

assert not pipeline.has_provider(str)
assert not pipeline.has_provider(One[str])
assert not pipeline.has_provider(Two[str, float])
Loading