Skip to content

Commit

Permalink
Improve match def default handling
Browse files Browse the repository at this point in the history
Resolves   #618.
  • Loading branch information
evhub committed Feb 25, 2024
1 parent 0adf2d6 commit fcd1043
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 30 deletions.
12 changes: 9 additions & 3 deletions DOCS.md
Original file line number Diff line number Diff line change
Expand Up @@ -1149,11 +1149,17 @@ depth: 1

### `match`

Coconut provides fully-featured, functional pattern-matching through its `match` statements.
Coconut provides fully-featured, functional pattern-matching through its `match` statements. Coconut `match` syntax is a strict superset of [Python's `match` syntax](https://peps.python.org/pep-0636/).

_Note: In describing Coconut's pattern-matching syntax, this section focuses on `match` statements, but Coconut's pattern-matching can also be used in many other places, such as [pattern-matching function definition](#pattern-matching-functions), [`case` statements](#case), [destructuring assignment](#destructuring-assignment), [`match data`](#match-data), and [`match for`](#match-for)._

##### Overview

Match statements follow the basic syntax `match <pattern> in <value>`. The match statement will attempt to match the value against the pattern, and if successful, bind any variables in the pattern to whatever is in the same position in the value, and execute the code below the match statement. Match statements also support, in their basic syntax, an `if <cond>` that will check the condition after executing the match before executing the code below, and an `else` statement afterwards that will only be executed if the `match` statement is not. What is allowed in the match statement's pattern has no equivalent in Python, and thus the specifications below are provided to explain it.
Match statements follow the basic syntax `match <pattern> in <value>`. The match statement will attempt to match the value against the pattern, and if successful, bind any variables in the pattern to whatever is in the same position in the value, and execute the code below the match statement.

Match statements also support, in their basic syntax, an `if <cond>` that will check the condition after executing the match before executing the code below, and an `else` statement afterwards that will only be executed if the `match` statement is not.

All pattern-matching in Coconut is atomic, such that no assignments will be executed unless the whole match succeeds.

##### Syntax Specification

Expand Down Expand Up @@ -2494,7 +2500,7 @@ If `<pattern>` has a variable name (via any variable binding that binds the enti

In addition to supporting pattern-matching in their arguments, pattern-matching function definitions also have a couple of notable differences compared to Python functions. Specifically:
- If pattern-matching function definition fails, it will raise a [`MatchError`](#matcherror) (just like [destructuring assignment](#destructuring-assignment)) instead of a `TypeError`.
- All defaults in pattern-matching function definition are late-bound rather than early-bound. Thus, `match def f(xs=[]) = xs` will instantiate a new list for each call where `xs` is not given, unlike `def f(xs=[]) = xs`, which will use the same list for all calls where `xs` is unspecified.
- All defaults in pattern-matching function definition are late-bound rather than early-bound. Thus, `match def f(xs=[]) = xs` will instantiate a new list for each call where `xs` is not given, unlike `def f(xs=[]) = xs`, which will use the same list for all calls where `xs` is unspecified. This also allows defaults for later arguments to be specified in terms of matched values from earlier arguments, as in `match def f(x, y=x) = (x, y)`.

Pattern-matching function definition can also be combined with `async` functions, [`copyclosure` functions](#copyclosure-functions), [`yield` functions](#explicit-generators), [infix function definition](#infix-functions), and [assignment function syntax](#assignment-functions). The various keywords in front of the `def` can be put in any order.

Expand Down
97 changes: 71 additions & 26 deletions coconut/compiler/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,39 @@ def get_set_name_var(self, name):
"""Gets the var for checking whether a name should be set."""
return match_set_name_var + "_" + name

def add_default_expr(self, assign_to, default_expr):
"""Add code that evaluates expr in the context of any names that have been matched so far
and assigns the result to assign_to if assign_to is currently _coconut_sentinel."""
vars_var = self.get_temp_var()
add_names_code = []
for name in self.names:
add_names_code.append(
handle_indentation(
"""
if {set_name_var} is not _coconut_sentinel:
{vars_var}["{name}"] = {set_name_var}
""",
add_newline=True,
).format(
set_name_var=self.get_set_name_var(name),
vars_var=vars_var,
name=name,
)
)
code = self.comp.reformat_post_deferred_code_proc(assign_to + " = " + default_expr)
self.add_def(handle_indentation("""
if {assign_to} is _coconut_sentinel:
{vars_var} = _coconut.globals().copy()
{vars_var}.update(_coconut.locals().copy())
{add_names_code}_coconut_exec({code_str}, {vars_var})
{assign_to} = {vars_var}["{assign_to}"]
""").format(
vars_var=vars_var,
add_names_code="".join(add_names_code),
assign_to=assign_to,
code_str=self.comp.wrap_str_of(code),
))

def register_name(self, name):
"""Register a new name at the current position."""
internal_assert(lambda: name not in self.parent_names and name not in self.names, "attempt to register duplicate name", name)
Expand Down Expand Up @@ -373,7 +406,7 @@ def match_function(
).format(
first_arg=first_arg,
args=args,
),
)
)

with self.down_a_level():
Expand Down Expand Up @@ -418,7 +451,7 @@ def match_in_args_kwargs(self, pos_only_match_args, match_args, args, kwargs, al
# if i >= req_len
"_coconut.sum((_coconut.len(" + args + ") > " + str(i) + ", "
+ ", ".join('"' + name + '" in ' + kwargs for name in names)
+ ")) == 1",
+ ")) == 1"
)
tempvar = self.get_temp_var()
self.add_def(
Expand All @@ -428,24 +461,27 @@ def match_in_args_kwargs(self, pos_only_match_args, match_args, args, kwargs, al
kwargs + '.pop("' + name + '") if "' + name + '" in ' + kwargs + " else "
for name in names[:-1]
)
+ kwargs + '.pop("' + names[-1] + '")',
+ kwargs + '.pop("' + names[-1] + '")'
)
with self.down_a_level():
self.match(match, tempvar)
else:
if not names:
tempvar = self.get_temp_var()
self.add_def(tempvar + " = " + args + "[" + str(i) + "] if _coconut.len(" + args + ") > " + str(i) + " else " + default)
with self.down_a_level():
self.match(match, tempvar)
self.add_def(tempvar + " = " + args + "[" + str(i) + "] if _coconut.len(" + args + ") > " + str(i) + " else _coconut_sentinel")
# go down to end to ensure we've matched as much as possible before evaluating the default
with self.down_to_end():
self.add_default_expr(tempvar, default)
with self.down_a_level():
self.match(match, tempvar)
else:
arg_checks[i] = (
# if i < req_len
None,
# if i >= req_len
"_coconut.sum((_coconut.len(" + args + ") > " + str(i) + ", "
+ ", ".join('"' + name + '" in ' + kwargs for name in names)
+ ")) <= 1",
+ ")) <= 1"
)
tempvar = self.get_temp_var()
self.add_def(
Expand All @@ -455,10 +491,13 @@ def match_in_args_kwargs(self, pos_only_match_args, match_args, args, kwargs, al
kwargs + '.pop("' + name + '") if "' + name + '" in ' + kwargs + " else "
for name in names
)
+ default,
+ "_coconut_sentinel"
)
with self.down_a_level():
self.match(match, tempvar)
# go down to end to ensure we've matched as much as possible before evaluating the default
with self.down_to_end():
self.add_default_expr(tempvar, default)
with self.down_a_level():
self.match(match, tempvar)

# length checking
max_len = None if allow_star_args else len(pos_only_match_args) + len(match_args)
Expand All @@ -484,12 +523,18 @@ def match_in_kwargs(self, match_args, kwargs):
kwargs + '.pop("' + name + '") if "' + name + '" in ' + kwargs + " else "
for name in names
)
+ (default if default is not None else "_coconut_sentinel"),
+ "_coconut_sentinel"
)
with self.down_a_level():
if default is None:
if default is None:
with self.down_a_level():
self.add_check(tempvar + " is not _coconut_sentinel")
self.match(match, tempvar)
self.match(match, tempvar)
else:
# go down to end to ensure we've matched as much as possible before evaluating the default
with self.down_to_end():
self.add_default_expr(tempvar, default)
with self.down_a_level():
self.match(match, tempvar)

def match_dict(self, tokens, item):
"""Matches a dictionary."""
Expand Down Expand Up @@ -1054,7 +1099,7 @@ def match_class(self, tokens, item):
).format(
num_pos_matches=len(pos_matches),
cls_name=cls_name,
),
)
)
else:
self_match_matcher.match(pos_matches[0], item)
Expand All @@ -1077,7 +1122,7 @@ def match_class(self, tokens, item):
num_pos_matches=len(pos_matches),
type_any=self.comp.wrap_comment(" type: _coconut.typing.Any"),
type_ignore=self.comp.type_ignore_comment(),
),
)
)
with other_cls_matcher.down_a_level():
for i, match in enumerate(pos_matches):
Expand All @@ -1098,7 +1143,7 @@ def match_class(self, tokens, item):
star_match_var=star_match_var,
item=item,
num_pos_matches=len(pos_matches),
),
)
)
with self.down_a_level():
self.match(star_match, star_match_var)
Expand All @@ -1118,7 +1163,7 @@ def match_data(self, tokens, item):
"_coconut.len({item}) >= {min_len}".format(
item=item,
min_len=len(pos_matches),
),
)
)

