diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index af76552..7a726bf 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -42,7 +42,7 @@ repos: hooks: - id: mypy # uses py311 syntax, mypy configured for py39 - exclude: tests/(eval|autofix)_files/.*_py311.py + exclude: tests/(eval|autofix)_files/.*_py(310|311).py - repo: https://github.com/RobertCraigie/pyright-python rev: v1.1.396 diff --git a/docs/changelog.rst b/docs/changelog.rst index 682b5db..0cf1f93 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -4,6 +4,10 @@ Changelog `CalVer, YY.month.patch `_ +25.4.1 +====== +- Add match-case (structural pattern matching) support to ASYNC103, 104, 910, 911 & 912. + 25.3.1 ====== - Add except* support to ASYNC102, 103, 104, 120, 910, 911, 912. diff --git a/docs/usage.rst b/docs/usage.rst index 950a01f..4ba6d70 100644 --- a/docs/usage.rst +++ b/docs/usage.rst @@ -33,7 +33,7 @@ adding the following to your ``.pre-commit-config.yaml``: minimum_pre_commit_version: '2.9.0' repos: - repo: https://github.com/python-trio/flake8-async - rev: 25.3.1 + rev: 25.4.1 hooks: - id: flake8-async # args: ["--enable=ASYNC100,ASYNC112", "--disable=", "--autofix=ASYNC"] diff --git a/flake8_async/__init__.py b/flake8_async/__init__.py index f635c79..a1f5546 100644 --- a/flake8_async/__init__.py +++ b/flake8_async/__init__.py @@ -38,7 +38,7 @@ # CalVer: YY.month.patch, e.g. first release of July 2022 == "22.7.1" -__version__ = "25.3.1" +__version__ = "25.4.1" # taken from https://github.com/Zac-HD/shed diff --git a/flake8_async/visitors/visitor103_104.py b/flake8_async/visitors/visitor103_104.py index 502fa73..3e234e7 100644 --- a/flake8_async/visitors/visitor103_104.py +++ b/flake8_async/visitors/visitor103_104.py @@ -199,6 +199,24 @@ def visit_If(self, node: ast.If): # if body didn't raise, or it's unraised after else, set unraise self.unraised = not body_raised or self.unraised + def visit_Match(self, node: ast.Match): # type: ignore[name-defined] + if not self.unraised: + return + all_cases_raise = True + has_fallback = False + for case in node.cases: + # check for "bare pattern", i.e `case varname:` + has_fallback |= ( + case.guard is None + and isinstance(case.pattern, ast.MatchAs) # type: ignore[attr-defined] + and case.pattern.pattern is None + ) + self.visit_nodes(case.body) + all_cases_raise &= not self.unraised + self.unraised = True + + self.unraised = not (all_cases_raise and has_fallback) + # A loop is guaranteed to raise if: # condition always raises, or # else always raises, and diff --git a/flake8_async/visitors/visitor91x.py b/flake8_async/visitors/visitor91x.py index adee79a..1a54a2e 100644 --- a/flake8_async/visitors/visitor91x.py +++ b/flake8_async/visitors/visitor91x.py @@ -198,6 +198,22 @@ def copy(self): ) +@dataclass +class MatchState: + # TryState, LoopState, and MatchState all do fairly similar things. It would be nice + # to harmonize them and share logic. + base_uncheckpointed_statements: set[Statement] = field(default_factory=set) + case_uncheckpointed_statements: set[Statement] = field(default_factory=set) + has_fallback: bool = False + + def copy(self): + return MatchState( + base_uncheckpointed_statements=self.base_uncheckpointed_statements.copy(), + case_uncheckpointed_statements=self.case_uncheckpointed_statements.copy(), + has_fallback=self.has_fallback, + ) + + def checkpoint_statement(library: str) -> cst.SimpleStatementLine: # logic before this should stop code from wanting to insert the non-existing # asyncio.lowlevel.checkpoint @@ -373,6 +389,7 @@ def __init__(self, *args: Any, **kwargs: Any): self.loop_state = LoopState() self.try_state = TryState() + self.match_state = MatchState() # ASYNC100 self.has_checkpoint_stack: list[bool] = [] @@ -894,6 +911,55 @@ def visit_IfExp(self, node: cst.IfExp) -> bool: self.leave_If(node, node) # type: ignore return False # libcst shouldn't visit subnodes again + def leave_Match_subject(self, node: cst.Match) -> None: + # We start the match logic after parsing the subject, instead of visit_Match, + # since the subject is always executed and might checkpoint. + if not self.async_function: + return + self.save_state(node, "match_state", copy=True) + self.match_state = MatchState( + base_uncheckpointed_statements=self.uncheckpointed_statements.copy() + ) + + def visit_MatchCase(self, node: cst.MatchCase) -> None: + # enter each case from the state after parsing the subject + self.uncheckpointed_statements = self.match_state.base_uncheckpointed_statements + + def leave_MatchCase_guard(self, node: cst.MatchCase) -> None: + # `case _:` is no pattern and no guard, which means we know body is executed. + # But we also know that `case _ if :` is guaranteed to execute the guard, + # so for later logic we can treat them the same *if* there's no pattern and that + # guard checkpoints. + if ( + isinstance(node.pattern, cst.MatchAs) + and node.pattern.pattern is None + and (node.guard is None or not self.uncheckpointed_statements) + ): + self.match_state.has_fallback = True + + def leave_MatchCase( + self, original_node: cst.MatchCase, updated_node: cst.MatchCase + ) -> cst.MatchCase: + # collect the state at the end of each case + self.match_state.case_uncheckpointed_statements.update( + self.uncheckpointed_statements + ) + return updated_node + + def leave_Match( + self, original_node: cst.Match, updated_node: cst.Match + ) -> cst.Match: + # leave the Match with the worst-case of all branches + self.uncheckpointed_statements = self.match_state.case_uncheckpointed_statements + # if no fallback, also add the state at entering the match (after parsing subject) + if not self.match_state.has_fallback: + self.uncheckpointed_statements.update( + self.match_state.base_uncheckpointed_statements + ) + + self.restore_state(original_node) + return updated_node + def visit_While(self, node: cst.While | cst.For): self.save_state( node, diff --git a/flake8_async/visitors/visitor_utility.py b/flake8_async/visitors/visitor_utility.py index 4474f21..1e70785 100644 --- a/flake8_async/visitors/visitor_utility.py +++ b/flake8_async/visitors/visitor_utility.py @@ -17,7 +17,7 @@ from re import Match import libcst as cst - from libcst._position import CodeRange + from libcst.metadata import CodeRange @utility_visitor diff --git a/tests/autofix_files/async91x_py310.py b/tests/autofix_files/async91x_py310.py new file mode 100644 index 0000000..2f691d0 --- /dev/null +++ b/tests/autofix_files/async91x_py310.py @@ -0,0 +1,90 @@ +# ARG --enable=ASYNC910,ASYNC911,ASYNC913 +# AUTOFIX +# ASYNCIO_NO_AUTOFIX +import trio + + +async def foo(): ... + + +async def match_subject() -> None: + match await foo(): + case False: + pass + + +async def match_not_all_cases() -> ( # ASYNC910: 0, "exit", Statement("function definition", lineno) + None +): + match foo(): + case 1: + ... + case _: + await foo() + await trio.lowlevel.checkpoint() + + +async def match_no_fallback() -> ( # ASYNC910: 0, "exit", Statement("function definition", lineno) + None +): + match foo(): + case 1: + await foo() + case 2: + await foo() + case _ if True: + await foo() + await trio.lowlevel.checkpoint() + + +async def match_fallback_is_guarded() -> ( # ASYNC910: 0, "exit", Statement("function definition", lineno) + None +): + match foo(): + case 1: + await foo() + case 2: + await foo() + case _ if foo(): + await foo() + await trio.lowlevel.checkpoint() + + +async def match_all_cases() -> None: + match foo(): + case 1: + await foo() + case 2: + await foo() + case _: + await foo() + + +async def match_fallback_await_in_guard() -> None: + # The case guard is only executed if the pattern matches, so we can mostly treat + # it as part of the body, except for a special case for fallback+checkpointing guard. + match foo(): + case 1 if await foo(): + ... + case _ if await foo(): + ... + + +async def match_checkpoint_guard() -> None: + # The above pattern is quite cursed, but this seems fairly reasonable to do. + match foo(): + case 1 if await foo(): + ... + case _: + await foo() + + +async def match_not_checkpoint_in_all_guards() -> ( # ASYNC910: 0, "exit", Statement("function definition", lineno) + None +): + match foo(): + case 1: + ... + case _ if await foo(): + ... + await trio.lowlevel.checkpoint() diff --git a/tests/autofix_files/async91x_py310.py.diff b/tests/autofix_files/async91x_py310.py.diff new file mode 100644 index 0000000..47c84f3 --- /dev/null +++ b/tests/autofix_files/async91x_py310.py.diff @@ -0,0 +1,31 @@ +--- ++++ +@@ x,6 x,7 @@ + ... + case _: + await foo() ++ await trio.lowlevel.checkpoint() + + + async def match_no_fallback() -> ( # ASYNC910: 0, "exit", Statement("function definition", lineno) +@@ x,6 x,7 @@ + await foo() + case _ if True: + await foo() ++ await trio.lowlevel.checkpoint() + + + async def match_fallback_is_guarded() -> ( # ASYNC910: 0, "exit", Statement("function definition", lineno) +@@ x,6 x,7 @@ + await foo() + case _ if foo(): + await foo() ++ await trio.lowlevel.checkpoint() + + + async def match_all_cases() -> None: +@@ x,3 x,4 @@ + ... + case _ if await foo(): + ... ++ await trio.lowlevel.checkpoint() diff --git a/tests/eval_files/async103_104_py310.py b/tests/eval_files/async103_104_py310.py new file mode 100644 index 0000000..965fe32 --- /dev/null +++ b/tests/eval_files/async103_104_py310.py @@ -0,0 +1,58 @@ +"""Test for ASYNC103/ASYNC104 with structural pattern matching + +ASYNC103: no-reraise-cancelled +ASYNC104: cancelled-not-raised +""" + +# ARG --enable=ASYNC103,ASYNC104 + + +def foo() -> Any: ... + + +try: + ... +except BaseException as e: # ASYNC103_trio: 7, "BaseException" + match foo(): + case True: + raise e + case False: + ... + case _: + raise e + +try: + ... +except BaseException: # ASYNC103_trio: 7, "BaseException" + match foo(): + case True: + raise + +try: + ... +except BaseException: # safe + match foo(): + case True: + raise + case False: + raise + case _: + raise +try: + ... +except BaseException: # ASYNC103_trio: 7, "BaseException" + match foo(): + case _ if foo(): + raise +try: + ... +except BaseException: # ASYNC103_trio: 7, "BaseException" + match foo(): + case 1: + return # ASYNC104: 12 + case 2: + raise + case 3: + return # ASYNC104: 12 + case blah: + raise diff --git a/tests/eval_files/async91x_py310.py b/tests/eval_files/async91x_py310.py new file mode 100644 index 0000000..9367fac --- /dev/null +++ b/tests/eval_files/async91x_py310.py @@ -0,0 +1,86 @@ +# ARG --enable=ASYNC910,ASYNC911,ASYNC913 +# AUTOFIX +# ASYNCIO_NO_AUTOFIX +import trio + + +async def foo(): ... + + +async def match_subject() -> None: + match await foo(): + case False: + pass + + +async def match_not_all_cases() -> ( # ASYNC910: 0, "exit", Statement("function definition", lineno) + None +): + match foo(): + case 1: + ... + case _: + await foo() + + +async def match_no_fallback() -> ( # ASYNC910: 0, "exit", Statement("function definition", lineno) + None +): + match foo(): + case 1: + await foo() + case 2: + await foo() + case _ if True: + await foo() + + +async def match_fallback_is_guarded() -> ( # ASYNC910: 0, "exit", Statement("function definition", lineno) + None +): + match foo(): + case 1: + await foo() + case 2: + await foo() + case _ if foo(): + await foo() + + +async def match_all_cases() -> None: + match foo(): + case 1: + await foo() + case 2: + await foo() + case _: + await foo() + + +async def match_fallback_await_in_guard() -> None: + # The case guard is only executed if the pattern matches, so we can mostly treat + # it as part of the body, except for a special case for fallback+checkpointing guard. + match foo(): + case 1 if await foo(): + ... + case _ if await foo(): + ... + + +async def match_checkpoint_guard() -> None: + # The above pattern is quite cursed, but this seems fairly reasonable to do. + match foo(): + case 1 if await foo(): + ... + case _: + await foo() + + +async def match_not_checkpoint_in_all_guards() -> ( # ASYNC910: 0, "exit", Statement("function definition", lineno) + None +): + match foo(): + case 1: + ... + case _ if await foo(): + ...