diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index af76552..7a726bf 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -42,7 +42,7 @@ repos:
hooks:
- id: mypy
# uses py311 syntax, mypy configured for py39
- exclude: tests/(eval|autofix)_files/.*_py311.py
+ exclude: tests/(eval|autofix)_files/.*_py(310|311).py
- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.396
diff --git a/docs/changelog.rst b/docs/changelog.rst
index 682b5db..0cf1f93 100644
--- a/docs/changelog.rst
+++ b/docs/changelog.rst
@@ -4,6 +4,10 @@ Changelog
`CalVer, YY.month.patch `_
+25.4.1
+======
+- Add match-case (structural pattern matching) support to ASYNC103, 104, 910, 911 & 912.
+
25.3.1
======
- Add except* support to ASYNC102, 103, 104, 120, 910, 911, 912.
diff --git a/docs/usage.rst b/docs/usage.rst
index 950a01f..4ba6d70 100644
--- a/docs/usage.rst
+++ b/docs/usage.rst
@@ -33,7 +33,7 @@ adding the following to your ``.pre-commit-config.yaml``:
minimum_pre_commit_version: '2.9.0'
repos:
- repo: https://github.com/python-trio/flake8-async
- rev: 25.3.1
+ rev: 25.4.1
hooks:
- id: flake8-async
# args: ["--enable=ASYNC100,ASYNC112", "--disable=", "--autofix=ASYNC"]
diff --git a/flake8_async/__init__.py b/flake8_async/__init__.py
index f635c79..a1f5546 100644
--- a/flake8_async/__init__.py
+++ b/flake8_async/__init__.py
@@ -38,7 +38,7 @@
# CalVer: YY.month.patch, e.g. first release of July 2022 == "22.7.1"
-__version__ = "25.3.1"
+__version__ = "25.4.1"
# taken from https://github.com/Zac-HD/shed
diff --git a/flake8_async/visitors/visitor103_104.py b/flake8_async/visitors/visitor103_104.py
index 502fa73..3e234e7 100644
--- a/flake8_async/visitors/visitor103_104.py
+++ b/flake8_async/visitors/visitor103_104.py
@@ -199,6 +199,24 @@ def visit_If(self, node: ast.If):
# if body didn't raise, or it's unraised after else, set unraise
self.unraised = not body_raised or self.unraised
+ def visit_Match(self, node: ast.Match): # type: ignore[name-defined]
+ if not self.unraised:
+ return
+ all_cases_raise = True
+ has_fallback = False
+ for case in node.cases:
+ # check for "bare pattern", i.e `case varname:`
+ has_fallback |= (
+ case.guard is None
+ and isinstance(case.pattern, ast.MatchAs) # type: ignore[attr-defined]
+ and case.pattern.pattern is None
+ )
+ self.visit_nodes(case.body)
+ all_cases_raise &= not self.unraised
+ self.unraised = True
+
+ self.unraised = not (all_cases_raise and has_fallback)
+
# A loop is guaranteed to raise if:
# condition always raises, or
# else always raises, and
diff --git a/flake8_async/visitors/visitor91x.py b/flake8_async/visitors/visitor91x.py
index adee79a..1a54a2e 100644
--- a/flake8_async/visitors/visitor91x.py
+++ b/flake8_async/visitors/visitor91x.py
@@ -198,6 +198,22 @@ def copy(self):
)
+@dataclass
+class MatchState:
+ # TryState, LoopState, and MatchState all do fairly similar things. It would be nice
+ # to harmonize them and share logic.
+ base_uncheckpointed_statements: set[Statement] = field(default_factory=set)
+ case_uncheckpointed_statements: set[Statement] = field(default_factory=set)
+ has_fallback: bool = False
+
+ def copy(self):
+ return MatchState(
+ base_uncheckpointed_statements=self.base_uncheckpointed_statements.copy(),
+ case_uncheckpointed_statements=self.case_uncheckpointed_statements.copy(),
+ has_fallback=self.has_fallback,
+ )
+
+
def checkpoint_statement(library: str) -> cst.SimpleStatementLine:
# logic before this should stop code from wanting to insert the non-existing
# asyncio.lowlevel.checkpoint
@@ -373,6 +389,7 @@ def __init__(self, *args: Any, **kwargs: Any):
self.loop_state = LoopState()
self.try_state = TryState()
+ self.match_state = MatchState()
# ASYNC100
self.has_checkpoint_stack: list[bool] = []
@@ -894,6 +911,55 @@ def visit_IfExp(self, node: cst.IfExp) -> bool:
self.leave_If(node, node) # type: ignore
return False # libcst shouldn't visit subnodes again
+ def leave_Match_subject(self, node: cst.Match) -> None:
+ # We start the match logic after parsing the subject, instead of visit_Match,
+ # since the subject is always executed and might checkpoint.
+ if not self.async_function:
+ return
+ self.save_state(node, "match_state", copy=True)
+ self.match_state = MatchState(
+ base_uncheckpointed_statements=self.uncheckpointed_statements.copy()
+ )
+
+ def visit_MatchCase(self, node: cst.MatchCase) -> None:
+ # enter each case from the state after parsing the subject
+ self.uncheckpointed_statements = self.match_state.base_uncheckpointed_statements
+
+ def leave_MatchCase_guard(self, node: cst.MatchCase) -> None:
+ # `case _:` is no pattern and no guard, which means we know body is executed.
+ # But we also know that `case _ if :` is guaranteed to execute the guard,
+ # so for later logic we can treat them the same *if* there's no pattern and that
+ # guard checkpoints.
+ if (
+ isinstance(node.pattern, cst.MatchAs)
+ and node.pattern.pattern is None
+ and (node.guard is None or not self.uncheckpointed_statements)
+ ):
+ self.match_state.has_fallback = True
+
+ def leave_MatchCase(
+ self, original_node: cst.MatchCase, updated_node: cst.MatchCase
+ ) -> cst.MatchCase:
+ # collect the state at the end of each case
+ self.match_state.case_uncheckpointed_statements.update(
+ self.uncheckpointed_statements
+ )
+ return updated_node
+
+ def leave_Match(
+ self, original_node: cst.Match, updated_node: cst.Match
+ ) -> cst.Match:
+ # leave the Match with the worst-case of all branches
+ self.uncheckpointed_statements = self.match_state.case_uncheckpointed_statements
+ # if no fallback, also add the state at entering the match (after parsing subject)
+ if not self.match_state.has_fallback:
+ self.uncheckpointed_statements.update(
+ self.match_state.base_uncheckpointed_statements
+ )
+
+ self.restore_state(original_node)
+ return updated_node
+
def visit_While(self, node: cst.While | cst.For):
self.save_state(
node,
diff --git a/flake8_async/visitors/visitor_utility.py b/flake8_async/visitors/visitor_utility.py
index 4474f21..1e70785 100644
--- a/flake8_async/visitors/visitor_utility.py
+++ b/flake8_async/visitors/visitor_utility.py
@@ -17,7 +17,7 @@
from re import Match
import libcst as cst
- from libcst._position import CodeRange
+ from libcst.metadata import CodeRange
@utility_visitor
diff --git a/tests/autofix_files/async91x_py310.py b/tests/autofix_files/async91x_py310.py
new file mode 100644
index 0000000..2f691d0
--- /dev/null
+++ b/tests/autofix_files/async91x_py310.py
@@ -0,0 +1,90 @@
+# ARG --enable=ASYNC910,ASYNC911,ASYNC913
+# AUTOFIX
+# ASYNCIO_NO_AUTOFIX
+import trio
+
+
+async def foo(): ...
+
+
+async def match_subject() -> None:
+ match await foo():
+ case False:
+ pass
+
+
+async def match_not_all_cases() -> ( # ASYNC910: 0, "exit", Statement("function definition", lineno)
+ None
+):
+ match foo():
+ case 1:
+ ...
+ case _:
+ await foo()
+ await trio.lowlevel.checkpoint()
+
+
+async def match_no_fallback() -> ( # ASYNC910: 0, "exit", Statement("function definition", lineno)
+ None
+):
+ match foo():
+ case 1:
+ await foo()
+ case 2:
+ await foo()
+ case _ if True:
+ await foo()
+ await trio.lowlevel.checkpoint()
+
+
+async def match_fallback_is_guarded() -> ( # ASYNC910: 0, "exit", Statement("function definition", lineno)
+ None
+):
+ match foo():
+ case 1:
+ await foo()
+ case 2:
+ await foo()
+ case _ if foo():
+ await foo()
+ await trio.lowlevel.checkpoint()
+
+
+async def match_all_cases() -> None:
+ match foo():
+ case 1:
+ await foo()
+ case 2:
+ await foo()
+ case _:
+ await foo()
+
+
+async def match_fallback_await_in_guard() -> None:
+ # The case guard is only executed if the pattern matches, so we can mostly treat
+ # it as part of the body, except for a special case for fallback+checkpointing guard.
+ match foo():
+ case 1 if await foo():
+ ...
+ case _ if await foo():
+ ...
+
+
+async def match_checkpoint_guard() -> None:
+ # The above pattern is quite cursed, but this seems fairly reasonable to do.
+ match foo():
+ case 1 if await foo():
+ ...
+ case _:
+ await foo()
+
+
+async def match_not_checkpoint_in_all_guards() -> ( # ASYNC910: 0, "exit", Statement("function definition", lineno)
+ None
+):
+ match foo():
+ case 1:
+ ...
+ case _ if await foo():
+ ...
+ await trio.lowlevel.checkpoint()
diff --git a/tests/autofix_files/async91x_py310.py.diff b/tests/autofix_files/async91x_py310.py.diff
new file mode 100644
index 0000000..47c84f3
--- /dev/null
+++ b/tests/autofix_files/async91x_py310.py.diff
@@ -0,0 +1,31 @@
+---
++++
+@@ x,6 x,7 @@
+ ...
+ case _:
+ await foo()
++ await trio.lowlevel.checkpoint()
+
+
+ async def match_no_fallback() -> ( # ASYNC910: 0, "exit", Statement("function definition", lineno)
+@@ x,6 x,7 @@
+ await foo()
+ case _ if True:
+ await foo()
++ await trio.lowlevel.checkpoint()
+
+
+ async def match_fallback_is_guarded() -> ( # ASYNC910: 0, "exit", Statement("function definition", lineno)
+@@ x,6 x,7 @@
+ await foo()
+ case _ if foo():
+ await foo()
++ await trio.lowlevel.checkpoint()
+
+
+ async def match_all_cases() -> None:
+@@ x,3 x,4 @@
+ ...
+ case _ if await foo():
+ ...
++ await trio.lowlevel.checkpoint()
diff --git a/tests/eval_files/async103_104_py310.py b/tests/eval_files/async103_104_py310.py
new file mode 100644
index 0000000..965fe32
--- /dev/null
+++ b/tests/eval_files/async103_104_py310.py
@@ -0,0 +1,58 @@
+"""Test for ASYNC103/ASYNC104 with structural pattern matching
+
+ASYNC103: no-reraise-cancelled
+ASYNC104: cancelled-not-raised
+"""
+
+# ARG --enable=ASYNC103,ASYNC104
+
+
+def foo() -> Any: ...
+
+
+try:
+ ...
+except BaseException as e: # ASYNC103_trio: 7, "BaseException"
+ match foo():
+ case True:
+ raise e
+ case False:
+ ...
+ case _:
+ raise e
+
+try:
+ ...
+except BaseException: # ASYNC103_trio: 7, "BaseException"
+ match foo():
+ case True:
+ raise
+
+try:
+ ...
+except BaseException: # safe
+ match foo():
+ case True:
+ raise
+ case False:
+ raise
+ case _:
+ raise
+try:
+ ...
+except BaseException: # ASYNC103_trio: 7, "BaseException"
+ match foo():
+ case _ if foo():
+ raise
+try:
+ ...
+except BaseException: # ASYNC103_trio: 7, "BaseException"
+ match foo():
+ case 1:
+ return # ASYNC104: 12
+ case 2:
+ raise
+ case 3:
+ return # ASYNC104: 12
+ case blah:
+ raise
diff --git a/tests/eval_files/async91x_py310.py b/tests/eval_files/async91x_py310.py
new file mode 100644
index 0000000..9367fac
--- /dev/null
+++ b/tests/eval_files/async91x_py310.py
@@ -0,0 +1,86 @@
+# ARG --enable=ASYNC910,ASYNC911,ASYNC913
+# AUTOFIX
+# ASYNCIO_NO_AUTOFIX
+import trio
+
+
+async def foo(): ...
+
+
+async def match_subject() -> None:
+ match await foo():
+ case False:
+ pass
+
+
+async def match_not_all_cases() -> ( # ASYNC910: 0, "exit", Statement("function definition", lineno)
+ None
+):
+ match foo():
+ case 1:
+ ...
+ case _:
+ await foo()
+
+
+async def match_no_fallback() -> ( # ASYNC910: 0, "exit", Statement("function definition", lineno)
+ None
+):
+ match foo():
+ case 1:
+ await foo()
+ case 2:
+ await foo()
+ case _ if True:
+ await foo()
+
+
+async def match_fallback_is_guarded() -> ( # ASYNC910: 0, "exit", Statement("function definition", lineno)
+ None
+):
+ match foo():
+ case 1:
+ await foo()
+ case 2:
+ await foo()
+ case _ if foo():
+ await foo()
+
+
+async def match_all_cases() -> None:
+ match foo():
+ case 1:
+ await foo()
+ case 2:
+ await foo()
+ case _:
+ await foo()
+
+
+async def match_fallback_await_in_guard() -> None:
+ # The case guard is only executed if the pattern matches, so we can mostly treat
+ # it as part of the body, except for a special case for fallback+checkpointing guard.
+ match foo():
+ case 1 if await foo():
+ ...
+ case _ if await foo():
+ ...
+
+
+async def match_checkpoint_guard() -> None:
+ # The above pattern is quite cursed, but this seems fairly reasonable to do.
+ match foo():
+ case 1 if await foo():
+ ...
+ case _:
+ await foo()
+
+
+async def match_not_checkpoint_in_all_guards() -> ( # ASYNC910: 0, "exit", Statement("function definition", lineno)
+ None
+):
+ match foo():
+ case 1:
+ ...
+ case _ if await foo():
+ ...