Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support structural pattern matching in ASYNC103,104 and 91X #363

Merged
merged 5 commits into from
Apr 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ repos:
hooks:
- id: mypy
# uses py311 syntax, mypy configured for py39
exclude: tests/(eval|autofix)_files/.*_py311.py
exclude: tests/(eval|autofix)_files/.*_py(310|311).py

- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.396
Expand Down
4 changes: 4 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ Changelog

`CalVer, YY.month.patch <https://calver.org/>`_

25.4.1
======
- Add match-case (structural pattern matching) support to ASYNC103, 104, 910, 911 & 912.

25.3.1
======
- Add except* support to ASYNC102, 103, 104, 120, 910, 911, 912.
Expand Down
2 changes: 1 addition & 1 deletion docs/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ adding the following to your ``.pre-commit-config.yaml``:
minimum_pre_commit_version: '2.9.0'
repos:
- repo: https://github.com/python-trio/flake8-async
rev: 25.3.1
rev: 25.4.1
hooks:
- id: flake8-async
# args: ["--enable=ASYNC100,ASYNC112", "--disable=", "--autofix=ASYNC"]
Expand Down
2 changes: 1 addition & 1 deletion flake8_async/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@


# CalVer: YY.month.patch, e.g. first release of July 2022 == "22.7.1"
__version__ = "25.3.1"
__version__ = "25.4.1"


# taken from https://github.com/Zac-HD/shed
Expand Down
18 changes: 18 additions & 0 deletions flake8_async/visitors/visitor103_104.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,24 @@ def visit_If(self, node: ast.If):
# if body didn't raise, or it's unraised after else, set unraise
self.unraised = not body_raised or self.unraised

def visit_Match(self, node: ast.Match): # type: ignore[name-defined]
if not self.unraised:
return
all_cases_raise = True
has_fallback = False
for case in node.cases:
# check for "bare pattern", i.e `case varname:`
has_fallback |= (
case.guard is None
and isinstance(case.pattern, ast.MatchAs) # type: ignore[attr-defined]
and case.pattern.pattern is None
)
self.visit_nodes(case.body)
all_cases_raise &= not self.unraised
self.unraised = True

self.unraised = not (all_cases_raise and has_fallback)

# A loop is guaranteed to raise if:
# condition always raises, or
# else always raises, and
Expand Down
66 changes: 66 additions & 0 deletions flake8_async/visitors/visitor91x.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,22 @@ def copy(self):
)


@dataclass
class MatchState:
# TryState, LoopState, and MatchState all do fairly similar things. It would be nice
# to harmonize them and share logic.
base_uncheckpointed_statements: set[Statement] = field(default_factory=set)
case_uncheckpointed_statements: set[Statement] = field(default_factory=set)
has_fallback: bool = False

def copy(self):
return MatchState(
base_uncheckpointed_statements=self.base_uncheckpointed_statements.copy(),
case_uncheckpointed_statements=self.case_uncheckpointed_statements.copy(),
has_fallback=self.has_fallback,
)


def checkpoint_statement(library: str) -> cst.SimpleStatementLine:
# logic before this should stop code from wanting to insert the non-existing
# asyncio.lowlevel.checkpoint
Expand Down Expand Up @@ -373,6 +389,7 @@ def __init__(self, *args: Any, **kwargs: Any):

self.loop_state = LoopState()
self.try_state = TryState()
self.match_state = MatchState()

# ASYNC100
self.has_checkpoint_stack: list[bool] = []
Expand Down Expand Up @@ -894,6 +911,55 @@ def visit_IfExp(self, node: cst.IfExp) -> bool:
self.leave_If(node, node) # type: ignore
return False # libcst shouldn't visit subnodes again

def leave_Match_subject(self, node: cst.Match) -> None:
# We start the match logic after parsing the subject, instead of visit_Match,
# since the subject is always executed and might checkpoint.
if not self.async_function:
return
self.save_state(node, "match_state", copy=True)
self.match_state = MatchState(
base_uncheckpointed_statements=self.uncheckpointed_statements.copy()
)

def visit_MatchCase(self, node: cst.MatchCase) -> None:
# enter each case from the state after parsing the subject
self.uncheckpointed_statements = self.match_state.base_uncheckpointed_statements

def leave_MatchCase_guard(self, node: cst.MatchCase) -> None:
# `case _:` is no pattern and no guard, which means we know body is executed.
# But we also know that `case _ if <guard>:` is guaranteed to execute the guard,
# so for later logic we can treat them the same *if* there's no pattern and that
# guard checkpoints.
if (
isinstance(node.pattern, cst.MatchAs)
and node.pattern.pattern is None
and (node.guard is None or not self.uncheckpointed_statements)
):
self.match_state.has_fallback = True

