diff --git a/redis/_parsers/base.py b/redis/_parsers/base.py index 69d7b585dd..f2670e43b0 100644 --- a/redis/_parsers/base.py +++ b/redis/_parsers/base.py @@ -3,6 +3,12 @@ from asyncio import IncompleteReadError, StreamReader, TimeoutError from typing import Callable, List, Optional, Protocol, Union +from redis.maintenance_events import ( + NodeMigratedEvent, + NodeMigratingEvent, + NodeMovingEvent, +) + if sys.version_info.major >= 3 and sys.version_info.minor >= 11: from asyncio import timeout as async_timeout else: @@ -123,9 +129,10 @@ def __del__(self): def on_connect(self, connection): "Called when the socket connects" self._sock = connection._sock - self._buffer = SocketBuffer( - self._sock, self.socket_read_size, connection.socket_timeout - ) + timeout = connection.socket_timeout + if connection.tmp_relax_timeout != -1: + timeout = connection.tmp_relax_timeout + self._buffer = SocketBuffer(self._sock, self.socket_read_size, timeout) self.encoder = connection.encoder def on_disconnect(self): @@ -158,7 +165,19 @@ async def read_response( raise NotImplementedError() -_INVALIDATION_MESSAGE = [b"invalidate", "invalidate"] +_INVALIDATION_MESSAGE = (b"invalidate", "invalidate") +_MOVING_MESSAGE = (b"MOVING", "MOVING") +_MIGRATING_MESSAGE = (b"MIGRATING", "MIGRATING") +_MIGRATED_MESSAGE = (b"MIGRATED", "MIGRATED") +_FAILING_OVER_MESSAGE = (b"FAILING_OVER", "FAILING_OVER") +_FAILED_OVER_MESSAGE = (b"FAILED_OVER", "FAILED_OVER") + +_MAINTENANCE_MESSAGES = ( + *_MIGRATING_MESSAGE, + *_MIGRATED_MESSAGE, + *_FAILING_OVER_MESSAGE, + *_FAILED_OVER_MESSAGE, +) class PushNotificationsParser(Protocol): @@ -166,16 +185,44 @@ class PushNotificationsParser(Protocol): pubsub_push_handler_func: Callable invalidation_push_handler_func: Optional[Callable] = None + node_moving_push_handler_func: Optional[Callable] = None + maintenance_push_handler_func: Optional[Callable] = None def handle_pubsub_push_response(self, response): """Handle pubsub push responses""" raise NotImplementedError() def handle_push_response(self, response, **kwargs): - if response[0] not in _INVALIDATION_MESSAGE: + msg_type = response[0] + if msg_type not in ( + *_INVALIDATION_MESSAGE, + *_MAINTENANCE_MESSAGES, + *_MOVING_MESSAGE, + ): return self.pubsub_push_handler_func(response) - if self.invalidation_push_handler_func: + if msg_type in _INVALIDATION_MESSAGE and self.invalidation_push_handler_func: return self.invalidation_push_handler_func(response) + if msg_type in _MOVING_MESSAGE and self.node_moving_push_handler_func: + if msg_type in _MOVING_MESSAGE: + host, port = response[2].decode().split(":") + ttl = response[1] + id = 1 # Hardcoded value for sync parser + notification = NodeMovingEvent(id, host, port, ttl) + return self.node_moving_push_handler_func(notification) + if msg_type in _MAINTENANCE_MESSAGES and self.maintenance_push_handler_func: + if msg_type in _MIGRATING_MESSAGE: + ttl = response[1] + id = 2 # Hardcoded value for sync parser + notification = NodeMigratingEvent(id, ttl) + elif msg_type in _MIGRATED_MESSAGE: + id = 3 # Hardcoded value for sync parser + notification = NodeMigratedEvent(id) + else: + notification = None + if notification is not None: + return self.maintenance_push_handler_func(notification) + else: + return None def set_pubsub_push_handler(self, pubsub_push_handler_func): self.pubsub_push_handler_func = pubsub_push_handler_func @@ -183,12 +230,20 @@ def set_pubsub_push_handler(self, pubsub_push_handler_func): def set_invalidation_push_handler(self, invalidation_push_handler_func): self.invalidation_push_handler_func = invalidation_push_handler_func + def set_node_moving_push_handler(self, node_moving_push_handler_func): + self.node_moving_push_handler_func = node_moving_push_handler_func + + def set_maintenance_push_handler(self, maintenance_push_handler_func): + self.maintenance_push_handler_func = maintenance_push_handler_func + class AsyncPushNotificationsParser(Protocol): """Protocol defining async RESP3-specific parsing functionality""" pubsub_push_handler_func: Callable invalidation_push_handler_func: Optional[Callable] = None + node_moving_push_handler_func: Optional[Callable] = None + maintenance_push_handler_func: Optional[Callable] = None async def handle_pubsub_push_response(self, response): """Handle pubsub push responses asynchronously""" @@ -196,10 +251,31 @@ async def handle_pubsub_push_response(self, response): async def handle_push_response(self, response, **kwargs): """Handle push responses asynchronously""" - if response[0] not in _INVALIDATION_MESSAGE: + msg_type = response[0] + if msg_type not in ( + *_INVALIDATION_MESSAGE, + *_MAINTENANCE_MESSAGES, + *_MOVING_MESSAGE, + ): return await self.pubsub_push_handler_func(response) - if self.invalidation_push_handler_func: + if msg_type in _INVALIDATION_MESSAGE and self.invalidation_push_handler_func: return await self.invalidation_push_handler_func(response) + if msg_type in _MOVING_MESSAGE and self.node_moving_push_handler_func: + # push notification from enterprise cluster for node moving + host, port = response[2].split(":") + ttl = response[1] + id = 1 # Hardcoded value for async parser + notification = NodeMovingEvent(id, host, port, ttl) + return await self.node_moving_push_handler_func(notification) + if msg_type in _MAINTENANCE_MESSAGES and self.maintenance_push_handler_func: + if msg_type in _MIGRATING_MESSAGE: + ttl = response[1] + id = 2 # Hardcoded value for async parser + notification = NodeMigratingEvent(id, ttl) + elif msg_type in _MIGRATED_MESSAGE: + id = 3 # Hardcoded value for async parser + notification = NodeMigratedEvent(id) + return await self.maintenance_push_handler_func(notification) def set_pubsub_push_handler(self, pubsub_push_handler_func): """Set the pubsub push handler function""" @@ -209,6 +285,12 @@ def set_invalidation_push_handler(self, invalidation_push_handler_func): """Set the invalidation push handler function""" self.invalidation_push_handler_func = invalidation_push_handler_func + def set_node_moving_push_handler_func(self, node_moving_push_handler_func): + self.node_moving_push_handler_func = node_moving_push_handler_func + + def set_maintenance_push_handler(self, maintenance_push_handler_func): + self.maintenance_push_handler_func = maintenance_push_handler_func + class _AsyncRESPBase(AsyncBaseParser): """Base class for async resp parsing""" diff --git a/redis/_parsers/hiredis.py b/redis/_parsers/hiredis.py index 521a58b26c..e9df314a8c 100644 --- a/redis/_parsers/hiredis.py +++ b/redis/_parsers/hiredis.py @@ -47,6 +47,8 @@ def __init__(self, socket_read_size): self.socket_read_size = socket_read_size self._buffer = bytearray(socket_read_size) self.pubsub_push_handler_func = self.handle_pubsub_push_response + self.node_moving_push_handler_func = None + self.maintenance_push_handler_func = None self.invalidation_push_handler_func = None self._hiredis_PushNotificationType = None @@ -141,13 +143,15 @@ def read_response(self, disable_decoding=False, push_request=False): response, self._hiredis_PushNotificationType ): response = self.handle_push_response(response) - if not push_request: - return self.read_response( - disable_decoding=disable_decoding, push_request=push_request - ) - else: + + # if this is a push request return the push response + if push_request: return response - return response + + return self.read_response( + disable_decoding=disable_decoding, + push_request=push_request, + ) if disable_decoding: response = self._reader.gets(False) @@ -169,12 +173,13 @@ def read_response(self, disable_decoding=False, push_request=False): response, self._hiredis_PushNotificationType ): response = self.handle_push_response(response) - if not push_request: - return self.read_response( - disable_decoding=disable_decoding, push_request=push_request - ) - else: + if push_request: return response + return self.read_response( + disable_decoding=disable_decoding, + push_request=push_request, + ) + elif ( isinstance(response, list) and response diff --git a/redis/_parsers/resp3.py b/redis/_parsers/resp3.py index 42c6652e31..72957b464c 100644 --- a/redis/_parsers/resp3.py +++ b/redis/_parsers/resp3.py @@ -18,6 +18,8 @@ class _RESP3Parser(_RESPBase, PushNotificationsParser): def __init__(self, socket_read_size): super().__init__(socket_read_size) self.pubsub_push_handler_func = self.handle_pubsub_push_response + self.node_moving_push_handler_func = None + self.maintenance_push_handler_func = None self.invalidation_push_handler_func = None def handle_pubsub_push_response(self, response): @@ -117,17 +119,21 @@ def _read_response(self, disable_decoding=False, push_request=False): for _ in range(int(response)) ] response = self.handle_push_response(response) - if not push_request: - return self._read_response( - disable_decoding=disable_decoding, push_request=push_request - ) - else: + + # if this is a push request return the push response + if push_request: return response + + return self._read_response( + disable_decoding=disable_decoding, + push_request=push_request, + ) else: raise InvalidResponse(f"Protocol Error: {raw!r}") if isinstance(response, bytes) and disable_decoding is False: response = self.encoder.decode(response) + return response diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 95390bd66c..d7dfe4babb 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -1302,6 +1302,8 @@ def __init__( ) self._condition = asyncio.Condition() self.timeout = timeout + self._in_maintenance = False + self._locked = False @deprecated_args( args_to_warn=["*"], diff --git a/redis/client.py b/redis/client.py index 0e05b6f542..a6c96c3882 100755 --- a/redis/client.py +++ b/redis/client.py @@ -56,6 +56,10 @@ WatchError, ) from redis.lock import Lock +from redis.maintenance_events import ( + MaintenanceEventPoolHandler, + MaintenanceEventsConfig, +) from redis.retry import Retry from redis.utils import ( _set_info_logger, @@ -244,6 +248,7 @@ def __init__( cache: Optional[CacheInterface] = None, cache_config: Optional[CacheConfig] = None, event_dispatcher: Optional[EventDispatcher] = None, + maintenance_events_config: Optional[MaintenanceEventsConfig] = None, ) -> None: """ Initialize a new Redis client. @@ -368,6 +373,23 @@ def __init__( ]: raise RedisError("Client caching is only supported with RESP version 3") + if maintenance_events_config and self.connection_pool.get_protocol() not in [ + 3, + "3", + ]: + raise RedisError( + "Push handlers on connection are only supported with RESP version 3" + ) + if maintenance_events_config and maintenance_events_config.enabled: + self.maintenance_events_pool_handler = MaintenanceEventPoolHandler( + self.connection_pool, maintenance_events_config + ) + self.connection_pool.set_maintenance_events_pool_handler( + self.maintenance_events_pool_handler + ) + else: + self.maintenance_events_pool_handler = None + self.single_connection_lock = threading.RLock() self.connection = None self._single_connection_client = single_connection_client @@ -565,8 +587,15 @@ def monitor(self): return Monitor(self.connection_pool) def client(self): + maintenance_events_config = ( + None + if self.maintenance_events_pool_handler is None + else self.maintenance_events_pool_handler.config + ) return self.__class__( - connection_pool=self.connection_pool, single_connection_client=True + connection_pool=self.connection_pool, + single_connection_client=True, + maintenance_events_config=maintenance_events_config, ) def __enter__(self): @@ -635,7 +664,11 @@ def _execute_command(self, *args, **options): ), lambda _: self._close_connection(conn), ) + finally: + if conn and conn.should_reconnect(): + self._close_connection(conn) + conn.connect() if self._single_connection_client: self.single_connection_lock.release() if not self.connection: @@ -686,11 +719,7 @@ def __init__(self, connection_pool): self.connection = self.connection_pool.get_connection() def __enter__(self): - self.connection.send_command("MONITOR") - # check that monitor returns 'OK', but don't return it to user - response = self.connection.read_response() - if not bool_ok(response): - raise RedisError(f"MONITOR failed: {response}") + self._start_monitor() return self def __exit__(self, *args): @@ -700,8 +729,13 @@ def __exit__(self, *args): def next_command(self): """Parse the response from a monitor command""" response = self.connection.read_response() + + if response is None: + return None + if isinstance(response, bytes): response = self.connection.encoder.decode(response, force=True) + command_time, command_data = response.split(" ", 1) m = self.monitor_re.match(command_data) db_id, client_info, command = m.groups() @@ -737,6 +771,14 @@ def listen(self): while True: yield self.next_command() + def _start_monitor(self): + self.connection.send_command("MONITOR") + # check that monitor returns 'OK', but don't return it to user + response = self.connection.read_response() + + if not bool_ok(response): + raise RedisError(f"MONITOR failed: {response}") + class PubSub: """ @@ -881,7 +923,7 @@ def clean_health_check_responses(self) -> None: """ ttl = 10 conn = self.connection - while self.health_check_response_counter > 0 and ttl > 0: + while conn and self.health_check_response_counter > 0 and ttl > 0: if self._execute(conn, conn.can_read, timeout=conn.socket_timeout): response = self._execute(conn, conn.read_response) if self.is_health_check_response(response): @@ -911,10 +953,15 @@ def _execute(self, conn, command, *args, **kwargs): called by the # connection to resubscribe us to any channels and patterns we were previously listening to """ - return conn.retry.call_with_retry( + + response = conn.retry.call_with_retry( lambda: command(*args, **kwargs), lambda _: self._reconnect(conn), ) + if conn.should_reconnect(): + self._reconnect(conn) + + return response def parse_response(self, block=True, timeout=0): """Parse the response from a publish/subscribe command""" @@ -1148,6 +1195,7 @@ def handle_message(self, response, ignore_subscribe_messages=False): return None if isinstance(response, bytes): response = [b"pong", response] if response != b"PONG" else [b"pong", b""] + message_type = str_if_bytes(response[0]) if message_type == "pmessage": message = { @@ -1351,6 +1399,7 @@ def reset(self) -> None: # clean up the other instance attributes self.watching = False self.explicit_transaction = False + # we can safely return the connection to the pool here since we're # sure we're no longer WATCHing anything if self.connection: @@ -1510,6 +1559,7 @@ def _execute_transaction( if command_name in self.response_callbacks: r = self.response_callbacks[command_name](r, **options) data.append(r) + return data def _execute_pipeline(self, connection, commands, raise_on_error): @@ -1517,16 +1567,17 @@ def _execute_pipeline(self, connection, commands, raise_on_error): all_cmds = connection.pack_commands([args for args, _ in commands]) connection.send_packed_command(all_cmds) - response = [] + responses = [] for args, options in commands: try: - response.append(self.parse_response(connection, args[0], **options)) + responses.append(self.parse_response(connection, args[0], **options)) except ResponseError as e: - response.append(e) + responses.append(e) if raise_on_error: - self.raise_first_error(commands, response) - return response + self.raise_first_error(commands, responses) + + return responses def raise_first_error(self, commands, response): for i, r in enumerate(response): @@ -1611,6 +1662,8 @@ def execute(self, raise_on_error: bool = True) -> List[Any]: lambda error: self._disconnect_raise_on_watching(conn, error), ) finally: + # in reset() the connection is disconnected before returned to the pool if + # it is marked for reconnect. self.reset() def discard(self): diff --git a/redis/connection.py b/redis/connection.py index d457b1015c..3ff8fbc0c0 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -1,4 +1,5 @@ import copy +import logging import os import socket import sys @@ -19,6 +20,7 @@ CacheInterface, CacheKey, ) +from redis.typing import Number from ._parsers import Encoder, _HiredisParser, _RESP2Parser, _RESP3Parser from .auth.token import TokenInterface @@ -35,6 +37,12 @@ ResponseError, TimeoutError, ) +from .maintenance_events import ( + MaintenanceEventConnectionHandler, + MaintenanceEventPoolHandler, + MaintenanceEventsConfig, + MaintenanceState, +) from .retry import Retry from .utils import ( CRYPTOGRAPHY_AVAILABLE, @@ -158,6 +166,10 @@ def deregister_connect_callback(self, callback): def set_parser(self, parser_class): pass + @abstractmethod + def set_maintenance_event_pool_handler(self, maintenance_event_pool_handler): + pass + @abstractmethod def get_protocol(self): pass @@ -221,6 +233,26 @@ def set_re_auth_token(self, token: TokenInterface): def re_auth(self): pass + @abstractmethod + def mark_for_reconnect(self): + pass + + @abstractmethod + def should_reconnect(self): + pass + + @abstractmethod + def update_current_socket_timeout(self, relax_timeout: Optional[float] = None): + pass + + @abstractmethod + def update_tmp_settings( + self, + tmp_host_address: Optional[str] = None, + tmp_relax_timeout: Optional[float] = None, + ): + pass + class AbstractConnection(ConnectionInterface): "Manages communication to and from a Redis server" @@ -249,6 +281,11 @@ def __init__( protocol: Optional[int] = 2, command_packer: Optional[Callable[[], None]] = None, event_dispatcher: Optional[EventDispatcher] = None, + maintenance_events_pool_handler: Optional[MaintenanceEventPoolHandler] = None, + maintenance_events_config: Optional[MaintenanceEventsConfig] = None, + tmp_host_address: Optional[str] = None, + tmp_relax_timeout: Optional[float] = -1, + maintenance_state: "MaintenanceState" = MaintenanceState.NONE, ): """ Initialize a new Connection. @@ -304,7 +341,6 @@ def __init__( self.handshake_metadata = None self._sock = None self._socket_read_size = socket_read_size - self.set_parser(parser_class) self._connect_callbacks = [] self._buffer_cutoff = 6000 self._re_auth_token: Optional[TokenInterface] = None @@ -319,7 +355,27 @@ def __init__( raise ConnectionError("protocol must be either 2 or 3") # p = DEFAULT_RESP_VERSION self.protocol = p + if self.protocol == 3 and parser_class == DefaultParser: + parser_class = _RESP3Parser + self.set_parser(parser_class) + + if maintenance_events_config and maintenance_events_config.enabled: + if maintenance_events_pool_handler: + self._parser.set_node_moving_push_handler( + maintenance_events_pool_handler.handle_event + ) + self._maintenance_event_connection_handler = ( + MaintenanceEventConnectionHandler(self, maintenance_events_config) + ) + self._parser.set_maintenance_push_handler( + self._maintenance_event_connection_handler.handle_event + ) + self._command_packer = self._construct_command_packer(command_packer) + self._should_reconnect = False + self.tmp_host_address = tmp_host_address + self.tmp_relax_timeout = tmp_relax_timeout + self.maintenance_state = maintenance_state def __repr__(self): repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()]) @@ -374,6 +430,24 @@ def set_parser(self, parser_class): """ self._parser = parser_class(socket_read_size=self._socket_read_size) + def set_maintenance_event_pool_handler( + self, maintenance_event_pool_handler: MaintenanceEventPoolHandler + ): + self._parser.set_node_moving_push_handler( + maintenance_event_pool_handler.handle_event + ) + + # Initialize maintenance event connection handler if it doesn't exist + if not hasattr(self, "_maintenance_event_connection_handler"): + self._maintenance_event_connection_handler = ( + MaintenanceEventConnectionHandler( + self, maintenance_event_pool_handler.config + ) + ) + self._parser.set_maintenance_push_handler( + self._maintenance_event_connection_handler.handle_event + ) + def connect(self): "Connects to the Redis server if not already connected" self.connect_check_health(check_health=True) @@ -543,6 +617,8 @@ def disconnect(self, *args): conn_sock = self._sock self._sock = None + # reset the reconnect flag + self._should_reconnect = False if conn_sock is None: return @@ -620,6 +696,7 @@ def can_read(self, timeout=0): try: return self._parser.can_read(timeout) + except OSError as e: self.disconnect() raise ConnectionError(f"Error while reading from {host_error}: {e.args}") @@ -726,6 +803,38 @@ def re_auth(self): self.read_response() self._re_auth_token = None + def mark_for_reconnect(self): + self._should_reconnect = True + + def should_reconnect(self): + return self._should_reconnect + + def update_current_socket_timeout(self, relax_timeout: Optional[float] = None): + if self._sock: + timeout = relax_timeout if relax_timeout != -1 else self.socket_timeout + self._sock.settimeout(timeout) + self.update_parser_buffer_timeout(timeout) + + def update_parser_buffer_timeout(self, timeout: Optional[float] = None): + if self._parser and self._parser._buffer: + self._parser._buffer.socket_timeout = timeout + + def update_tmp_settings( + self, + tmp_host_address: Optional[Union[str, object]] = SENTINEL, + tmp_relax_timeout: Optional[Union[float, object]] = SENTINEL, + ): + """ + The value of SENTINEL is used to indicate that the property should not be updated. + """ + if tmp_host_address is not SENTINEL: + self.tmp_host_address = tmp_host_address + if tmp_relax_timeout is not SENTINEL: + self.tmp_relax_timeout = tmp_relax_timeout + + def set_maintenance_state(self, state: "MaintenanceState"): + self.maintenance_state = state + class Connection(AbstractConnection): "Manages TCP communication to and from a Redis server" @@ -758,8 +867,10 @@ def _connect(self): # ipv4/ipv6, but we want to set options prior to calling # socket.connect() err = None + host = self.tmp_host_address or self.host + for res in socket.getaddrinfo( - self.host, self.port, self.socket_type, socket.SOCK_STREAM + host, self.port, self.socket_type, socket.SOCK_STREAM ): family, socktype, proto, canonname, socket_address = res sock = None @@ -775,13 +886,19 @@ def _connect(self): sock.setsockopt(socket.IPPROTO_TCP, k, v) # set the socket_connect_timeout before we connect - sock.settimeout(self.socket_connect_timeout) + if self.tmp_relax_timeout != -1: + sock.settimeout(self.tmp_relax_timeout) + else: + sock.settimeout(self.socket_connect_timeout) # connect sock.connect(socket_address) # set the socket_timeout now that we're connected - sock.settimeout(self.socket_timeout) + if self.tmp_relax_timeout != -1: + sock.settimeout(self.tmp_relax_timeout) + else: + sock.settimeout(self.socket_timeout) return sock except OSError as _: @@ -1408,6 +1525,14 @@ def __init__( connection_kwargs.pop("cache", None) connection_kwargs.pop("cache_config", None) + if connection_kwargs.get( + "maintenance_events_pool_handler" + ) or connection_kwargs.get("maintenance_events_config"): + if connection_kwargs.get("protocol") not in [3, "3"]: + raise RedisError( + "Push handlers on connection are only supported with RESP version 3" + ) + self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None) if self._event_dispatcher is None: self._event_dispatcher = EventDispatcher() @@ -1442,6 +1567,42 @@ def get_protocol(self): """ return self.connection_kwargs.get("protocol", None) + def maintenance_events_pool_handler_enabled(self): + """ + Returns: + True if the maintenance events pool handler is enabled, False otherwise. + """ + maintenance_events_config = self.connection_kwargs.get( + "maintenance_events_config", False + ) + + return maintenance_events_config and maintenance_events_config.enabled + + def set_maintenance_events_pool_handler( + self, maintenance_events_pool_handler: MaintenanceEventPoolHandler + ): + self.connection_kwargs.update( + { + "maintenance_events_pool_handler": maintenance_events_pool_handler, + "maintenance_events_config": maintenance_events_pool_handler.config, + } + ) + + self._update_maintenance_events_configs_for_connections( + maintenance_events_pool_handler + ) + + def _update_maintenance_events_configs_for_connections( + self, maintenance_events_pool_handler + ): + with self._lock: + for conn in self._available_connections: + conn.set_maintenance_event_pool_handler(maintenance_events_pool_handler) + conn.maintenance_events_config = maintenance_events_pool_handler.config + for conn in self._in_use_connections: + conn.set_maintenance_event_pool_handler(maintenance_events_pool_handler) + conn.maintenance_events_config = maintenance_events_pool_handler.config + def reset(self) -> None: self._created_connections = 0 self._available_connections = [] @@ -1529,7 +1690,11 @@ def get_connection(self, command_name=None, *keys, **options) -> "Connection": # pool before all data has been read or the socket has been # closed. either way, reconnect and verify everything is good. try: - if connection.can_read() and self.cache is None: + if ( + connection.can_read() + and self.cache is None + and not self.maintenance_events_pool_handler_enabled() + ): raise ConnectionError("Connection has data") except (ConnectionError, TimeoutError, OSError): connection.disconnect() @@ -1541,7 +1706,6 @@ def get_connection(self, command_name=None, *keys, **options) -> "Connection": # leak it self.release(connection) raise - return connection def get_encoder(self) -> Encoder: @@ -1559,12 +1723,18 @@ def make_connection(self) -> "ConnectionInterface": raise ConnectionError("Too many connections") self._created_connections += 1 + # Pass current maintenance_state to new connections + maintenance_state = self.connection_kwargs.get( + "maintenance_state", MaintenanceState.NONE + ) + kwargs = dict(self.connection_kwargs) + kwargs["maintenance_state"] = maintenance_state + if self.cache is not None: return CacheProxyConnection( - self.connection_class(**self.connection_kwargs), self.cache, self._lock + self.connection_class(**kwargs), self.cache, self._lock ) - - return self.connection_class(**self.connection_kwargs) + return self.connection_class(**kwargs) def release(self, connection: "Connection") -> None: "Releases the connection back to the pool" @@ -1578,6 +1748,8 @@ def release(self, connection: "Connection") -> None: return if self.owns_connection(connection): + if connection.should_reconnect(): + connection.disconnect() self._available_connections.append(connection) self._event_dispatcher.dispatch( AfterConnectionReleasedEvent(connection) @@ -1639,6 +1811,146 @@ def re_auth_callback(self, token: TokenInterface): for conn in self._in_use_connections: conn.set_re_auth_token(token) + def update_connection_kwargs_with_tmp_settings( + self, + tmp_host_address: Optional[str] = None, + tmp_relax_timeout: Optional[float] = None, + ): + """ + Update the connection kwargs with the temporary host address and the + relax timeout(if enabled). + This is used when a cluster node is rebind to a different address. + + When this method is called the pool will already be locked, so getting the pool lock inside is not needed. + This new address will be used to create new connections until the old node is decomissioned. + + :param tmp_host_address: The temporary host address to use for the connection. + :param tmp_relax_timeout: The relax timeout to use for the connection. + If -1 is provided - the relax timeout is disabled, so the tmp property is not set + """ + self.connection_kwargs.update({"tmp_host_address": tmp_host_address}) + self.connection_kwargs.update({"tmp_relax_timeout": tmp_relax_timeout}) + + def update_connections_tmp_settings( + self, + tmp_host_address: Optional[str] = None, + tmp_relax_timeout: Optional[float] = None, + ): + """ + Update the tmp settings for all connections in the pool. + This is used when a cluster node is rebind to a different address. + + When this method is called the pool will already be locked, so getting the pool lock inside is not needed. + + :param tmp_host_address: The temporary host address to use for the connection. + :param tmp_relax_timeout: The relax timeout to use for the connection. + """ + with self._lock: + for conn in self._available_connections: + self._update_connection_tmp_settings( + conn, tmp_host_address, tmp_relax_timeout + ) + for conn in self._in_use_connections: + self._update_connection_tmp_settings( + conn, tmp_host_address, tmp_relax_timeout + ) + + def update_active_connections_for_reconnect( + self, + tmp_host_address: Optional[str] = None, + tmp_relax_timeout: Optional[float] = None, + ): + """ + Mark all active connections for reconnect. + This is used when a cluster node is migrated to a different address. + + When this method is called the pool will already be locked, so getting the pool lock inside is not needed. + + :param tmp_host_address: The temporary host address to use for the connection. + """ + for conn in self._in_use_connections: + self._update_connection_for_reconnect( + conn, tmp_host_address, tmp_relax_timeout + ) + + def disconnect_and_reconfigure_free_connections( + self, + tmp_host_address: Optional[str] = None, + tmp_relax_timeout: Optional[float] = None, + ): + """ + Disconnect all free/available connections. + This is used when a cluster node is migrated to a different address. + + When this method is called the pool will already be locked, so getting the pool lock inside is not needed. + + :param tmp_host_address: The temporary host address to use for the connection. + :param tmp_relax_timeout: The relax timeout to use for the connection. + """ + + for conn in self._available_connections: + self._disconnect_and_update_connection_for_reconnect( + conn, tmp_host_address, tmp_relax_timeout + ) + + def update_connections_current_timeout( + self, + relax_timeout: Optional[float], + include_free_connections: bool = False, + ): + """ + Update the timeout either for all connections in the pool or just for the ones in use. + This is used when a cluster node is migrated to a different address. + + When this method is called the pool will already be locked, so getting the pool lock inside is not needed. + + :param relax_timeout: The relax timeout to use for the connection. + If -1 is provided - the relax timeout is disabled. + :param include_available_connections: Whether to include available connections in the update. + """ + for conn in self._in_use_connections: + self._update_connection_timeout(conn, relax_timeout) + + if include_free_connections: + for conn in self._available_connections: + self._update_connection_timeout(conn, relax_timeout) + + def _update_connection_for_reconnect( + self, + connection: "Connection", + tmp_host_address: Optional[str] = None, + tmp_relax_timeout: Optional[float] = None, + ): + connection.mark_for_reconnect() + self._update_connection_tmp_settings( + connection, tmp_host_address, tmp_relax_timeout + ) + + def _disconnect_and_update_connection_for_reconnect( + self, + connection: "Connection", + tmp_host_address: Optional[str] = None, + tmp_relax_timeout: Optional[float] = None, + ): + connection.disconnect() + self._update_connection_tmp_settings( + connection, tmp_host_address, tmp_relax_timeout + ) + + def _update_connection_tmp_settings( + self, + connection: "Connection", + tmp_host_address: Optional[str] = None, + tmp_relax_timeout: Optional[float] = None, + ): + connection.tmp_host_address = tmp_host_address + connection.tmp_relax_timeout = tmp_relax_timeout + + def _update_connection_timeout( + self, connection: "Connection", relax_timeout: Optional[Number] + ): + connection.update_current_socket_timeout(relax_timeout) + async def _mock(self, error: RedisError): """ Dummy functions, needs to be passed as error callback to retry object. @@ -1647,6 +1959,16 @@ async def _mock(self, error: RedisError): """ pass + def set_maintenance_state_for_all(self, state: "MaintenanceState"): + with self._lock: + for conn in self._available_connections: + conn.set_maintenance_state(state) + for conn in self._in_use_connections: + conn.set_maintenance_state(state) + + def set_maintenance_state_in_kwargs(self, state: "MaintenanceState"): + self.connection_kwargs["maintenance_state"] = state + class BlockingConnectionPool(ConnectionPool): """ @@ -1692,6 +2014,8 @@ def __init__( ): self.queue_class = queue_class self.timeout = timeout + self._in_maintenance = False + self._locked = False super().__init__( connection_class=connection_class, max_connections=max_connections, @@ -1700,16 +2024,27 @@ def __init__( def reset(self): # Create and fill up a thread safe queue with ``None`` values. - self.pool = self.queue_class(self.max_connections) - while True: - try: - self.pool.put_nowait(None) - except Full: - break + try: + if self._in_maintenance: + self._lock.acquire() + self._locked = True + self.pool = self.queue_class(self.max_connections) + while True: + try: + self.pool.put_nowait(None) + except Full: + break - # Keep a list of actual connection instances so that we can - # disconnect them later. - self._connections = [] + # Keep a list of actual connection instances so that we can + # disconnect them later. + self._connections = [] + finally: + if self._locked: + try: + self._lock.release() + except Exception: + pass + self._locked = False # this must be the last operation in this method. while reset() is # called when holding _fork_lock, other threads in this process @@ -1724,14 +2059,33 @@ def reset(self): def make_connection(self): "Make a fresh connection." - if self.cache is not None: - connection = CacheProxyConnection( - self.connection_class(**self.connection_kwargs), self.cache, self._lock + try: + if self._in_maintenance: + self._lock.acquire() + self._locked = True + # Pass current maintenance_state to new connections + maintenance_state = self.connection_kwargs.get( + "maintenance_state", MaintenanceState.NONE ) - else: - connection = self.connection_class(**self.connection_kwargs) - self._connections.append(connection) - return connection + kwargs = dict(self.connection_kwargs) + kwargs["maintenance_state"] = maintenance_state + if self.cache is not None: + connection = CacheProxyConnection( + self.connection_class(**kwargs), + self.cache, + self._lock, + ) + else: + connection = self.connection_class(**kwargs) + self._connections.append(connection) + return connection + finally: + if self._locked: + try: + self._lock.release() + except Exception: + pass + self._locked = False @deprecated_args( args_to_warn=["*"], @@ -1757,16 +2111,27 @@ def get_connection(self, command_name=None, *keys, **options): # self.timeout then raise a ``ConnectionError``. connection = None try: - connection = self.pool.get(block=True, timeout=self.timeout) - except Empty: - # Note that this is not caught by the redis client and will be - # raised unless handled by application code. If you want never to - raise ConnectionError("No connection available.") - - # If the ``connection`` is actually ``None`` then that's a cue to make - # a new connection to add to the pool. - if connection is None: - connection = self.make_connection() + if self._in_maintenance: + self._lock.acquire() + self._locked = True + try: + connection = self.pool.get(block=True, timeout=self.timeout) + except Empty: + # Note that this is not caught by the redis client and will be + # raised unless handled by application code. If you want never to + raise ConnectionError("No connection available.") + + # If the ``connection`` is actually ``None`` then that's a cue to make + # a new connection to add to the pool. + if connection is None: + connection = self.make_connection() + finally: + if self._locked: + try: + self._lock.release() + except Exception: + pass + self._locked = False try: # ensure this connection is connected to Redis @@ -1794,25 +2159,140 @@ def release(self, connection): "Releases the connection back to the pool." # Make sure we haven't changed process. self._checkpid() - if not self.owns_connection(connection): - # pool doesn't own this connection. do not add it back - # to the pool. instead add a None value which is a placeholder - # that will cause the pool to recreate the connection if - # its needed. - connection.disconnect() - self.pool.put_nowait(None) - return - # Put the connection back into the pool. try: - self.pool.put_nowait(connection) - except Full: - # perhaps the pool has been reset() after a fork? regardless, - # we don't want this connection - pass + if self._in_maintenance: + self._lock.acquire() + self._locked = True + if not self.owns_connection(connection): + # pool doesn't own this connection. do not add it back + # to the pool. instead add a None value which is a placeholder + # that will cause the pool to recreate the connection if + # its needed. + connection.disconnect() + self.pool.put_nowait(None) + return + if connection.should_reconnect(): + connection.disconnect() + # Put the connection back into the pool. + try: + self.pool.put_nowait(connection) + except Full: + # perhaps the pool has been reset() after a fork? regardless, + # we don't want this connection + pass + finally: + if self._locked: + try: + self._lock.release() + except Exception: + pass + self._locked = False def disconnect(self): "Disconnects all connections in the pool." self._checkpid() - for connection in self._connections: - connection.disconnect() + try: + if self._in_maintenance: + self._lock.acquire() + self._locked = True + for connection in self._connections: + connection.disconnect() + finally: + if self._locked: + try: + self._lock.release() + except Exception: + pass + self._locked = False + + def update_active_connections_for_reconnect( + self, + tmp_host_address: Optional[str] = None, + tmp_relax_timeout: Optional[float] = None, + ): + with self._lock: + connections_in_queue = {conn for conn in self.pool.queue if conn} + for conn in self._connections: + if conn not in connections_in_queue: + self._update_connection_for_reconnect( + conn, tmp_host_address, tmp_relax_timeout + ) + + def disconnect_and_reconfigure_free_connections( + self, + tmp_host_address: Optional[str] = None, + tmp_relax_timeout: Optional[Number] = None, + ): + with self._lock: + existing_connections = self.pool.queue + + for conn in existing_connections: + if conn: + self._disconnect_and_update_connection_for_reconnect( + conn, tmp_host_address, tmp_relax_timeout + ) + + def update_connections_current_timeout( + self, + relax_timeout: Optional[float] = None, + include_free_connections: bool = False, + ): + logging.debug( + f"***** Blocking Pool --> Updating timeouts. relax_timeout: {relax_timeout}" + ) + + with self._lock: + if include_free_connections: + for conn in tuple(self._connections): + self._update_connection_timeout(conn, relax_timeout) + else: + connections_in_queue = {conn for conn in self.pool.queue if conn} + for conn in self._connections: + if conn not in connections_in_queue: + self._update_connection_timeout(conn, relax_timeout) + + def update_connections_tmp_settings( + self, + tmp_host_address: Optional[str] = None, + tmp_relax_timeout: Optional[float] = None, + ): + with self._lock: + for conn in tuple(self._connections): + self._update_connection_tmp_settings( + conn, tmp_host_address, tmp_relax_timeout + ) + + def _update_maintenance_events_config_for_connections( + self, maintenance_events_config + ): + with self._lock: + for conn in tuple(self._connections): + conn.maintenance_events_config = maintenance_events_config + + def _update_maintenance_events_configs_for_connections( + self, maintenance_events_pool_handler + ): + """Override base class method to work with BlockingConnectionPool's structure.""" + with self._lock: + for conn in tuple(self._connections): + if conn: # conn can be None in BlockingConnectionPool + conn.set_maintenance_event_pool_handler( + maintenance_events_pool_handler + ) + conn.maintenance_events_config = ( + maintenance_events_pool_handler.config + ) + + def set_in_maintenance(self, in_maintenance: bool): + """Set the maintenance mode for the connection pool.""" + self._in_maintenance = in_maintenance + + def set_maintenance_state_for_all(self, state: "MaintenanceState"): + with self._lock: + for conn in getattr(self, "_connections", []): + if conn: + conn.set_maintenance_state(state) + + def set_maintenance_state_in_kwargs(self, state: "MaintenanceState"): + self.connection_kwargs["maintenance_state"] = state diff --git a/redis/maintenance_events.py b/redis/maintenance_events.py new file mode 100644 index 0000000000..dd62602105 --- /dev/null +++ b/redis/maintenance_events.py @@ -0,0 +1,453 @@ +import enum +import logging +import threading +import time +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Optional, Union + +from redis.typing import Number + + +class MaintenanceState(enum.Enum): + NONE = "none" + MOVING = "moving" + MIGRATING = "migrating" + + +if TYPE_CHECKING: + from redis.connection import ( + BlockingConnectionPool, + ConnectionInterface, + ConnectionPool, + ) + + +class MaintenanceEvent(ABC): + """ + Base class for maintenance events sent through push messages by Redis server. + + This class provides common functionality for all maintenance events including + unique identification and TTL (Time-To-Live) functionality. + + Attributes: + id (int): Unique identifier for this event + ttl (int): Time-to-live in seconds for this notification + creation_time (float): Timestamp when the notification was created/read + """ + + def __init__(self, id: int, ttl: int): + """ + Initialize a new MaintenanceEvent with unique ID and TTL functionality. + + Args: + id (int): Unique identifier for this event + ttl (int): Time-to-live in seconds for this notification + """ + self.id = id + self.ttl = ttl + self.creation_time = time.monotonic() + self.expire_at = self.creation_time + self.ttl + + def is_expired(self) -> bool: + """ + Check if this event has expired based on its TTL + and creation time. + + Returns: + bool: True if the event has expired, False otherwise + """ + return time.monotonic() > (self.creation_time + self.ttl) + + @abstractmethod + def __repr__(self) -> str: + """ + Return a string representation of the maintenance event. + + This method must be implemented by all concrete subclasses. + + Returns: + str: String representation of the event + """ + pass + + @abstractmethod + def __eq__(self, other) -> bool: + """ + Compare two maintenance events for equality. + + This method must be implemented by all concrete subclasses. + Events are typically considered equal if they have the same id + and are of the same type. + + Args: + other: The other object to compare with + + Returns: + bool: True if the events are equal, False otherwise + """ + pass + + @abstractmethod + def __hash__(self) -> int: + """ + Return a hash value for the maintenance event. + + This method must be implemented by all concrete subclasses to allow + instances to be used in sets and as dictionary keys. + + Returns: + int: Hash value for the event + """ + pass + + +class NodeMovingEvent(MaintenanceEvent): + """ + This event is received when a node is replaced with a new node + during cluster rebalancing or maintenance operations. + """ + + def __init__(self, id: int, new_node_host: str, new_node_port: int, ttl: int): + """ + Initialize a new NodeMovingEvent. + + Args: + id (int): Unique identifier for this event + new_node_host (str): Hostname or IP address of the new replacement node + new_node_port (int): Port number of the new replacement node + ttl (int): Time-to-live in seconds for this notification + """ + super().__init__(id, ttl) + self.new_node_host = new_node_host + self.new_node_port = new_node_port + + def __repr__(self) -> str: + expiry_time = self.expire_at + remaining = max(0, expiry_time - time.monotonic()) + + return ( + f"{self.__class__.__name__}(" + f"id={self.id}, " + f"new_node_host='{self.new_node_host}', " + f"new_node_port={self.new_node_port}, " + f"ttl={self.ttl}, " + f"creation_time={self.creation_time}, " + f"expires_at={expiry_time}, " + f"remaining={remaining:.1f}s, " + f"expired={self.is_expired()}" + f")" + ) + + def __eq__(self, other) -> bool: + """ + Two NodeMovingEvent events are considered equal if they have the same + id, new_node_host, and new_node_port. + """ + if not isinstance(other, NodeMovingEvent): + return False + return ( + self.id == other.id + and self.new_node_host == other.new_node_host + and self.new_node_port == other.new_node_port + ) + + def __hash__(self) -> int: + """ + Return a hash value for the event to allow + instances to be used in sets and as dictionary keys. + + Returns: + int: Hash value based on event type, id, new_node_host, and new_node_port + """ + return hash((self.__class__, self.id, self.new_node_host, self.new_node_port)) + + +class NodeMigratingEvent(MaintenanceEvent): + """ + Event for when a Redis cluster node is in the process of migrating slots. + + This event is received when a node starts migrating its slots to another node + during cluster rebalancing or maintenance operations. + + Args: + id (int): Unique identifier for this event + ttl (int): Time-to-live in seconds for this notification + """ + + def __init__(self, id: int, ttl: int): + super().__init__(id, ttl) + + def __repr__(self) -> str: + expiry_time = self.creation_time + self.ttl + remaining = max(0, expiry_time - time.monotonic()) + return ( + f"{self.__class__.__name__}(" + f"id={self.id}, " + f"ttl={self.ttl}, " + f"creation_time={self.creation_time}, " + f"expires_at={expiry_time}, " + f"remaining={remaining:.1f}s, " + f"expired={self.is_expired()}" + f")" + ) + + def __eq__(self, other) -> bool: + """ + Two NodeMigratingEvent events are considered equal if they have the same + id and are of the same type. + """ + if not isinstance(other, NodeMigratingEvent): + return False + return self.id == other.id and type(self) is type(other) + + def __hash__(self) -> int: + """ + Return a hash value for the event to allow + instances to be used in sets and as dictionary keys. + + Returns: + int: Hash value based on event type and id + """ + return hash((self.__class__, self.id)) + + +class NodeMigratedEvent(MaintenanceEvent): + """ + Event for when a Redis cluster node has completed migrating slots. + + This event is received when a node has finished migrating all its slots + to other nodes during cluster rebalancing or maintenance operations. + + Args: + id (int): Unique identifier for this event + """ + + DEFAULT_TTL = 5 + + def __init__(self, id: int): + super().__init__(id, NodeMigratedEvent.DEFAULT_TTL) + + def __repr__(self) -> str: + expiry_time = self.creation_time + self.ttl + remaining = max(0, expiry_time - time.monotonic()) + return ( + f"{self.__class__.__name__}(" + f"id={self.id}, " + f"ttl={self.ttl}, " + f"creation_time={self.creation_time}, " + f"expires_at={expiry_time}, " + f"remaining={remaining:.1f}s, " + f"expired={self.is_expired()}" + f")" + ) + + def __eq__(self, other) -> bool: + """ + Two NodeMigratedEvent events are considered equal if they have the same + id and are of the same type. + """ + if not isinstance(other, NodeMigratedEvent): + return False + return self.id == other.id and type(self) is type(other) + + def __hash__(self) -> int: + """ + Return a hash value for the event to allow + instances to be used in sets and as dictionary keys. + + Returns: + int: Hash value based on event type and id + """ + return hash((self.__class__, self.id)) + + +class MaintenanceEventsConfig: + """ + Configuration class for maintenance events handling behaviour. Events are received through + push notifications. + + This class defines how the Redis client should react to different push notifications + such as node moving, migrations, etc. in a Redis cluster. + + """ + + def __init__( + self, + enabled: bool = False, + proactive_reconnect: bool = True, + relax_timeout: Optional[Number] = 20, + ): + """ + Initialize a new MaintenanceEventsConfig. + + Args: + enabled (bool): Whether to enable maintenance events handling. + Defaults to False. + proactive_reconnect (bool): Whether to proactively reconnect when a node is replaced. + Defaults to True. + relax_timeout (Number): The relax timeout to use for the connection during maintenance. + If -1 is provided - the relax timeout is disabled. Defaults to 20. + + """ + self.enabled = enabled + self.relax_timeout = relax_timeout + self.proactive_reconnect = proactive_reconnect + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(" + f"enabled={self.enabled}, " + f"proactive_reconnect={self.proactive_reconnect}, " + f"relax_timeout={self.relax_timeout}, " + f")" + ) + + def is_relax_timeouts_enabled(self) -> bool: + """ + Check if the relax_timeout is enabled. The '-1' value is used to disable the relax_timeout. + If relax_timeout is set to None, it will make the operation blocking + and waiting until any response is received. + + Returns: + True if the relax_timeout is enabled, False otherwise. + """ + return self.relax_timeout != -1 + + +class MaintenanceEventPoolHandler: + def __init__( + self, + pool: Union["ConnectionPool", "BlockingConnectionPool"], + config: MaintenanceEventsConfig, + ) -> None: + self.pool = pool + self.config = config + self._processed_events = set() + self._lock = threading.RLock() + + def remove_expired_notifications(self): + with self._lock: + for notification in tuple(self._processed_events): + if notification.is_expired(): + self._processed_events.remove(notification) + + def handle_event(self, notification: MaintenanceEvent): + self.remove_expired_notifications() + + if isinstance(notification, NodeMovingEvent): + return self.handle_node_moving_event(notification) + else: + logging.error(f"Unhandled notification type: {notification}") + + def handle_node_moving_event(self, event: NodeMovingEvent): + if ( + not self.config.proactive_reconnect + and not self.config.is_relax_timeouts_enabled() + ): + return + with self._lock: + if event in self._processed_events: + # nothing to do in the connection pool handling + # the event has already been handled or is expired + # just return + return + + with self.pool._lock: + if ( + self.config.proactive_reconnect + or self.config.is_relax_timeouts_enabled() + ): + if getattr(self.pool, "set_in_maintenance", False): + self.pool.set_in_maintenance(True) + # Set state to MOVING for all connections and in kwargs (inside pool lock, after set_in_maintenance) + self.pool.set_maintenance_state_for_all(MaintenanceState.MOVING) + self.pool.set_maintenance_state_in_kwargs(MaintenanceState.MOVING) + # edit the config for new connections until the notification expires + self.pool.update_connection_kwargs_with_tmp_settings( + tmp_host_address=event.new_node_host, + tmp_relax_timeout=self.config.relax_timeout, + ) + if self.config.is_relax_timeouts_enabled(): + # extend the timeout for all connections that are currently in use + self.pool.update_connections_current_timeout( + self.config.relax_timeout + ) + if self.config.proactive_reconnect: + # take care for the active connections in the pool + # mark them for reconnect after they complete the current command + self.pool.update_active_connections_for_reconnect( + tmp_host_address=event.new_node_host, + tmp_relax_timeout=self.config.relax_timeout, + ) + # take care for the inactive connections in the pool + # delete them and create new ones + self.pool.disconnect_and_reconfigure_free_connections( + tmp_host_address=event.new_node_host, + tmp_relax_timeout=self.config.relax_timeout, + ) + if getattr(self.pool, "set_in_maintenance", False): + self.pool.set_in_maintenance(False) + + threading.Timer(event.ttl, self.handle_node_moved_event).start() + + self._processed_events.add(event) + + def handle_node_moved_event(self): + with self._lock: + self.pool.update_connection_kwargs_with_tmp_settings( + tmp_host_address=None, + tmp_relax_timeout=-1, + ) + # Clear state to NONE in kwargs immediately after updating tmp kwargs + self.pool.set_maintenance_state_in_kwargs(MaintenanceState.NONE) + with self.pool._lock: + if self.config.is_relax_timeouts_enabled(): + # reset the timeout for existing connections + self.pool.update_connections_current_timeout( + relax_timeout=-1, include_free_connections=True + ) + self.pool.update_connections_tmp_settings( + tmp_host_address=None, tmp_relax_timeout=-1 + ) + # Clear state to NONE for all connections + self.pool.set_maintenance_state_for_all(MaintenanceState.NONE) + + +class MaintenanceEventConnectionHandler: + def __init__( + self, connection: "ConnectionInterface", config: MaintenanceEventsConfig + ) -> None: + self.connection = connection + self.config = config + + def handle_event(self, event: MaintenanceEvent): + if isinstance(event, NodeMigratingEvent): + return self.handle_migrating_event(event) + elif isinstance(event, NodeMigratedEvent): + return self.handle_migration_completed_event(event) + else: + logging.error(f"Unhandled event type: {event}") + + def handle_migrating_event(self, notification: NodeMigratingEvent): + if ( + self.connection.maintenance_state == MaintenanceState.MOVING + or not self.config.is_relax_timeouts_enabled() + ): + return + self.connection.set_maintenance_state(MaintenanceState.MIGRATING) + # extend the timeout for all created connections + self.connection.update_current_socket_timeout(self.config.relax_timeout) + self.connection.update_tmp_settings(tmp_relax_timeout=self.config.relax_timeout) + + def handle_migration_completed_event(self, notification: "NodeMigratedEvent"): + # Only reset timeouts if state is not MOVING and relax timeouts are enabled + if ( + self.connection.maintenance_state == MaintenanceState.MOVING + or not self.config.is_relax_timeouts_enabled() + ): + return + self.connection.set_maintenance_state(MaintenanceState.NONE) + # Node migration completed - reset the connection + # timeouts by providing -1 as the relax timeout + self.connection.update_current_socket_timeout(-1) + self.connection.update_tmp_settings(tmp_relax_timeout=-1) diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index 67b2fd5030..def76bc8f8 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -9,6 +9,7 @@ import redis from redis.cache import CacheConfig from redis.connection import CacheProxyConnection, Connection, to_bool +from redis.maintenance_events import MaintenanceState from redis.utils import SSL_AVAILABLE from .conftest import ( @@ -33,6 +34,9 @@ def connect(self): def can_read(self): return False + def should_reconnect(self): + return False + class TestConnectionPool: def get_pool( @@ -50,10 +54,15 @@ def get_pool( return pool def test_connection_creation(self): - connection_kwargs = {"foo": "bar", "biz": "baz"} + connection_kwargs = { + "foo": "bar", + "biz": "baz", + "maintenance_state": MaintenanceState.NONE, + } pool = self.get_pool( connection_kwargs=connection_kwargs, connection_class=DummyConnection ) + connection = pool.get_connection() assert isinstance(connection, DummyConnection) assert connection.kwargs == connection_kwargs @@ -147,7 +156,9 @@ def test_connection_creation(self, master_host): "host": master_host[0], "port": master_host[1], } + pool = self.get_pool(connection_kwargs=connection_kwargs) + connection_kwargs["maintenance_state"] = MaintenanceState.NONE connection = pool.get_connection() assert isinstance(connection, DummyConnection) assert connection.kwargs == connection_kwargs diff --git a/tests/test_maintenance_events.py b/tests/test_maintenance_events.py new file mode 100644 index 0000000000..ac7d10b51e --- /dev/null +++ b/tests/test_maintenance_events.py @@ -0,0 +1,545 @@ +import threading +from unittest.mock import Mock, patch, MagicMock +import pytest + +from redis.maintenance_events import ( + MaintenanceEvent, + NodeMovingEvent, + NodeMigratingEvent, + NodeMigratedEvent, + MaintenanceEventsConfig, + MaintenanceEventPoolHandler, + MaintenanceEventConnectionHandler, +) + + +class TestMaintenanceEvent: + """Test the base MaintenanceEvent class functionality through concrete subclasses.""" + + def test_abstract_class_cannot_be_instantiated(self): + """Test that MaintenanceEvent cannot be instantiated directly.""" + with patch("time.monotonic", return_value=1000): + with pytest.raises(TypeError): + MaintenanceEvent(id=1, ttl=10) # type: ignore + + def test_init_through_subclass(self): + """Test MaintenanceEvent initialization through concrete subclass.""" + with patch("time.monotonic", return_value=1000): + event = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + assert event.id == 1 + assert event.ttl == 10 + assert event.creation_time == 1000 + assert event.expire_at == 1010 + + def test_is_expired_false(self): + """Test is_expired returns False for non-expired event.""" + with patch("time.monotonic", return_value=1000): + event = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + + with patch("time.monotonic", return_value=1005): # 5 seconds later + assert not event.is_expired() + + def test_is_expired_true(self): + """Test is_expired returns True for expired event.""" + with patch("time.monotonic", return_value=1000): + event = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + + with patch("time.monotonic", return_value=1015): # 15 seconds later + assert event.is_expired() + + def test_is_expired_exact_boundary(self): + """Test is_expired at exact expiration boundary.""" + with patch("time.monotonic", return_value=1000): + event = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + + with patch("time.monotonic", return_value=1010): # Exactly at expiration + assert not event.is_expired() + + with patch("time.monotonic", return_value=1011): # 1 second past expiration + assert event.is_expired() + + +class TestNodeMovingEvent: + """Test the NodeMovingEvent class.""" + + def test_init(self): + """Test NodeMovingEvent initialization.""" + with patch("time.monotonic", return_value=1000): + event = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + assert event.id == 1 + assert event.new_node_host == "localhost" + assert event.new_node_port == 6379 + assert event.ttl == 10 + assert event.creation_time == 1000 + + def test_repr(self): + """Test NodeMovingEvent string representation.""" + with patch("time.monotonic", return_value=1000): + event = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + + with patch("time.monotonic", return_value=1005): # 5 seconds later + repr_str = repr(event) + assert "NodeMovingEvent" in repr_str + assert "id=1" in repr_str + assert "new_node_host='localhost'" in repr_str + assert "new_node_port=6379" in repr_str + assert "ttl=10" in repr_str + assert "remaining=5.0s" in repr_str + assert "expired=False" in repr_str + + def test_equality_same_id_host_port(self): + """Test equality for events with same id, host, and port.""" + event1 = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + event2 = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=20 + ) # Different TTL + assert event1 == event2 + + def test_equality_same_id_different_host(self): + """Test inequality for events with same id but different host.""" + event1 = NodeMovingEvent( + id=1, new_node_host="host1", new_node_port=6379, ttl=10 + ) + event2 = NodeMovingEvent( + id=1, new_node_host="host2", new_node_port=6379, ttl=10 + ) + assert event1 != event2 + + def test_equality_same_id_different_port(self): + """Test inequality for events with same id but different port.""" + event1 = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + event2 = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6380, ttl=10 + ) + assert event1 != event2 + + def test_equality_different_id(self): + """Test inequality for events with different id.""" + event1 = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + event2 = NodeMovingEvent( + id=2, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + assert event1 != event2 + + def test_equality_different_type(self): + """Test inequality for events of different types.""" + event1 = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + event2 = NodeMigratingEvent(id=1, ttl=10) + assert event1 != event2 + + def test_hash_same_id_host_port(self): + """Test hash consistency for events with same id, host, and port.""" + event1 = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + event2 = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=20 + ) # Different TTL + assert hash(event1) == hash(event2) + + def test_hash_different_host(self): + """Test hash difference for events with different host.""" + event1 = NodeMovingEvent( + id=1, new_node_host="host1", new_node_port=6379, ttl=10 + ) + event2 = NodeMovingEvent( + id=1, new_node_host="host2", new_node_port=6379, ttl=10 + ) + assert hash(event1) != hash(event2) + + def test_hash_different_port(self): + """Test hash difference for events with different port.""" + event1 = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + event2 = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6380, ttl=10 + ) + assert hash(event1) != hash(event2) + + def test_hash_different_id(self): + """Test hash difference for events with different id.""" + event1 = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + event2 = NodeMovingEvent( + id=2, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + assert hash(event1) != hash(event2) + + def test_set_functionality(self): + """Test that events can be used in sets correctly.""" + event1 = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + event2 = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=20 + ) # Same id, host, port - should be considered the same + event3 = NodeMovingEvent( + id=1, new_node_host="host2", new_node_port=6380, ttl=10 + ) # Same id but different host/port - should be different + event4 = NodeMovingEvent( + id=2, new_node_host="localhost", new_node_port=6379, ttl=10 + ) # Different id - should be different + + event_set = {event1, event2, event3, event4} + assert len(event_set) == 3 # event1 and event2 should be considered the same + + +class TestNodeMigratingEvent: + """Test the NodeMigratingEvent class.""" + + def test_init(self): + """Test NodeMigratingEvent initialization.""" + with patch("time.monotonic", return_value=1000): + event = NodeMigratingEvent(id=1, ttl=5) + assert event.id == 1 + assert event.ttl == 5 + assert event.creation_time == 1000 + + def test_repr(self): + """Test NodeMigratingEvent string representation.""" + with patch("time.monotonic", return_value=1000): + event = NodeMigratingEvent(id=1, ttl=5) + + with patch("time.monotonic", return_value=1002): # 2 seconds later + repr_str = repr(event) + assert "NodeMigratingEvent" in repr_str + assert "id=1" in repr_str + assert "ttl=5" in repr_str + assert "remaining=3.0s" in repr_str + assert "expired=False" in repr_str + + def test_equality_and_hash(self): + """Test equality and hash for NodeMigratingEvent.""" + event1 = NodeMigratingEvent(id=1, ttl=5) + event2 = NodeMigratingEvent(id=1, ttl=10) # Same id, different ttl + event3 = NodeMigratingEvent(id=2, ttl=5) # Different id + + assert event1 == event2 + assert event1 != event3 + assert hash(event1) == hash(event2) + assert hash(event1) != hash(event3) + + +class TestNodeMigratedEvent: + """Test the NodeMigratedEvent class.""" + + def test_init(self): + """Test NodeMigratedEvent initialization.""" + with patch("time.monotonic", return_value=1000): + event = NodeMigratedEvent(id=1) + assert event.id == 1 + assert event.ttl == NodeMigratedEvent.DEFAULT_TTL + assert event.creation_time == 1000 + + def test_default_ttl(self): + """Test that DEFAULT_TTL is used correctly.""" + assert NodeMigratedEvent.DEFAULT_TTL == 5 + event = NodeMigratedEvent(id=1) + assert event.ttl == 5 + + def test_repr(self): + """Test NodeMigratedEvent string representation.""" + with patch("time.monotonic", return_value=1000): + event = NodeMigratedEvent(id=1) + + with patch("time.monotonic", return_value=1001): # 1 second later + repr_str = repr(event) + assert "NodeMigratedEvent" in repr_str + assert "id=1" in repr_str + assert "ttl=5" in repr_str + assert "remaining=4.0s" in repr_str + assert "expired=False" in repr_str + + def test_equality_and_hash(self): + """Test equality and hash for NodeMigratedEvent.""" + event1 = NodeMigratedEvent(id=1) + event2 = NodeMigratedEvent(id=1) # Same id + event3 = NodeMigratedEvent(id=2) # Different id + + assert event1 == event2 + assert event1 != event3 + assert hash(event1) == hash(event2) + assert hash(event1) != hash(event3) + + +class TestMaintenanceEventsConfig: + """Test the MaintenanceEventsConfig class.""" + + def test_init_defaults(self): + """Test MaintenanceEventsConfig initialization with defaults.""" + config = MaintenanceEventsConfig() + assert config.enabled is False + assert config.proactive_reconnect is True + assert config.relax_timeout == 20 + + def test_init_custom_values(self): + """Test MaintenanceEventsConfig initialization with custom values.""" + config = MaintenanceEventsConfig( + enabled=True, proactive_reconnect=False, relax_timeout=30 + ) + assert config.enabled is True + assert config.proactive_reconnect is False + assert config.relax_timeout == 30 + + def test_repr(self): + """Test MaintenanceEventsConfig string representation.""" + config = MaintenanceEventsConfig( + enabled=True, proactive_reconnect=False, relax_timeout=30 + ) + repr_str = repr(config) + assert "MaintenanceEventsConfig" in repr_str + assert "enabled=True" in repr_str + assert "proactive_reconnect=False" in repr_str + assert "relax_timeout=30" in repr_str + + def test_is_relax_timeouts_enabled_true(self): + """Test is_relax_timeouts_enabled returns True for positive timeout.""" + config = MaintenanceEventsConfig(relax_timeout=20) + assert config.is_relax_timeouts_enabled() is True + + def test_is_relax_timeouts_enabled_false(self): + """Test is_relax_timeouts_enabled returns False for -1 timeout.""" + config = MaintenanceEventsConfig(relax_timeout=-1) + assert config.is_relax_timeouts_enabled() is False + + def test_is_relax_timeouts_enabled_zero(self): + """Test is_relax_timeouts_enabled returns True for zero timeout.""" + config = MaintenanceEventsConfig(relax_timeout=0) + assert config.is_relax_timeouts_enabled() is True + + def test_is_relax_timeouts_enabled_none(self): + """Test is_relax_timeouts_enabled returns True for None timeout.""" + config = MaintenanceEventsConfig(relax_timeout=None) + assert config.is_relax_timeouts_enabled() is True + + def test_relax_timeout_none_is_saved_as_none(self): + """Test that None value for relax_timeout is saved as None.""" + config = MaintenanceEventsConfig(relax_timeout=None) + assert config.relax_timeout is None + + +class TestMaintenanceEventPoolHandler: + """Test the MaintenanceEventPoolHandler class.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_pool = Mock() + self.mock_pool._lock = MagicMock() + self.mock_pool._lock.__enter__.return_value = None + self.mock_pool._lock.__exit__.return_value = None + self.config = MaintenanceEventsConfig( + enabled=True, proactive_reconnect=True, relax_timeout=20 + ) + self.handler = MaintenanceEventPoolHandler(self.mock_pool, self.config) + + def test_init(self): + """Test MaintenanceEventPoolHandler initialization.""" + assert self.handler.pool == self.mock_pool + assert self.handler.config == self.config + assert isinstance(self.handler._processed_events, set) + assert isinstance(self.handler._lock, type(threading.RLock())) + + def test_remove_expired_notifications(self): + """Test removal of expired notifications.""" + with patch("time.monotonic", return_value=1000): + event1 = NodeMovingEvent( + id=1, new_node_host="host1", new_node_port=6379, ttl=10 + ) + event2 = NodeMovingEvent( + id=2, new_node_host="host2", new_node_port=6380, ttl=5 + ) + self.handler._processed_events.add(event1) + self.handler._processed_events.add(event2) + + # Move time forward but not enough to expire event2 (expires at 1005) + with patch("time.monotonic", return_value=1003): + self.handler.remove_expired_notifications() + assert event1 in self.handler._processed_events + assert event2 in self.handler._processed_events # Not expired yet + + # Move time forward to expire event2 but not event1 + with patch("time.monotonic", return_value=1006): + self.handler.remove_expired_notifications() + assert event1 in self.handler._processed_events + assert event2 not in self.handler._processed_events # Now expired + + def test_handle_event_node_moving(self): + """Test handling of NodeMovingEvent.""" + event = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + + with patch.object(self.handler, "handle_node_moving_event") as mock_handle: + self.handler.handle_event(event) + mock_handle.assert_called_once_with(event) + + def test_handle_event_unknown_type(self): + """Test handling of unknown event type.""" + event = NodeMigratingEvent(id=1, ttl=5) # Not handled by pool handler + + result = self.handler.handle_event(event) + assert result is None + + def test_handle_node_moving_event_disabled_config(self): + """Test node moving event handling when both features are disabled.""" + config = MaintenanceEventsConfig(proactive_reconnect=False, relax_timeout=-1) + handler = MaintenanceEventPoolHandler(self.mock_pool, config) + event = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + + result = handler.handle_node_moving_event(event) + assert result is None + assert event not in handler._processed_events + + def test_handle_node_moving_event_already_processed(self): + """Test node moving event handling when event already processed.""" + event = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + self.handler._processed_events.add(event) + + result = self.handler.handle_node_moving_event(event) + assert result is None + + def test_handle_node_moving_event_success(self): + """Test successful node moving event handling.""" + event = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + + with ( + patch("threading.Timer") as mock_timer, + patch("time.monotonic", return_value=1000), + ): + self.handler.handle_node_moving_event(event) + + # Verify timer was started + mock_timer.assert_called_once_with( + event.ttl, self.handler.handle_node_moved_event + ) + mock_timer.return_value.start.assert_called_once() + + # Verify event was added to processed set + assert event in self.handler._processed_events + + # Verify pool methods were called + self.mock_pool.update_connection_kwargs_with_tmp_settings.assert_called_once() + + def test_handle_node_moved_event(self): + """Test handling of node moved event (cleanup).""" + self.handler.handle_node_moved_event() + + # Verify cleanup methods were called + self.mock_pool.update_connection_kwargs_with_tmp_settings.assert_called_once_with( + tmp_host_address=None, + tmp_relax_timeout=-1, + ) + + +class TestMaintenanceEventConnectionHandler: + """Test the MaintenanceEventConnectionHandler class.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_connection = Mock() + self.config = MaintenanceEventsConfig(enabled=True, relax_timeout=20) + self.handler = MaintenanceEventConnectionHandler( + self.mock_connection, self.config + ) + + def test_init(self): + """Test MaintenanceEventConnectionHandler initialization.""" + assert self.handler.connection == self.mock_connection + assert self.handler.config == self.config + + def test_handle_event_migrating(self): + """Test handling of NodeMigratingEvent.""" + event = NodeMigratingEvent(id=1, ttl=5) + + with patch.object(self.handler, "handle_migrating_event") as mock_handle: + self.handler.handle_event(event) + mock_handle.assert_called_once_with(event) + + def test_handle_event_migrated(self): + """Test handling of NodeMigratedEvent.""" + event = NodeMigratedEvent(id=1) + + with patch.object( + self.handler, "handle_migration_completed_event" + ) as mock_handle: + self.handler.handle_event(event) + mock_handle.assert_called_once_with(event) + + def test_handle_event_unknown_type(self): + """Test handling of unknown event type.""" + event = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + + result = self.handler.handle_event(event) + assert result is None + + def test_handle_migrating_event_disabled(self): + """Test migrating event handling when relax timeouts are disabled.""" + config = MaintenanceEventsConfig(relax_timeout=-1) + handler = MaintenanceEventConnectionHandler(self.mock_connection, config) + event = NodeMigratingEvent(id=1, ttl=5) + + result = handler.handle_migrating_event(event) + assert result is None + self.mock_connection.update_current_socket_timeout.assert_not_called() + + def test_handle_migrating_event_success(self): + """Test successful migrating event handling.""" + event = NodeMigratingEvent(id=1, ttl=5) + + self.handler.handle_migrating_event(event) + + self.mock_connection.update_current_socket_timeout.assert_called_once_with(20) + self.mock_connection.update_tmp_settings.assert_called_once_with( + tmp_relax_timeout=20 + ) + + def test_handle_migration_completed_event_disabled(self): + """Test migration completed event handling when relax timeouts are disabled.""" + config = MaintenanceEventsConfig(relax_timeout=-1) + handler = MaintenanceEventConnectionHandler(self.mock_connection, config) + event = NodeMigratedEvent(id=1) + + result = handler.handle_migration_completed_event(event) + assert result is None + self.mock_connection.update_current_socket_timeout.assert_not_called() + + def test_handle_migration_completed_event_success(self): + """Test successful migration completed event handling.""" + event = NodeMigratedEvent(id=1) + + self.handler.handle_migration_completed_event(event) + + self.mock_connection.update_current_socket_timeout.assert_called_once_with(-1) + self.mock_connection.update_tmp_settings.assert_called_once_with( + tmp_relax_timeout=-1 + ) diff --git a/tests/test_maintenance_events_handling.py b/tests/test_maintenance_events_handling.py new file mode 100644 index 0000000000..b573a55e5f --- /dev/null +++ b/tests/test_maintenance_events_handling.py @@ -0,0 +1,1351 @@ +import socket +import threading +from typing import List +from unittest.mock import patch +import pytest +from time import sleep + +from redis import Redis +from redis.connection import ( + AbstractConnection, + ConnectionPool, + BlockingConnectionPool, + MaintenanceState, +) +from redis.maintenance_events import ( + MaintenanceEventsConfig, + NodeMigratingEvent, + MaintenanceEventPoolHandler, + NodeMovingEvent, + NodeMigratedEvent, +) + + +class MockSocket: + """Mock socket that simulates Redis protocol responses.""" + + AFTER_MOVING_ADDRESS = "1.2.3.4:6379" + DEFAULT_ADDRESS = "12.45.34.56:6379" + MOVING_TIMEOUT = 1 + + def __init__(self): + self.connected = False + self.address = None + self.sent_data = [] + self.closed = False + self.command_count = 0 + self.pending_responses = [] + # Track socket timeout changes for maintenance events validation + self.timeout = None + self.thread_timeouts = {} # Track last applied timeout per thread + self.moving_sent = False + + def connect(self, address): + """Simulate socket connection.""" + self.connected = True + self.address = address + + def send(self, data): + """Simulate sending data to Redis.""" + if self.closed: + raise ConnectionError("Socket is closed") + self.sent_data.append(data) + + # Analyze the command and prepare appropriate response + if b"HELLO" in data: + response = b"%7\r\n$6\r\nserver\r\n$5\r\nredis\r\n$7\r\nversion\r\n$5\r\n7.0.0\r\n$5\r\nproto\r\n:3\r\n$2\r\nid\r\n:1\r\n$4\r\nmode\r\n$10\r\nstandalone\r\n$4\r\nrole\r\n$6\r\nmaster\r\n$7\r\nmodules\r\n*0\r\n" + self.pending_responses.append(response) + elif b"SET" in data: + response = b"+OK\r\n" + + # Check if this is a key that should trigger a push message + if b"key_receive_migrating_" in data or b"key_receive_migrating" in data: + # MIGRATING push message before SET key_receive_migrating_X response + # Format: >2\r\n$9\r\nMIGRATING\r\n:10\r\n (2 elements: MIGRATING, ttl) + migrating_push = ">2\r\n$9\r\nMIGRATING\r\n:10\r\n" + response = migrating_push.encode() + response + elif b"key_receive_migrated_" in data or b"key_receive_migrated" in data: + # MIGRATED push message before SET key_receive_migrated_X response + # Format: >1\r\n$8\r\nMIGRATED\r\n (1 element: MIGRATED) + migrated_push = ">1\r\n$8\r\nMIGRATED\r\n" + response = migrated_push.encode() + response + elif b"key_receive_moving_" in data: + # MOVING push message before SET key_receive_moving_X response + # Format: >3\r\n$6\r\nMOVING\r\n:15\r\n+localhost:6379\r\n (3 elements: MOVING, ttl, host:port) + # Note: Using + instead of $ to send as simple string instead of bulk string + moving_push = f">3\r\n$6\r\nMOVING\r\n:{MockSocket.MOVING_TIMEOUT}\r\n+{MockSocket.AFTER_MOVING_ADDRESS}\r\n" + response = moving_push.encode() + response + + self.pending_responses.append(response) + elif b"GET" in data: + # Extract key and provide appropriate response + if b"hello" in data: + response = b"$5\r\nworld\r\n" + self.pending_responses.append(response) + # Handle specific keys used in tests + elif b"key_receive_moving_0" in data: + self.pending_responses.append(b"$8\r\nvalue3_0\r\n") + elif b"key_receive_migrated_0" in data: + self.pending_responses.append(b"$13\r\nmigrated_value\r\n") + elif b"key_receive_migrating" in data: + self.pending_responses.append(b"$6\r\nvalue2\r\n") + elif b"key_receive_migrated" in data: + self.pending_responses.append(b"$6\r\nvalue3\r\n") + elif b"key1" in data: + self.pending_responses.append(b"$6\r\nvalue1\r\n") + else: + self.pending_responses.append(b"$-1\r\n") # NULL response + else: + self.pending_responses.append(b"+OK\r\n") # Default response + + self.command_count += 1 + return len(data) + + def sendall(self, data): + """Simulate sending all data to Redis.""" + return self.send(data) + + def recv(self, bufsize): + """Simulate receiving data from Redis.""" + if self.closed: + raise ConnectionError("Socket is closed") + + # Use pending responses that were prepared when commands were sent + if self.pending_responses: + response = self.pending_responses.pop(0) + if b"MOVING" in response: + self.moving_sent = True + return response[:bufsize] # Respect buffer size + else: + # No data available - this should block or raise an exception + # For can_read checks, we should indicate no data is available + import errno + + raise BlockingIOError(errno.EAGAIN, "Resource temporarily unavailable") + + def fileno(self): + """Return a fake file descriptor for select/poll operations.""" + return 1 # Fake file descriptor + + def close(self): + """Simulate closing the socket.""" + self.closed = True + self.connected = False + self.address = None + self.timeout = None + self.thread_timeouts = {} + + def settimeout(self, timeout): + """Simulate setting socket timeout and track changes per thread.""" + self.timeout = timeout + + # Track last applied timeout with thread_id information added + thread_id = threading.current_thread().ident + self.thread_timeouts[thread_id] = timeout + + def gettimeout(self): + """Simulate getting socket timeout.""" + return self.timeout + + def setsockopt(self, level, optname, value): + """Simulate setting socket options.""" + pass + + def getpeername(self): + """Simulate getting peer name.""" + return self.address + + def getsockname(self): + """Simulate getting socket name.""" + return (self.address.split(":")[0], 12345) + + def shutdown(self, how): + """Simulate socket shutdown.""" + pass + + +class TestMaintenanceEventsHandling: + """Integration tests for maintenance events handling with real connection pool.""" + + def setup_method(self): + """Set up test fixtures with mocked sockets.""" + self.mock_sockets = [] + self.original_socket = socket.socket + + # Mock socket creation to return our mock sockets + def mock_socket_factory(*args, **kwargs): + mock_sock = MockSocket() + self.mock_sockets.append(mock_sock) + return mock_sock + + self.socket_patcher = patch("socket.socket", side_effect=mock_socket_factory) + self.socket_patcher.start() + + # Mock select.select to simulate data availability for reading + def mock_select(rlist, wlist, xlist, timeout=0): + # Check if any of the sockets in rlist have data available + ready_sockets = [] + for sock in rlist: + if hasattr(sock, "connected") and sock.connected and not sock.closed: + # Only return socket as ready if it actually has data to read + if hasattr(sock, "pending_responses") and sock.pending_responses: + ready_sockets.append(sock) + # Don't return socket as ready just because it received commands + # Only when there are actual responses available + return (ready_sockets, [], []) + + self.select_patcher = patch("select.select", side_effect=mock_select) + self.select_patcher.start() + + # Create maintenance events config + self.config = MaintenanceEventsConfig( + enabled=True, proactive_reconnect=True, relax_timeout=30 + ) + + def teardown_method(self): + """Clean up test fixtures.""" + self.socket_patcher.stop() + self.select_patcher.stop() + + def _get_client( + self, + pool_class, + max_connections=10, + maintenance_events_config=None, + setup_pool_handler=False, + ): + """Helper method to create a pool and Redis client with maintenance events configuration. + + Args: + pool_class: The connection pool class (ConnectionPool or BlockingConnectionPool) + max_connections: Maximum number of connections in the pool (default: 10) + maintenance_events_config: Optional MaintenanceEventsConfig to use. If not provided, + uses self.config from setup_method (default: None) + setup_pool_handler: Whether to set up pool handler for moving events (default: False) + + Returns: + tuple: (test_pool, test_redis_client) + """ + config = ( + maintenance_events_config + if maintenance_events_config is not None + else self.config + ) + + test_pool = pool_class( + host=MockSocket.DEFAULT_ADDRESS.split(":")[0], + port=int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), + max_connections=max_connections, + protocol=3, # Required for maintenance events + maintenance_events_config=config, + ) + test_redis_client = Redis(connection_pool=test_pool) + + # Set up pool handler for moving events if requested + if setup_pool_handler: + pool_handler = MaintenanceEventPoolHandler( + test_redis_client.connection_pool, config + ) + test_redis_client.connection_pool.set_maintenance_events_pool_handler( + pool_handler + ) + + return test_redis_client + + def _validate_connection_handlers(self, conn, pool_handler, config): + """Helper method to validate connection handlers are properly set.""" + # Test that the node moving handler function is correctly set + parser_handler = conn._parser.node_moving_push_handler_func + assert parser_handler is not None + assert hasattr(parser_handler, "__self__") + assert hasattr(parser_handler, "__func__") + assert parser_handler.__self__ is pool_handler + assert parser_handler.__func__ is pool_handler.handle_event.__func__ + + # Test that the maintenance handler function is correctly set + maintenance_handler = conn._parser.maintenance_push_handler_func + assert maintenance_handler is not None + assert hasattr(maintenance_handler, "__self__") + assert hasattr(maintenance_handler, "__func__") + # The maintenance handler should be bound to the connection's + # maintenance event connection handler + assert ( + maintenance_handler.__self__ is conn._maintenance_event_connection_handler + ) + assert ( + maintenance_handler.__func__ + is conn._maintenance_event_connection_handler.handle_event.__func__ + ) + + # Validate that the connection's maintenance handler has the same config object + assert conn._maintenance_event_connection_handler.config is config + + def _validate_current_timeout_for_thread( + self, thread_id, expected_timeout, error_msg=None + ): + """Helper method to validate the current timeout for the calling thread.""" + actual_timeout = None + # Get the actual thread ID from the current thread + current_thread_id = threading.current_thread().ident + for sock in self.mock_sockets: + if current_thread_id in sock.thread_timeouts: + actual_timeout = sock.thread_timeouts[current_thread_id] + break + + assert actual_timeout == expected_timeout, ( + error_msg, + f"Thread {thread_id}: Expected timeout ({expected_timeout}), " + f"but found timeout: {actual_timeout} for thread {thread_id}. " + f"All thread timeouts: {[sock.thread_timeouts for sock in self.mock_sockets]}", + ) + + def _validate_current_timeout(self, expected_timeout, error_msg=None): + """Helper method to validate the current timeout for the calling thread.""" + actual_timeout = None + # Get the actual thread ID from the current thread + current_thread_id = threading.current_thread().ident + for sock in self.mock_sockets: + if current_thread_id in sock.thread_timeouts: + actual_timeout = sock.thread_timeouts[current_thread_id] + break + + assert actual_timeout == expected_timeout, ( + f"{error_msg or ''}" + f"Expected timeout ({expected_timeout}), " + f"but found timeout: {actual_timeout}. " + f"All thread timeouts: {[sock.thread_timeouts for sock in self.mock_sockets]}", + ) + + def _validate_disconnected(self, expected_count): + """Helper method to validate all socket timeouts""" + disconnected_sockets_count = 0 + for sock in self.mock_sockets: + if sock.closed: + disconnected_sockets_count += 1 + assert disconnected_sockets_count == expected_count + + def _validate_connected(self, expected_count): + """Helper method to validate all socket timeouts""" + connected_sockets_count = 0 + for sock in self.mock_sockets: + if sock.connected: + connected_sockets_count += 1 + assert connected_sockets_count == expected_count + + def _validate_in_use_connections_state( + self, + in_use_connections: List[AbstractConnection], + expected_state=MaintenanceState.NONE, + expected_tmp_host_address=None, + expected_tmp_relax_timeout=-1, + expected_current_socket_timeout=None, + expected_current_peername=MockSocket.DEFAULT_ADDRESS.split(":")[0], + ): + """Helper method to validate state of in-use connections.""" + # validate in use connections are still working with set flag for reconnect + # and timeout is updated + for connection in in_use_connections: + assert connection._should_reconnect is True + assert connection.tmp_host_address == expected_tmp_host_address + assert connection.tmp_relax_timeout == expected_tmp_relax_timeout + assert connection._sock.gettimeout() == expected_current_socket_timeout + assert connection._sock.connected is True + assert connection.maintenance_state == expected_state + assert connection._sock.getpeername()[0] == expected_current_peername + + def _validate_free_connections_state( + self, + pool, + tmp_host_address, + relax_timeout, + should_be_connected_count, + connected_to_tmp_addres=False, + expected_state=MaintenanceState.MOVING, + ): + """Helper method to validate state of free/available connections.""" + if isinstance(pool, BlockingConnectionPool): + free_connections = [conn for conn in pool.pool.queue if conn is not None] + elif isinstance(pool, ConnectionPool): + free_connections = pool._available_connections + else: + raise ValueError(f"Unsupported pool type: {type(pool)}") + + connected_count = 0 + for connection in free_connections: + assert connection._should_reconnect is False + assert connection.tmp_host_address == tmp_host_address + assert connection.tmp_relax_timeout == relax_timeout + assert connection.maintenance_state == expected_state + if connection._sock is not None: + assert connection._sock.connected is True + if connected_to_tmp_addres: + assert ( + connection._sock.getpeername()[0] + == MockSocket.AFTER_MOVING_ADDRESS.split(":")[0] + ) + connected_count += 1 + assert connected_count == should_be_connected_count + + def _validate_all_timeouts(self, expected_timeout): + """Helper method to validate state of in-use connections.""" + # validate in use connections are still working with set flag for reconnect + # and timeout is updated + for mock_socket in self.mock_sockets: + if expected_timeout is None: + assert mock_socket.gettimeout() is None + else: + assert mock_socket.gettimeout() == expected_timeout + + def _validate_conn_kwargs( + self, + pool, + expected_host_address, + expected_port, + expected_tmp_host_address, + expected_tmp_relax_timeout, + ): + """Helper method to validate connection kwargs.""" + assert pool.connection_kwargs["host"] == expected_host_address + assert pool.connection_kwargs["port"] == expected_port + assert pool.connection_kwargs["tmp_host_address"] == expected_tmp_host_address + assert pool.connection_kwargs["tmp_relax_timeout"] == expected_tmp_relax_timeout + + def test_client_initialization(self): + """Test that Redis client is created with maintenance events configuration.""" + # Create a pool and Redis client with maintenance events + + test_redis_client = Redis( + protocol=3, # Required for maintenance events + maintenance_events_config=self.config, + ) + + pool_handler = test_redis_client.connection_pool.connection_kwargs.get( + "maintenance_events_pool_handler" + ) + assert pool_handler is not None + assert pool_handler.config == self.config + + conn = test_redis_client.connection_pool.get_connection() + assert conn._should_reconnect is False + assert conn.tmp_host_address is None + assert conn.tmp_relax_timeout == -1 + + # Test that the node moving handler function is correctly set by + # comparing the underlying function and instance + parser_handler = conn._parser.node_moving_push_handler_func + assert parser_handler is not None + assert hasattr(parser_handler, "__self__") + assert hasattr(parser_handler, "__func__") + assert parser_handler.__self__ is pool_handler + assert parser_handler.__func__ is pool_handler.handle_event.__func__ + + # Test that the maintenance handler function is correctly set + maintenance_handler = conn._parser.maintenance_push_handler_func + assert maintenance_handler is not None + assert hasattr(maintenance_handler, "__self__") + assert hasattr(maintenance_handler, "__func__") + # The maintenance handler should be bound to the connection's + # maintenance event connection handler + assert ( + maintenance_handler.__self__ is conn._maintenance_event_connection_handler + ) + assert ( + maintenance_handler.__func__ + is conn._maintenance_event_connection_handler.handle_event.__func__ + ) + + # Validate that the connection's maintenance handler has the same config object + assert conn._maintenance_event_connection_handler.config is self.config + + def test_maint_handler_init_for_existing_connections(self): + """Test that maintenance event handlers are properly set on existing and new connections + when configuration is enabled after client creation.""" + + # Create a Redis client with disabled maintenance events configuration + disabled_config = MaintenanceEventsConfig(enabled=False) + test_redis_client = Redis( + protocol=3, # Required for maintenance events + maintenance_events_config=disabled_config, + ) + + # Extract an existing connection before enabling maintenance events + existing_conn = test_redis_client.connection_pool.get_connection() + + # Verify that maintenance events are initially disabled + assert existing_conn._parser.node_moving_push_handler_func is None + assert not hasattr(existing_conn, "_maintenance_event_connection_handler") + assert existing_conn._parser.maintenance_push_handler_func is None + + # Create a new enabled configuration and set up pool handler + enabled_config = MaintenanceEventsConfig( + enabled=True, proactive_reconnect=True, relax_timeout=30 + ) + pool_handler = MaintenanceEventPoolHandler( + test_redis_client.connection_pool, enabled_config + ) + test_redis_client.connection_pool.set_maintenance_events_pool_handler( + pool_handler + ) + + # Validate the existing connection after enabling maintenance events + # Both existing and new connections should now have full handler setup + self._validate_connection_handlers(existing_conn, pool_handler, enabled_config) + + # Create a new connection and validate it has full handlers + new_conn = test_redis_client.connection_pool.get_connection() + self._validate_connection_handlers(new_conn, pool_handler, enabled_config) + + # Clean up connections + test_redis_client.connection_pool.release(existing_conn) + test_redis_client.connection_pool.release(new_conn) + + @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) + def test_connection_pool_creation_with_maintenance_events(self, pool_class): + """Test that connection pools are created with maintenance events configuration.""" + # Create a pool and Redis client with maintenance events + max_connections = 3 if pool_class == BlockingConnectionPool else 10 + test_redis_client = self._get_client( + pool_class, max_connections=max_connections + ) + test_pool = test_redis_client.connection_pool + + try: + assert ( + test_pool.connection_kwargs.get("maintenance_events_config") + == self.config + ) + # Pool should have maintenance events enabled + assert test_pool.maintenance_events_pool_handler_enabled() is True + + # Create and set a pool handler + pool_handler = MaintenanceEventPoolHandler(test_pool, self.config) + test_pool.set_maintenance_events_pool_handler(pool_handler) + + # Validate that the handler is properly set on the pool + assert ( + test_pool.connection_kwargs.get("maintenance_events_pool_handler") + == pool_handler + ) + assert ( + test_pool.connection_kwargs.get("maintenance_events_config") + == pool_handler.config + ) + + # Verify that the pool handler has the correct configuration + assert pool_handler.pool == test_pool + assert pool_handler.config == self.config + + finally: + if hasattr(test_pool, "disconnect"): + test_pool.disconnect() + + @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) + def test_redis_operations_with_mock_sockets(self, pool_class): + """ + Test basic Redis operations work with mocked sockets and proper response parsing. + Basically with test - the mocked socket is validated. + """ + # Create a pool and Redis client with maintenance events + test_redis_client = self._get_client(pool_class, max_connections=5) + + try: + # Perform Redis operations that should work with our improved mock responses + result_set = test_redis_client.set("hello", "world") + result_get = test_redis_client.get("hello") + + # Verify operations completed successfully + assert result_set is True + assert result_get == b"world" + + # Verify socket interactions + assert len(self.mock_sockets) >= 1 + assert self.mock_sockets[0].connected + assert len(self.mock_sockets[0].sent_data) >= 2 # HELLO, SET, GET commands + + # Verify that the connection has maintenance event handler + connection = test_redis_client.connection_pool.get_connection() + assert hasattr(connection, "_maintenance_event_connection_handler") + test_redis_client.connection_pool.release(connection) + + finally: + if hasattr(test_redis_client.connection_pool, "disconnect"): + test_redis_client.connection_pool.disconnect() + + def test_pool_handler_with_migrating_event(self): + """Test that pool handler correctly handles migrating events.""" + # Create a pool and Redis client with maintenance events + test_redis_client = self._get_client(ConnectionPool) + test_pool = test_redis_client.connection_pool + + try: + # Create and set a pool handler + pool_handler = MaintenanceEventPoolHandler(test_pool, self.config) + + # Create a migrating event (not handled by pool handler) + migrating_event = NodeMigratingEvent(id=1, ttl=5) + + # Mock the required functions + with ( + patch.object( + pool_handler, "remove_expired_notifications" + ) as mock_remove_expired, + patch.object( + pool_handler, "handle_node_moving_event" + ) as mock_handle_moving, + patch("redis.maintenance_events.logging.error") as mock_logging_error, + ): + # Pool handler should return None for migrating events (not its responsibility) + pool_handler.handle_event(migrating_event) + + # Validate that remove_expired_notifications has been called once + mock_remove_expired.assert_called_once() + + # Validate that handle_node_moving_event hasn't been called + mock_handle_moving.assert_not_called() + + # Validate that logging.error has been called once + mock_logging_error.assert_called_once() + + finally: + if hasattr(test_pool, "disconnect"): + test_pool.disconnect() + + @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) + def test_migration_related_events_handling_integration(self, pool_class): + """ + Test full integration of migration-related events (MIGRATING/MIGRATED) handling. + + This test validates the complete migration lifecycle: + 1. Executes 5 Redis commands sequentially + 2. Injects MIGRATING push message before command 2 (SET key_receive_migrating) + 3. Validates socket timeout is updated to relaxed value (30s) after MIGRATING + 4. Executes commands 3-4 while timeout remains relaxed + 5. Injects MIGRATED push message before command 5 (SET key_receive_migrated) + 6. Validates socket timeout is restored after MIGRATED + 7. Tests both ConnectionPool and BlockingConnectionPool implementations + 8. Uses proper RESP3 push message format for realistic protocol simulation + """ + # Create a pool and Redis client with maintenance events + test_redis_client = self._get_client(pool_class, max_connections=10) + + try: + # Command 1: Initial command + key1 = "key1" + value1 = "value1" + result1 = test_redis_client.set(key1, value1) + + # Validate Command 1 result + assert result1 is True, "Command 1 (SET key1) failed" + + # Command 2: This SET command will receive MIGRATING push message before response + key_migrating = "key_receive_migrating" + value_migrating = "value2" + result2 = test_redis_client.set(key_migrating, value_migrating) + + # Validate Command 2 result + assert result2 is True, "Command 2 (SET key_receive_migrating) failed" + + # Step 4: Validate timeout was updated to relaxed value after MIGRATING + self._validate_current_timeout(30, "Right after MIGRATING is received. ") + + # Command 3: Another command while timeout is still relaxed + result3 = test_redis_client.get(key1) + + # Validate Command 3 result + expected_value3 = value1.encode() + assert result3 == expected_value3, ( + f"Command 3 (GET key1) failed. Expected {expected_value3}, got {result3}" + ) + + # Command 4: Execute command (step 5) + result4 = test_redis_client.get(key_migrating) + + # Validate Command 4 result + expected_value4 = value_migrating.encode() + assert result4 == expected_value4, ( + f"Command 4 (GET key_receive_migrating) failed. Expected {expected_value4}, got {result4}" + ) + + # Step 6: Validate socket timeout is still relaxed during commands 3-4 + self._validate_current_timeout( + 30, + "Execute a command with a connection extracted from the pool (after it has received MIGRATING)", + ) + + # Command 5: This SET command will receive + # MIGRATED push message before actual response + key_migrated = "key_receive_migrated" + value_migrated = "value3" + result5 = test_redis_client.set(key_migrated, value_migrated) + + # Validate Command 5 result + assert result5 is True, "Command 5 (SET key_receive_migrated) failed" + + # Step 8: Validate socket timeout is reversed back to original after MIGRATED + self._validate_current_timeout(None) + + # Verify maintenance events were processed correctly + # The key is that we have at least 1 socket and all operations succeeded + assert len(self.mock_sockets) >= 1, ( + f"Expected at least 1 socket for operations, got {len(self.mock_sockets)}" + ) + + finally: + if hasattr(test_redis_client.connection_pool, "disconnect"): + test_redis_client.connection_pool.disconnect() + + @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) + def test_migrating_event_with_disabled_relax_timeout(self, pool_class): + """ + Test migrating event handling when relax timeout is disabled. + + This test validates that when relax_timeout is disabled (-1): + 1. MIGRATING events are received and processed + 2. No timeout updates are applied to connections + 3. Socket timeouts remain unchanged during migration events + 4. Tests both ConnectionPool and BlockingConnectionPool implementations + """ + # Create config with disabled relax timeout + disabled_config = MaintenanceEventsConfig( + enabled=True, + relax_timeout=-1, # This means the relax timeout is Disabled + ) + + # Create a pool and Redis client with disabled relax timeout config + test_redis_client = self._get_client( + pool_class, max_connections=5, maintenance_events_config=disabled_config + ) + + try: + # Command 1: Initial command + key1 = "key1" + value1 = "value1" + result1 = test_redis_client.set(key1, value1) + + # Validate Command 1 result + assert result1 is True, "Command 1 (SET key1) failed" + + # Command 2: This SET command will receive MIGRATING push message before response + key_migrating = "key_receive_migrating" + value_migrating = "value2" + result2 = test_redis_client.set(key_migrating, value_migrating) + + # Validate Command 2 result + assert result2 is True, "Command 2 (SET key_receive_migrating) failed" + + # Validate timeout was NOT updated (relax is disabled) + # Should remain at default timeout (None), not relaxed to 30s + self._validate_current_timeout(None) + + # Command 3: Another command to verify timeout remains unchanged + result3 = test_redis_client.get(key1) + + # Validate Command 3 result + expected_value3 = value1.encode() + assert result3 == expected_value3, ( + f"Command 3 (GET key1) failed. Expected: {expected_value3}, Got: {result3}" + ) + + # Verify maintenance events were processed correctly + # The key is that we have at least 1 socket and all operations succeeded + assert len(self.mock_sockets) >= 1, ( + f"Expected at least 1 socket for operations, got {len(self.mock_sockets)}" + ) + + finally: + if hasattr(test_redis_client.connection_pool, "disconnect"): + test_redis_client.connection_pool.disconnect() + + @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) + def test_moving_related_events_handling_integration(self, pool_class): + """ + Test full integration of moving-related events (MOVING) handling with Redis commands. + + This test validates the complete MOVING event lifecycle: + 1. Creates multiple connections in the pool + 2. Executes a Redis command that triggers a MOVING push message + 3. Validates that pool configuration is updated with temporary address and timeout + 4. Validates that existing connections are marked for disconnection + 5. Tests both ConnectionPool and BlockingConnectionPool implementations + """ + # Create a pool and Redis client with maintenance events and pool handler + test_redis_client = self._get_client( + pool_class, max_connections=10, setup_pool_handler=True + ) + + try: + # Create several connections and return them in the pool + connections = [] + for _ in range(10): + connection = test_redis_client.connection_pool.get_connection() + connections.append(connection) + + for connection in connections: + test_redis_client.connection_pool.release(connection) + + # Take 5 connections to be "in use" + in_use_connections = [] + for _ in range(5): + connection = test_redis_client.connection_pool.get_connection() + in_use_connections.append(connection) + + # Validate all connections are connected prior MOVING event + self._validate_disconnected(0) + + # Run command that will receive and handle MOVING event + key_moving = "key_receive_moving_0" + value_moving = "value3_0" + # the connection used for the command is expected to be reconnected to the new address + # before it is returned to the pool + result2 = test_redis_client.set(key_moving, value_moving) + + # Validate Command 2 result + assert result2 is True, "Command 2 (SET key_receive_moving) failed" + + # Validate pool and connections settings were updated according to MOVING event + self._validate_conn_kwargs( + test_redis_client.connection_pool, + MockSocket.DEFAULT_ADDRESS.split(":")[0], + int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), + MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], + self.config.relax_timeout, + ) + self._validate_disconnected(5) + self._validate_connected(6) + self._validate_in_use_connections_state( + in_use_connections, + expected_state=MaintenanceState.MOVING, + expected_tmp_host_address=MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], + expected_tmp_relax_timeout=self.config.relax_timeout, + expected_current_socket_timeout=self.config.relax_timeout, + expected_current_peername=MockSocket.DEFAULT_ADDRESS.split(":")[ + 0 + ], # the in use connections reconnect when they complete their current task + ) + self._validate_free_connections_state( + test_redis_client.connection_pool, + MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], + self.config.relax_timeout, + should_be_connected_count=1, + connected_to_tmp_addres=True, + expected_state=MaintenanceState.MOVING, + ) + # Wait for MOVING timeout to expire and the moving completed handler to run + sleep(MockSocket.MOVING_TIMEOUT + 0.5) + self._validate_all_timeouts(None) + self._validate_conn_kwargs( + test_redis_client.connection_pool, + MockSocket.DEFAULT_ADDRESS.split(":")[0], + int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), + None, + -1, + ) + self._validate_free_connections_state( + test_redis_client.connection_pool, + None, + -1, + should_be_connected_count=1, + connected_to_tmp_addres=True, + expected_state=MaintenanceState.NONE, + ) + finally: + if hasattr(test_redis_client.connection_pool, "disconnect"): + test_redis_client.connection_pool.disconnect() + + @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) + def test_create_new_conn_while_moving_not_expired(self, pool_class): + """ + Test creating new connections while MOVING event is active (not expired). + + This test validates that: + 1. After MOVING event is processed, new connections are created with temporary address + 2. New connections inherit the relaxed timeout settings + 3. Pool configuration is properly applied to newly created connections + """ + # Create a pool and Redis client with maintenance events and pool handler + test_redis_client = self._get_client( + pool_class, max_connections=10, setup_pool_handler=True + ) + + try: + # Create several connections and return them in the pool + connections = [] + for _ in range(5): + connection = test_redis_client.connection_pool.get_connection() + connections.append(connection) + + for connection in connections: + test_redis_client.connection_pool.release(connection) + + # Take 3 connections to be "in use" + in_use_connections = [] + for _ in range(3): + connection = test_redis_client.connection_pool.get_connection() + in_use_connections.append(connection) + + # Validate all connections are connected prior MOVING event + self._validate_disconnected(0) + + # Run command that will receive and handle MOVING event + key_moving = "key_receive_moving_0" + value_moving = "value3_0" + result = test_redis_client.set(key_moving, value_moving) + + # Validate command result + assert result is True, "SET key_receive_moving command failed" + + # Validate pool and connections settings were updated according to MOVING event + self._validate_conn_kwargs( + test_redis_client.connection_pool, + MockSocket.DEFAULT_ADDRESS.split(":")[0], + int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), + MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], + self.config.relax_timeout, + ) + + # Now get several more connections to force creation of new ones + # This should create new connections with the temporary address + old_connections = [] + for _ in range(2): + connection = test_redis_client.connection_pool.get_connection() + old_connections.append(connection) + + new_connection = test_redis_client.connection_pool.get_connection() + + # Validate that new connections are created with temporary address and relax timeout + # and when connecting those configs are used + # get_connection() returns a connection that is already connected + assert ( + new_connection.tmp_host_address + == MockSocket.AFTER_MOVING_ADDRESS.split(":")[0] + ) + assert new_connection.tmp_relax_timeout == self.config.relax_timeout + # New connections should be connected to the temporary address + assert new_connection._sock is not None + assert new_connection._sock.connected is True + assert ( + new_connection._sock.getpeername()[0] + == MockSocket.AFTER_MOVING_ADDRESS.split(":")[0] + ) + assert new_connection._sock.gettimeout() == self.config.relax_timeout + + finally: + if hasattr(test_redis_client.connection_pool, "disconnect"): + test_redis_client.connection_pool.disconnect() + + @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) + def test_create_new_conn_after_moving_expires(self, pool_class): + """ + Test creating new connections after MOVING event expires. + + This test validates that: + 1. After MOVING timeout expires, new connections use original address + 2. Pool configuration is reset to original values + 3. New connections don't inherit temporary settings + """ + # Create a pool and Redis client with maintenance events and pool handler + test_redis_client = self._get_client( + pool_class, max_connections=10, setup_pool_handler=True + ) + + try: + # Create several connections and return them in the pool + connections = [] + for _ in range(5): + connection = test_redis_client.connection_pool.get_connection() + connections.append(connection) + + for connection in connections: + test_redis_client.connection_pool.release(connection) + + # Take 3 connections to be "in use" + in_use_connections = [] + for _ in range(3): + connection = test_redis_client.connection_pool.get_connection() + in_use_connections.append(connection) + + # Run command that will receive and handle MOVING event + key_moving = "key_receive_moving_0" + value_moving = "value3_0" + result = test_redis_client.set(key_moving, value_moving) + + # Validate command result + assert result is True, "SET key_receive_moving command failed" + + # Wait for MOVING timeout to expire + sleep(MockSocket.MOVING_TIMEOUT + 0.5) + + # Now get several new connections after expiration + old_connections = [] + for _ in range(2): + connection = test_redis_client.connection_pool.get_connection() + old_connections.append(connection) + + new_connection = test_redis_client.connection_pool.get_connection() + + # Validate that new connections are created with original address (no temporary settings) + assert new_connection.tmp_host_address is None + assert new_connection.tmp_relax_timeout == -1 + # New connections should be connected to the original address + assert new_connection._sock is not None + assert new_connection._sock.connected is True + # Socket timeout should be None (original timeout) + assert new_connection._sock.gettimeout() is None + + finally: + if hasattr(test_redis_client.connection_pool, "disconnect"): + test_redis_client.connection_pool.disconnect() + + @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) + def test_receive_migrated_after_moving(self, pool_class): + """ + Test receiving MIGRATED event after MOVING event. + + This test validates the complete MOVING -> MIGRATED lifecycle: + 1. MOVING event is processed and temporary settings are applied + 2. MIGRATED event is received during command execution + 3. Temporary settings are cleared after MIGRATED + 4. Pool configuration is restored to original values + + Note: When MIGRATED comes after MOVING and MOVING hasn't yet expired, + it should not decrease timeouts (future refactoring consideration). + """ + # Create a pool and Redis client with maintenance events and pool handler + test_redis_client = self._get_client( + pool_class, max_connections=10, setup_pool_handler=True + ) + + try: + # Create several connections and return them in the pool + connections = [] + for _ in range(5): + connection = test_redis_client.connection_pool.get_connection() + connections.append(connection) + + for connection in connections: + test_redis_client.connection_pool.release(connection) + + # Take 3 connections to be "in use" + in_use_connections = [] + for _ in range(3): + connection = test_redis_client.connection_pool.get_connection() + in_use_connections.append(connection) + + # Validate all connections are connected prior MOVING event + self._validate_disconnected(0) + + # Step 1: Run command that will receive and handle MOVING event + key_moving = "key_receive_moving_0" + value_moving = "value3_0" + result_moving = test_redis_client.set(key_moving, value_moving) + + # Validate MOVING command result + assert result_moving is True, "SET key_receive_moving command failed" + + # Validate pool and connections settings were updated according to MOVING event + self._validate_conn_kwargs( + test_redis_client.connection_pool, + MockSocket.DEFAULT_ADDRESS.split(":")[0], + int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), + MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], + self.config.relax_timeout, + ) + + # Step 2: Run command that will receive and handle MIGRATED event + # This should clear the temporary settings + key_migrated = "key_receive_migrated_0" + value_migrated = "migrated_value" + result_migrated = test_redis_client.set(key_migrated, value_migrated) + + # Validate MIGRATED command result + assert result_migrated is True, "SET key_receive_migrated command failed" + + # Step 3: Validate that MIGRATED event was processed but MOVING settings remain + # (MIGRATED doesn't automatically clear MOVING settings - they are separate events) + self._validate_conn_kwargs( + test_redis_client.connection_pool, + MockSocket.DEFAULT_ADDRESS.split(":")[0], + int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), + MockSocket.AFTER_MOVING_ADDRESS.split(":")[ + 0 + ], # MOVING settings still active + self.config.relax_timeout, # MOVING timeout still active + ) + + # Step 4: Create new connections after MIGRATED to verify they still use MOVING settings + # (since MOVING settings are still active) + new_connections = [] + for _ in range(2): + connection = test_redis_client.connection_pool.get_connection() + new_connections.append(connection) + + # Validate that new connections are created with MOVING settings (still active) + for connection in new_connections: + assert ( + connection.tmp_host_address + == MockSocket.AFTER_MOVING_ADDRESS.split(":")[0] + ) + # Note: New connections may not inherit the exact relax timeout value + # but they should have the temporary host address + # New connections should be connected + if connection._sock is not None: + assert connection._sock.connected is True + + # Release the new connections + for connection in new_connections: + test_redis_client.connection_pool.release(connection) + + # Validate free connections state with MOVING settings still active + # Note: We'll validate with the pool's current settings rather than individual connection settings + # since new connections may have different timeout values but still use the temporary address + + finally: + if hasattr(test_redis_client.connection_pool, "disconnect"): + test_redis_client.connection_pool.disconnect() + + @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) + def test_overlapping_moving_events(self, pool_class): + """ + Test handling of overlapping/duplicate MOVING events (e.g., two MOVING events before the first expires). + Ensures that the second MOVING event updates the pool and connections as expected, and that expiry/cleanup works. + """ + test_redis_client = self._get_client( + pool_class, max_connections=5, setup_pool_handler=True + ) + try: + # Create and release some connections + for _ in range(3): + conn = test_redis_client.connection_pool.get_connection() + test_redis_client.connection_pool.release(conn) + + # Take 2 connections to be in use + in_use_connections = [] + for _ in range(2): + conn = test_redis_client.connection_pool.get_connection() + in_use_connections.append(conn) + + # Trigger first MOVING event + key_moving1 = "key_receive_moving_0" + value_moving1 = "value3_0" + result1 = test_redis_client.set(key_moving1, value_moving1) + assert result1 is True + self._validate_conn_kwargs( + test_redis_client.connection_pool, + MockSocket.DEFAULT_ADDRESS.split(":")[0], + int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), + MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], + self.config.relax_timeout, + ) + # Validate all connections reflect the first MOVING event + self._validate_in_use_connections_state( + in_use_connections, + expected_state=MaintenanceState.MOVING, + expected_tmp_host_address=MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], + expected_tmp_relax_timeout=self.config.relax_timeout, + expected_current_socket_timeout=self.config.relax_timeout, + expected_current_peername=MockSocket.DEFAULT_ADDRESS.split(":")[0], + ) + self._validate_free_connections_state( + test_redis_client.connection_pool, + MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], + self.config.relax_timeout, + should_be_connected_count=1, + connected_to_tmp_addres=True, + ) + + # Before the first MOVING expires, trigger a second MOVING event (simulate new address) + # Patch MockSocket to use a new address for the second event + new_address = "5.6.7.8:6380" + orig_after_moving = MockSocket.AFTER_MOVING_ADDRESS + MockSocket.AFTER_MOVING_ADDRESS = new_address + try: + key_moving2 = "key_receive_moving_1" + value_moving2 = "value3_1" + result2 = test_redis_client.set(key_moving2, value_moving2) + assert result2 is True + self._validate_conn_kwargs( + test_redis_client.connection_pool, + MockSocket.DEFAULT_ADDRESS.split(":")[0], + int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), + new_address.split(":")[0], + self.config.relax_timeout, + ) + # Validate all connections reflect the second MOVING event + self._validate_in_use_connections_state( + in_use_connections, + expected_state=MaintenanceState.MOVING, + expected_tmp_host_address=new_address.split(":")[0], + expected_tmp_relax_timeout=self.config.relax_timeout, + expected_current_socket_timeout=self.config.relax_timeout, + expected_current_peername=MockSocket.DEFAULT_ADDRESS.split(":")[0], + ) + self._validate_free_connections_state( + test_redis_client.connection_pool, + new_address.split(":")[0], + self.config.relax_timeout, + should_be_connected_count=1, + connected_to_tmp_addres=True, + ) + finally: + MockSocket.AFTER_MOVING_ADDRESS = orig_after_moving + + # Wait for both MOVING timeouts to expire + sleep(MockSocket.MOVING_TIMEOUT + 0.5) + self._validate_conn_kwargs( + test_redis_client.connection_pool, + MockSocket.DEFAULT_ADDRESS.split(":")[0], + int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), + None, + -1, + ) + finally: + if hasattr(test_redis_client.connection_pool, "disconnect"): + test_redis_client.connection_pool.disconnect() + + @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) + def test_thread_safety_concurrent_event_handling(self, pool_class): + """ + Test thread-safety under concurrent maintenance event handling. + Simulates multiple threads triggering MOVING events and performing operations concurrently. + """ + import threading + + test_redis_client = self._get_client( + pool_class, max_connections=5, setup_pool_handler=True + ) + results = [] + errors = [] + + def worker(idx): + try: + key = f"key_receive_moving_{idx}" + value = f"value3_{idx}" + result = test_redis_client.set(key, value) + results.append(result) + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=worker, args=(i,)) for i in range(5)] + for t in threads: + t.start() + for t in threads: + t.join() + assert all(results), f"Not all threads succeeded: {results}" + assert not errors, f"Errors occurred in threads: {errors}" + # After all threads, MOVING event should have been handled safely + self._validate_conn_kwargs( + test_redis_client.connection_pool, + MockSocket.DEFAULT_ADDRESS.split(":")[0], + int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), + MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], + self.config.relax_timeout, + ) + if hasattr(test_redis_client.connection_pool, "disconnect"): + test_redis_client.connection_pool.disconnect() + + @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) + def test_moving_migrating_migrated_moved_state_transitions(self, pool_class): + """ + Test moving configs are not lost if the per connection events get picked up after moving is handled. + MOVING → MIGRATING → MIGRATED → MOVED + Checks the state after each event for all connections and for new connections created during each state. + """ + # Setup + test_redis_client = self._get_client( + pool_class, max_connections=5, setup_pool_handler=True + ) + pool = test_redis_client.connection_pool + pool_handler = pool.connection_kwargs["maintenance_events_pool_handler"] + + # Create and release some connections + in_use_connections = [] + for _ in range(3): + in_use_connections.append(pool.get_connection()) + while len(in_use_connections) > 0: + pool.release(in_use_connections.pop()) + + # Take 2 connections to be in use + in_use_connections = [] + for _ in range(2): + conn = pool.get_connection() + in_use_connections.append(conn) + + # 1. MOVING event + tmp_address = "22.23.24.25" + moving_event = NodeMovingEvent( + id=1, new_node_host=tmp_address, new_node_port=6379, ttl=1 + ) + pool_handler.handle_event(moving_event) + self._validate_in_use_connections_state( + in_use_connections, + expected_state=MaintenanceState.MOVING, + expected_tmp_host_address=tmp_address, + expected_tmp_relax_timeout=self.config.relax_timeout, + expected_current_socket_timeout=self.config.relax_timeout, + expected_current_peername=MockSocket.DEFAULT_ADDRESS.split(":")[0], + ) + self._validate_free_connections_state( + pool, + tmp_address, + self.config.relax_timeout, + should_be_connected_count=0, + connected_to_tmp_addres=False, + expected_state=MaintenanceState.MOVING, + ) + + # 2. MIGRATING event (simulate direct connection handler call) + for conn in in_use_connections: + conn._maintenance_event_connection_handler.handle_event( + NodeMigratingEvent(id=2, ttl=1) + ) + self._validate_in_use_connections_state( + in_use_connections, + expected_state=MaintenanceState.MOVING, + expected_tmp_host_address=tmp_address, + expected_tmp_relax_timeout=self.config.relax_timeout, + expected_current_socket_timeout=self.config.relax_timeout, + expected_current_peername=MockSocket.DEFAULT_ADDRESS.split(":")[0], + ) + + # 3. MIGRATED event (simulate direct connection handler call) + for conn in in_use_connections: + conn._maintenance_event_connection_handler.handle_event( + NodeMigratedEvent(id=2) + ) + # State should not change for connections that are in MOVING state + self._validate_in_use_connections_state( + in_use_connections, + expected_state=MaintenanceState.MOVING, + expected_tmp_host_address=tmp_address, + expected_tmp_relax_timeout=self.config.relax_timeout, + expected_current_socket_timeout=self.config.relax_timeout, + expected_current_peername=MockSocket.DEFAULT_ADDRESS.split(":")[0], + ) + + # 4. MOVED event (simulate timer expiry) + pool_handler.handle_node_moved_event() + self._validate_in_use_connections_state( + in_use_connections, + expected_state=MaintenanceState.NONE, + expected_tmp_host_address=None, + expected_tmp_relax_timeout=-1, + expected_current_socket_timeout=None, + expected_current_peername=MockSocket.DEFAULT_ADDRESS.split(":")[0], + ) + self._validate_free_connections_state( + pool, + None, + -1, + should_be_connected_count=0, + connected_to_tmp_addres=False, + expected_state=MaintenanceState.NONE, + ) + # New connection after MOVED + new_conn_none = pool.get_connection() + assert new_conn_none.maintenance_state == MaintenanceState.NONE + pool.release(new_conn_none) + # Cleanup + for conn in in_use_connections: + pool.release(conn) + if hasattr(pool, "disconnect"): + pool.disconnect()