self.match_all_in(pos_matches, item)
Expand Down Expand Up @@ -1152,7 +1197,7 @@ def match_data(self, tokens, item):
min_len=len(pos_matches),
name_matches=tuple_str_of(name_matches, add_quotes=True),
type_ignore=self.comp.type_ignore_comment(),
),
)
)
with self.down_a_level():
self.add_check(temp_var)
Expand All @@ -1172,7 +1217,7 @@ def match_data_or_class(self, tokens, item):
is_data_var=is_data_var,
cls_name=cls_name,
type_ignore=self.comp.type_ignore_comment(),
),
)
)

if_data, if_class = self.branches(2)
Expand Down Expand Up @@ -1248,7 +1293,7 @@ def match_view(self, tokens, item):
func_result_var=func_result_var,
view_func=view_func,
item=item,
),
)
)

with self.down_a_level():
Expand Down Expand Up @@ -1325,7 +1370,7 @@ def out(self):
check_var=self.check_var,
parameterization=parameterization,
child_checks=child.out().rstrip(),
),
)
)

# handle normal child groups
Expand Down Expand Up @@ -1353,7 +1398,7 @@ def out(self):
).format(
check_var=self.check_var,
children_checks=children_checks,
),
)
)

# commit variable definitions
Expand All @@ -1369,7 +1414,7 @@ def out(self):
).format(
set_name_var=self.get_set_name_var(name),
name=name,
),
)
)
if name_set_code:
out.append(
Expand All @@ -1381,7 +1426,7 @@ def out(self):
).format(
check_var=self.check_var,
name_set_code="".join(name_set_code),
),
)
)

