From 4a366eacb58cd351cc6576e52f446c886d26a60a Mon Sep 17 00:00:00 2001 From: Evan Hubinger Date: Tue, 12 Nov 2024 00:47:08 -0800 Subject: [PATCH] Add memoize.RECURSIVE Refs #858. --- DOCS.md | 2 + __coconut__/__init__.pyi | 1 + coconut/compiler/templates/header.py_template | 41 ++++++++++++++++--- coconut/root.py | 2 +- .../src/cocotest/agnostic/primary_2.coco | 14 +++++++ .../tests/src/cocotest/agnostic/suite.coco | 5 ++- coconut/tests/src/cocotest/agnostic/util.coco | 20 ++++++--- 7 files changed, 72 insertions(+), 13 deletions(-) diff --git a/DOCS.md b/DOCS.md index 720241ff..8e4b5ef1 100644 --- a/DOCS.md +++ b/DOCS.md @@ -3116,6 +3116,8 @@ _Note: Passing `--strict` disables deprecated features._ Coconut provides `functools.lru_cache` as a built-in under the name `memoize` with the modification that the _maxsize_ parameter is set to `None` by default. `memoize` makes the use case of optimizing recursive functions easier, as a _maxsize_ of `None` is usually what is desired in that case. +`memoize` also supports a special `maxsize=memoize.RECURSIVE` argument, which will allow the cache to grow without bound within a single call to the top-level function, but clear the cache after the top-level call returns. + Use of `memoize` requires `functools.lru_cache`, which exists in the Python 3 standard library, but under Python 2 will require `pip install backports.functools_lru_cache` to function. Additionally, if on Python 2 and `backports.functools_lru_cache` is present, Coconut will patch `functools` such that `functools.lru_cache = backports.functools_lru_cache.lru_cache`. Note that, if the function to be memoized is a generator or otherwise returns an iterator, [`recursive_generator`](#recursive_generator) can also be used to achieve a similar effect, the use of which is required for recursive generators. diff --git a/__coconut__/__init__.pyi b/__coconut__/__init__.pyi index f2bc2bdf..5b342905 100644 --- a/__coconut__/__init__.pyi +++ b/__coconut__/__init__.pyi @@ -206,6 +206,7 @@ _coconut_zip = zip zip_longest = _coconut.zip_longest memoize = _lru_cache +memoize.RECURSIVE = None # type: ignore reduce = _coconut.functools.reduce takewhile = _coconut.itertools.takewhile dropwhile = _coconut.itertools.dropwhile diff --git a/coconut/compiler/templates/header.py_template b/coconut/compiler/templates/header.py_template index 33fb4b4b..fe149946 100644 --- a/coconut/compiler/templates/header.py_template +++ b/coconut/compiler/templates/header.py_template @@ -1643,17 +1643,46 @@ def fmap(func, obj, **kwargs): else: mapped_obj = _coconut_map(func, obj) return _coconut_base_makedata(obj.__class__, mapped_obj, from_fmap=True, fallback_to_init=fallback_to_init) -def _coconut_memoize_helper(maxsize=None, typed=False): - return maxsize, typed def memoize(*args, **kwargs): """Decorator that memoizes a function, preventing it from being recomputed if it is called multiple times with the same arguments.""" if not kwargs and _coconut.len(args) == 1 and _coconut.callable(args[0]): - return _coconut.functools.lru_cache(maxsize=None)(args[0]) + return _coconut_memoize_helper()(args[0]) if _coconut.len(kwargs) == 1 and "user_function" in kwargs and _coconut.callable(kwargs["user_function"]): - return _coconut.functools.lru_cache(maxsize=None)(kwargs["user_function"]) - maxsize, typed = _coconut_memoize_helper(*args, **kwargs) - return _coconut.functools.lru_cache(maxsize, typed) + return _coconut_memoize_helper()(kwargs["user_function"]) + return _coconut_memoize_helper(*args, **kwargs) +memoize.RECURSIVE = _coconut_Sentinel() +def _coconut_memoize_helper(maxsize=None, typed=False): + if maxsize is memoize.RECURSIVE: + def memoizer(func): + """memoize(...)""" + inside = [False] + cache = {empty_dict} + @_coconut_wraps(func) + def memoized_func(*args, **kwargs): + if typed: + key = (_coconut.tuple((x, _coconut.type(x)) for x in args), _coconut.tuple((k, _coconut.type(k), v, _coconut.type(v)) for k, v in kwargs.items())) + else: + key = (args, _coconut.tuple(kwargs.items())) + got = cache.get(key, _coconut_sentinel) + if got is not _coconut_sentinel: + return got + outer_inside, inside[0] = inside[0], True + try: + got = func(*args, **kwargs) + cache[key] = got + return got + finally: + inside[0] = outer_inside + if not inside[0]: + cache.clear() + memoized_func.__module__ = _coconut.getattr(func, "__module__", None) + memoized_func.__name__ = _coconut.getattr(func, "__name__", None) + memoized_func.__qualname__ = _coconut.getattr(func, "__qualname__", None) + return memoized_func + return memoizer + else: + return _coconut.functools.lru_cache(maxsize, typed) {def_call_set_names} class override(_coconut_baseclass): """Declare a method in a subclass as an override of a parent class method. diff --git a/coconut/root.py b/coconut/root.py index 51ebd26d..1ca4d11b 100644 --- a/coconut/root.py +++ b/coconut/root.py @@ -26,7 +26,7 @@ VERSION = "3.1.2" VERSION_NAME = None # False for release, int >= 1 for develop -DEVELOP = 3 +DEVELOP = 4 ALPHA = False # for pre releases rather than post releases assert DEVELOP is False or DEVELOP >= 1, "DEVELOP must be False or an int >= 1" diff --git a/coconut/tests/src/cocotest/agnostic/primary_2.coco b/coconut/tests/src/cocotest/agnostic/primary_2.coco index 7d164584..5025f0b8 100644 --- a/coconut/tests/src/cocotest/agnostic/primary_2.coco +++ b/coconut/tests/src/cocotest/agnostic/primary_2.coco @@ -491,6 +491,20 @@ def primary_test_2() -> bool: assert reduce(function=(+), iterable=range(5), initial=-1) == 9 # type: ignore assert takewhile(predicate=ident, iterable=[1, 2, 1, 0, 1]) |> list == [1, 2, 1] # type: ignore assert dropwhile(predicate=(not), iterable=range(5)) |> list == [1, 2, 3, 4] # type: ignore + @memoize(typed=True) + def typed_memoized_func(x): + if x is 1: + return None + else: + return (x, typed_memoized_func(1)) + assert typed_memoized_func(1.0) == (1.0, None) + assert typed_memoized_func(1.0)[0] |> type == float + @memoize() + def untyped_memoized_func(x=None): + if x is None: + return (untyped_memoized_func(1), untyped_memoized_func(1.0)) + return x + assert untyped_memoized_func() |> map$(type) |> tuple == (int, float) with process_map.multiple_sequential_calls(): # type: ignore assert map((+), range(3), range(4)$[:-1], strict=True) |> list == [0, 2, 4] == process_map((+), range(3), range(4)$[:-1], strict=True) |> list # type: ignore diff --git a/coconut/tests/src/cocotest/agnostic/suite.coco b/coconut/tests/src/cocotest/agnostic/suite.coco index 3589aedf..a20dde50 100644 --- a/coconut/tests/src/cocotest/agnostic/suite.coco +++ b/coconut/tests/src/cocotest/agnostic/suite.coco @@ -562,7 +562,10 @@ def suite_test() -> bool: assert plus1sqsum_all(1, 2) == 13 == plus1sqsum_all_(1, 2) # type: ignore assert sum_list_range(10) == 45 assert sum2([3, 4]) == 7 - assert ridiculously_recursive(300) == 201666561657114122540576123152528437944095370972927688812965354745141489205495516550423117825 == ridiculously_recursive_(300) + with process_map.multiple_sequential_calls(): + for ridiculously_recursive in (ridiculously_recursive1, ridiculously_recursive2, ridiculously_recursive3): + got = process_map(ridiculously_recursive, [300]) + assert got == (201666561657114122540576123152528437944095370972927688812965354745141489205495516550423117825,), got assert [fib(n) for n in range(16)] == [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610] == [fib_(n) for n in range(16)] assert [fib_alt1(n) for n in range(16)] == [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610] == [fib_alt2(n) for n in range(16)] assert fib.cache_info().hits == 28 diff --git a/coconut/tests/src/cocotest/agnostic/util.coco b/coconut/tests/src/cocotest/agnostic/util.coco index 56dbe52c..e478aeff 100644 --- a/coconut/tests/src/cocotest/agnostic/util.coco +++ b/coconut/tests/src/cocotest/agnostic/util.coco @@ -1346,24 +1346,34 @@ def sum2(ab) = a + b where: # Memoization import functools -@memoize() -def ridiculously_recursive(n): +@memoize(None) +def ridiculously_recursive1(n): """Requires maxsize=None when called on large numbers.""" if n <= 0: return 1 result = 0 for i in range(1, 200): - result += ridiculously_recursive(n-i) + result += ridiculously_recursive1(n-i) return result @functools.lru_cache(maxsize=None) # type: ignore -def ridiculously_recursive_(n): +def ridiculously_recursive2(n): """Requires maxsize=None when called on large numbers.""" if n <= 0: return 1 result = 0 for i in range(1, 200): - result += ridiculously_recursive_(n-i) + result += ridiculously_recursive2(n-i) + return result + +@memoize(memoize.RECURSIVE) # type: ignore +def ridiculously_recursive3(n): + """Requires maxsize=None when called on large numbers.""" + if n <= 0: + return 1 + result = 0 + for i in range(1, 200): + result += ridiculously_recursive3(n-i) return result def fib(n if n < 2) = n