def leave_MatchCase(
self, original_node: cst.MatchCase, updated_node: cst.MatchCase
) -> cst.MatchCase:
# collect the state at the end of each case
self.match_state.case_uncheckpointed_statements.update(
self.uncheckpointed_statements
)
return updated_node

def leave_Match(
self, original_node: cst.Match, updated_node: cst.Match
) -> cst.Match:
# leave the Match with the worst-case of all branches
self.uncheckpointed_statements = self.match_state.case_uncheckpointed_statements
# if no fallback, also add the state at entering the match (after parsing subject)
if not self.match_state.has_fallback:
self.uncheckpointed_statements.update(
self.match_state.base_uncheckpointed_statements
)

self.restore_state(original_node)
return updated_node

def visit_While(self, node: cst.While | cst.For):
self.save_state(
node,
Expand Down
2 changes: 1 addition & 1 deletion flake8_async/visitors/visitor_utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from re import Match

import libcst as cst
from libcst._position import CodeRange
from libcst.metadata import CodeRange


@utility_visitor
Expand Down
90 changes: 90 additions & 0 deletions tests/autofix_files/async91x_py310.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# ARG --enable=ASYNC910,ASYNC911,ASYNC913
# AUTOFIX
# ASYNCIO_NO_AUTOFIX
import trio


async def foo(): ...


async def match_subject() -> None:
match await foo():
case False:
pass


async def match_not_all_cases() -> ( # ASYNC910: 0, "exit", Statement("function definition", lineno)
None
):
match foo():
case 1:
...
case _:
await foo()
await trio.lowlevel.checkpoint()


async def match_no_fallback() -> ( # ASYNC910: 0, "exit", Statement("function definition", lineno)
None
):
match foo():
case 1:
await foo()
case 2:
await foo()
case _ if True:
await foo()
await trio.lowlevel.checkpoint()


async def match_fallback_is_guarded() -> ( # ASYNC910: 0, "exit", Statement("function definition", lineno)
None
):
match foo():
case 1:
await foo()
case 2:
await foo()
case _ if foo():
await foo()
await trio.lowlevel.checkpoint()


async def match_all_cases() -> None:
match foo():
case 1:
await foo()
case 2:
await foo()
case _:
await foo()


async def match_fallback_await_in_guard() -> None:
# The case guard is only executed if the pattern matches, so we can mostly treat
# it as part of the body, except for a special case for fallback+checkpointing guard.
match foo():
case 1 if await foo():
...
case _ if await foo():
...


async def match_checkpoint_guard() -> None:
# The above pattern is quite cursed, but this seems fairly reasonable to do.
match foo():
case 1 if await foo():
...
case _:
await foo()


async def match_not_checkpoint_in_all_guards() -> ( # ASYNC910: 0, "exit", Statement("function definition", lineno)
None
):
match foo():
case 1:
...
case _ if await foo():
...
await trio.lowlevel.checkpoint()
31 changes: 31 additions & 0 deletions tests/autofix_files/async91x_py310.py.diff
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
---
+++
@@ x,6 x,7 @@
...
case _:
await foo()
+ await trio.lowlevel.checkpoint()


async def match_no_fallback() -> ( # ASYNC910: 0, "exit", Statement("function definition", lineno)
@@ x,6 x,7 @@
await foo()
case _ if True:
await foo()
+ await trio.lowlevel.checkpoint()


async def match_fallback_is_guarded() -> ( # ASYNC910: 0, "exit", Statement("function definition", lineno)
@@ x,6 x,7 @@
await foo()
case _ if foo():
await foo()
+ await trio.lowlevel.checkpoint()


async def match_all_cases() -> None:
@@ x,3 x,4 @@
...
case _ if await foo():
...
+ await trio.lowlevel.checkpoint()
58 changes: 58 additions & 0 deletions tests/eval_files/async103_104_py310.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""Test for ASYNC103/ASYNC104 with structural pattern matching

ASYNC103: no-reraise-cancelled
ASYNC104: cancelled-not-raised
"""

# ARG --enable=ASYNC103,ASYNC104


def foo() -> Any: ...


try:
...
except BaseException as e: # ASYNC103_trio: 7, "BaseException"
match foo():
case True:
raise e
case False:
...
case _:
raise e

try:
...
except BaseException: # ASYNC103_trio: 7, "BaseException"
match foo():
case True:
raise

try:
...
except BaseException: # safe
match foo():
case True:
raise
case False:
raise
case _:
raise
try:
...
except BaseException: # ASYNC103_trio: 7, "BaseException"
match foo():
case _ if foo():
raise
try:
...
except BaseException: # ASYNC103_trio: 7, "BaseException"
match foo():
case 1:
return # ASYNC104: 12
case 2:
raise
case 3:
return # ASYNC104: 12
case blah:
raise
Loading
Loading