From 76c956cd8b1ad0d19d83e1c92ccdab090d6d16f6 Mon Sep 17 00:00:00 2001 From: Evan Hubinger Date: Thu, 21 Mar 2024 01:50:31 -0700 Subject: [PATCH] Improve PEP 695 implementation Refs #757. --- coconut/compiler/compiler.py | 59 +++++++++++++++---- coconut/compiler/templates/header.py_template | 2 +- coconut/constants.py | 2 +- coconut/root.py | 2 +- coconut/tests/main_test.py | 10 ++-- coconut/tests/src/extras.coco | 3 +- 6 files changed, 57 insertions(+), 21 deletions(-) diff --git a/coconut/compiler/compiler.py b/coconut/compiler/compiler.py index 4a42cf145..306f0aced 100644 --- a/coconut/compiler/compiler.py +++ b/coconut/compiler/compiler.py @@ -3180,7 +3180,16 @@ def classdef_handle(self, original, loc, tokens): """Process class definitions.""" decorators, name, paramdefs, classlist_toks, body = tokens - out = "".join(paramdefs) + decorators + "class " + name + out = "" + + # paramdefs are type params on >= 3.12 and type var assignments on < 3.12 + if paramdefs: + if self.target_info >= (3, 12): + name += "[" + ", ".join(paramdefs) + "]" + else: + out += "".join(paramdefs) + + out += decorators + "class " + name # handle classlist base_classes = [] @@ -3210,7 +3219,7 @@ def classdef_handle(self, original, loc, tokens): base_classes.append(join_args(pos_args, star_args, kwd_args, dubstar_args)) - if paramdefs: + if paramdefs and self.target_info < (3, 12): base_classes.append(self.get_generic_for_typevars()) if not classlist_toks and not self.target.startswith("3"): @@ -3442,9 +3451,16 @@ def assemble_data(self, decorators, name, namedtuple_call, inherit, extra_stmts, IMPORTANT: Any changes to assemble_data must be reflected in the definition of Expected in header.py_template. """ + print(paramdefs) # create class - out = [ - "".join(paramdefs), + out = [] + if paramdefs: + # paramdefs are type params on >= 3.12 and type var assignments on < 3.12 + if self.target_info >= (3, 12): + name += "[" + ", ".join(paramdefs) + "]" + else: + out += ["".join(paramdefs)] + out += [ decorators, "class ", name, @@ -3453,7 +3469,7 @@ def assemble_data(self, decorators, name, namedtuple_call, inherit, extra_stmts, ] if inherit is not None: out += [", ", inherit] - if paramdefs: + if paramdefs and self.target_info < (3, 12): out += [", ", self.get_generic_for_typevars()] if not self.target.startswith("3"): out.append(", _coconut.object") @@ -4564,15 +4580,21 @@ def funcname_typeparams_handle(self, tokens): return name else: name, paramdefs = tokens - return self.add_code_before_marker_with_replacement(name, "".join(paramdefs), add_spaces=False) + # paramdefs are type params on >= 3.12 and type var assignments on < 3.12 + if self.target_info >= (3, 12): + return name + "[" + ", ".join(paramdefs) + "]" + else: + return self.add_code_before_marker_with_replacement(name, "".join(paramdefs), add_spaces=False) funcname_typeparams_handle.ignore_one_token = True def type_param_handle(self, original, loc, tokens): """Compile a type param into an assignment.""" args = "" + raw_bound = None bound_op = None bound_op_type = "" + stars = "" if "TypeVar" in tokens: TypeVarFunc = "TypeVar" bound_op_type = "bound" @@ -4580,18 +4602,24 @@ def type_param_handle(self, original, loc, tokens): name_loc, name = tokens else: name_loc, name, bound_op, bound = tokens + # raw_bound is for >=3.12, so it is for_py_typedef, but args is for <3.12, so it isn't + raw_bound = self.wrap_typedef(bound, for_py_typedef=True) args = ", bound=" + self.wrap_typedef(bound, for_py_typedef=False) elif "TypeVar constraint" in tokens: TypeVarFunc = "TypeVar" bound_op_type = "constraint" name_loc, name, bound_op, constraints = tokens + # for_py_typedef is different in the two cases here as above + raw_bound = ", ".join(self.wrap_typedef(c, for_py_typedef=True) for c in constraints) args = ", " + ", ".join(self.wrap_typedef(c, for_py_typedef=False) for c in constraints) elif "TypeVarTuple" in tokens: TypeVarFunc = "TypeVarTuple" name_loc, name = tokens + stars = "*" elif "ParamSpec" in tokens: TypeVarFunc = "ParamSpec" name_loc, name = tokens + stars = "**" else: raise CoconutInternalException("invalid type_param tokens", tokens) @@ -4612,8 +4640,14 @@ def type_param_handle(self, original, loc, tokens): loc, ) + # on >= 3.12, return a type param + if self.target_info >= (3, 12): + return stars + name + (": " + raw_bound if raw_bound is not None else "") + + # on < 3.12, return a type variable assignment + kwargs = "" - # uncomment these lines whenever mypy adds support for infer_variance in TypeVar + # TODO: uncomment these lines whenever mypy adds support for infer_variance in TypeVar # (and remove the warning about it in the DOCS) # if TypeVarFunc == "TypeVar": # kwargs += ", infer_variance=True" @@ -4644,6 +4678,7 @@ def type_param_handle(self, original, loc, tokens): def get_generic_for_typevars(self): """Get the Generic instances for the current typevars.""" + internal_assert(self.target_info < (3, 12), "get_generic_for_typevars should only be used on targets < 3.12") typevar_info = self.current_parsing_context("typevars") internal_assert(typevar_info is not None, "get_generic_for_typevars called with no typevars") generics = [] @@ -4677,16 +4712,18 @@ def type_alias_stmt_handle(self, tokens): paramdefs = () else: name, paramdefs, typedef = tokens - out = "".join(paramdefs) + + # paramdefs are type params on >= 3.12 and type var assignments on < 3.12 if self.target_info >= (3, 12): - out += "type " + name + " = " + self.wrap_typedef(typedef, for_py_typedef=True) + if paramdefs: + name += "[" + ", ".join(paramdefs) + "]" + return "type " + name + " = " + self.wrap_typedef(typedef, for_py_typedef=True) else: - out += self.typed_assign_stmt_handle([ + return "".join(paramdefs) + self.typed_assign_stmt_handle([ name, "_coconut.typing.TypeAlias", self.wrap_typedef(typedef, for_py_typedef=False), ]) - return out def where_item_handle(self, tokens): """Manage where items.""" diff --git a/coconut/compiler/templates/header.py_template b/coconut/compiler/templates/header.py_template index eba9a8a2e..ca203aaed 100644 --- a/coconut/compiler/templates/header.py_template +++ b/coconut/compiler/templates/header.py_template @@ -61,7 +61,7 @@ class _coconut{object}:{COMMENT.EVERYTHING_HERE_MUST_BE_COPIED_TO_STUB_FILE} reiterables = abc.Sequence, abc.Mapping, abc.Set fmappables = list, tuple, dict, set, frozenset, bytes, bytearray abc.Sequence.register(collections.deque) - Ellipsis, NotImplemented, NotImplementedError, Exception, AttributeError, ImportError, IndexError, KeyError, NameError, TypeError, ValueError, StopIteration, RuntimeError, all, any, bool, bytes, callable, chr, classmethod, complex, dict, enumerate, filter, float, frozenset, getattr, hasattr, hash, id, int, isinstance, issubclass, iter, len, list, locals, globals, map, min, max, next, object, ord, property, range, reversed, set, setattr, slice, str, sum, super, tuple, type, vars, zip, repr, print{comma_bytearray} = Ellipsis, NotImplemented, NotImplementedError, Exception, AttributeError, ImportError, IndexError, KeyError, NameError, TypeError, ValueError, StopIteration, RuntimeError, all, any, bool, bytes, callable, chr, classmethod, complex, dict, enumerate, filter, float, frozenset, getattr, hasattr, hash, id, int, isinstance, issubclass, iter, len, list, locals, globals, map, min, max, next, object, ord, property, range, reversed, set, setattr, slice, str, sum, {lstatic}super{rstatic}, tuple, type, vars, zip, {lstatic}repr{rstatic}, {lstatic}print{rstatic}{comma_bytearray} + Ellipsis, NotImplemented, NotImplementedError, Exception, AttributeError, ImportError, IndexError, KeyError, NameError, TypeError, ValueError, StopIteration, RuntimeError, all, any, bool, bytes, callable, chr, classmethod, complex, dict, enumerate, filter, float, frozenset, getattr, hasattr, hash, id, int, isinstance, issubclass, iter, len, list, locals, globals, map, min, max, next, object, ord, property, range, reversed, set, setattr, slice, str, sum, super, tuple, type, vars, zip, repr, print{comma_bytearray} = Ellipsis, NotImplemented, NotImplementedError, Exception, AttributeError, ImportError, IndexError, KeyError, NameError, TypeError, ValueError, StopIteration, RuntimeError, all, any, bool, bytes, callable, chr, classmethod, complex, dict, enumerate, filter, float, frozenset, getattr, hasattr, hash, id, int, isinstance, issubclass, iter, len, list, locals, globals, map, {lstatic}min{rstatic}, {lstatic}max{rstatic}, next, object, ord, property, range, reversed, set, setattr, slice, str, sum, {lstatic}super{rstatic}, tuple, type, vars, zip, {lstatic}repr{rstatic}, {lstatic}print{rstatic}{comma_bytearray} @_coconut.functools.wraps(_coconut.functools.partial) def _coconut_partial(_coconut_func, *args, **kwargs): partial_func = _coconut.functools.partial(_coconut_func, *args, **kwargs) diff --git a/coconut/constants.py b/coconut/constants.py index 3df18eb5b..62e5f2967 100644 --- a/coconut/constants.py +++ b/coconut/constants.py @@ -93,7 +93,7 @@ def get_path_env_var(env_var, default): PY38 and not WINDOWS and not PYPY - # disabled until MyPy supports PEP 695 + # TODO: disabled until MyPy supports PEP 695 and not PY312 ) XONSH = ( diff --git a/coconut/root.py b/coconut/root.py index 9b0305794..bcd320001 100644 --- a/coconut/root.py +++ b/coconut/root.py @@ -26,7 +26,7 @@ VERSION = "3.1.0" VERSION_NAME = None # False for release, int >= 1 for develop -DEVELOP = 5 +DEVELOP = 6 ALPHA = False # for pre releases rather than post releases assert DEVELOP is False or DEVELOP >= 1, "DEVELOP must be False or an int >= 1" diff --git a/coconut/tests/main_test.py b/coconut/tests/main_test.py index ea43b7319..39d12d20a 100644 --- a/coconut/tests/main_test.py +++ b/coconut/tests/main_test.py @@ -1092,11 +1092,11 @@ def test_bbopt(self): if not PYPY and PY38 and not PY310: install_bbopt() - def test_pyprover(self): - with using_paths(pyprover): - comp_pyprover() - if PY38: - run_pyprover() + # def test_pyprover(self): + # with using_paths(pyprover): + # comp_pyprover() + # if PY38: + # run_pyprover() def test_pyston(self): with using_paths(pyston): diff --git a/coconut/tests/src/extras.coco b/coconut/tests/src/extras.coco index 0bb22fbde..91981ce4a 100644 --- a/coconut/tests/src/extras.coco +++ b/coconut/tests/src/extras.coco @@ -474,8 +474,7 @@ type Num = int | float""".strip()) assert parse("type L[T] = list[T]").strip().endswith(""" # Compiled Coconut: ----------------------------------------------------------- -_coconut_typevar_T_0 = _coconut.typing.TypeVar("_coconut_typevar_T_0") -type L = list[_coconut_typevar_T_0]""".strip()) +type L[T] = list[T]""".strip()) setup(line_numbers=False, minify=True) assert parse("123 # derp", "lenient") == "123# derp"