Skip to content

Commit

Permalink
ENH: Covariance
Browse files Browse the repository at this point in the history
Make `Result[T, E]` and `Option[T]` covariant, so that subclasses of `T`
and `E` satisfy types for methods that return `Result[T, E]`. This
significantly improves the flexibility of methods like `.map()`,
`.and_then()`, `.map_err()`, etc.
  • Loading branch information
mplanchard committed Sep 23, 2020
1 parent 0eaef65 commit 2c4fa67
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 32 deletions.
19 changes: 19 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,21 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

## [1.5.0] - 2020-09-23

### Added

- Type inference now allows covariance for `Result` and `Option` wrapped types
- Allows a function of type `Callable[[int], RuntimeError]` to be applied
via flatmap (`.and_then()`) to a result of type `Result[int, Exception]`
- Allows e.g. any of the following to be assigned to a type of
`Result[Number, Exception]`:
- `Result[int, RuntimeError]`
- `Result[float, TypeError]`
- etc.
- This makes `.and_then()`/`.flatmap()`, `.map()`, `.map_err()`, and so on
much more convenient to use in result chains.

## [1.4.0] - 2020-03-09

### Added
Expand Down Expand Up @@ -84,8 +99,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Apache license

[Unreleased]: https://github.com/mplanchard/safetywrap/compare/v1.4.0...HEAD
<<<<<<< HEAD
[1.4.0]: https://github.com/mplanchard/safetywrap/compare/v1.3.1...v1.4.0
[1.3.1]: https://github.com/mplanchard/safetywrap/compare/v1.3.0...v1.3.1
=======
[1.4.0]: https://github.com/mplanchard/safetywrap/compare/v1.3.0...v1.4.0
>>>>>>> be5a2ed... Covariance
[1.3.0]: https://github.com/mplanchard/safetywrap/compare/v1.2.0...v1.3.0
[1.2.0]: https://github.com/mplanchard/safetywrap/compare/v1.1.0...v1.2.0
[1.1.0]: https://github.com/mplanchard/safetywrap/compare/v1.0.2...v1.1.0
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@
REQ_FILE = join(PACKAGE_DIR, "requirements_unfrozen.txt")
if exists(REQ_FILE):
with open(join(PACKAGE_DIR, "requirements.txt")) as reqfile:
for ln in (l.strip() for l in reqfile):
for ln in (line.strip() for line in reqfile):
if ln and not ln.startswith("#"):
PACKAGE_DEPENDENCIES += (ln,)

Expand Down Expand Up @@ -153,5 +153,5 @@
setup_requires=SETUP_DEPENDENCIES,
tests_require=TEST_DEPENDENCIES,
url=URL,
version=__version__,
version="0.0.0",
)
28 changes: 14 additions & 14 deletions src/safetywrap/_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from ._interface import _Option, _Result


T = t.TypeVar("T")
E = t.TypeVar("E")
T = t.TypeVar("T", covariant=True)
E = t.TypeVar("E", covariant=True)
U = t.TypeVar("U")
F = t.TypeVar("F")

Expand Down Expand Up @@ -49,8 +49,8 @@ def of(

@staticmethod
def collect(
iterable: t.Iterable["Result[T, E]"],
) -> "Result[t.Tuple[T, ...], E]":
iterable: t.Iterable["Result[U, F]"],
) -> "Result[t.Tuple[U, ...], F]":
"""Collect an iterable of Results into a Result of an iterable.
Given some iterable of type Iterable[Result[T, E]], try to collect
Expand All @@ -71,22 +71,22 @@ def collect(
hinted, either by a variable annotation or a return type.
"""
# Non-functional code here to enable true short-circuiting.
ok_vals: t.Tuple[T, ...] = ()
ok_vals: t.Tuple[U, ...] = ()
for result in iterable:
if result.is_err():
return result.map(lambda _: ())
ok_vals += (result.unwrap(),)
return Ok(ok_vals)

@staticmethod
def err_if(predicate: t.Callable[[T], bool], value: T) -> "Result[T, T]":
def err_if(predicate: t.Callable[[U], bool], value: U) -> "Result[U, U]":
"""Return Err(val) if predicate(val) is True, otherwise Ok(val)."""
if predicate(value):
return Err(value)
return Ok(value)

@staticmethod
def ok_if(predicate: t.Callable[[T], bool], value: T) -> "Result[T, T]":
def ok_if(predicate: t.Callable[[U], bool], value: U) -> "Result[U, U]":
"""Return Ok(val) if predicate(val) is True, otherwise Err(val)."""
if predicate(value):
return Ok(value)
Expand All @@ -110,14 +110,14 @@ def of(value: t.Optional[T]) -> "Option[T]":
return Some(value)

