From 8fe4fc555e7e9d32054695ea6b54671051608763 Mon Sep 17 00:00:00 2001 From: mernmic <92449806+mernmic@users.noreply.github.com> Date: Mon, 1 Apr 2024 07:31:04 -0500 Subject: [PATCH] Extend RedisSettings to include redis Retry Helper settings (#387) * extend RedisSettings retry settings * fix type and settings test * add redis.Retry type * fix test to allow arbitrary types * add testing for retry settings * update tests * granular patch handling * update comment * stop patch when exists * update retry type to asyncio * chore: test cleanup * fix exception type --------- Co-authored-by: Samuel Colvin --- arq/connections.py | 8 +++++ tests/conftest.py | 40 ++++++++++++++++++++++ tests/test_utils.py | 4 +-- tests/test_worker.py | 79 +++++++++++++++++++++++++++++++++++++++++++- 4 files changed, 128 insertions(+), 3 deletions(-) diff --git a/arq/connections.py b/arq/connections.py index ec11b8c7..8aac55ff 100644 --- a/arq/connections.py +++ b/arq/connections.py @@ -9,6 +9,7 @@ from uuid import uuid4 from redis.asyncio import ConnectionPool, Redis +from redis.asyncio.retry import Retry from redis.asyncio.sentinel import Sentinel from redis.exceptions import RedisError, WatchError @@ -47,6 +48,10 @@ class RedisSettings: sentinel: bool = False sentinel_master: str = 'mymaster' + retry_on_timeout: bool = False + retry_on_error: Optional[List[Exception]] = None + retry: Optional[Retry] = None + @classmethod def from_dsn(cls, dsn: str) -> 'RedisSettings': conf = urlparse(dsn) @@ -254,6 +259,9 @@ def pool_factory(*args: Any, **kwargs: Any) -> ArqRedis: ssl_ca_certs=settings.ssl_ca_certs, ssl_ca_data=settings.ssl_ca_data, ssl_check_hostname=settings.ssl_check_hostname, + retry=settings.retry, + retry_on_timeout=settings.retry_on_timeout, + retry_on_error=settings.retry_on_error, ) while True: diff --git a/tests/conftest.py b/tests/conftest.py index 3b050be5..b9332eed 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,6 +5,9 @@ import msgpack import pytest +import redis.exceptions +from redis.asyncio.retry import Retry +from redis.backoff import NoBackoff from arq.connections import ArqRedis, create_pool from arq.worker import Worker @@ -44,6 +47,21 @@ async def arq_redis_msgpack(loop): await redis_.close(close_connection_pool=True) +@pytest.fixture +async def arq_redis_retry(loop): + redis_ = ArqRedis( + host='localhost', + port=6379, + encoding='utf-8', + retry=Retry(backoff=NoBackoff(), retries=3), + retry_on_timeout=True, + retry_on_error=[redis.exceptions.ConnectionError], + ) + await redis_.flushall() + yield redis_ + await redis_.close(close_connection_pool=True) + + @pytest.fixture async def worker(arq_redis): worker_: Worker = None @@ -61,6 +79,28 @@ def create(functions=[], burst=True, poll_delay=0, max_jobs=10, arq_redis=arq_re await worker_.close() +@pytest.fixture +async def worker_retry(arq_redis_retry): + worker_retry_: Worker = None + + def create(functions=[], burst=True, poll_delay=0, max_jobs=10, arq_redis=arq_redis_retry, **kwargs): + nonlocal worker_retry_ + worker_retry_ = Worker( + functions=functions, + redis_pool=arq_redis, + burst=burst, + poll_delay=poll_delay, + max_jobs=max_jobs, + **kwargs, + ) + return worker_retry_ + + yield create + + if worker_retry_: + await worker_retry_.close() + + @pytest.fixture(name='create_pool') async def fix_create_pool(loop): pools = [] diff --git a/tests/test_utils.py b/tests/test_utils.py index 997c137d..c1156db4 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -21,7 +21,7 @@ def test_settings_changed(): "RedisSettings(host='localhost', port=123, unix_socket_path=None, database=0, username=None, password=None, " "ssl=False, ssl_keyfile=None, ssl_certfile=None, ssl_cert_reqs='required', ssl_ca_certs=None, " 'ssl_ca_data=None, ssl_check_hostname=False, conn_timeout=1, conn_retries=5, conn_retry_delay=1, ' - "sentinel=False, sentinel_master='mymaster')" + "sentinel=False, sentinel_master='mymaster', retry_on_timeout=False, retry_on_error=None, retry=None)" ) == str(settings) @@ -109,7 +109,7 @@ def test_typing(): def test_redis_settings_validation(): - class Settings(BaseModel): + class Settings(BaseModel, arbitrary_types_allowed=True): redis_settings: RedisSettings @field_validator('redis_settings', mode='before') diff --git a/tests/test_worker.py b/tests/test_worker.py index 192a0d87..a25f0f1d 100644 --- a/tests/test_worker.py +++ b/tests/test_worker.py @@ -5,10 +5,11 @@ import signal import sys from datetime import datetime, timedelta, timezone -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import msgpack import pytest +import redis.exceptions from arq.connections import ArqRedis, RedisSettings from arq.constants import abort_jobs_ss, default_queue_name, expires_extra_ms, health_check_key_suffix, job_key_prefix @@ -1024,3 +1025,79 @@ async def test_worker_timezone_defaults_to_system_timezone(worker): worker = worker(functions=[func(foobar)]) assert worker.timezone is not None assert worker.timezone == datetime.now().astimezone().tzinfo + + +@pytest.mark.parametrize( + 'exception_thrown', + [ + redis.exceptions.ConnectionError('Error while reading from host'), + redis.exceptions.TimeoutError('Timeout reading from host'), + ], +) +async def test_worker_retry(mocker, worker_retry, exception_thrown): + # Testing redis exceptions, with retry settings specified + worker = worker_retry(functions=[func(foobar)]) + + # patch db read_response to mimic connection exceptions + p = patch.object(worker.pool.connection_pool.connection_class, 'read_response', side_effect=exception_thrown) + + # baseline + await worker.main() + await worker._poll_iteration() + + # spy method handling call_with_retry failure + spy = mocker.spy(worker.pool, '_disconnect_raise') + + try: + # start patch + p.start() + + # assert exception thrown + with pytest.raises(type(exception_thrown)): + await worker._poll_iteration() + + # assert retry counts and no exception thrown during '_disconnect_raise' + assert spy.call_count == 4 # retries setting + 1 + assert spy.spy_exception is None + + finally: + # stop patch to allow worker cleanup + p.stop() + + +@pytest.mark.parametrize( + 'exception_thrown', + [ + redis.exceptions.ConnectionError('Error while reading from host'), + redis.exceptions.TimeoutError('Timeout reading from host'), + ], +) +async def test_worker_crash(mocker, worker, exception_thrown): + # Testing redis exceptions, no retry settings specified + worker = worker(functions=[func(foobar)]) + + # patch db read_response to mimic connection exceptions + p = patch.object(worker.pool.connection_pool.connection_class, 'read_response', side_effect=exception_thrown) + + # baseline + await worker.main() + await worker._poll_iteration() + + # spy method handling call_with_retry failure + spy = mocker.spy(worker.pool, '_disconnect_raise') + + try: + # start patch + p.start() + + # assert exception thrown + with pytest.raises(type(exception_thrown)): + await worker._poll_iteration() + + # assert no retry counts and exception thrown during '_disconnect_raise' + assert spy.call_count == 1 + assert spy.spy_exception == exception_thrown + + finally: + # stop patch to allow worker cleanup + p.stop()