diff --git a/redis/connection.py b/redis/connection.py index f745ecc1d5..5e9edb4e70 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -356,87 +356,154 @@ def on_connect(self): ) auth_args = cred_provider.get_credentials() - # if resp version is specified and we have auth args, - # we need to send them via HELLO - if auth_args and self.protocol not in [2, "2"]: - if isinstance(self._parser, _RESP2Parser): - self.set_parser(_RESP3Parser) - # update cluster exception classes - self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES - self._parser.on_connect(self) - if len(auth_args) == 1: - auth_args = ["default", auth_args[0]] - self.send_command("HELLO", self.protocol, "AUTH", *auth_args) - response = self.read_response() - # if response.get(b"proto") != self.protocol and response.get( - # "proto" - # ) != self.protocol: - # raise ConnectionError("Invalid RESP version") - elif auth_args: - # avoid checking health here -- PING will fail if we try - # to check the health prior to the AUTH - self.send_command("AUTH", *auth_args, check_health=False) + # try to send HELLO command (for Redis 6.0 and above) + try: + # if resp version is specified and we have auth args, + # we need to send them via HELLO + if auth_args and self.protocol not in [2, "2"]: + if isinstance(self._parser, _RESP2Parser): + self.set_parser(_RESP3Parser) + # update cluster exception classes + self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES + self._parser.on_connect(self) + if len(auth_args) == 1: + auth_args = ["default", auth_args[0]] + self.send_command("HELLO", self.protocol, "AUTH", *auth_args) + else: + self.send_command("HELLO", self.protocol) - try: - auth_response = self.read_response() - except AuthenticationWrongNumberOfArgsError: - # a username and password were specified but the Redis - # server seems to be < 6.0.0 which expects a single password - # arg. retry auth with just the password. - # https://github.com/andymccurdy/redis-py/issues/1274 - self.send_command("AUTH", auth_args[-1], check_health=False) - auth_response = self.read_response() - - if str_if_bytes(auth_response) != "OK": - raise AuthenticationError("Invalid Username or Password") - - # if resp version is specified, switch to it - elif self.protocol not in [2, "2"]: - if isinstance(self._parser, _RESP2Parser): - self.set_parser(_RESP3Parser) - # update cluster exception classes - self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES - self._parser.on_connect(self) - self.send_command("HELLO", self.protocol) - response = self.read_response() - if ( - response.get(b"proto") != self.protocol - and response.get("proto") != self.protocol - ): + self.read_response() + + except Exception as e: + if str(e) == "Invalid RESP version": raise ConnectionError("Invalid RESP version") + # fall back to AUTH command (for Redis versions less than 6.0) + else: + # avoid checking health here -- PING will fail if we try + # to check the health prior to the AUTH + if auth_args: + # check if only password is provided and RESP version < 6 + if not self.username and self.password and self.protocol in [2, "2"]: + self.send_command("AUTH", self.password, check_health=False) + else: + self.send_command("AUTH", *auth_args, check_health=False) + + # start a transaction block with MULTI + try: + self.send_command('MULTI') + self.read_response() - # if a client_name is given, set it - if self.client_name: - self.send_command("CLIENT", "SETNAME", self.client_name) - if str_if_bytes(self.read_response()) != "OK": - raise ConnectionError("Error setting client name") + # if a client_name is given, set it + if self.client_name: + self.send_command("CLIENT", "SETNAME", self.client_name) - try: # set the library name and version if self.lib_name: self.send_command("CLIENT", "SETINFO", "LIB-NAME", self.lib_name) - self.read_response() - except ResponseError: - pass - - try: if self.lib_version: self.send_command("CLIENT", "SETINFO", "LIB-VER", self.lib_version) - self.read_response() - except ResponseError: - pass - # if a database is specified, switch to it - if self.db: - self.send_command("SELECT", self.db) - if str_if_bytes(self.read_response()) != "OK": - raise ConnectionError("Invalid Database") + # if a database is specified, switch to it + if self.db: + self.send_command("SELECT", self.db) + + # if client caching is enabled, start tracking + if self.client_cache: + self.send_command("CLIENT", "TRACKING", "ON") + + # execute the MULTI block + self.send_command('EXEC') + responses = self._read_exec_responses() + self._handle_responses(responses, auth_args) + except AuthenticationError as e: + if str(e) == "Invalid Username or Password": + raise AuthenticationError("Invalid Username or Password") from e + except Exception: + raise ConnectionError("Error during EXEC handling") + + def _read_exec_responses(self): + # read the response for EXEC which should be a list + response = self.read_response() + if response == b'OK': + # EXEC did not execute correctly, likely due to previous error + raise ConnectionError("EXEC command did not execute correctly") + while response == b'QUEUED': + response = self.read_response() + if not isinstance(response, list): + raise ConnectionError(f"EXEC command did not return a list: {response}") + return response - # if client caching is enabled, start tracking - if self.client_cache: - self.send_command("CLIENT", "TRACKING", "ON") - self.read_response() - self._parser.set_invalidation_push_handler(self._cache_invalidation_process) + def _handle_responses(self, responses, auth_args): + if not isinstance(responses, list): + raise ConnectionError(f"EXEC command did not return a list: {responses}") + + response_iter = iter(responses) + + try: + # handle HELLO + AUTH + if auth_args and self.protocol not in [2, "2"]: + response = next(response_iter, None) + if isinstance(response, dict) and ( + response.get(b"proto") != self.protocol and response.get("proto") != self.protocol): + raise ConnectionError("Invalid RESP version") + + response = next(response_iter, None) + if isinstance(response, bytes) and str_if_bytes(response) != "OK": + raise AuthenticationError("Invalid Username or Password") + elif auth_args: + response = next(response_iter, None) + if isinstance(response, bytes) and str_if_bytes(response) != "OK": + try: + # a username and password were specified but the Redis + # server seems to be < 6.0.0 which expects a single password + # arg. retry auth with just the password. + # https://github.com/andymccurdy/redis-py/issues/1274 + self.send_command("AUTH", auth_args[-1], check_health=False) + auth_response = self.read_response() + if isinstance(auth_response, bytes) and str_if_bytes( + auth_response) != "OK": + raise AuthenticationError("Invalid Username or Password") + # add the retry response to the responses list for further processing + responses = [auth_response] + list(response_iter) + response_iter = iter(responses) + except AuthenticationWrongNumberOfArgsError: + raise AuthenticationError("Invalid Username or Password") + + # handle CLIENT SETNAME + if self.client_name: + response = next(response_iter, None) + if isinstance(response, bytes) and str_if_bytes(response) != "OK": + raise ConnectionError("Error setting client name") + + # handle CLIENT SETINFO LIB-NAME + if self.lib_name: + response = next(response_iter, None) + if isinstance(response, bytes) and str_if_bytes(response) != "OK": + raise ConnectionError("Error setting client library name") + + # handle CLIENT SETINFO LIB-VER + if self.lib_version: + response = next(response_iter, None) + if isinstance(response, bytes) and str_if_bytes(response) != "OK": + raise ConnectionError("Error setting client library version") + + # handle SELECT + if self.db: + response = next(response_iter, None) + if isinstance(response, bytes) and str_if_bytes(response) != "OK": + raise ConnectionError("Invalid Database") + + # handle CLIENT TRACKING ON + if self.client_cache: + response = next(response_iter, None) + if isinstance(response, bytes) and str_if_bytes(response) != "OK": + raise ConnectionError("Error enabling client tracking") + self._parser.set_invalidation_push_handler( + self._cache_invalidation_process) + except (AuthenticationError, ConnectionError): + raise + except Exception as e: + raise ConnectionError("Error during response handling") from e def disconnect(self, *args): "Disconnects from the Redis server" diff --git a/tests/test_connection.py b/tests/test_connection.py index bff249559e..a724c2e5b9 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1,8 +1,9 @@ +import itertools import socket import types +from unittest import TestCase from unittest import mock -from unittest.mock import patch - +from unittest.mock import patch, MagicMock import pytest import redis from redis import ConnectionPool, Redis @@ -13,6 +14,8 @@ SSLConnection, UnixDomainSocketConnection, parse_url, + UsernamePasswordCredentialProvider, + AuthenticationError ) from redis.exceptions import ConnectionError, InvalidResponse, TimeoutError from redis.retry import Retry @@ -55,7 +58,7 @@ def inner(): # assert mod.get('fookey') == d -class TestConnection: +class TestConnection(TestCase): def test_disconnect(self): conn = Connection() mock_sock = mock.Mock() @@ -131,6 +134,189 @@ def test_connect_timeout_error_without_retry(self): assert str(e.value) == "Timeout connecting to server" self.clear(conn) + @patch.object(Connection, 'send_command') + @patch.object(Connection, 'read_response') + def test_on_connect(self, mock_read_response, mock_send_command): + """Test that the on_connect function sends the correct commands""" + conn = Connection() + + conn._parser = MagicMock() + conn._parser.on_connect.return_value = None + conn.credential_provider = None + conn.username = "myuser" + conn.password = "password" + conn.protocol = 3 + conn.client_name = "test-client" + conn.lib_name = "test" + conn.lib_version = "1234" + conn.db = 0 + conn.client_cache = True + + # command response + mock_read_response.side_effect = itertools.cycle([ + b'QUEUED', # MULTI + b'QUEUED', # HELLO + b'QUEUED', # AUTH + b'QUEUED', # CLIENT SETNAME + b'QUEUED', # CLIENT SETINFO LIB-NAME + b'QUEUED', # CLIENT SETINFO LIB-VER + b'QUEUED', # SELECT + b'QUEUED', # CLIENT TRACKING ON + [ # EXEC response list + {"proto": 3, "version": "6"}, + b'OK', + b'OK', + b'OK', + b'OK', + b'OK', + b'OK', + b'OK' + ] + ]) + + conn.on_connect() + + mock_read_response.side_effect = itertools.repeat("OK") + + @patch.object(Connection, 'send_command') + @patch.object(Connection, 'read_response') + def test_on_connect_fail_hello(self, mock_read_response, mock_send_command): + """Test that on_connect handles connection failure HELLO command""" + conn = Connection() + + conn._parser = MagicMock() + conn._parser.on_connect.return_value = None + conn.credential_provider = None + conn.username = "myuser" + conn.password = "password" + conn.protocol = -1 # invalid protocol + conn.client_name = "test-client" + conn.lib_name = "test" + conn.lib_version = "1234" + conn.db = 0 + conn.client_cache = True + + # simulate a failure in the HELLO command response + mock_read_response.side_effect = itertools.cycle([ + Exception("Invalid RESP version"), # HELLO (fails) + b'QUEUED', # MULTI + ]) + + with self.assertRaises(ConnectionError): + conn.on_connect() + + mock_send_command.assert_any_call('HELLO', -1, 'AUTH', 'myuser', 'password'), + + mock_send_command.assert_called() + mock_read_response.assert_called() + + @patch.object(Connection, 'send_command') + @patch.object(Connection, 'read_response') + def test_on_connect_fail_auth(self, mock_read_response, mock_send_command): + """Test that on_connect handles connection failure AUTH command""" + conn = Connection() + + conn._parser = MagicMock() + conn._parser.on_connect.return_value = None + conn.credential_provider = None + conn.username = "myuser" + conn.password = "wrong-password" + conn.protocol = 3 + conn.client_name = "test-client" + conn.lib_name = "test" + conn.lib_version = "1234" + conn.db = 1 + conn.client_cache = True + + # simulate a failure in the HELLO command response + mock_read_response.side_effect = itertools.cycle([ + {"proto": 3, "version": "6"}, # HELLO + b'QUEUED', # MULTI + b'QUEUED', # AUTH + b'QUEUED', # CLIENT SETNAME + b'QUEUED', # CLIENT SETINFO LIB-NAME + b'QUEUED', # CLIENT SETINFO LIB-VER + b'QUEUED', # SELECT + b'QUEUED', # CLIENT TRACKING ON + [ + {"proto": 3, "version": "6"}, # HELLO response + b'ERR invalid password', # AUTH response + b'OK', # CLIENT SETNAME response + b'OK', # CLIENT SETINFO LIB-NAME response + b'OK', # CLIENT SETINFO LIB-VER response + b'OK', # SELECT response + b'OK' # CLIENT TRACKING ON response + ] + ]) + + with self.assertRaises(AuthenticationError): + conn.on_connect() + + mock_send_command.assert_any_call( + 'HELLO', 3, 'AUTH', 'myuser', 'wrong-password'), + mock_send_command.assert_any_call('CLIENT', 'SETNAME', 'test-client'), + mock_send_command.assert_any_call('CLIENT', 'SETINFO', 'LIB-NAME', 'test'), + mock_send_command.assert_any_call('CLIENT', 'SETINFO', 'LIB-VER', '1234'), + mock_send_command.assert_any_call('SELECT', 1), + mock_send_command.assert_any_call('CLIENT', 'TRACKING', 'ON'), + mock_send_command.assert_any_call('EXEC') + + mock_send_command.assert_called() + mock_read_response.assert_called() + + @patch.object(Connection, 'send_command') + @patch.object(Connection, 'read_response') + def test_on_connect_auth_with_password_only( + self, mock_read_response, mock_send_command): + """Test on_connect handling of password-only AUTH for Redis versions below 6.0.0 without HELLO command""" + conn = Connection() + + conn._parser = MagicMock() + conn._parser.on_connect.return_value = None + conn.credential_provider = None + conn.username = None + conn.password = "password" + conn.protocol = 1 + conn.client_name = "test-client" + conn.lib_name = "test" + conn.lib_version = "1234" + conn.db = 1 + conn.client_cache = True + + # command response to simulate Redis < 6.0.0 behavior + mock_read_response.side_effect = itertools.cycle([ + Exception("ERR HELLO"), # HELLO (fails) + b'QUEUED', # MULTI + b'QUEUED', # AUTH + b'QUEUED', # CLIENT SETNAME + b'QUEUED', # CLIENT SETINFO LIB-NAME + b'QUEUED', # CLIENT SETINFO LIB-VER + b'QUEUED', # SELECT + b'QUEUED', # CLIENT TRACKING ON + [ + b'OK', # AUTH response + b'OK', # CLIENT SETNAME response + b'OK', # CLIENT SETINFO LIB-NAME response + b'OK', # CLIENT SETINFO LIB-VER response + b'OK', # SELECT response + b'OK' # CLIENT TRACKING ON response + ] + ]) + + conn.on_connect() + + mock_send_command.assert_any_call('HELLO', 1, 'AUTH', 'default', 'password'), + mock_send_command.assert_any_call('MULTI'), + mock_send_command.assert_any_call( + 'AUTH', 'default', 'password', check_health=False) + mock_send_command.assert_any_call('CLIENT', 'SETNAME', 'test-client') + mock_send_command.assert_any_call('CLIENT', 'SETINFO', 'LIB-NAME', 'test') + mock_send_command.assert_any_call('CLIENT', 'SETINFO', 'LIB-VER', '1234') + mock_send_command.assert_any_call('SELECT', 1) + mock_send_command.assert_any_call('CLIENT', 'TRACKING', 'ON') + mock_send_command.assert_any_call('EXEC') + mock_read_response.assert_called() + @pytest.mark.onlynoncluster @pytest.mark.parametrize(