@staticmethod
def nothing_if(predicate: t.Callable[[T], bool], value: T) -> "Option[T]":
def nothing_if(predicate: t.Callable[[U], bool], value: U) -> "Option[U]":
"""Return Nothing() if predicate(val) is True, else Some(val)."""
if predicate(value):
return Nothing()
return Some(value)

@staticmethod
def some_if(predicate: t.Callable[[T], bool], value: T) -> "Option[T]":
def some_if(predicate: t.Callable[[U], bool], value: U) -> "Option[U]":
"""Return Some(val) if predicate(val) is True, else Nothing()."""
if predicate(value):
return Some(value)
Expand Down Expand Up @@ -558,12 +558,12 @@ def map_or_else(
"""Apply `fn` to contained value, or compute a default."""
return fn(self._value)

def ok_or(self, err: E) -> Result[T, E]:
def ok_or(self, err: F) -> Result[T, F]:
"""Transform an option into a `Result`.
Maps `Some(v)` to `Ok(v)` or `None` to `Err(err)`.
"""
res: Result[T, E] = Ok(self._value)
res: Result[T, F] = Ok(self._value)
return res

def ok_or_else(self, err_fn: t.Callable[[], E]) -> Result[T, E]:
Expand Down Expand Up @@ -741,12 +741,12 @@ def map_or_else(
"""Apply `fn` to contained value, or compute a default."""
return default()

def ok_or(self, err: E) -> Result[T, E]:
def ok_or(self, err: F) -> Result[T, F]:
"""Transform an option into a `Result`.
Maps `Some(v)` to `Ok(v)` or `None` to `Err(err)`.
"""
res: Result[T, E] = Err(err)
res: Result[T, F] = Err(err)
return res

def ok_or_else(self, err_fn: t.Callable[[], E]) -> Result[T, E]:
Expand Down Expand Up @@ -790,7 +790,7 @@ def __ne__(self, other: t.Any) -> bool:

def __str__(self) -> str:
"""Return a string representation of Nothing()."""
return f"Nothing()"
return "Nothing()"

def __repr__(self) -> str:
"""Return a string representation of Nothing()."""
Expand Down
18 changes: 9 additions & 9 deletions src/safetywrap/_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

# pylint: disable=invalid-name

T = t.TypeVar("T")
E = t.TypeVar("E")
T = t.TypeVar("T", covariant=True)
E = t.TypeVar("E", covariant=True)
U = t.TypeVar("U")
F = t.TypeVar("F")

Expand Down Expand Up @@ -51,8 +51,8 @@ def of(

@staticmethod
def collect(
iterable: t.Iterable["Result[T, E]"],
) -> "Result[t.Tuple[T, ...], E]":
iterable: t.Iterable["Result[U, F]"],
) -> "Result[t.Tuple[U, ...], F]":
"""Convert an iterable of Results into a Result of an iterable.
Given some iterable of type Iterable[Result[T, E]], try to collect
Expand All @@ -63,12 +63,12 @@ def collect(
raise NotImplementedError

@staticmethod
def err_if(predicate: t.Callable[[T], bool], value: T) -> "Result[T, T]":
def err_if(predicate: t.Callable[[U], bool], value: U) -> "Result[U, U]":
"""Return Err(val) if predicate(val) is True, otherwise Ok(val)."""
raise NotImplementedError

@staticmethod
def ok_if(predicate: t.Callable[[T], bool], value: T) -> "Result[T, T]":
def ok_if(predicate: t.Callable[[U], bool], value: U) -> "Result[U, U]":
"""Return Ok(val) if predicate(val) is True, otherwise Err(val)."""
raise NotImplementedError

Expand Down Expand Up @@ -234,12 +234,12 @@ def of(value: t.Optional[T]) -> "Option[T]":
raise NotImplementedError

@staticmethod
def nothing_if(predicate: t.Callable[[T], bool], value: T) -> "Option[T]":
def nothing_if(predicate: t.Callable[[U], bool], value: U) -> "Option[U]":
"""Return Nothing() if predicate(val) is True, else Some(val)."""
raise NotImplementedError

@staticmethod
def some_if(predicate: t.Callable[[T], bool], value: T) -> "Option[T]":
def some_if(predicate: t.Callable[[U], bool], value: U) -> "Option[U]":
"""Return Some(val) if predicate(val) is True, else Nothing()."""
raise NotImplementedError

Expand Down Expand Up @@ -352,7 +352,7 @@ def map_or_else(
"""Apply `fn` to contained value, or compute a default."""
raise NotImplementedError

def ok_or(self, err: E) -> "Result[T, E]":
def ok_or(self, err: F) -> "Result[T, F]":
"""Transform an option into a `Result`.
Maps `Some(v)` to `Ok(v)` or `None` to `Err(err)`.
Expand Down
14 changes: 11 additions & 3 deletions tests/test_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,11 @@ def filter_meths(cls: t.Type, meth: str) -> bool:
@pytest.mark.parametrize(
"meth",
filter(
lambda m: TestNoConcretesInInterfaces.filter_meths(_Result, m),
lambda m: TestNoConcretesInInterfaces.filter_meths(
# No idea why it thinks `m` is "object", not "str"
_Result,
m, # type: ignore
),
_Result.__dict__,
),
)
Expand All @@ -115,8 +119,12 @@ def test_no_concrete_result_methods(self, meth: str) -> None:
@pytest.mark.parametrize(
"meth",
filter(
lambda m: TestNoConcretesInInterfaces.filter_meths(_Option, m),
_Option.__dict__,
lambda m: TestNoConcretesInInterfaces.filter_meths(
# No idea why it thinks `m` is "object", not "str"
_Option,
m, # type: ignore
),
_Option.__dict__.keys(),
),
)
def test_no_concrete_option_methods(self, meth: str) -> None:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_option.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def test_expect_and_aliases_raising(
self, method: str, exc_cls: t.Type[Exception]
) -> None:
"""Can specify exception msg/cls if value is not Some()."""
exp_exc = exc_cls if exc_cls else RuntimeError
exp_exc: t.Type[Exception] = exc_cls if exc_cls else RuntimeError
kwargs = {"exc_cls": exc_cls} if exc_cls else {}
msg = "not what I expected"

Expand Down
37 changes: 34 additions & 3 deletions tests/test_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,22 @@ def test_and_then(
"""Test that and_then chains result-generating functions."""
assert start.and_then(first).and_then(second) == exp

def test_and_then_covariance(self) -> None:
"""Covariant errors are acceptable for flatmapping."""

class MyInt(int):
"""A subclass of int."""

def process_int(i: int) -> Result[MyInt, RuntimeError]:
return Err(RuntimeError(f"it broke {i}"))

start: Result[int, Exception] = Ok(5)
# We can flatmap w/a function that takes any covariant type of
# int or Exception. The result remains the original exception type,
# since we cannot guarantee narrowing to the covariant type.
flatmapped: Result[int, Exception] = start.and_then(process_int)
assert flatmapped

def test_flatmap(self) -> None:
"""Flatmap is an alias for and_then"""
ok: Result[int, int] = Ok(2)
Expand Down Expand Up @@ -192,7 +208,7 @@ def test_err(self, start: Result[int, str], exp: Option[str]) -> None:
@pytest.mark.parametrize("exc_cls", (None, IOError))
def test_expect_raising(self, exc_cls: t.Type[Exception]) -> None:
"""Test expecting a value to be Ok()."""
exp_exc = exc_cls if exc_cls else RuntimeError
exp_exc: t.Type[Exception] = exc_cls if exc_cls else RuntimeError
kwargs = {"exc_cls": exc_cls} if exc_cls else {}
input_val = 2
msg = "not what I expected"
Expand All @@ -206,7 +222,7 @@ def test_expect_raising(self, exc_cls: t.Type[Exception]) -> None:
@pytest.mark.parametrize("exc_cls", (None, IOError))
def test_raise_if_err_raising(self, exc_cls: t.Type[Exception]) -> None:
"""Test raise_if_err for Err() values."""
exp_exc = exc_cls if exc_cls else RuntimeError
exp_exc: t.Type[Exception] = exc_cls if exc_cls else RuntimeError
kwargs = {"exc_cls": exc_cls} if exc_cls else {}
input_val = 2
msg = "not what I expected"
Expand All @@ -228,7 +244,7 @@ def test_raise_if_err_ok(self) -> None:
@pytest.mark.parametrize("exc_cls", (None, IOError))
def test_expect_err_raising(self, exc_cls: t.Type[Exception]) -> None:
"""Test expecting a value to be Ok()."""
exp_exc = exc_cls if exc_cls else RuntimeError
exp_exc: t.Type[Exception] = exc_cls if exc_cls else RuntimeError
kwargs = {"exc_cls": exc_cls} if exc_cls else {}
msg = "not what I expected"

Expand Down Expand Up @@ -265,6 +281,21 @@ def test_map(self, start: Result[int, str], exp: Result[int, str]) -> None:
""".map() will map onto Ok() and ignore Err()."""
assert start.map(lambda x: int(x ** 2)) == exp

def test_map_covariance(self) -> None:
"""The input type to the map fn is covariant."""

class MyStr(str):
"""Subclass of str."""

def to_mystr(string: str) -> MyStr:
return MyStr(string) if not isinstance(string, MyStr) else string

start: Result[str, str] = Ok("foo")
# We can assign the result to [str, str] even though we know it's
# actually a MyStr, since MyStr is covariant with str
end: Result[str, str] = start.map(to_mystr)
assert end == Ok(MyStr("foo"))

@pytest.mark.parametrize(
"start, exp", ((Ok("foo"), Ok("foo")), (Err(2), Err("2")))
)
Expand Down

0 comments on commit 2c4fa67

Please sign in to comment.