@@ -23,6 +23,54 @@ class SlaveNotFoundError(ConnectionError):
23
23
pass
24
24
25
25
26
+ class AsyncSentinelConnectionPoolProxy :
27
+ def __init__ (
28
+ self ,
29
+ connection_pool ,
30
+ is_master ,
31
+ check_connection ,
32
+ service_name ,
33
+ sentinel_manager ,
34
+ ):
35
+ self .connection_pool_ref = weakref .ref (connection_pool )
36
+ self .is_master = is_master
37
+ self .check_connection = check_connection
38
+ self .service_name = service_name
39
+ self .sentinel_manager = sentinel_manager
40
+ self .reset ()
41
+
42
+ def reset (self ):
43
+ self .master_address = None
44
+ self .slave_rr_counter = None
45
+
46
+ async def get_master_address (self ):
47
+ master_address = await self .sentinel_manager .discover_master (self .service_name )
48
+ if self .is_master and self .master_address != master_address :
49
+ self .master_address = master_address
50
+ # disconnect any idle connections so that they reconnect
51
+ # to the new master the next time that they are used.
52
+ connection_pool = self .connection_pool_ref ()
53
+ if connection_pool is not None :
54
+ await connection_pool .disconnect (inuse_connections = False )
55
+ return master_address
56
+
57
+ async def rotate_slaves (self ) -> AsyncIterator :
58
+ slaves = await self .sentinel_manager .discover_slaves (self .service_name )
59
+ if slaves :
60
+ if self .slave_rr_counter is None :
61
+ self .slave_rr_counter = random .randint (0 , len (slaves ) - 1 )
62
+ for _ in range (len (slaves )):
63
+ self .slave_rr_counter = (self .slave_rr_counter + 1 ) % len (slaves )
64
+ slave = slaves [self .slave_rr_counter ]
65
+ yield slave
66
+ # Fallback to the master connection
67
+ try :
68
+ yield await self .get_master_address ()
69
+ except MasterNotFoundError :
70
+ pass
71
+ raise SlaveNotFoundError (f"No slave found for { self .service_name !r} " )
72
+
73
+
26
74
class SentinelManagedConnection (Connection ):
27
75
def __init__ (self , ** kwargs ):
28
76
self .connection_pool = kwargs .pop ("connection_pool" )
@@ -116,12 +164,17 @@ def __init__(self, service_name, sentinel_manager, **kwargs):
116
164
)
117
165
self .is_master = kwargs .pop ("is_master" , True )
118
166
self .check_connection = kwargs .pop ("check_connection" , False )
167
+ self .proxy = AsyncSentinelConnectionPoolProxy (
168
+ connection_pool = self ,
169
+ is_master = self .is_master ,
170
+ check_connection = self .check_connection ,
171
+ service_name = service_name ,
172
+ sentinel_manager = sentinel_manager ,
173
+ )
119
174
super ().__init__ (** kwargs )
120
- self .connection_kwargs ["connection_pool" ] = weakref .proxy ( self )
175
+ self .connection_kwargs ["connection_pool" ] = self .proxy
121
176
self .service_name = service_name
122
177
self .sentinel_manager = sentinel_manager
123
- self .master_address = None
124
- self .slave_rr_counter = None
125
178
126
179
def __repr__ (self ):
127
180
return (
@@ -131,8 +184,11 @@ def __repr__(self):
131
184
132
185
def reset (self ):
133
186
super ().reset ()
134
- self .master_address = None
135
- self .slave_rr_counter = None
187
+ self .proxy .reset ()
188
+
189
+ @property
190
+ def master_address (self ):
191
+ return self .proxy .master_address
136
192
137
193
def owns_connection (self , connection : Connection ):
138
194
check = not self .is_master or (
@@ -141,31 +197,12 @@ def owns_connection(self, connection: Connection):
141
197
return check and super ().owns_connection (connection )
142
198
143
199
async def get_master_address (self ):
144
- master_address = await self .sentinel_manager .discover_master (self .service_name )
145
- if self .is_master :
146
- if self .master_address != master_address :
147
- self .master_address = master_address
148
- # disconnect any idle connections so that they reconnect
149
- # to the new master the next time that they are used.
150
- await self .disconnect (inuse_connections = False )
151
- return master_address
200
+ return await self .proxy .get_master_address ()
152
201
153
202
async def rotate_slaves (self ) -> AsyncIterator :
154
203
"""Round-robin slave balancer"""
155
- slaves = await self .sentinel_manager .discover_slaves (self .service_name )
156
- if slaves :
157
- if self .slave_rr_counter is None :
158
- self .slave_rr_counter = random .randint (0 , len (slaves ) - 1 )
159
- for _ in range (len (slaves )):
160
- self .slave_rr_counter = (self .slave_rr_counter + 1 ) % len (slaves )
161
- slave = slaves [self .slave_rr_counter ]
162
- yield slave
163
- # Fallback to the master connection
164
- try :
165
- yield await self .get_master_address ()
166
- except MasterNotFoundError :
167
- pass
168
- raise SlaveNotFoundError (f"No slave found for { self .service_name !r} " )
204
+ async for x in self .proxy .rotate_slaves ():
205
+ yield x
169
206
170
207
171
208
class Sentinel (AsyncSentinelCommands ):
0 commit comments