diff --git a/.bumpversion.cfg b/.bumpversion.cfg index e39827d..e9785bf 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.7.3 +current_version = 0.8.0 commit = True tag = True diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 7b2a23c..0925da7 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -9,9 +9,10 @@ jobs: strategy: matrix: python-version: - - 3.7 - - 3.8 - 3.9 + - "3.10" + - "3.11" + - "3.12" services: redis: diff --git a/aioredlock/__init__.py b/aioredlock/__init__.py index 843671f..2d3b8f9 100644 --- a/aioredlock/__init__.py +++ b/aioredlock/__init__.py @@ -4,10 +4,10 @@ from aioredlock.sentinel import Sentinel __all__ = ( - 'Aioredlock', - 'Lock', - 'LockError', - 'LockAcquiringError', - 'LockRuntimeError', - 'Sentinel' + "Aioredlock", + "Lock", + "LockError", + "LockAcquiringError", + "LockRuntimeError", + "Sentinel", ) diff --git a/aioredlock/algorithm.py b/aioredlock/algorithm.py index b9ae01c..083dbaa 100644 --- a/aioredlock/algorithm.py +++ b/aioredlock/algorithm.py @@ -57,8 +57,7 @@ def log(self): return logging.getLogger(__name__) async def _set_lock(self, resource, lock_identifier, lease_time): - - error = RuntimeError('Retry count less then one') + error = RuntimeError("Retry count less then one") # Proportional drift time to the length of the lock # See https://redis.io/topics/distlock#is-the-algorithm-asynchronous for more info @@ -67,22 +66,23 @@ async def _set_lock(self, resource, lock_identifier, lease_time): try: # global try/except to catch CancelledError for n in range(self.retry_count): - self.log.debug('Acquiring lock "%s" try %d/%d', - resource, n + 1, self.retry_count) + self.log.debug( + 'Acquiring lock "%s" try %d/%d', resource, n + 1, self.retry_count + ) if n != 0: - delay = random.uniform(self.retry_delay_min, - self.retry_delay_max) + delay = random.uniform(self.retry_delay_min, self.retry_delay_max) await asyncio.sleep(delay) try: - elapsed_time = await self.redis.set_lock(resource, lock_identifier, lease_time) + elapsed_time = await self.redis.set_lock( + resource, lock_identifier, lease_time + ) except LockError as exc: error = exc continue if lease_time - elapsed_time - drift <= 0: - error = LockError('Lock timeout') - self.log.debug('Timeout in acquiring the lock "%s"', - resource) + error = LockError("Lock timeout") + self.log.debug('Timeout in acquiring the lock "%s"', resource) continue error = None @@ -114,8 +114,7 @@ async def _auto_extend(self, lock): try: await self.extend(lock) except Exception: - self.log.debug('Error in extending the lock "%s"', - lock.resource) + self.log.debug('Error in extending the lock "%s"', lock.resource) self._watchdogs[lock.resource] = asyncio.ensure_future(self._auto_extend(lock)) @@ -143,7 +142,9 @@ async def lock(self, resource, lock_timeout=None, lock_identifier=None): lock = Lock(self, resource, lock_identifier, lock_timeout, valid=True) if lock_timeout is None: - self._watchdogs[lock.resource] = asyncio.ensure_future(self._auto_extend(lock)) + self._watchdogs[lock.resource] = asyncio.ensure_future( + self._auto_extend(lock) + ) self._locks[resource] = lock return lock @@ -161,7 +162,7 @@ async def extend(self, lock, lock_timeout=None): self.log.debug('Extending lock "%s"', lock.resource) if not lock.valid: - raise RuntimeError('Lock is not valid') + raise RuntimeError("Lock is not valid") if lock_timeout is not None and lock_timeout <= 0: raise ValueError("Lock timeout must be greater than 0 seconds.") @@ -218,8 +219,9 @@ async def is_locked(self, resource_or_lock): resource = resource_or_lock else: raise TypeError( - 'Argument should be ether aioredlock.Lock instance or string, ' - '%s is given.', type(resource_or_lock) + "Argument should be ether aioredlock.Lock instance or string, " + "%s is given.", + type(resource_or_lock), ) return await self.redis.is_locked(resource) @@ -228,7 +230,7 @@ async def destroy(self): """ cancel all _watchdogs, unlock all locks and Clear all the redis connections """ - self.log.debug('Destroying %s', repr(self)) + self.log.debug("Destroying %s", self) for resource, lock in self._locks.copy().items(): if lock.valid: diff --git a/aioredlock/lock.py b/aioredlock/lock.py index 58f5b09..8c3d811 100644 --- a/aioredlock/lock.py +++ b/aioredlock/lock.py @@ -3,7 +3,6 @@ @attr.s class Lock: - lock_manager = attr.ib() resource = attr.ib() id = attr.ib() diff --git a/aioredlock/redis.py b/aioredlock/redis.py index 27958c3..5d23460 100644 --- a/aioredlock/redis.py +++ b/aioredlock/redis.py @@ -2,10 +2,12 @@ import logging import re import time -from distutils.version import StrictVersion +import functools from itertools import groupby +from typing import Optional +from redis.asyncio import ConnectionPool, Redis as AIORedis -import aioredis +from redis.exceptions import RedisError, ResponseError from aioredlock.errors import LockError, LockAcquiringError, LockRuntimeError from aioredlock.sentinel import Sentinel @@ -29,7 +31,6 @@ def raise_error(results, default_message): class Instance: - # KEYS[1] - lock resource key # ARGS[1] - lock unique identifier # ARGS[2] - expiration time in milliseconds @@ -64,28 +65,39 @@ class Instance: return redis.error_reply('ERROR') end""" + @staticmethod + def ensure_connection(func): + @functools.wraps(func) + async def wrapper(self, *args, **kwargs): + # Ensure connection is established before proceeding + if self._client is None: + await self.connect() + # Call the original function + return await func(self, *args, **kwargs) + + return wrapper + def __init__(self, connection): """ Redis instance constructor - Constructor takes single argument - a redis host address The address can be one of the following: - * a dict - {'host': 'localhost', 'port': 6379, - 'db': 0, 'password': 'pass'} - all keys except host and port will be passed as kwargs to - the aioredis.create_redis_pool(); - * an aioredlock.redis.Sentinel object; + * a dict - { + 'host': 'localhost', + 'port': 6379, + 'db': 0, + 'password': 'pass' + } + in this case redis.asyncio.Redis will be used; * a Redis URI - "redis://host:6379/0?encoding=utf-8"; - * a (host, port) tuple - ('localhost', 6379); - * or a unix domain socket path string - "/path/to/redis.sock". + * or a unix domain socket path string - "unix://path/to/redis.sock". * a redis connection pool. - :param connection: redis host address (dict, tuple or str) """ self.connection = connection - self._pool = None + self._client: Optional[AIORedis] = None self._lock = asyncio.Lock() self.set_lock_script_sha1 = None @@ -101,86 +113,84 @@ def __repr__(self): return "<%s(connection='%s'>" % (self.__class__.__name__, connection_details) @staticmethod - async def _create_redis_pool(*args, **kwargs): - """ - Adapter to support both aioredis-0.3.0 and aioredis-1.0.0 - For aioredis-1.0.0 and later calls: - aioredis.create_redis_pool(*args, **kwargs) - For aioredis-0.3.0 calls: - aioredis.create_pool(*args, **kwargs) - """ + async def _create_redis(*args, **kwargs) -> AIORedis: + if args[0] is None: + return AIORedis(**kwargs) - if StrictVersion(aioredis.__version__) >= StrictVersion('1.0.0'): # pragma no cover - return await aioredis.create_redis_pool(*args, **kwargs) - else: # pragma no cover - return await aioredis.create_pool(*args, **kwargs) + return AIORedis.from_url(*args, **kwargs) - async def _register_scripts(self, redis): + async def _register_scripts(self): tasks = [] for script in [ - self.SET_LOCK_SCRIPT, - self.UNSET_LOCK_SCRIPT, - self.GET_LOCK_TTL_SCRIPT, + self.SET_LOCK_SCRIPT, + self.UNSET_LOCK_SCRIPT, + self.GET_LOCK_TTL_SCRIPT, ]: - script = re.sub(r'^\s+', '', script, flags=re.M).strip() - tasks.append(redis.script_load(script)) + script = re.sub(r"^\s+", "", script, flags=re.M).strip() + tasks.append(self._client.script_load(script)) + results = await asyncio.gather(*tasks) ( self.set_lock_script_sha1, self.unset_lock_script_sha1, self.get_lock_ttl_script_sha1, - ) = (r.decode() if isinstance(r, bytes) else r for r in await asyncio.gather(*tasks)) + ) = (r if isinstance(r, str) else r.decode("utf-8") for r in results) - async def connect(self): + async def connect(self) -> AIORedis: """ - Get an connection for the self instance + Get a connection for the self instance """ - address, redis_kwargs = (), {} + address, redis_kwargs = None, {} if isinstance(self.connection, Sentinel): - self._pool = await self.connection.get_master() + self._client = self.connection.get_master() elif isinstance(self.connection, dict): - # a dict like {'host': 'localhost', 'port': 6379, - # 'db': 0, 'password': 'pass'} - kwargs = self.connection.copy() - address = ( - kwargs.pop('host', 'localhost'), - kwargs.pop('port', 6379) - ) - redis_kwargs = kwargs - elif isinstance(self.connection, aioredis.Redis): - self._pool = self.connection + assert "host" in self.connection, "Host is not specified" + redis_kwargs = self.connection + elif isinstance(self.connection, ConnectionPool): + conn_kwargs = self.connection.connection_kwargs + url = f"redis://{conn_kwargs['host']}:{conn_kwargs['port']}/{conn_kwargs['db']}?encoding=utf-8" + self._client = AIORedis.from_url(url) + elif isinstance(self.connection, tuple): + # a tuple ('localhost', 6379, 0, 'pass'), db and password are optional + redis_kwargs = { + "host": self.connection[0], + "port": self.connection[1], + "db": self.connection[2] if len(self.connection) > 2 else 0, + "password": self.connection[3] if len(self.connection) > 3 else None, + } else: - # a tuple or list ('localhost', 6379) # a string "redis://host:6379/0?encoding=utf-8" or - # a unix domain socket path "/path/to/redis.sock" + # a unix domain socket path "unix:///path/to/redis.sock" address = self.connection - if self._pool is None: - if 'minsize' not in redis_kwargs: - redis_kwargs['minsize'] = 1 - if 'maxsize' not in redis_kwargs: - redis_kwargs['maxsize'] = 100 + if self._client is None: + redis_kwargs["max_connections"] = redis_kwargs.get("max_connections", 100) async with self._lock: - if self._pool is None: - self.log.debug('Connecting %s', repr(self)) - self._pool = await self._create_redis_pool(address, **redis_kwargs) + if self._client is None: + self.log.debug("Connecting %s", repr(self)) + self._client = await self._create_redis(address, **redis_kwargs) if self.set_lock_script_sha1 is None or self.unset_lock_script_sha1 is None: - with await self._pool as redis: - await self._register_scripts(redis) + await self._register_scripts() - return await self._pool + return self._client - async def close(self): + async def aclose(self): """ Closes connection and resets pool """ - if self._pool is not None and not isinstance(self.connection, aioredis.Redis): - self._pool.close() - await self._pool.wait_closed() - self._pool = None - - async def set_lock(self, resource, lock_identifier, lock_timeout, register_scripts=False): + if self._client is not None and not isinstance(self.connection, AIORedis): + try: + await self._client.aclose() + except AttributeError: + await self._client.close() + + self._client = None + + @ensure_connection + async def set_lock( + self, resource, lock_identifier, lock_timeout, register_scripts=False + ): """ Lock this instance and set lock expiration time to lock_timeout :param resource: redis key to set @@ -192,35 +202,33 @@ async def set_lock(self, resource, lock_identifier, lock_timeout, register_scrip lock_timeout_ms = int(lock_timeout * 1000) try: - with await self.connect() as redis: - if register_scripts is True: - await self._register_scripts(redis) - await redis.evalsha( - self.set_lock_script_sha1, - keys=[resource], - args=[lock_identifier, lock_timeout_ms] + if register_scripts is True: + await self._register_scripts() + await self._client.evalsha( + self.set_lock_script_sha1, 1, resource, lock_identifier, lock_timeout_ms + ) + except ResponseError as exc: # script fault + if exc.args[0].startswith("NOSCRIPT"): + return await self.set_lock( + resource, lock_identifier, lock_timeout, register_scripts=True ) - except aioredis.errors.ReplyError as exc: # script fault - if exc.args[0].startswith('NOSCRIPT'): - return await self.set_lock(resource, lock_identifier, lock_timeout, register_scripts=True) - self.log.debug('Can not set lock "%s" on %s', - resource, repr(self)) - raise LockAcquiringError('Can not set lock') from exc - except (aioredis.errors.RedisError, OSError) as exc: - self.log.error('Can not set lock "%s" on %s: %s', - resource, repr(self), repr(exc)) - raise LockRuntimeError('Can not set lock') from exc + self.log.debug('Can not set lock "%s" on %s', resource, repr(self)) + raise LockAcquiringError("Can not set lock") from exc + except (RedisError, OSError) as exc: + self.log.error( + 'Can not set lock "%s" on %s: %s', resource, repr(self), repr(exc) + ) + raise LockRuntimeError("Can not set lock") from exc except asyncio.CancelledError: - self.log.debug('Lock "%s" is cancelled on %s', - resource, repr(self)) + self.log.debug('Lock "%s" is cancelled on %s', resource, repr(self)) raise except Exception: - self.log.exception('Can not set lock "%s" on %s', - resource, repr(self)) + self.log.exception('Can not set lock "%s" on %s', resource, repr(self)) raise else: self.log.debug('Lock "%s" is set on %s', resource, repr(self)) + @ensure_connection async def get_lock_ttl(self, resource, lock_identifier, register_scripts=False): """ Fetch this instance and set lock expiration time to lock_timeout @@ -230,36 +238,35 @@ async def get_lock_ttl(self, resource, lock_identifier, register_scripts=False): :raises: LockError if lock is not available """ try: - with await self.connect() as redis: - if register_scripts is True: - await self._register_scripts(redis) - ttl = await redis.evalsha( - self.get_lock_ttl_script_sha1, - keys=[resource], - args=[lock_identifier] + if register_scripts is True: + await self._register_scripts() + + ttl = await self._client.evalsha( + self.get_lock_ttl_script_sha1, 1, resource, lock_identifier + ) + except ResponseError as exc: # script fault + if exc.args[0].startswith("NOSCRIPT"): + return await self.get_lock_ttl( + resource, lock_identifier, register_scripts=True ) - except aioredis.errors.ReplyError as exc: # script fault - if exc.args[0].startswith('NOSCRIPT'): - return await self.get_lock_ttl(resource, lock_identifier, register_scripts=True) - self.log.debug('Can not get lock "%s" on %s', - resource, repr(self)) - raise LockAcquiringError('Can not get lock') from exc - except (aioredis.errors.RedisError, OSError) as exc: - self.log.error('Can not get lock "%s" on %s: %s', - resource, repr(self), repr(exc)) - raise LockRuntimeError('Can not get lock') from exc + self.log.debug('Can not get lock "%s" on %s', resource, repr(self)) + raise LockAcquiringError("Can not get lock") from exc + except (RedisError, OSError) as exc: + self.log.error( + 'Can not get lock "%s" on %s: %s', resource, repr(self), repr(exc) + ) + raise LockRuntimeError("Can not get lock") from exc except asyncio.CancelledError: - self.log.debug('Lock "%s" is cancelled on %s', - resource, repr(self)) + self.log.debug('Lock "%s" is cancelled on %s', resource, repr(self)) raise except Exception: - self.log.exception('Can not get lock "%s" on %s', - resource, repr(self)) + self.log.exception('Can not get lock "%s" on %s', resource, repr(self)) raise else: self.log.debug('Lock "%s" with TTL %s is on %s', resource, ttl, repr(self)) return ttl + @ensure_connection async def unset_lock(self, resource, lock_identifier, register_scripts=False): """ Unlock this instance @@ -268,35 +275,33 @@ async def unset_lock(self, resource, lock_identifier, register_scripts=False): :raises: LockError if the lock resource acquired with different lock_identifier """ try: - with await self.connect() as redis: - if register_scripts is True: - await self._register_scripts(redis) - await redis.evalsha( - self.unset_lock_script_sha1, - keys=[resource], - args=[lock_identifier] + if register_scripts is True: + await self._register_scripts() + await self._client.evalsha( + self.unset_lock_script_sha1, 1, resource, lock_identifier + ) + except ResponseError as exc: # script fault + if exc.args[0].startswith("NOSCRIPT"): + return await self.unset_lock( + resource, lock_identifier, register_scripts=True ) - except aioredis.errors.ReplyError as exc: # script fault - if exc.args[0].startswith('NOSCRIPT'): - return await self.unset_lock(resource, lock_identifier, register_scripts=True) - self.log.debug('Can not unset lock "%s" on %s', - resource, repr(self)) - raise LockAcquiringError('Can not unset lock') from exc - except (aioredis.errors.RedisError, OSError) as exc: - self.log.error('Can not unset lock "%s" on %s: %s', - resource, repr(self), repr(exc)) - raise LockRuntimeError('Can not unset lock') from exc + self.log.debug('Can not unset lock "%s" on %s', resource, repr(self)) + raise LockAcquiringError("Can not unset lock") from exc + except (RedisError, OSError) as exc: + self.log.error( + 'Can not unset lock "%s" on %s: %s', resource, repr(self), repr(exc) + ) + raise LockRuntimeError("Can not unset lock") from exc except asyncio.CancelledError: - self.log.debug('Lock "%s" unset is cancelled on %s', - resource, repr(self)) + self.log.debug('Lock "%s" unset is cancelled on %s', resource, repr(self)) raise except Exception: - self.log.exception('Can not unset lock "%s" on %s', - resource, repr(self)) + self.log.exception('Can not unset lock "%s" on %s', resource, repr(self)) raise else: self.log.debug('Lock "%s" is unset on %s', resource, repr(self)) + @ensure_connection async def is_locked(self, resource): """ Checks if the resource is locked by any redlock instance. @@ -305,8 +310,7 @@ async def is_locked(self, resource): :returns: True if locked else False """ - with await self.connect() as redis: - lock_identifier = await redis.get(resource) + lock_identifier = await self._client.get(resource) if lock_identifier: return True else: @@ -314,9 +318,7 @@ async def is_locked(self, resource): class Redis: - def __init__(self, redis_connections): - self.instances = [] for connection in redis_connections: self.instances.append(Instance(connection)) @@ -339,17 +341,25 @@ async def set_lock(self, resource, lock_identifier, lock_timeout=10.0): """ start_time = time.monotonic() - successes = await asyncio.gather(*[ - i.set_lock(resource, lock_identifier, lock_timeout) for - i in self.instances - ], return_exceptions=True) + successes = await asyncio.gather( + *[ + i.set_lock(resource, lock_identifier, lock_timeout) + for i in self.instances + ], + return_exceptions=True, + ) successful_sets = sum(s is None for s in successes) elapsed_time = time.monotonic() - start_time locked = successful_sets >= int(len(self.instances) / 2) + 1 - self.log.debug('Lock "%s" is set on %d/%d instances in %s seconds', - resource, successful_sets, len(self.instances), elapsed_time) + self.log.debug( + 'Lock "%s" is set on %d/%d instances in %s seconds', + resource, + successful_sets, + len(self.instances), + elapsed_time, + ) if not locked: raise_error(successes, 'Can not acquire the lock "%s"' % resource) @@ -367,18 +377,23 @@ async def get_lock_ttl(self, resource, lock_identifier=None): been set to at least (N/2 + 1) instances """ start_time = time.monotonic() - successes = await asyncio.gather(*[ - i.get_lock_ttl(resource, lock_identifier) for - i in self.instances - ], return_exceptions=True) - successful_list = [s for s in successes if not isinstance(s, Exception)] + successes = await asyncio.gather( + *[i.get_lock_ttl(resource, lock_identifier) for i in self.instances], + return_exceptions=True, + ) + successful_list = [s for s in successes if not isinstance(s, BaseException)] # should check if all the value are approx. the same with math.isclose... locked = len(successful_list) >= int(len(self.instances) / 2) + 1 success = all_equal(successful_list) and locked elapsed_time = time.monotonic() - start_time - self.log.debug('Lock "%s" is set on %d/%d instances in %s seconds', - resource, len(successful_list), len(self.instances), elapsed_time) + self.log.debug( + 'Lock "%s" is set on %d/%d instances in %s seconds', + resource, + len(successful_list), + len(self.instances), + elapsed_time, + ) if not success: raise_error(successes, 'Could not fetch the TTL for lock "%s"' % resource) @@ -397,24 +412,29 @@ async def unset_lock(self, resource, lock_identifier): """ if not self.instances: - return .0 + return 0.0 start_time = time.monotonic() - successes = await asyncio.gather(*[ - i.unset_lock(resource, lock_identifier) for - i in self.instances - ], return_exceptions=True) + successes = await asyncio.gather( + *[i.unset_lock(resource, lock_identifier) for i in self.instances], + return_exceptions=True, + ) successful_removes = sum(s is None for s in successes) elapsed_time = time.monotonic() - start_time unlocked = successful_removes >= int(len(self.instances) / 2) + 1 - self.log.debug('Lock "%s" is unset on %d/%d instances in %s seconds', - resource, successful_removes, len(self.instances), elapsed_time) + self.log.debug( + 'Lock "%s" is unset on %d/%d instances in %s seconds', + resource, + successful_removes, + len(self.instances), + elapsed_time, + ) if not unlocked: - raise_error(successes, 'Can not release the lock') + raise_error(successes, "Can not release the lock") return elapsed_time @@ -426,20 +446,18 @@ async def is_locked(self, resource): :returns: True if locked else False """ - successes = await asyncio.gather(*[ - i.is_locked(resource) for - i in self.instances - ], return_exceptions=True) + successes = await asyncio.gather( + *[i.is_locked(resource) for i in self.instances], return_exceptions=True + ) successful_sets = sum(s is True for s in successes) return successful_sets >= int(len(self.instances) / 2) + 1 async def clear_connections(self): - - self.log.debug('Clearing connection') + self.log.debug("Clearing connection") if self.instances: coros = [] while self.instances: - coros.append(self.instances.pop().close()) - await asyncio.gather(*(coros)) + coros.append(self.instances.pop().aclose()) + await asyncio.gather(*coros) diff --git a/aioredlock/sentinel.py b/aioredlock/sentinel.py index 4c5fa2d..a8d11fd 100644 --- a/aioredlock/sentinel.py +++ b/aioredlock/sentinel.py @@ -2,26 +2,26 @@ import ssl import urllib.parse -import aioredis.sentinel +from redis.asyncio import Sentinel as RedisSentinel, Redis as AIORedis class SentinelConfigError(Exception): - ''' + """ Exception raised if Configuration is not valid when instantiating a Sentinel object. - ''' + """ class Sentinel: - - def __init__(self, connection, master=None, password=None, db=None, ssl_context=None): - ''' + def __init__( + self, connection, master=None, password=None, db=None, ssl_context=None + ): + """ The connection address can be one of the following: * a dict - {'host': 'localhost', 'port': 6379} * a Redis URI - "redis://host:6379/0?encoding=utf-8&master=mymaster"; * a (host, port) tuple - ('localhost', 6379); * or a unix domain socket path string - "/path/to/redis.sock". - * a redis connection pool. :param connection: The connection address can be one of the following: @@ -46,27 +46,31 @@ def __init__(self, connection, master=None, password=None, db=None, ssl_context= For example, if 'master' is specified in the connection dictionary, but also specified as the master kwarg, the master kwarg will be used instead. - ''' + """ address, kwargs = (), {} if isinstance(connection, dict): kwargs.update(connection) - address = [(kwargs.pop('host'), kwargs.pop('port', 26379))] - elif isinstance(connection, str) and re.match(r'^rediss?://.*\:\d+/\d?\??.*$', connection): + address = [(kwargs.pop("host"), kwargs.pop("port", 26379))] + elif isinstance(connection, str) and re.match( + r"^rediss?://.*\:\d+/\d?\??.*$", connection + ): url = urllib.parse.urlparse(connection) - query = {key: value[0] for key, value in urllib.parse.parse_qs(url.query).items()} + query = { + key: value[0] for key, value in urllib.parse.parse_qs(url.query).items() + } address = [(url.hostname, url.port or 6379)] - dbnum = url.path.strip('/') + dbnum = url.path.strip("/") - if url.scheme == 'rediss': - kwargs['ssl'] = ssl.create_default_context() - verify_mode = query.pop('ssl_cert_reqs', None) + if url.scheme == "rediss": + kwargs["ssl"] = ssl.create_default_context() + verify_mode = query.pop("ssl_cert_reqs", None) if verify_mode is not None and hasattr(ssl, verify_mode.upper()): - if verify_mode == 'CERT_NONE': - kwargs['ssl'].check_hostname = False - kwargs['ssl'].verify_mode = getattr(ssl, verify_mode.upper()) + if verify_mode == "CERT_NONE": + kwargs["ssl"].check_hostname = False + kwargs["ssl"].verify_mode = getattr(ssl, verify_mode.upper()) - kwargs['db'] = int(dbnum) if dbnum.isdigit() else 0 - kwargs['password'] = url.password + kwargs["db"] = int(dbnum) if dbnum.isdigit() else 0 + kwargs["password"] = url.password kwargs.update(query) elif isinstance(connection, tuple): @@ -74,42 +78,43 @@ def __init__(self, connection, master=None, password=None, db=None, ssl_context= elif isinstance(connection, list): address = connection else: - raise SentinelConfigError('Invalid Sentinel Configuration') + raise SentinelConfigError("Invalid Sentinel Configuration") if db is not None: - kwargs['db'] = db + kwargs["db"] = db if password is not None: - kwargs['password'] = password + kwargs["password"] = password if ssl_context is True: - kwargs['ssl'] = ssl.create_default_context() + kwargs["ssl"] = ssl.create_default_context() elif ssl_context is not None: - kwargs['ssl'] = ssl_context + kwargs["ssl"] = ssl_context - self.master = kwargs.pop('master', None) + self.master = kwargs.pop("master", None) if master: self.master = master if self.master is None: - raise SentinelConfigError('Master name required for sentinel to be configured') + raise SentinelConfigError( + "Master name required for sentinel to be configured" + ) - kwargs['minsize'] = 1 if 'minsize' not in kwargs else int(kwargs['minsize']) - kwargs['maxsize'] = 100 if 'maxsize' not in kwargs else int(kwargs['maxsize']) + kwargs["max_connections"] = int(kwargs.get("max_connections", 100)) self.connection = address self.redis_kwargs = kwargs - async def get_sentinel(self): - ''' + def get_sentinel(self): + """ Retrieve sentinel object from aioredis. - ''' - return await aioredis.sentinel.create_sentinel( + """ + return RedisSentinel( sentinels=self.connection, **self.redis_kwargs, ) - async def get_master(self): - ''' + def get_master(self) -> AIORedis: + """ Get ``Redis`` instance for specified ``master`` - ''' - sentinel = await self.get_sentinel() - return await sentinel.master_for(self.master) + """ + sentinel = self.get_sentinel() + return sentinel.master_for(self.master) diff --git a/examples/basic_lock.py b/examples/basic_lock.py index bcb3963..203d1c5 100755 --- a/examples/basic_lock.py +++ b/examples/basic_lock.py @@ -5,22 +5,19 @@ async def basic_lock(): - lock_manager = Aioredlock([{ - 'host': 'localhost', - 'port': 6379, - 'db': 0, - 'password': None - }]) + lock_manager = Aioredlock( + [{"host": "localhost", "port": 6379, "db": 0, "password": None}] + ) if await lock_manager.is_locked("resource"): - print('The resource is already acquired') + print("The resource is already acquired") try: lock = await lock_manager.lock("resource") except LockAcquiringError: - print('Something happened during normal operation. We just log it.') + print("Something happened during normal operation. We just log it.") except LockError: - print('Something is really wrong and we prefer to raise the exception') + print("Something is really wrong and we prefer to raise the exception") raise assert lock.valid is True assert await lock_manager.is_locked("resource") is True diff --git a/examples/lock_context.py b/examples/lock_context.py index 7f7e7de..7f55de4 100755 --- a/examples/lock_context.py +++ b/examples/lock_context.py @@ -5,15 +5,17 @@ async def lock_context(): - lock_manager = Aioredlock([ - 'redis://localhost:6379/0', - 'redis://localhost:6379/1', - 'redis://localhost:6379/2', - 'redis://localhost:6379/3', - ]) + lock_manager = Aioredlock( + [ + "redis://localhost:6379/0", + "redis://localhost:6379/1", + "redis://localhost:6379/2", + "redis://localhost:6379/3", + ] + ) if await lock_manager.is_locked("resource"): - print('The resource is already acquired') + print("The resource is already acquired") try: # if you dont set your lock's lock_timeout, its lifetime will be automatically extended @@ -30,9 +32,9 @@ async def lock_context(): assert lock.valid is False # lock will be released by context manager except LockAcquiringError: - print('Something happened during normal operation. We just log it.') + print("Something happened during normal operation. We just log it.") except LockError: - print('Something is really wrong and we prefer to raise the exception') + print("Something is really wrong and we prefer to raise the exception") raise assert lock.valid is False diff --git a/examples/sentinel.py b/examples/sentinel.py index e79fe23..1d4d637 100755 --- a/examples/sentinel.py +++ b/examples/sentinel.py @@ -33,6 +33,7 @@ .. _Sentinels: https://redis.io/topics/sentinel .. _TunTap: https://github.com/AlmirKadric-Published/docker-tuntap-osx """ + import asyncio import logging @@ -48,21 +49,25 @@ async def get_container(name): async def get_container_ip(name, network=None): container = await get_container(name) - return container['NetworkSettings']['Networks'][network or 'aioredlock_backend']['IPAddress'] + return container["NetworkSettings"]["Networks"][network or "aioredlock_backend"][ + "IPAddress" + ] async def lock_context(): - sentinel_ip = await get_container_ip('aioredlock_sentinel_1') + sentinel_ip = await get_container_ip("aioredlock_sentinel_1") - lock_manager = Aioredlock([ - Sentinel('redis://{0}:26379/0?master=leader'.format(sentinel_ip)), - Sentinel('redis://{0}:26379/1?master=leader'.format(sentinel_ip)), - Sentinel('redis://{0}:26379/2?master=leader'.format(sentinel_ip)), - Sentinel('redis://{0}:26379/3?master=leader'.format(sentinel_ip)), - ]) + lock_manager = Aioredlock( + [ + Sentinel("redis://{0}:26379/0?master=leader".format(sentinel_ip)), + Sentinel("redis://{0}:26379/1?master=leader".format(sentinel_ip)), + Sentinel("redis://{0}:26379/2?master=leader".format(sentinel_ip)), + Sentinel("redis://{0}:26379/3?master=leader".format(sentinel_ip)), + ] + ) if await lock_manager.is_locked("resource"): - print('The resource is already acquired') + print("The resource is already acquired") try: # if you dont set your lock's lock_timeout, its lifetime will be automatically extended @@ -71,7 +76,7 @@ async def lock_context(): assert await lock_manager.is_locked("resource") is True # pause leader to simulate a failing node and cause a failover - container = await get_container('aioredlock_leader_1') + container = await get_container("aioredlock_leader_1") await container.pause() # Do your stuff having the lock @@ -85,9 +90,9 @@ async def lock_context(): assert lock.valid is False # lock will be released by context manager except LockAcquiringError: - print('Something happened during normal operation. We just log it.') + print("Something happened during normal operation. We just log it.") except LockError: - print('Something is really wrong and we prefer to raise the exception') + print("Something is really wrong and we prefer to raise the exception") raise assert lock.valid is False diff --git a/setup.py b/setup.py index a823d78..e68ac28 100644 --- a/setup.py +++ b/setup.py @@ -5,48 +5,43 @@ here = path.abspath(path.dirname(__file__)) -with open(path.join(here, 'README.rst'), encoding='utf-8') as f: +with open(path.join(here, "README.rst"), encoding="utf-8") as f: long_description = f.read() setup( - name='aioredlock', - - version='0.7.3', - - description='Asyncio implementation of Redis distributed locks', + name="aioredlock", + version="0.8.0", + description="Asyncio implementation of Redis distributed locks", long_description=long_description, - - url='https://github.com/joanvila/aioredlock', - - author='Joan Vilà Cuñat', - author_email='vila.joan94@gmail.com', - - license='MIT', - + url="https://github.com/joanvila/aioredlock", + author="Joan Vilà Cuñat", + author_email="vila.joan94@gmail.com", + license="MIT", classifiers=[ - 'Development Status :: 5 - Production/Stable', - - 'Intended Audience :: Developers', - 'Topic :: Software Development :: Libraries :: Python Modules', - - 'License :: OSI Approved :: MIT License', - - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', + "Development Status :: 5 - Production/Stable", + "Intended Audience :: Developers", + "Topic :: Software Development :: Libraries :: Python Modules", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", ], - - keywords='redis redlock distributed locks asyncio', - + keywords="redis redlock distributed locks asyncio", packages=find_packages(), - - python_requires='>=3.6', - install_requires=['aioredis<2.0.0', 'attrs >= 17.4.0'], + python_requires=">=3.9", + install_requires=["redis>=5.0.0", "attrs >= 17.4.0"], extras_require={ - 'test': ['pytest==6.1.0', 'pytest-asyncio', 'pytest-mock', 'pytest-cov', 'flake8'], - 'cicd': ['codecov'], - 'package': ['bump2version', 'twine', 'wheel'], - 'examples': ['aiodocker'], + "test": [ + "pytest<=8.0.0", + "pytest-asyncio", + "pytest-mock", + "pytest-cov", + "flake8", + ], + "cicd": ["codecov"], + "package": ["bump2version", "twine", "wheel"], + "examples": ["aiodocker"], }, ) diff --git a/tests/acceptance/test_aioredlock.py b/tests/acceptance/test_aioredlock.py index d711c9f..b60ee89 100644 --- a/tests/acceptance/test_aioredlock.py +++ b/tests/acceptance/test_aioredlock.py @@ -2,7 +2,7 @@ import unittest.mock import uuid -import aioredis +from redis.asyncio import ConnectionPool, Redis as AIORedis import pytest from aioredlock import Aioredlock, LockError @@ -10,19 +10,18 @@ @pytest.fixture def redis_one_connection(): - return [{'host': 'localhost', 'port': 6379, 'db': 0}] + return [{"host": "localhost", "port": 6379, "db": 0}] @pytest.fixture def redis_two_connections(): return [ - {'host': 'localhost', 'port': 6379, 'db': 0}, - {'host': 'localhost', 'port': 6379, 'db': 1} + {"host": "localhost", "port": 6379, "db": 0}, + {"host": "localhost", "port": 6379, "db": 1}, ] class TestAioredlock: - async def check_simple_lock(self, lock_manager): resource = str(uuid.uuid4()) @@ -90,68 +89,57 @@ async def check_two_locks_on_same_resource(self, lock_manager): await lock_manager.destroy() @pytest.mark.asyncio - async def test_simple_aioredlock_one_instance( - self, - redis_one_connection): - + async def test_simple_aioredlock_one_instance(self, redis_one_connection): await self.check_simple_lock(Aioredlock(redis_one_connection)) @pytest.mark.asyncio - async def test_simple_aioredlock_one_instance_pool( - self, - redis_one_connection): - address = 'redis://{host}:{port}/{db}'.format(**redis_one_connection[0]) - pool = await aioredis.create_redis_pool(address=address, encoding='utf-8') - await self.check_simple_lock(Aioredlock([pool])) + async def test_simple_aioredlock_one_instance_pool(self, redis_one_connection): + address = "redis://{host}:{port}/{db}".format(**redis_one_connection[0]) + client = ConnectionPool.from_url(address) + await self.check_simple_lock(Aioredlock([client])) @pytest.mark.asyncio async def test_aioredlock_two_locks_on_different_resources_one_instance( - self, - redis_one_connection): - - await self.check_two_locks_on_different_resources(Aioredlock(redis_one_connection)) + self, redis_one_connection + ): + await self.check_two_locks_on_different_resources( + Aioredlock(redis_one_connection) + ) @pytest.mark.asyncio async def test_aioredlock_two_locks_on_same_resource_one_instance( - self, - redis_one_connection): - + self, redis_one_connection + ): await self.check_two_locks_on_same_resource(Aioredlock(redis_one_connection)) @pytest.mark.asyncio - async def test_simple_aioredlock_two_instances( - self, - redis_two_connections): - + async def test_simple_aioredlock_two_instances(self, redis_two_connections): await self.check_simple_lock(Aioredlock(redis_two_connections)) @pytest.mark.asyncio async def test_aioredlock_two_locks_on_different_resources_two_instances( - self, - redis_two_connections): - - await self.check_two_locks_on_different_resources(Aioredlock(redis_two_connections)) + self, redis_two_connections + ): + await self.check_two_locks_on_different_resources( + Aioredlock(redis_two_connections) + ) @pytest.mark.asyncio async def test_aioredlock_two_locks_on_same_resource_two_instances( - self, - redis_two_connections): - + self, redis_two_connections + ): await self.check_two_locks_on_same_resource(Aioredlock(redis_two_connections)) @pytest.mark.asyncio async def test_aioredlock_lock_with_first_failed_try_two_instances( - self, - redis_two_connections + self, redis_two_connections ): - lock_manager = Aioredlock(redis_two_connections) resource = str(uuid.uuid4()) - garbage_value = 'garbage' + garbage_value = "garbage" - first_redis = await aioredis.create_redis( - (redis_two_connections[0]['host'], - redis_two_connections[0]['port']) + first_redis = AIORedis.from_url( + f"redis://{redis_two_connections[0]['host']}:{redis_two_connections[0]['port']}/0" ) # write garbage to resource key in first instance @@ -159,11 +147,10 @@ async def test_aioredlock_lock_with_first_failed_try_two_instances( is_garbage = True # this patched sleep function will remove garbage from - # frist instance before second try + # first instance before second try real_sleep = asyncio.sleep async def fake_sleep(delay): - nonlocal is_garbage # remove garbage on sleep @@ -171,7 +158,6 @@ async def fake_sleep(delay): await first_redis.delete(resource) is_garbage = False - # print('fake_sleep(%s), value %s' % (delay, value)) await real_sleep(delay) # here we will try to lock while first redis instance still have @@ -186,5 +172,4 @@ async def fake_sleep(delay): assert lock.valid is False await lock_manager.destroy() - first_redis.close() - await first_redis.wait_closed() + await first_redis.close() diff --git a/tests/ut/conftest.py b/tests/ut/conftest.py index ad66030..70c53dc 100644 --- a/tests/ut/conftest.py +++ b/tests/ut/conftest.py @@ -1,7 +1,7 @@ import asyncio import ssl import uuid -from unittest.mock import MagicMock +from unittest.mock import MagicMock, AsyncMock from unittest.mock import patch import pytest @@ -31,18 +31,17 @@ def unlocked_lock(lock_manager_redis_patched): @pytest.fixture def lock_manager_redis_patched(): - with patch("aioredlock.algorithm.Redis") as mock_redis, \ - patch("asyncio.sleep", dummy_sleep): - mock_redis.set_lock.return_value = asyncio.Future() - mock_redis.set_lock.return_value.set_result(0.005) - mock_redis.unset_lock.return_value = asyncio.Future() - mock_redis.unset_lock.return_value.set_result(0.005) - mock_redis.is_locked.return_value = asyncio.Future() - mock_redis.is_locked.return_value.set_result(False) - mock_redis.clear_connections.return_value = asyncio.Future() - mock_redis.clear_connections.return_value.set_result(MagicMock()) - mock_redis.get_lock_ttl.return_value = asyncio.Future() - mock_redis.get_lock_ttl.return_value.set_result(Lock(None, "resource_name", 1, -1, True)) + with ( + patch("aioredlock.algorithm.Redis") as mock_redis, + patch("asyncio.sleep", dummy_sleep), + ): + mock_redis.set_lock = AsyncMock(return_value=0.005) + mock_redis.unset_lock = AsyncMock(return_value=0.005) + mock_redis.is_locked = AsyncMock(return_value=False) + mock_redis.clear_connections = AsyncMock() + mock_redis.get_lock_ttl = AsyncMock( + return_value=Lock(None, "resource_name", 1, -1, True) + ) lock_manager = Aioredlock(internal_lock_timeout=1.0) lock_manager.redis = mock_redis @@ -52,13 +51,14 @@ def lock_manager_redis_patched(): @pytest.fixture def aioredlock_patched(): - with patch("aioredlock.algorithm.Aioredlock", MagicMock) as mock_aioredlock, \ - patch("asyncio.sleep", dummy_sleep): + with ( + patch("aioredlock.algorithm.Aioredlock", MagicMock) as mock_aioredlock, + patch("asyncio.sleep", dummy_sleep), + ): async def dummy_lock(resource): lock_identifier = str(uuid.uuid4()) - return Lock(mock_aioredlock, resource, - lock_identifier, valid=True) + return Lock(mock_aioredlock, resource, lock_identifier, valid=True) mock_aioredlock.lock = MagicMock(side_effect=dummy_lock) mock_aioredlock.extend = MagicMock(return_value=asyncio.Future()) @@ -74,12 +74,5 @@ async def dummy_lock(resource): @pytest.fixture def ssl_context(): context = ssl.create_default_context() - with patch('ssl.create_default_context', return_value=context): + with patch("ssl.create_default_context", return_value=context): yield context - - -@pytest.fixture -def fake_coro(): - async def func(thing): - return thing - return func diff --git a/tests/ut/test_algorithm.py b/tests/ut/test_algorithm.py index 5ac5a7b..72fce5a 100644 --- a/tests/ut/test_algorithm.py +++ b/tests/ut/test_algorithm.py @@ -1,6 +1,6 @@ import asyncio import sys -from unittest.mock import ANY, call, MagicMock, patch +from unittest.mock import ANY, call, MagicMock, AsyncMock, patch import pytest @@ -9,11 +9,17 @@ real_sleep = asyncio.sleep -@pytest.mark.parametrize('method,exc_message', [ - ('_validate_retry_count', "Retry count must be greater or equal 1."), - ('_validate_retry_delay', "Retry delay must be greater than 0 seconds."), - ('_validate_internal_lock_timeout', "Internal lock_timeout must be greater than 0 seconds.") -]) +@pytest.mark.parametrize( + "method,exc_message", + [ + ("_validate_retry_count", "Retry count must be greater or equal 1."), + ("_validate_retry_delay", "Retry delay must be greater than 0 seconds."), + ( + "_validate_internal_lock_timeout", + "Internal lock_timeout must be greater than 0 seconds.", + ), + ], +) def test_validator(method, exc_message): with pytest.raises(ValueError) as exc_info: getattr(Aioredlock, method)(None, None, -1) @@ -21,40 +27,34 @@ def test_validator(method, exc_message): class TestAioredlock: - def test_default_initialization(self): with patch("aioredlock.algorithm.Redis.__init__") as mock_redis: mock_redis.return_value = None lock_manager = Aioredlock() mock_redis.assert_called_once_with( - [{'host': 'localhost', 'port': 6379}], - + [{"host": "localhost", "port": 6379}], ) assert lock_manager.redis def test_initialization_with_params(self): with patch("aioredlock.algorithm.Redis.__init__") as mock_redis: mock_redis.return_value = None - lock_manager = Aioredlock([{'host': '::1', 'port': 1}]) + lock_manager = Aioredlock([{"host": "::1", "port": 1}]) mock_redis.assert_called_once_with( - [{'host': '::1', 'port': 1}], + [{"host": "::1", "port": 1}], ) assert lock_manager.redis - @pytest.mark.parametrize('param', [ - 'retry_count', - 'retry_delay_min', - 'retry_delay_max', - 'internal_lock_timeout' - ]) - @pytest.mark.parametrize('value,exc_type', [ - (-1, ValueError), - (0, ValueError), - ('string', ValueError), - (None, TypeError) - ]) + @pytest.mark.parametrize( + "param", + ["retry_count", "retry_delay_min", "retry_delay_max", "internal_lock_timeout"], + ) + @pytest.mark.parametrize( + "value,exc_type", + [(-1, ValueError), (0, ValueError), ("string", ValueError), (None, TypeError)], + ) def test_initialization_with_invalid_params(self, param, value, exc_type): lock_manager = None with pytest.raises(exc_type): @@ -65,14 +65,10 @@ def test_initialization_with_invalid_params(self, param, value, exc_type): async def test_lock(self, lock_manager_redis_patched, locked_lock): lock_manager, redis = lock_manager_redis_patched - lock = await lock_manager.lock('resource', 1.0) + lock = await lock_manager.lock("resource", 1.0) - redis.set_lock.assert_called_once_with( - 'resource', - ANY, - 1.0 - ) - assert lock.resource == 'resource' + redis.set_lock.assert_called_once_with("resource", ANY, 1.0) + assert lock.resource == "resource" assert lock.id == ANY assert lock.valid is True @@ -87,81 +83,81 @@ async def test_lock_one_retry(self, lock_manager_redis_patched, locked_lock): lock_manager, redis = lock_manager_redis_patched future = asyncio.Future() future.set_result(0.001) - redis.set_lock = MagicMock(side_effect=[ - LockError('Can not lock'), - future, - ]) + redis.set_lock = MagicMock( + side_effect=[ + LockError("Can not lock"), + future, + ] + ) - lock = await lock_manager.lock('resource', 1.0) + lock = await lock_manager.lock("resource", 1.0) - calls = [ - call('resource', ANY, 1.0), - call('resource', ANY, 1.0) - ] + calls = [call("resource", ANY, 1.0), call("resource", ANY, 1.0)] redis.set_lock.assert_has_calls(calls) redis.unset_lock.assert_not_called() - assert lock.resource == 'resource' + assert lock.resource == "resource" assert lock.id == ANY assert lock.valid is True @pytest.mark.asyncio async def test_lock_expire_retries(self, lock_manager_redis_patched, locked_lock): lock_manager, redis = lock_manager_redis_patched - redis.set_lock = MagicMock(side_effect=[ - LockError('Can not lock'), - LockError('Can not lock'), - LockError('Can not lock') - ]) + redis.set_lock = MagicMock( + side_effect=[ + LockError("Can not lock"), + LockError("Can not lock"), + LockError("Can not lock"), + ] + ) with pytest.raises(LockError): - await lock_manager.lock('resource', 1.0) + await lock_manager.lock("resource", 1.0) await real_sleep(0.1) # wait until cleaning is completed calls = [ - call('resource', ANY, 1.0), - call('resource', ANY, 1.0), - call('resource', ANY, 1.0) + call("resource", ANY, 1.0), + call("resource", ANY, 1.0), + call("resource", ANY, 1.0), ] redis.set_lock.assert_has_calls(calls) - redis.unset_lock.assert_called_once_with('resource', ANY) + redis.unset_lock.assert_called_once_with("resource", ANY) @pytest.mark.asyncio - async def test_lock_one_timeout(self, fake_coro, lock_manager_redis_patched, locked_lock): + async def test_lock_one_timeout(self, lock_manager_redis_patched, locked_lock): lock_manager, redis = lock_manager_redis_patched - redis.set_lock.side_effect = [fake_coro(1.5), fake_coro(0.001)] + redis.set_lock.side_effect = [1.5, 0.001] - lock = await lock_manager.lock('resource', 1.0) + lock = await lock_manager.lock("resource", 1.0) - calls = [ - call('resource', ANY, 1.0), - call('resource', ANY, 1.0) - ] + calls = [call("resource", ANY, 1.0), call("resource", ANY, 1.0)] redis.set_lock.assert_has_calls(calls) redis.unset_lock.assert_not_called() - assert lock.resource == 'resource' + assert lock.resource == "resource" assert lock.id == ANY assert lock.valid is True @pytest.mark.asyncio - async def test_lock_expire_retries_for_timeouts(self, fake_coro, lock_manager_redis_patched, locked_lock): + async def test_lock_expire_retries_for_timeouts( + self, lock_manager_redis_patched, locked_lock + ): lock_manager, redis = lock_manager_redis_patched - redis.set_lock.side_effect = [fake_coro(1.100), fake_coro(1.001), fake_coro(2.000)] + redis.set_lock.side_effect = [1.100, 1.001, 2.000] with pytest.raises(LockError): - await lock_manager.lock('resource', 1.0) + await lock_manager.lock("resource", 1.0) await real_sleep(0.1) # wait until cleaning is completed calls = [ - call('resource', ANY, 1.0), - call('resource', ANY, 1.0), - call('resource', ANY, 1.0) + call("resource", ANY, 1.0), + call("resource", ANY, 1.0), + call("resource", ANY, 1.0), ] redis.set_lock.assert_has_calls(calls) - redis.unset_lock.assert_called_once_with('resource', ANY) + redis.unset_lock.assert_called_once_with("resource", ANY) @pytest.mark.asyncio async def test_cancel_lock_(self, lock_manager_redis_patched): @@ -173,29 +169,26 @@ async def mock_set_lock(*args, **kwargs): redis.set_lock = MagicMock(side_effect=mock_set_lock) with pytest.raises(asyncio.TimeoutError): - await asyncio.wait_for(lock_manager.lock('resource', 1.0), 0.1) + await asyncio.wait_for(lock_manager.lock("resource", 1.0), 0.1) # The exception handling of the cancelled lock is run in background and # can not be awaited, so we have to sleep until the unset_lock has done. await real_sleep(0.1) - redis.set_lock.assert_called_once_with('resource', ANY, 1.0) - redis.unset_lock.assert_called_once_with('resource', ANY) + redis.set_lock.assert_called_once_with("resource", ANY, 1.0) + redis.unset_lock.assert_called_once_with("resource", ANY) @pytest.mark.asyncio async def test_extend_lock(self, lock_manager_redis_patched, locked_lock): lock_manager, redis = lock_manager_redis_patched - lock = await lock_manager.lock('resource', 1.0) + lock = await lock_manager.lock("resource", 1.0) await lock_manager.extend(lock) - calls = [ - call('resource', ANY, 1.0), - call('resource', ANY, 1.0) - ] + calls = [call("resource", ANY, 1.0), call("resource", ANY, 1.0)] redis.set_lock.assert_has_calls(calls) - assert lock.resource == 'resource' + assert lock.resource == "resource" assert lock.id == ANY assert lock.valid is True @@ -213,9 +206,9 @@ async def test_extend_with_invalid_param(self, lock_manager_redis_patched): @pytest.mark.asyncio async def test_extend_lock_error(self, lock_manager_redis_patched, locked_lock): lock_manager, redis = lock_manager_redis_patched - lock = await lock_manager.lock('resource') + lock = await lock_manager.lock("resource") - redis.set_lock = MagicMock(side_effect=LockError('Can not lock')) + redis.set_lock = MagicMock(side_effect=LockError("Can not lock")) with pytest.raises(LockError): await lock_manager.extend(lock) @@ -226,10 +219,7 @@ async def test_unlock(self, lock_manager_redis_patched, locked_lock): await lock_manager.unlock(locked_lock) - redis.unset_lock.assert_called_once_with( - locked_lock.resource, - locked_lock.id - ) + redis.unset_lock.assert_called_once_with(locked_lock.resource, locked_lock.id) assert locked_lock.valid is False @pytest.mark.asyncio @@ -242,10 +232,11 @@ async def test_unlock_type_error(self, lock_manager_redis_patched): @pytest.mark.asyncio @pytest.mark.parametrize("by_resource", [True, False]) @pytest.mark.parametrize("locked", [True, False]) - async def test_is_locked(self, lock_manager_redis_patched, locked_lock, by_resource, locked): + async def test_is_locked( + self, lock_manager_redis_patched, locked_lock, by_resource, locked + ): lock_manager, redis = lock_manager_redis_patched - redis.is_locked.return_value = asyncio.Future() - redis.is_locked.return_value.set_result(locked) + redis.is_locked = AsyncMock(return_value=locked) Lock.valid = locked resource = locked_lock.resource @@ -267,25 +258,22 @@ async def test_is_locked_type_error(self, lock_manager_redis_patched): async def test_context_manager(self, lock_manager_redis_patched): lock_manager, redis = lock_manager_redis_patched - async with await lock_manager.lock('resource', 1.0) as lock: - assert lock.resource == 'resource' + async with await lock_manager.lock("resource", 1.0) as lock: + assert lock.resource == "resource" assert lock.id == ANY assert lock.valid is True await lock.extend() assert lock.valid is False - calls = [ - call('resource', ANY, 1.0), - call('resource', ANY, 1.0) - ] + calls = [call("resource", ANY, 1.0), call("resource", ANY, 1.0)] redis.set_lock.assert_has_calls(calls) - redis.unset_lock.assert_called_once_with('resource', ANY) + redis.unset_lock.assert_called_once_with("resource", ANY) @pytest.mark.asyncio async def test_destroy_lock_manager(self, lock_manager_redis_patched): lock_manager, redis = lock_manager_redis_patched - lock_manager.unlock = MagicMock(side_effect=LockError('Can not lock')) + lock_manager.unlock = MagicMock(side_effect=LockError("Can not lock")) await lock_manager.lock("resource", 1.0) await lock_manager.destroy() @@ -295,12 +283,9 @@ async def test_destroy_lock_manager(self, lock_manager_redis_patched): @pytest.mark.asyncio async def test_auto_extend(self): with patch("aioredlock.algorithm.Redis") as mock_redis: - mock_redis.set_lock.return_value = asyncio.Future() - mock_redis.set_lock.return_value.set_result(0.005) - mock_redis.unset_lock.return_value = asyncio.Future() - mock_redis.unset_lock.return_value.set_result(0.005) - mock_redis.clear_connections.return_value = asyncio.Future() - mock_redis.clear_connections.return_value.set_result(MagicMock()) + mock_redis.set_lock = AsyncMock(return_value=0.005) + mock_redis.unset_lock = AsyncMock(return_value=0.005) + mock_redis.clear_connections = AsyncMock() lock_manager = Aioredlock(internal_lock_timeout=1) lock_manager.redis = mock_redis @@ -308,8 +293,10 @@ async def test_auto_extend(self): await real_sleep(lock_manager.internal_lock_timeout * 3) - calls = [call('resource', lock.id, lock_manager.internal_lock_timeout) - for _ in range(5)] + calls = [ + call("resource", lock.id, lock_manager.internal_lock_timeout) + for _ in range(5) + ] mock_redis.set_lock.assert_has_calls(calls) await lock_manager.destroy() @@ -318,30 +305,24 @@ async def test_auto_extend(self): @pytest.mark.asyncio async def test_auto_extend_with_extend_failed(self): with patch("aioredlock.algorithm.Redis") as mock_redis: - mock_redis.set_lock.return_value = asyncio.Future() - mock_redis.set_lock.return_value.set_result(0.005) - mock_redis.unset_lock.return_value = asyncio.Future() - mock_redis.unset_lock.return_value.set_result(0.005) - mock_redis.clear_connections.return_value = asyncio.Future() - mock_redis.clear_connections.return_value.set_result(MagicMock()) + mock_redis.set_lock = AsyncMock(return_value=0.005) + mock_redis.unset_lock = AsyncMock(return_Value=0.005) + mock_redis.clear_connections = AsyncMock() lock_manager = Aioredlock(internal_lock_timeout=1.0) lock_manager.redis = mock_redis lock = await lock_manager.lock("resource") lock.valid = False await real_sleep(lock_manager.internal_lock_timeout * 3) - calls = [call('resource', lock.id, lock_manager.internal_lock_timeout)] + calls = [call("resource", lock.id, lock_manager.internal_lock_timeout)] mock_redis.set_lock.assert_has_calls(calls) @pytest.mark.asyncio async def test_unlock_with_watchdog_failed(self): with patch("aioredlock.algorithm.Redis") as mock_redis: - mock_redis.set_lock.return_value = asyncio.Future() - mock_redis.set_lock.return_value.set_result(0.005) - mock_redis.unset_lock.return_value = asyncio.Future() - mock_redis.unset_lock.return_value.set_result(0.005) - mock_redis.clear_connections.return_value = asyncio.Future() - mock_redis.clear_connections.return_value.set_result(MagicMock()) + mock_redis.set_lock = AsyncMock(return_value=0.005) + mock_redis.unset_lock = AsyncMock(return_value=0.005) + mock_redis.clear_connections = AsyncMock() lock_manager = Aioredlock(internal_lock_timeout=1.0) lock_manager.redis = mock_redis @@ -362,10 +343,11 @@ async def test_unlock_with_watchdog_failed(self): assert lock.valid is False @pytest.mark.asyncio - async def test_get_active_locks(self, lock_manager_redis_patched, locked_lock, unlocked_lock): + async def test_get_active_locks( + self, lock_manager_redis_patched, locked_lock, unlocked_lock + ): lock_manager, redis = lock_manager_redis_patched - redis.is_locked.return_value = asyncio.Future() - redis.is_locked.return_value.set_result(True) + redis.is_locked = AsyncMock(return_value=True) locks = await lock_manager.get_active_locks() @@ -375,8 +357,7 @@ async def test_get_active_locks(self, lock_manager_redis_patched, locked_lock, u @pytest.mark.asyncio async def test_get_lock(self, lock_manager_redis_patched, locked_lock): lock_manager, redis = lock_manager_redis_patched - redis.get_lock_ttl.return_value = asyncio.Future() - redis.get_lock_ttl.return_value.set_result(-1) + redis.get_lock_ttl = AsyncMock(return_value=-1) lock = await lock_manager.get_lock("resource_name", 1) assert lock == locked_lock diff --git a/tests/ut/test_lock.py b/tests/ut/test_lock.py index 4f0daaf..80cf294 100644 --- a/tests/ut/test_lock.py +++ b/tests/ut/test_lock.py @@ -4,7 +4,6 @@ class TestLock: - def test_lock(self): lock_manager = Aioredlock() lock = Lock(lock_manager, "potato", 1, 1.0) @@ -14,12 +13,12 @@ def test_lock(self): @pytest.mark.asyncio async def test_extend(self, aioredlock_patched): - lock = await aioredlock_patched.lock('foo') + lock = await aioredlock_patched.lock("foo") await lock.extend() aioredlock_patched.extend.assert_called_once_with(lock) @pytest.mark.asyncio async def test_release(self, aioredlock_patched): - lock = await aioredlock_patched.lock('foo') + lock = await aioredlock_patched.lock("foo") await lock.release() aioredlock_patched.unlock.assert_called_once_with(lock) diff --git a/tests/ut/test_redis.py b/tests/ut/test_redis.py index 693d93a..d34da8b 100644 --- a/tests/ut/test_redis.py +++ b/tests/ut/test_redis.py @@ -1,9 +1,10 @@ +from __future__ import annotations import asyncio import hashlib -import sys -from unittest.mock import MagicMock, call, patch +from unittest.mock import call, patch, AsyncMock, Mock, AsyncMockMixin -import aioredis +from redis.asyncio import ConnectionPool +from redis.exceptions import ResponseError import pytest from aioredlock.errors import LockError, LockAcquiringError, LockRuntimeError @@ -11,68 +12,62 @@ from aioredlock.sentinel import Sentinel -def callculate_sha1(text): +def calculate_sha1(text): sha1 = hashlib.sha1() sha1.update(text.encode()) digest = sha1.hexdigest() return digest -EVAL_OK = b'OK' -EVAL_ERROR = aioredis.errors.ReplyError('ERROR') -CANCELLED = asyncio.CancelledError('CANCELLED') -CONNECT_ERROR = OSError('ERROR') -RANDOM_ERROR = Exception('FAULT') +EVAL_OK = b"OK" +EVAL_ERROR = ResponseError("ERROR") +CANCELLED = asyncio.CancelledError("CANCELLED") +CONNECT_ERROR = OSError("ERROR") +RANDOM_ERROR = Exception("FAULT") -class FakePool: +@pytest.fixture +async def fake_client() -> FakeClient: + _fake_pool = FakeClient() + return _fake_pool - SET_IF_NOT_EXIST = 'SET_IF_NOT_EXIST' - def __init__(self): +class FakeClient(AsyncMockMixin): + SET_IF_NOT_EXIST = "SET_IF_NOT_EXIST" + def __init__(self): + super().__init__() self.script_cache = {} - - self.evalsha = MagicMock(return_value=asyncio.Future()) - self.evalsha.return_value.set_result(True) - self.get = MagicMock(return_value=asyncio.Future()) - self.get.return_value.set_result(False) - self.script_load = MagicMock(side_effect=self._fake_script_load) - self.execute = MagicMock(side_effect=self._fake_execute) - self.close = MagicMock(return_value=asyncio.Future()) - self.close.return_value.set_result(True) - - def __await__(self): - yield - return self - - def __enter__(self): - return self - - def __exit__(self, *args, **kwargs): - pass - - def __call__(self): - return self + self.connection_kwargs = {} + self.evalsha = AsyncMock(return_value=True) + self.get = AsyncMock(return_value=False) + self.script_load = AsyncMock(side_effect=self._fake_script_load) + self.execute_command = AsyncMock(side_effect=self._fake_execute_command) + self.aclose = AsyncMock(return_value=True) + self.release = AsyncMock() def is_fake(self): # Only for development purposes return True - async def _fake_script_load(self, script): - - digest = callculate_sha1(script) + def _fake_script_load(self, script): + digest = calculate_sha1(script) self.script_cache[digest] = script return digest.encode() - async def _fake_execute(self, *args): - cmd = b' '.join(args[:2]) - if cmd == b'SCRIPT LOAD': - return await self._fake_script_load(args[-1]) + def _fake_execute_command(self, *args): + cmd = b" ".join(args[:2]) + if cmd == b"SCRIPT LOAD": + return self._fake_script_load(args[-1]) + def _fake_execute(self, *args): + cmd = b" ".join(args[:2]) + if cmd == b"SCRIPT LOAD": + return self._fake_script_load(args[-1]) -def fake_create_redis_pool(fake_pool): + +def fake_create_redis(fake_client): """ Original Redis pool have magick method __await__ to create exclusive connection. MagicMock sees this method and thinks that Redis pool @@ -80,291 +75,306 @@ def fake_create_redis_pool(fake_pool): To avoit this behavior we are using this constructor with Mock.side_effect instead of Mock.return_value. """ - async def create_redis_pool(*args, **kwargs): - return fake_pool - return create_redis_pool + async def create_redis(*args, **kwargs): + return fake_client -class TestInstance: + return create_redis - script_names = ['SET_LOCK_SCRIPT', 'UNSET_LOCK_SCRIPT', 'GET_LOCK_TTL_SCRIPT'] - def test_initialization(self): +class TestInstance: + script_names = ["SET_LOCK_SCRIPT", "UNSET_LOCK_SCRIPT", "GET_LOCK_TTL_SCRIPT"] - instance = Instance(('localhost', 6379)) + def test_initialization(self): + instance = Instance(("localhost", 6379)) - assert instance.connection == ('localhost', 6379) - assert instance._pool is None + assert instance.connection == ("localhost", 6379) + assert instance._client is None assert isinstance(instance._lock, asyncio.Lock) # scripts for name in self.script_names: - assert getattr(instance, '%s_sha1' % name.lower()) is None - - @pytest.mark.parametrize("connection, address, redis_kwargs", [ - (('localhost', 6379), ('localhost', 6379), {}), - ({'host': 'localhost', 'port': 6379, 'db': 0, 'password': 'pass'}, - ('localhost', 6379), {'db': 0, 'password': 'pass'}), - ("redis://host:6379/0?encoding=utf-8", - "redis://host:6379/0?encoding=utf-8", {}) - ]) - @pytest.mark.asyncio - async def test_connect_pool_not_created(self, connection, address, redis_kwargs): - with patch('aioredlock.redis.Instance._create_redis_pool') as \ - create_redis_pool: + assert getattr(instance, "%s_sha1" % name.lower()) is None - fake_pool = FakePool() - create_redis_pool.side_effect = fake_create_redis_pool(fake_pool) + @pytest.mark.parametrize( + "connection, expected_address, expected_kwargs", + [ + ( + ("localhost", 6379), + None, + {"host": "localhost", "port": 6379, "db": 0, "password": None}, + ), + ( + {"host": "localhost", "port": 6379, "db": 0, "password": "pass"}, + None, + {"host": "localhost", "port": 6379, "db": 0, "password": "pass"}, + ), + ( + "redis://host:6379/0?encoding=utf-8", + "redis://host:6379/0?encoding=utf-8", + {}, + ), + ], + ) + @pytest.mark.asyncio + async def test_connect_pool_not_created( + self, connection, expected_address, expected_kwargs, fake_client + ): + with patch( + "aioredlock.redis.Instance._create_redis", + AsyncMock(return_value=fake_client), + ) as create_redis_pool: instance = Instance(connection) - assert instance._pool is None + assert instance._client is None pool = await instance.connect() create_redis_pool.assert_called_once_with( - address, **redis_kwargs, - minsize=1, maxsize=100) - assert pool is fake_pool - assert instance._pool is fake_pool + expected_address, **expected_kwargs, max_connections=100 + ) + assert pool is fake_client + assert instance._client is fake_client # scripts assert pool.script_load.call_count == len(self.script_names) for name in self.script_names: - digest = getattr(instance, '%s_sha1' % name.lower()) + digest = getattr(instance, "%s_sha1" % name.lower()) assert digest assert digest in pool.script_cache - await fake_pool.close() + await fake_client.aclose() @pytest.mark.asyncio - async def test_connect_pool_not_created_with_minsize_and_maxsize(self): - connection = {'host': 'localhost', 'port': 6379, 'db': 0, 'password': 'pass', 'minsize': 2, 'maxsize': 5} - address = ('localhost', 6379) - redis_kwargs = {'db': 0, 'password': 'pass'} - with patch('aioredlock.redis.Instance._create_redis_pool') as \ - create_redis_pool: - - fake_pool = FakePool() - create_redis_pool.side_effect = fake_create_redis_pool(fake_pool) + async def test_connect_pool_not_created_with_max_connections(self, fake_client): + connection = { + "host": "localhost", + "port": 6379, + "db": 0, + "password": "pass", + "max_connections": 5, + } + with patch( + "aioredlock.redis.Instance._create_redis", + AsyncMock(return_value=fake_client), + ) as create_redis_pool: instance = Instance(connection) - assert instance._pool is None + assert instance._client is None pool = await instance.connect() - create_redis_pool.assert_called_once_with(address, **redis_kwargs, minsize=2, maxsize=5) - assert pool is fake_pool - assert instance._pool is fake_pool - - @pytest.mark.asyncio - async def test_connect_pool_already_created(self): - - with patch('aioredlock.redis.Instance._create_redis_pool') as \ - create_redis_pool: - instance = Instance(('localhost', 6379)) - fake_pool = FakePool() - instance._pool = fake_pool - + create_redis_pool.assert_called_once_with(None, **connection) + assert pool is fake_client + assert instance._client is fake_client + + async def test_connect_pool_already_created(self, fake_client): + with patch( + "aioredlock.redis.Instance._create_redis", + AsyncMock(return_value=fake_client), + ) as create_redis_pool: + instance = Instance(("localhost", 6379)) + fake_client = FakeClient() + instance._client = fake_client pool = await instance.connect() assert not create_redis_pool.called - assert pool is fake_pool + assert pool is fake_client assert pool.script_load.called is True @pytest.mark.asyncio - async def test_connect_pool_aioredis_instance(self): - - def awaiter(self): - yield from [] - - pool = FakePool() - redis_connection = aioredis.Redis(pool) - instance = Instance(redis_connection) - - assert instance._pool is None + async def test_connect_pool_aioredis_instance(self, mocker, fake_client): + pool = AsyncMock(spec=ConnectionPool) + pool.connection_kwargs = { + "host": "localhost", + "port": 6379, + "db": 0, + "password": "secret", + } + mocker.patch("redis.asyncio.ConnectionPool", return_value=pool) + + mocker.patch("redis.asyncio.Redis.from_url", return_value=fake_client) + instance = Instance(pool) + + assert instance._client is None await instance.connect() - assert pool.execute.call_count == len(self.script_names) + assert fake_client.script_load.call_count == len(self.script_names) assert instance.set_lock_script_sha1 is not None assert instance.unset_lock_script_sha1 is not None @pytest.mark.asyncio - async def test_connect_pool_aioredis_instance_with_sentinel(self): - - sentinel = Sentinel(('127.0.0.1', 26379), master='leader') - pool = FakePool() - redis_connection = aioredis.Redis(pool) - with patch.object(sentinel, 'get_master', return_value=asyncio.Future()) as mock_redis: - if sys.version_info < (3, 8, 0): - mock_redis.return_value.set_result(redis_connection) - else: - mock_redis.return_value = redis_connection + async def test_connect_pool_aioredis_instance_with_sentinel(self, fake_client): + sentinel = Sentinel(("127.0.0.1", 26379), master="leader") + with patch("redis.asyncio.Sentinel.master_for", Mock(return_value=fake_client)): instance = Instance(sentinel) - assert instance._pool is None + assert instance._client is None await instance.connect() - assert pool.execute.call_count == len(self.script_names) + + assert fake_client.script_load.call_count == len(self.script_names) assert instance.set_lock_script_sha1 is not None assert instance.unset_lock_script_sha1 is not None @pytest.fixture - def fake_instance(self): - with patch('aioredlock.redis.Instance._create_redis_pool') as \ - create_redis_pool: - fake_pool = FakePool() - create_redis_pool.side_effect = fake_create_redis_pool(fake_pool) - instance = Instance(('localhost', 6379)) + def fake_instance(self, fake_client): + with patch( + "aioredlock.redis.Instance._create_redis", + AsyncMock(return_value=fake_client), + ) as create_redis: + create_redis.side_effect = fake_create_redis(fake_client) + instance = Instance(("localhost", 6379)) yield instance @pytest.mark.asyncio - async def test_lock(self, fake_instance): + async def test_lock(self, fake_instance: Instance): instance = fake_instance await instance.connect() - pool = instance._pool + redis_client = instance._client - await instance.set_lock('resource', 'lock_id', 10.0) + await instance.set_lock("resource", "lock_id", 10.0) - pool.evalsha.assert_called_once_with( - instance.set_lock_script_sha1, - keys=['resource'], - args=['lock_id', 10000] + redis_client.evalsha.assert_called_once_with( + instance.set_lock_script_sha1, 1, "resource", "lock_id", 10000 ) @pytest.mark.asyncio - async def test_get_lock_ttl(self, fake_instance): + async def test_get_lock_ttl(self, fake_instance: Instance): instance = fake_instance await instance.connect() - pool = instance._pool + redis_client = instance._client - await instance.get_lock_ttl('resource', 'lock_id') - pool.evalsha.assert_called_with( - instance.get_lock_ttl_script_sha1, - keys=['resource'], - args=['lock_id'] + await instance.get_lock_ttl("resource", "lock_id") + redis_client.evalsha.assert_called_with( + instance.get_lock_ttl_script_sha1, 1, "resource", "lock_id" ) @pytest.mark.asyncio - async def test_lock_sleep(self, fake_instance, event_loop): + async def test_lock_sleep(self, fake_instance: Instance): + loop = asyncio.get_running_loop() instance = fake_instance async def hold_lock(instance): async with instance._lock: - await asyncio.sleep(.1) - instance._pool = FakePool() + await asyncio.sleep(0.1) + instance._client = FakeClient() - event_loop.create_task(hold_lock(instance)) - await asyncio.sleep(.1) + await loop.create_task(hold_lock(instance)) + await asyncio.sleep(0.1) await instance.connect() - pool = instance._pool + redis_client = instance._client - await instance.set_lock('resource', 'lock_id', 10.0) + await instance.set_lock("resource", "lock_id", 10.0) - pool.evalsha.assert_called_once_with( - instance.set_lock_script_sha1, - keys=['resource'], - args=['lock_id', 10000] + redis_client.evalsha.assert_called_once_with( + instance.set_lock_script_sha1, 1, "resource", "lock_id", 10000 ) - instance._pool = None - await instance.close() - assert pool.close.called is False + instance._client = None + await instance.aclose() + assert redis_client.aclose.called is False @pytest.mark.asyncio @pytest.mark.parametrize( - 'func,args,expected_keys,expected_args', + "func,args,expected_keys,expected_args", ( - ('set_lock', ('resource', 'lock_id', 10.0), ['resource'], ['lock_id', 10000]), - ('unset_lock', ('resource', 'lock_id'), ['resource'], ['lock_id']), - ('get_lock_ttl', ('resource', 'lock_id'), ['resource'], ['lock_id']), - ) + ( + "set_lock", + ("resource", "lock_id", 10.0), + ["resource"], + ["lock_id", 10000], + ), + ("unset_lock", ("resource", "lock_id"), ["resource"], ["lock_id"]), + ("get_lock_ttl", ("resource", "lock_id"), ["resource"], ["lock_id"]), + ), ) - async def test_lock_without_scripts(self, fake_coro, fake_instance, func, args, expected_keys, expected_args): + async def test_lock_without_scripts( + self, fake_instance: Instance, func, args, expected_keys, expected_args + ): instance = fake_instance await instance.connect() - pool = instance._pool - pool.evalsha.side_effect = [aioredis.errors.ReplyError('NOSCRIPT'), fake_coro(True)] + redis_client = instance._client + redis_client.evalsha.side_effect = [ + ResponseError("NOSCRIPT"), + AsyncMock(return_value=True), + ] await getattr(instance, func)(*args) - assert pool.evalsha.call_count == 2 - assert pool.script_load.call_count == 6 # for 3 scripts. + assert redis_client.evalsha.call_count == 2 + assert redis_client.script_load.call_count == 6 # for 3 scripts. - pool.evalsha.assert_called_with( - getattr(instance, '{0}_script_sha1'.format(func)), - keys=expected_keys, - args=expected_args, + redis_client.evalsha.assert_called_with( + getattr(instance, "{0}_script_sha1".format(func)), + 1, + *expected_keys, + *expected_args, ) @pytest.mark.asyncio - async def test_unset_lock(self, fake_instance): + async def test_unset_lock(self, fake_instance: Instance): instance = fake_instance await instance.connect() - pool = instance._pool + redis_client = instance._client - await instance.unset_lock('resource', 'lock_id') + await instance.unset_lock("resource", "lock_id") - pool.evalsha.assert_called_once_with( - instance.unset_lock_script_sha1, - keys=['resource'], - args=['lock_id'] + redis_client.evalsha.assert_called_once_with( + instance.unset_lock_script_sha1, 1, "resource", "lock_id" ) @pytest.mark.asyncio - @pytest.mark.parametrize("get_return_value,locked", [ - (b'lock_identifier', True), - (None, False), - ]) - async def test_is_locked(self, fake_instance, get_return_value, locked): + @pytest.mark.parametrize( + "get_return_value,locked", + [ + (b"lock_identifier", True), + (None, False), + ], + ) + async def test_is_locked(self, fake_instance: Instance, get_return_value, locked): instance = fake_instance await instance.connect() - pool = instance._pool + redis_client = instance._client - pool.get.return_value = asyncio.Future() - pool.get.return_value.set_result(get_return_value) + redis_client.get = AsyncMock(return_value=get_return_value) - res = await instance.is_locked('resource') + res = await instance.is_locked("resource") assert res == locked - pool.get.assert_called_once_with('resource') + redis_client.get.assert_called_once_with("resource") @pytest.fixture def redis_two_connections(): - return [ - {'host': 'localhost', 'port': 6379}, - {'host': '127.0.0.1', 'port': 6378} - ] + return [{"host": "localhost", "port": 6379}, {"host": "127.0.0.1", "port": 6378}] @pytest.fixture def redis_three_connections(): return [ - {'host': 'localhost', 'port': 6379}, - {'host': '127.0.0.1', 'port': 6378}, - {'host': '8.8.8.8', 'port': 6377} + {"host": "localhost", "port": 6379}, + {"host": "127.0.0.1", "port": 6378}, + {"host": "8.8.8.8", "port": 6377}, ] @pytest.fixture -def mock_redis_two_instances(redis_two_connections): - pool = FakePool() +def mock_redis_two_instances(redis_two_connections, fake_client): redis = Redis(redis_two_connections) for instance in redis.instances: - instance._pool = pool + instance._client = fake_client - yield redis, pool + return redis @pytest.fixture -def mock_redis_three_instances(redis_three_connections): - pool = FakePool() +def mock_redis_three_instances(redis_three_connections, fake_client): redis = Redis(redis_three_connections) for instance in redis.instances: - instance._pool = pool + instance._client = fake_client - yield redis, pool + return redis class TestRedis: - def test_initialization(self, redis_two_connections): with patch("aioredlock.redis.Instance.__init__") as mock_instance: mock_instance.return_value = None @@ -372,172 +382,175 @@ def test_initialization(self, redis_two_connections): redis = Redis(redis_two_connections) calls = [ - call({'host': 'localhost', 'port': 6379}), - call({'host': '127.0.0.1', 'port': 6378}) + call({"host": "localhost", "port": 6379}), + call({"host": "127.0.0.1", "port": 6378}), ] mock_instance.assert_has_calls(calls) assert len(redis.instances) == 2 - parametrize_methods = pytest.mark.parametrize("method_name, call_args", [ - ('set_lock', {'keys': ['resource'], 'args':['lock_id', 10000]}), - ('unset_lock', {'keys': ['resource'], 'args':['lock_id']}), - ('get_lock_ttl', {'keys': ['resource'], 'args':['lock_id']}), - ]) + parametrize_methods = pytest.mark.parametrize( + "method_name, call_args", + [ + ("set_lock", (1, "resource", "lock_id", 10000)), + ("unset_lock", (1, "resource", "lock_id")), + ("get_lock_ttl", (1, "resource", "lock_id")), + ], + ) @pytest.mark.asyncio @parametrize_methods async def test_lock( - self, mock_redis_two_instances, - method_name, call_args + self, mock_redis_two_instances, fake_client, method_name, call_args ): - redis, pool = mock_redis_two_instances + redis = mock_redis_two_instances method = getattr(redis, method_name) - await method('resource', 'lock_id') + await method("resource", "lock_id") - script_sha1 = getattr(redis.instances[0], '%s_script_sha1' % method_name) + script_sha1 = getattr(redis.instances[0], "%s_script_sha1" % method_name) - calls = [call(script_sha1, **call_args)] * 2 - pool.evalsha.assert_has_calls(calls) + calls = [call(script_sha1, *call_args)] * 2 + fake_client.evalsha.assert_has_calls(calls, any_order=True) @pytest.mark.asyncio - @pytest.mark.parametrize("get_return_value,locked", [ - (b'lock_identifier', True), - (None, False), - ]) - async def test_is_locked(self, mock_redis_two_instances, get_return_value, locked): - redis, pool = mock_redis_two_instances + @pytest.mark.parametrize( + "get_return_value,locked", + [ + (b"lock_identifier", True), + (None, False), + ], + ) + async def test_is_locked( + self, mock_redis_two_instances, fake_client, get_return_value, locked + ): + redis = mock_redis_two_instances - pool.get.return_value = asyncio.Future() - pool.get.return_value.set_result(get_return_value) + fake_client.get = AsyncMock(return_value=get_return_value) - res = await redis.is_locked('resource') + res = await redis.is_locked("resource") - calls = [call('resource')] * 2 - pool.get.assert_has_calls(calls) + calls = [call("resource")] * 2 + fake_client.get.assert_has_calls(calls) assert res == locked @pytest.mark.asyncio @parametrize_methods async def test_lock_one_of_two_instances_failed( - self, fake_coro, mock_redis_two_instances, - method_name, call_args + self, mock_redis_two_instances, fake_client, method_name, call_args ): - redis, pool = mock_redis_two_instances - pool.evalsha = MagicMock(side_effect=[EVAL_ERROR, EVAL_OK]) + redis = mock_redis_two_instances + fake_client.evalsha = AsyncMock(side_effect=[EVAL_ERROR, EVAL_OK]) method = getattr(redis, method_name) with pytest.raises(LockError): - await method('resource', 'lock_id') + await method("resource", "lock_id") - script_sha1 = getattr(redis.instances[0], '%s_script_sha1' % method_name) + script_sha1 = getattr(redis.instances[0], "%s_script_sha1" % method_name) - calls = [call(script_sha1, **call_args)] * 2 - pool.evalsha.assert_has_calls(calls) + calls = [call(script_sha1, *call_args)] * 2 + fake_client.evalsha.assert_has_calls(calls) @pytest.mark.asyncio - @pytest.mark.parametrize("redis_result, success", [ - ([EVAL_OK, EVAL_OK, EVAL_OK], True), - ([EVAL_OK, EVAL_OK, EVAL_ERROR], True), - ([EVAL_OK, EVAL_ERROR, CONNECT_ERROR], False), - ([EVAL_ERROR, EVAL_ERROR, CONNECT_ERROR], False), - ([EVAL_ERROR, CONNECT_ERROR, RANDOM_ERROR], False), - ([CANCELLED, CANCELLED, CANCELLED], False), - ]) + @pytest.mark.parametrize( + "redis_result, success", + [ + ([EVAL_OK, EVAL_OK, EVAL_OK], True), + ([EVAL_OK, EVAL_OK, EVAL_ERROR], True), + ([EVAL_OK, EVAL_ERROR, CONNECT_ERROR], False), + ([EVAL_ERROR, EVAL_ERROR, CONNECT_ERROR], False), + ([EVAL_ERROR, CONNECT_ERROR, RANDOM_ERROR], False), + ([CANCELLED, CANCELLED, CANCELLED], False), + ], + ) @parametrize_methods async def test_three_instances_combination( - self, - fake_coro, - mock_redis_three_instances, - redis_result, - success, - method_name, call_args, + self, + fake_client, + mock_redis_three_instances, + redis_result, + success, + method_name, + call_args, ): - redis, pool = mock_redis_three_instances - redis_result = [fake_coro(result) if isinstance(result, bytes) else result for result in redis_result] - pool.evalsha = MagicMock(side_effect=redis_result) + redis = mock_redis_three_instances + fake_client.evalsha = AsyncMock(side_effect=redis_result) method = getattr(redis, method_name) if success: - await method('resource', 'lock_id') + await method("resource", "lock_id") else: with pytest.raises(LockError) as exc_info: - await method('resource', 'lock_id') - assert hasattr(exc_info.value, '__cause__') + await method("resource", "lock_id") + assert hasattr(exc_info.value, "__cause__") assert isinstance(exc_info.value.__cause__, BaseException) - script_sha1 = getattr(redis.instances[0], - '%s_script_sha1' % method_name) + script_sha1 = getattr(redis.instances[0], "%s_script_sha1" % method_name) - calls = [call(script_sha1, **call_args)] * 3 - pool.evalsha.assert_has_calls(calls) + calls = [call(script_sha1, *call_args)] * 3 + fake_client.evalsha.assert_has_calls(calls) @pytest.mark.asyncio - @pytest.mark.parametrize("redis_result, error", [ - ([EVAL_OK, EVAL_ERROR, CONNECT_ERROR], LockRuntimeError), - ([EVAL_ERROR, EVAL_ERROR, CONNECT_ERROR], LockRuntimeError), - ([EVAL_ERROR, CONNECT_ERROR, RANDOM_ERROR], LockRuntimeError), - ([EVAL_ERROR, EVAL_ERROR, EVAL_OK], LockAcquiringError), - ([CANCELLED, CANCELLED, CANCELLED], LockError), - ([RANDOM_ERROR, CANCELLED, CANCELLED], LockError), - ]) + @pytest.mark.parametrize( + "redis_result, error", + [ + ([EVAL_OK, EVAL_ERROR, CONNECT_ERROR], LockRuntimeError), + ([EVAL_ERROR, EVAL_ERROR, CONNECT_ERROR], LockRuntimeError), + ([EVAL_ERROR, CONNECT_ERROR, RANDOM_ERROR], LockRuntimeError), + ([EVAL_ERROR, EVAL_ERROR, EVAL_OK], LockAcquiringError), + ([CANCELLED, CANCELLED, CANCELLED], LockError), + ([RANDOM_ERROR, CANCELLED, CANCELLED], LockError), + ], + ) @parametrize_methods async def test_three_instances_combination_errors( - self, - fake_coro, - mock_redis_three_instances, - redis_result, - error, - method_name, call_args, + self, + fake_client, + mock_redis_three_instances, + redis_result, + error, + method_name, + call_args, ): - redis, pool = mock_redis_three_instances - redis_result = [fake_coro(result) if isinstance(result, bytes) else result for result in redis_result] - pool.evalsha = MagicMock(side_effect=redis_result) + redis = mock_redis_three_instances + fake_client.evalsha = AsyncMock(side_effect=redis_result) method = getattr(redis, method_name) with pytest.raises(error) as exc_info: - await method('resource', 'lock_id') + await method("resource", "lock_id") - assert hasattr(exc_info.value, '__cause__') + assert hasattr(exc_info.value, "__cause__") assert isinstance(exc_info.value.__cause__, BaseException) - script_sha1 = getattr(redis.instances[0], - '%s_script_sha1' % method_name) + script_sha1 = getattr(redis.instances[0], "%s_script_sha1" % method_name) - calls = [call(script_sha1, **call_args)] * 3 - pool.evalsha.assert_has_calls(calls) + calls = [call(script_sha1, *call_args)] * 3 + fake_client.evalsha.assert_has_calls(calls) @pytest.mark.asyncio - async def test_clear_connections(self, mock_redis_two_instances): - redis, pool = mock_redis_two_instances - pool.close = MagicMock() - pool.wait_closed = MagicMock(return_value=asyncio.Future()) - pool.wait_closed.return_value.set_result(True) + async def test_clear_connections(self, mock_redis_two_instances, fake_client): + redis = mock_redis_two_instances + fake_client.aclose = AsyncMock(return_value=True) await redis.clear_connections() - pool.close.assert_has_calls([call(), call()]) - pool.wait_closed.assert_has_calls([call(), call()]) - - pool.close = MagicMock() - pool.wait_closed = MagicMock(return_value=asyncio.Future()) + fake_client.aclose.assert_has_calls([call(), call()]) + fake_client.aclose.reset_mock() await redis.clear_connections() - assert pool.close.called is False + assert fake_client.aclose.called is False @pytest.mark.asyncio - async def test_get_lock(self, mock_redis_two_instances, ): - redis, pool = mock_redis_two_instances + async def test_get_lock(self, mock_redis_two_instances, fake_client): + redis = mock_redis_two_instances - await redis.get_lock_ttl('resource', 'lock_id') + await redis.get_lock_ttl("resource", "lock_id") - script_sha1 = getattr(redis.instances[0], 'get_lock_ttl_script_sha1') + script_sha1 = getattr(redis.instances[0], "get_lock_ttl_script_sha1") - calls = [call(script_sha1, keys=['resource'], args=['lock_id'])] - pool.evalsha.assert_has_calls(calls) - # assert 0 + calls = [call(script_sha1, 1, "resource", "lock_id")] + fake_client.evalsha.assert_has_calls(calls) diff --git a/tests/ut/test_sentinel.py b/tests/ut/test_sentinel.py index 9fdaae1..dd9b92f 100644 --- a/tests/ut/test_sentinel.py +++ b/tests/ut/test_sentinel.py @@ -1,10 +1,6 @@ -import asyncio -import contextlib import ssl -import sys from unittest import mock -import aioredlock.sentinel from aioredlock.sentinel import Sentinel from aioredlock.sentinel import SentinelConfigError @@ -13,154 +9,146 @@ pytestmark = [pytest.mark.asyncio] -@contextlib.contextmanager -def mock_aioredis_sentinel(): - if sys.version_info < (3, 8, 0): - mock_obj = mock.MagicMock() - mock_obj.master_for.return_value = asyncio.Future() - mock_obj.master_for.return_value.set_result(True) - else: - mock_obj = mock.AsyncMock() - mock_obj.master_for.return_value = True - with mock.patch.object(aioredlock.sentinel.aioredis.sentinel, 'create_sentinel') as mock_sentinel: - if sys.version_info < (3, 8, 0): - mock_sentinel.return_value = asyncio.Future() - mock_sentinel.return_value.set_result(mock_obj) - else: - mock_sentinel.return_value = mock_obj - yield mock_sentinel +@pytest.fixture +def mocked_redis_sentinel(mocker): + mock_sentinel = mocker.patch( + "aioredlock.sentinel.RedisSentinel", + mock.Mock( + return_value=mock.AsyncMock(master_for=mock.AsyncMock(return_value=True)) + ), + ) + return mock_sentinel @pytest.mark.parametrize( - 'connection,kwargs,expected_kwargs,expected_master,with_ssl', ( + "connection,kwargs,expected_kwargs,expected_master,with_ssl", + ( ( - {'host': '127.0.0.1', 'port': 26379, 'master': 'leader'}, + {"host": "127.0.0.1", "port": 26379, "master": "leader"}, {}, - {'sentinels': [('127.0.0.1', 26379)], 'minsize': 1, 'maxsize': 100}, - 'leader', + {"sentinels": [("127.0.0.1", 26379)], "max_connections": 100}, + "leader", {}, ), ( - 'redis://:password@localhost:12345/0?master=whatever&encoding=utf-8&minsize=2&maxsize=5', + "redis://:password@localhost:12345/0?master=whatever&encoding=utf-8", {}, { - 'sentinels': [('localhost', 12345)], - 'db': 0, - 'encoding': 'utf-8', - 'password': 'password', - 'minsize': 2, - 'maxsize': 5, + "sentinels": [("localhost", 12345)], + "db": 0, + "encoding": "utf-8", + "password": "password", + "max_connections": 100, }, - 'whatever', + "whatever", {}, ), ( - 'redis://:password@localhost:12345/0?master=whatever&encoding=utf-8', - {'master': 'everything', 'password': 'newpass', 'db': 3}, + "redis://:password@localhost:12345/0?master=whatever&encoding=utf-8", + {"master": "everything", "password": "newpass", "db": 3}, { - 'sentinels': [('localhost', 12345)], - 'db': 3, - 'encoding': 'utf-8', - 'password': 'newpass', - 'minsize': 1, - 'maxsize': 100, + "sentinels": [("localhost", 12345)], + "db": 3, + "encoding": "utf-8", + "password": "newpass", + "max_connections": 100, }, - 'everything', + "everything", {}, ), ( - 'rediss://:password@localhost:12345/2?master=whatever&encoding=utf-8', + "rediss://:password@localhost:12345/2?master=whatever&encoding=utf-8", {}, { - 'sentinels': [('localhost', 12345)], - 'db': 2, - 'encoding': 'utf-8', - 'password': 'password', - 'minsize': 1, - 'maxsize': 100, + "sentinels": [("localhost", 12345)], + "db": 2, + "encoding": "utf-8", + "password": "password", + "max_connections": 100, }, - 'whatever', - {'verify_mode': ssl.CERT_REQUIRED, 'check_hostname': True}, + "whatever", + {"verify_mode": ssl.CERT_REQUIRED, "check_hostname": True}, ), ( - 'rediss://:password@localhost:12345/2?master=whatever&encoding=utf-8&ssl_cert_reqs=CERT_NONE', + "rediss://:password@localhost:12345/2?master=whatever&encoding=utf-8&ssl_cert_reqs=CERT_NONE", {}, { - 'sentinels': [('localhost', 12345)], - 'db': 2, - 'encoding': 'utf-8', - 'password': 'password', - 'minsize': 1, - 'maxsize': 100, + "sentinels": [("localhost", 12345)], + "db": 2, + "encoding": "utf-8", + "password": "password", + "max_connections": 100, }, - 'whatever', - {'verify_mode': ssl.CERT_NONE, 'check_hostname': False}, + "whatever", + {"verify_mode": ssl.CERT_NONE, "check_hostname": False}, ), ( - 'rediss://localhost:12345/2?master=whatever&encoding=utf-8&ssl_cert_reqs=CERT_OPTIONAL', + "rediss://localhost:12345/2?master=whatever&encoding=utf-8&ssl_cert_reqs=CERT_OPTIONAL", {}, { - 'sentinels': [('localhost', 12345)], - 'db': 2, - 'encoding': 'utf-8', - 'password': None, - 'minsize': 1, - 'maxsize': 100, + "sentinels": [("localhost", 12345)], + "db": 2, + "encoding": "utf-8", + "password": None, + "max_connections": 100, }, - 'whatever', - {'verify_mode': ssl.CERT_OPTIONAL, 'check_hostname': True}, + "whatever", + {"verify_mode": ssl.CERT_OPTIONAL, "check_hostname": True}, ), ( - ('127.0.0.1', 1234), - {'master': 'blah', 'ssl_context': True}, + ("127.0.0.1", 1234), + {"master": "blah", "ssl_context": True}, { - 'sentinels': [('127.0.0.1', 1234)], - 'minsize': 1, - 'maxsize': 100, + "sentinels": [("127.0.0.1", 1234)], + "max_connections": 100, }, - 'blah', + "blah", {}, ), ( - [('127.0.0.1', 1234), ('blah', 4829)], - {'master': 'blah', 'ssl_context': False}, + [("127.0.0.1", 1234), ("blah", 4829)], + {"master": "blah", "ssl_context": False}, { - 'sentinels': [('127.0.0.1', 1234), ('blah', 4829)], - 'minsize': 1, - 'maxsize': 100, - 'ssl': False, + "sentinels": [("127.0.0.1", 1234), ("blah", 4829)], + "max_connections": 100, + "ssl": False, }, - 'blah', + "blah", {}, ), - ) + ), ) -async def test_sentinel(ssl_context, connection, kwargs, expected_kwargs, expected_master, with_ssl): - with mock_aioredis_sentinel() as mock_sentinel: - sentinel = Sentinel(connection, **kwargs) - assert await sentinel.get_master() - assert mock_sentinel.called - if with_ssl or kwargs.get('ssl_context') is True: - expected_kwargs['ssl'] = ssl_context - mock_sentinel.assert_called_with(**expected_kwargs) - if sys.version_info < (3, 8, 0): - result = mock_sentinel.return_value.result() - else: - result = mock_sentinel.return_value +async def test_sentinel( + ssl_context, + connection, + kwargs, + expected_kwargs, + expected_master, + with_ssl, + mocked_redis_sentinel, +): + sentinel = Sentinel(connection, **kwargs) + result = await sentinel.get_master() + assert result is True + assert mocked_redis_sentinel.called + if with_ssl or kwargs.get("ssl_context") is True: + expected_kwargs["ssl"] = ssl_context + mocked_redis_sentinel.assert_called_with(**expected_kwargs) + result = mocked_redis_sentinel.return_value assert result.master_for.called result.master_for.assert_called_with(expected_master) if with_ssl: - assert ssl_context.check_hostname is with_ssl['check_hostname'] - assert ssl_context.verify_mode is with_ssl['verify_mode'] + assert ssl_context.check_hostname is with_ssl["check_hostname"] + assert ssl_context.verify_mode is with_ssl["verify_mode"] @pytest.mark.parametrize( - 'connection', + "connection", ( - 'redis://localhost:1234/0', - 'redis://localhost:1234/blah', + "redis://localhost:1234/0", + "redis://localhost:1234/blah", object(), - ) + ), ) async def test_sentinel_config_errors(connection): with pytest.raises(SentinelConfigError): diff --git a/tests/ut/test_utility.py b/tests/ut/test_utility.py index 3dae2cd..01924ac 100644 --- a/tests/ut/test_utility.py +++ b/tests/ut/test_utility.py @@ -15,14 +15,19 @@ def test_cleans_details_with_password(): details = {"foo": "bar", "password": "topsecret"} cleaned = clean_password(details) - assert json.loads(cleaned.replace("'", "\"")) == {'foo': 'bar', 'password': '*******'} + assert json.loads(cleaned.replace("'", '"')) == { + "foo": "bar", + "password": "*******", + } def test_cleans_details_with_password_in_list(): details = [{"foo": "bar", "password": "topsecret"}] cleaned = clean_password(details) - assert json.loads(cleaned.replace("'", "\"")) == [{'foo': 'bar', 'password': '*******'}] + assert json.loads(cleaned.replace("'", '"')) == [ + {"foo": "bar", "password": "*******"} + ] def test_ignores_non_dsn_string():