diff --git a/README.rst b/README.rst index 68df7e09..d443589c 100644 --- a/README.rst +++ b/README.rst @@ -247,6 +247,12 @@ automatically to *async* test functions. .. |pytestmark| replace:: ``pytestmark`` .. _pytestmark: http://doc.pytest.org/en/latest/example/markers.html#marking-whole-classes-or-modules +Timeout protection +------------------ + +Sometime tests can work much slowly than expected or even hang. + + Note about unittest ------------------- diff --git a/pytest_asyncio/_runner.py b/pytest_asyncio/_runner.py new file mode 100644 index 00000000..114ca8a0 --- /dev/null +++ b/pytest_asyncio/_runner.py @@ -0,0 +1,89 @@ +import asyncio +from typing import Awaitable, TypeVar, Union + +import pytest + +_R = TypeVar("_R") + + +class Runner: + def __init__(self, loop: asyncio.AbstractEventLoop) -> None: + self._loop = loop + self._task = None + self._timeout_hande = None + self._timeout_reached = False + + def run(self, coro: Awaitable[_R]) -> _R: + return self._loop.run_until_complete(self._async_wrapper(coro)) + + def run_test(self, coro: Awaitable[None]) -> None: + task = asyncio.ensure_future(coro, loop=self._loop) + try: + self.run(task) + except BaseException: + # run_until_complete doesn't get the result from exceptions + # that are not subclasses of `Exception`. Consume all + # exceptions to prevent asyncio's warning from logging. + if task.done() and not task.cancelled(): + task.exception() + raise + + def set_timer(self, timeout: Union[int, float]) -> None: + if self._timeout_hande is not None: + self._timeout_hande.cancel() + self._timeout_reached = False + self._timeout_hande = self._loop.call_later(timeout, self._on_timeout) + + def cancel_timer(self) -> None: + if self._timeout_hande is not None: + self._timeout_hande.cancel() + self._timeout_reached = False + self._timeout_hande = None + + async def _async_wrapper(self, coro: Awaitable[_R]) -> _R: + if self._timeout_reached: + # timeout can happen in a gap between tasks execution, + # it should be handled anyway + raise asyncio.TimeoutError() + task = asyncio.current_task() + assert self._task is None + self._task = task + try: + return await coro + except asyncio.CancelledError: + if self._timeout_reached: + raise asyncio.TimeoutError() + finally: + self._task = None + + def _on_timeout(self) -> None: + # the plugin is optional, + # pytest-asyncio should work fine without pytest-timeout + # That's why the lazy import is required here + import pytest_timeout + + if pytest_timeout.is_debugging(): + return + self._timeout_reached = True + if self._task is not None: + self._task.cancel() + + +def _install_runner(item: pytest.Item, loop: asyncio.AbstractEventLoop) -> None: + item._pytest_asyncio_runner = Runner(loop) + + +def _get_runner(item: pytest.Item) -> Runner: + runner = getattr(item, "_pytest_asyncio_runner", None) + if runner is not None: + return runner + else: + parent = item.parent + if parent is not None: + parent_runner = _get_runner(parent) + runner = item._pytest_asyncio_runner = Runner(parent_runner._loop) + return runner + else: # pragma: no cover + # can happen only if the plugin is broken and no event_loop fixture + # dependency was installed. + raise RuntimeError(f"There is no event_loop associated with {item}") diff --git a/pytest_asyncio/plugin.py b/pytest_asyncio/plugin.py index d3922ade..6639c515 100644 --- a/pytest_asyncio/plugin.py +++ b/pytest_asyncio/plugin.py @@ -8,6 +8,7 @@ import sys import warnings from typing import ( + TYPE_CHECKING, Any, AsyncIterator, Awaitable, @@ -20,17 +21,31 @@ Set, TypeVar, Union, - cast, overload, ) import pytest +from pluggy import PluginValidationError + +from ._runner import Runner, _get_runner, _install_runner if sys.version_info >= (3, 8): from typing import Literal else: from typing_extensions import Literal + +if TYPE_CHECKING: + from pytest_timeout import Settings + + +try: + pass + + HAS_TIMEOUT_SUPPORT = True +except ImportError: + HAS_TIMEOUT_SUPPORT = False + _R = TypeVar("_R") _ScopeName = Literal["session", "package", "module", "class", "function"] @@ -279,6 +294,8 @@ def pytest_fixture_setup( if fixturedef.argname == "event_loop": outcome = yield loop = outcome.get_result() + print("\ninstall runner", request.node, id(request.node), id(loop)) + _install_runner(request.node, loop) policy = asyncio.get_event_loop_policy() try: old_loop = policy.get_event_loop() @@ -331,11 +348,10 @@ def pytest_fixture_setup( fixture_stripper.add(FixtureStripper.EVENT_LOOP) def wrapper(*args, **kwargs): - loop = fixture_stripper.get_and_strip_from( - FixtureStripper.EVENT_LOOP, kwargs - ) + fixture_stripper.get_and_strip_from(FixtureStripper.EVENT_LOOP, kwargs) gen_obj = generator(*args, **kwargs) + runner = _get_runner(request.node) async def setup(): res = await gen_obj.__anext__() @@ -354,9 +370,9 @@ async def async_finalizer(): msg += "Yield only once." raise ValueError(msg) - loop.run_until_complete(async_finalizer()) + runner.run(async_finalizer()) - result = loop.run_until_complete(setup()) + result = runner.run(setup()) request.addfinalizer(finalizer) return result @@ -368,15 +384,14 @@ async def async_finalizer(): fixture_stripper.add(FixtureStripper.EVENT_LOOP) def wrapper(*args, **kwargs): - loop = fixture_stripper.get_and_strip_from( - FixtureStripper.EVENT_LOOP, kwargs - ) + fixture_stripper.get_and_strip_from(FixtureStripper.EVENT_LOOP, kwargs) + runner = _get_runner(request.node) async def setup(): res = await coro(*args, **kwargs) return res - return loop.run_until_complete(setup()) + return runner.run(setup()) fixturedef.func = wrapper yield @@ -391,17 +406,16 @@ def pytest_pyfunc_call(pyfuncitem: pytest.Function) -> Optional[object]: where the wrapped test coroutine is executed in an event loop. """ if "asyncio" in pyfuncitem.keywords: - funcargs: Dict[str, object] = pyfuncitem.funcargs # type: ignore[name-defined] - loop = cast(asyncio.AbstractEventLoop, funcargs["event_loop"]) + runner = _get_runner(pyfuncitem) if _is_hypothesis_test(pyfuncitem.obj): pyfuncitem.obj.hypothesis.inner_test = wrap_in_sync( pyfuncitem.obj.hypothesis.inner_test, - _loop=loop, + __runner=runner, ) else: pyfuncitem.obj = wrap_in_sync( pyfuncitem.obj, - _loop=loop, + __runner=runner, ) yield @@ -410,7 +424,7 @@ def _is_hypothesis_test(function: Any) -> bool: return getattr(function, "is_hypothesis_test", False) -def wrap_in_sync(func: Callable[..., Awaitable[Any]], _loop: asyncio.AbstractEventLoop): +def wrap_in_sync(func: Callable[..., Awaitable[Any]], __runner: Runner): """Return a sync wrapper around an async function executing it in the current event loop.""" @@ -425,16 +439,11 @@ def wrap_in_sync(func: Callable[..., Awaitable[Any]], _loop: asyncio.AbstractEve def inner(**kwargs): coro = func(**kwargs) if coro is not None: - task = asyncio.ensure_future(coro, loop=_loop) - try: - _loop.run_until_complete(task) - except BaseException: - # run_until_complete doesn't get the result from exceptions - # that are not subclasses of `Exception`. Consume all - # exceptions to prevent asyncio's warning from logging. - if task.done() and not task.cancelled(): - task.exception() - raise + # FIXME: add a warning if non-async function is marked + # with @pytest.mark.async. + # To automark please use async_mode = auto instead + # Maybe do nothing in legacy mode + __runner.run_test(coro) inner._raw_test_func = func # type: ignore[attr-defined] return inner @@ -459,6 +468,34 @@ def pytest_runtest_setup(item: pytest.Item) -> None: ) +if HAS_TIMEOUT_SUPPORT: + # Install hooks only if pytest-timeout is installed + try: + + @pytest.mark.tryfirst + def pytest_timeout_set_timer( + item: pytest.Item, settings: "Settings" + ) -> Optional[object]: + if item.get_closest_marker("asyncio") is None: + return None + runner = _get_runner(item) + runner.set_timer(settings.timeout) + return True + + @pytest.mark.tryfirst + def pytest_timeout_cancel_timer(item: pytest.Item) -> Optional[object]: + if item.get_closest_marker("asyncio") is None: + return None + runner = _get_runner(item) + runner.cancel_timer() + return True + + except PluginValidationError: # pragma: no cover + raise RuntimeError( + "pytest-asyncio requires pytest-timeout>=2.1.0, please upgrade" + ) + + @pytest.fixture def event_loop(request: "pytest.FixtureRequest") -> Iterator[asyncio.AbstractEventLoop]: """Create an instance of the default event loop for each test case.""" diff --git a/setup.cfg b/setup.cfg index 952a1dbe..c22b6443 100644 --- a/setup.cfg +++ b/setup.cfg @@ -46,6 +46,7 @@ testing = coverage==6.2 hypothesis >= 5.7.1 flaky >= 3.5.0 + pytest-timeout == 2.1.0 mypy == 0.931 [options.entry_points] diff --git a/tests/test_timeout.py b/tests/test_timeout.py new file mode 100644 index 00000000..9754cda7 --- /dev/null +++ b/tests/test_timeout.py @@ -0,0 +1,64 @@ +from textwrap import dedent + +pytest_plugins = "pytester" + + +def test_timeout_ok(pytester): + pytester.makepyfile( + dedent( + """\ + import asyncio + import pytest + + pytest_plugins = ['pytest_asyncio'] + + @pytest.mark.xfail(strict=True, raises=asyncio.TimeoutError) + @pytest.mark.timeout(0.01) + @pytest.mark.asyncio + async def test_a(): + await asyncio.sleep(1) + """ + ) + ) + result = pytester.runpytest("--asyncio-mode=strict") + result.assert_outcomes(xfailed=1) + + +def test_timeout_disabled(pytester): + pytester.makepyfile( + dedent( + """\ + import asyncio + import pytest + + pytest_plugins = ['pytest_asyncio'] + + @pytest.mark.timeout(0) + @pytest.mark.asyncio + async def test_a(): + await asyncio.sleep(0.01) + """ + ) + ) + result = pytester.runpytest("--asyncio-mode=strict") + result.assert_outcomes(passed=1) + + +def test_timeout_cmdline(pytester): + pytester.makepyfile( + dedent( + """\ + import asyncio + import pytest + + pytest_plugins = ['pytest_asyncio'] + + @pytest.mark.asyncio + @pytest.mark.xfail(strict=True, raises=asyncio.TimeoutError) + async def test_a(): + await asyncio.sleep(1) + """ + ) + ) + result = pytester.runpytest("--timeout=0.01", "--asyncio-mode=strict") + result.assert_outcomes(xfailed=1)