Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Thread-safe implementation of the sentinel connection pool #4

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,12 @@ Accessing redis-py's Sentinel instance
Change log
----------

v2.1.0
~~~~~~

* Thread-safe implementation of the sentinel connection pool, so only one pool per process is now used.
* Added `disconnect()` method for resetting the connection pool

v2.0.1
~~~~~~

Expand Down
250 changes: 221 additions & 29 deletions flask_redis_sentinel.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,181 @@

import six
import inspect
import random
import threading
import logging
import weakref
import redis
import redis.sentinel
import redis_sentinel_url
from redis._compat import nativestr
from flask import current_app
from werkzeug.local import Local, LocalProxy
from redis.exceptions import ConnectionError, ReadOnlyError
from werkzeug.local import LocalProxy
from werkzeug.utils import import_string

logger = logging.getLogger(__name__)


_EXTENSION_KEY = 'redissentinel'


class SentinelManagedConnection(redis.Connection):
def __init__(self, **kwargs):
self.connection_pool = kwargs.pop('connection_pool')
self.to_be_disconnected = False
super(SentinelManagedConnection, self).__init__(**kwargs)

def __repr__(self):
pool = self.connection_pool
s = '%s<service=%s%%s>' % (type(self).__name__, pool.service_name)
if self.host:
host_info = ',host=%s,port=%s' % (self.host, self.port)
s = s % host_info
return s

def connect_to(self, address):
self.host, self.port = address
super(SentinelManagedConnection, self).connect()
if self.connection_pool.check_connection:
self.send_command('PING')
if nativestr(self.read_response()) != 'PONG':
raise ConnectionError('PING failed')

def connect(self):
if self._sock:
return # already connected
if self.connection_pool.is_master:
self.connect_to(self.connection_pool.get_master_address())
else:
for slave in self.connection_pool.rotate_slaves():
try:
return self.connect_to(slave)
except ConnectionError:
continue
raise SlaveNotFoundError # Never be here

def read_response(self):
try:
return super(SentinelManagedConnection, self).read_response()
except ReadOnlyError:
if self.connection_pool.is_master:
# When talking to a master, a ReadOnlyError when likely
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think 'when' after ReadOnlyError should be removed

# indicates that the previous master that we're still connected
# to has been demoted to a slave and there's a new master.
self.to_be_disconnected = True
raise ConnectionError('The previous master is now a slave')
raise


class SentinelConnectionPool(redis.ConnectionPool):
"""
Sentinel backed connection pool.

If ``check_connection`` flag is set to True, SentinelManagedConnection
sends a PING command right after establishing the connection.
"""

def __init__(self, service_name, sentinel_manager, **kwargs):
kwargs['connection_class'] = kwargs.get(
'connection_class', SentinelManagedConnection)
self.is_master = kwargs.pop('is_master', True)
self.check_connection = kwargs.pop('check_connection', False)
super(SentinelConnectionPool, self).__init__(**kwargs)
self.connection_kwargs['connection_pool'] = weakref.proxy(self)
self.service_name = service_name
self.sentinel_manager = sentinel_manager

def __repr__(self):
return "%s<service=%s(%s)" % (
type(self).__name__,
self.service_name,
self.is_master and 'master' or 'slave',
)

def reset(self):
super(SentinelConnectionPool, self).reset()
self.master_address = None
self.slave_rr_counter = None

def get_master_address(self):
"""Get the address of the current master"""
master_address = self.sentinel_manager.discover_master(
self.service_name)
if self.is_master:
if master_address != self.master_address:
self.master_address = master_address
return master_address

def rotate_slaves(self):
"Round-robin slave balancer"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

triple quotes should be used for doc string (on multiple methods)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was copying the surrounding and didn't notice this. 🤦‍♂️

slaves = self.sentinel_manager.discover_slaves(self.service_name)
if slaves:
if self.slave_rr_counter is None:
self.slave_rr_counter = random.randint(0, len(slaves) - 1)
for _ in xrange(len(slaves)):
self.slave_rr_counter = (
self.slave_rr_counter + 1) % len(slaves)
slave = slaves[self.slave_rr_counter]
yield slave
# Fallback to the master connection
try:
yield self.get_master_address()
except MasterNotFoundError:
pass
raise SlaveNotFoundError('No slave found for %r' % (self.service_name))
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I cant find MasterNotFoundError and SlaveNotFoundError definied anywhere

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also redundant parentheses


def _check_connection(self, connection):
if connection.to_be_disconnected:
connection.disconnect()
self.get_master_address()
return False
if self.is_master:
if self.master_address != (connection.host, connection.port):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure here but, is it intentional to use self.master_address instead of self.get_master_address()? Maybe self.master_address should also be a private?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's intentional. self.get_master_address() connects to redis sentinel and resolves the current master, self.master_address is a cache from the last call. The resolution is not something that should happen every single time you take/return a connection from the pool.

connection.disconnect()
return False
return True

def get_connection(self, command_name, *keys, **options):
"Get a connection from the pool"
self._checkpid()
while True:
try:
connection = self._available_connections.pop()
except IndexError:
connection = self.make_connection()
else:
if not self._check_connection(connection):
continue
self._in_use_connections.add(connection)
return connection

def release(self, connection):
"Releases the connection back to the pool"
self._checkpid()
if connection.pid != self.pid:
return
self._in_use_connections.remove(connection)
if not self._check_connection(connection):
return
self._available_connections.append(connection)


class Sentinel(redis.sentinel.Sentinel):