# handle guards
Expand All @@ -1396,7 +1441,7 @@ def out(self):
).format(
check_var=self.check_var,
guards=paren_join(self.guards, "and"),
),
)
)

return "".join(out)
Expand Down
2 changes: 1 addition & 1 deletion coconut/root.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
VERSION = "3.0.4"
VERSION_NAME = None
# False for release, int >= 1 for develop
DEVELOP = 21
DEVELOP = 22
ALPHA = False # for pre releases rather than post releases

assert DEVELOP is False or DEVELOP >= 1, "DEVELOP must be False or an int >= 1"
Expand Down
3 changes: 3 additions & 0 deletions coconut/tests/src/cocotest/agnostic/primary_2.coco
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,9 @@ def primary_test_2() -> bool:
assert [=> y for y in range(2)] |> map$(call) |> list == [1, 1]
assert [def => y for y in range(2)] |> map$(call) |> list == [0, 1]
assert (x -> x -> def y -> (x, y))(1)(2)(3) == (2, 3)
match def maybe_dup(x, y=x) = (x, y)
assert maybe_dup(1) == (1, 1) == maybe_dup(x=1)
assert maybe_dup(1, 2) == (1, 2) == maybe_dup(x=1, y=2)

with process_map.multiple_sequential_calls(): # type: ignore
assert map((+), range(3), range(4)$[:-1], strict=True) |> list == [0, 2, 4] == process_map((+), range(3), range(4)$[:-1], strict=True) |> list # type: ignore
Expand Down

0 comments on commit fcd1043

Please sign in to comment.