def master_for(self, service_name, redis_class=redis.StrictRedis,
connection_pool_class=SentinelConnectionPool, **kwargs):
return super(Sentinel, self).master_for(
service_name, redis_class=redis_class,
connection_pool_class=connection_pool_class, **kwargs)

def slave_for(self, service_name, redis_class=redis.StrictRedis,
connection_pool_class=SentinelConnectionPool, **kwargs):
return super(Sentinel, self).slave_for(
service_name, redis_class=redis_class,
connection_pool_class=connection_pool_class, **kwargs)


class RedisSentinelInstance(object):

def __init__(self, url, client_class, client_options, sentinel_class, sentinel_options):
Expand All @@ -33,24 +197,34 @@ def __init__(self, url, client_class, client_options, sentinel_class, sentinel_o
self.client_options = client_options
self.sentinel_class = sentinel_class
self.sentinel_options = sentinel_options
self.local = Local()
self.connection = None
self.master_connections = {}
self.slave_connections = {}
self._connect_lock = threading.Lock()
self._connect()
if self.local.connection[0] is None:
# if there is no sentinel, we don't need to use thread-local storage
self.connection = self.local.connection
self.local = self

def _connect(self):
try:
return self.local.connection
except AttributeError:
with self._connect_lock:
if self.connection is not None:
return self.connection

conn = redis_sentinel_url.connect(
self.url,
sentinel_class=self.sentinel_class, sentinel_options=self.sentinel_options,
client_class=self.client_class, client_options=self.client_options)
self.local.connection = conn
self.connection = conn
return conn

def _iter_connections(self):
if self.connection is not None:
for conn in self.connection:
if conn is not None:
yield conn
for conn in six.itervalues(self.master_connections):
yield conn
for conn in six.itervalues(self.slave_connections):
yield conn

@property
def sentinel(self):
return self._connect()[0]
Expand All @@ -60,38 +234,53 @@ def default_connection(self):
return self._connect()[1]

def master_for(self, service_name, **kwargs):
try:
return self.local.master_connections[service_name]
except AttributeError:
self.local.master_connections = {}
except KeyError:
pass
with self._connect_lock:
try:
return self.master_connections[service_name]
except KeyError:
pass

sentinel = self.sentinel
if sentinel is None:
msg = 'Cannot get master {} using non-sentinel configuration'
raise RuntimeError(msg.format(service_name))

conn = sentinel.master_for(service_name, redis_class=self.client_class, **kwargs)
self.local.master_connections[service_name] = conn
return conn
with self._connect_lock:
try:
return self.master_connections[service_name]
except KeyError:
pass

conn = sentinel.master_for(service_name, redis_class=self.client_class, **kwargs)
self.master_connections[service_name] = conn
return conn

def slave_for(self, service_name, **kwargs):
try:
return self.local.slave_connections[service_name]
except AttributeError:
self.local.slave_connections = {}
except KeyError:
pass
with self._connect_lock:
try:
return self.slave_connections[service_name]
except KeyError:
pass

sentinel = self.sentinel
if sentinel is None:
msg = 'Cannot get slave {} using non-sentinel configuration'
raise RuntimeError(msg.format(service_name))

conn = sentinel.slave_for(service_name, redis_class=self.client_class, **kwargs)
self.local.slave_connections[service_name] = conn
return conn
with self._connect_lock:
try:
return self.slave_connections[service_name]
except KeyError:
pass

conn = sentinel.slave_for(service_name, redis_class=self.client_class, **kwargs)
self.slave_connections[service_name] = conn
return conn

def disconnect(self):
with self._connect_lock:
for conn in self._iter_connections():
conn.connection_pool.disconnect()


class RedisSentinel(object):
Expand Down Expand Up @@ -127,7 +316,7 @@ def init_app(self, app, config_prefix=None, client_class=None, sentinel_class=No
client_class = self._resolve_class(
config, 'CLASS', 'client_class', client_class, redis.StrictRedis)
sentinel_class = self._resolve_class(
config, 'SENTINEL_CLASS', 'sentinel_class', sentinel_class, redis.sentinel.Sentinel)
config, 'SENTINEL_CLASS', 'sentinel_class', sentinel_class, Sentinel)

url = config.pop('URL')
client_options = self._config_from_variables(config, client_class)
Expand Down Expand Up @@ -176,5 +365,8 @@ def master_for(self, service_name, **kwargs):
def slave_for(self, service_name, **kwargs):
return LocalProxy(lambda: self.get_instance().slave_for(service_name, **kwargs))

def disconnect(self):
return self.get_instance().disconnect()


SentinelExtension = RedisSentinel # for backwards-compatibility
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def read(fname):
setup(
name='Flask-Redis-Sentinel',
py_modules=['flask_redis_sentinel'],
version='2.0.1',
version='2.1.0',
install_requires=['Flask>=0.10.1', 'redis>=2.10.3', 'redis_sentinel_url>=1.0.0,<2.0.0', 'six'],
description='Redis-Sentinel integration for Flask',
long_description=read('README.rst'),
Expand Down
4 changes: 2 additions & 2 deletions test_flask_redis_sentinel.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,8 +302,8 @@ def test_sentinel_threads(self):
sentinel.init_app(self.app)

connections = self._check_threads(sentinel)
self.assertIsNot(connections['from_another_thread'], connections['from_main_thread'])
self.assertIsNot(connections['from_another_thread'], connections['from_main_thread_later'])
self.assertIs(connections['from_another_thread'], connections['from_main_thread'])
self.assertIs(connections['from_another_thread'], connections['from_main_thread_later'])
self.assertIs(connections['from_main_thread'], connections['from_main_thread_later'])

def test_redis_threads(self):
Expand Down