diff --git a/CHANGES b/CHANGES index 20d7d8f51f..3b6b638dc1 100644 --- a/CHANGES +++ b/CHANGES @@ -1,3 +1,4 @@ + * Fix get_node_from_slot for cluster and asyncio cluster during reshard (#2988) * Move doctests (doc code examples) to main branch * Update `ResponseT` type hint * Allow to control the minimum SSL version diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 40b2948a7f..1dfa320c48 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -43,7 +43,6 @@ REPLICA, SLOT_ID, AbstractRedisCluster, - LoadBalancer, block_pipeline_command, get_node_name, parse_cluster_slots, @@ -68,6 +67,7 @@ TimeoutError, TryAgainError, ) +from redis.load_balancer import LoadBalancer from redis.typing import AnyKeyT, EncodableT, KeyT from redis.utils import ( deprecated_function, @@ -1244,23 +1244,19 @@ def _update_moved_slots(self) -> None: def get_node_from_slot( self, slot: int, read_from_replicas: bool = False ) -> "ClusterNode": + """ + Gets a node that serves this hash slot + """ if self._moved_exception: self._update_moved_slots() - try: - if read_from_replicas: - # get the server index in a Round-Robin manner - primary_name = self.slots_cache[slot][0].name - node_idx = self.read_load_balancer.get_server_index( - primary_name, len(self.slots_cache[slot]) - ) - return self.slots_cache[slot][node_idx] - return self.slots_cache[slot][0] - except (IndexError, TypeError): - raise SlotNotCoveredError( - f'Slot "{slot}" not covered by the cluster. ' - f'"require_full_coverage={self.require_full_coverage}"' - ) + slot_nodes = self.slots_cache.get(slot, None) + if slot_nodes is None or len(slot_nodes) == 0: + raise SlotNotCoveredError(f'Slot "{slot}" not covered by the cluster.') + return self.read_load_balancer.get_node_from_slot( + slot_nodes, + read_from_replicas, + ) def get_nodes_by_server_type(self, server_type: str) -> List["ClusterNode"]: return [ diff --git a/redis/cluster.py b/redis/cluster.py index be7685e9a1..17df6d205d 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -31,6 +31,7 @@ TimeoutError, TryAgainError, ) +from redis.load_balancer import LoadBalancer from redis.lock import Lock from redis.retry import Retry from redis.utils import ( @@ -1319,25 +1320,6 @@ def invalidate_key_from_cache(self, key): self.redis_connection.invalidate_key_from_cache(key) -class LoadBalancer: - """ - Round-Robin Load Balancing - """ - - def __init__(self, start_index: int = 0) -> None: - self.primary_to_idx = {} - self.start_index = start_index - - def get_server_index(self, primary: str, list_size: int) -> int: - server_index = self.primary_to_idx.setdefault(primary, self.start_index) - # Update the index - self.primary_to_idx[primary] = (server_index + 1) % list_size - return server_index - - def reset(self) -> None: - self.primary_to_idx.clear() - - class NodesManager: def __init__( self, @@ -1426,40 +1408,24 @@ def _update_moved_slots(self): # Reset moved_exception self._moved_exception = None - def get_node_from_slot(self, slot, read_from_replicas=False, server_type=None): + def get_node_from_slot( + self, slot: int, read_from_replicas: bool = False + ) -> "ClusterNode": """ - Gets a node that servers this hash slot + Gets a node that serves this hash slot """ if self._moved_exception: with self._lock: if self._moved_exception: self._update_moved_slots() - if self.slots_cache.get(slot) is None or len(self.slots_cache[slot]) == 0: - raise SlotNotCoveredError( - f'Slot "{slot}" not covered by the cluster. ' - f'"require_full_coverage={self._require_full_coverage}"' - ) - - if read_from_replicas is True: - # get the server index in a Round-Robin manner - primary_name = self.slots_cache[slot][0].name - node_idx = self.read_load_balancer.get_server_index( - primary_name, len(self.slots_cache[slot]) - ) - elif ( - server_type is None - or server_type == PRIMARY - or len(self.slots_cache[slot]) == 1 - ): - # return a primary - node_idx = 0 - else: - # return a replica - # randomly choose one of the replicas - node_idx = random.randint(1, len(self.slots_cache[slot]) - 1) - - return self.slots_cache[slot][node_idx] + slot_nodes = self.slots_cache.get(slot, None) + if slot_nodes is None or len(slot_nodes) == 0: + raise SlotNotCoveredError(f'Slot "{slot}" not covered by the cluster.') + return self.read_load_balancer.get_node_from_slot( + slot_nodes, + read_from_replicas, + ) def get_nodes_by_server_type(self, server_type): """ diff --git a/redis/load_balancer.py b/redis/load_balancer.py new file mode 100644 index 0000000000..b206bdb869 --- /dev/null +++ b/redis/load_balancer.py @@ -0,0 +1,39 @@ +from typing import TYPE_CHECKING, List, Union + +if TYPE_CHECKING: + from redis.asyncio.cluster import ClusterNode as AsyncioClusterNode # noqa: F401 + from redis.cluster import ClusterNode + + +class LoadBalancer: + """ + Round-Robin Load Balancing + """ + + def __init__(self) -> None: + self.primary_name_to_last_used_index: dict[str, int] = {} + + def get_node_from_slot( + self, + slot_nodes: Union[List["ClusterNode"], List["AsyncioClusterNode"]], + read_from_replicas: bool, + ) -> Union["ClusterNode", "AsyncioClusterNode"]: + assert len(slot_nodes) > 0 + if not read_from_replicas: + return slot_nodes[0] + + primary_name = slot_nodes[0].name + node_idx = self.get_server_index(primary_name, len(slot_nodes)) + return slot_nodes[node_idx] + + def get_server_index(self, primary: str, list_size: int) -> int: + # default to -1 if not found, so after incrementing it will be 0 + server_index = ( + self.primary_name_to_last_used_index.get(primary, -1) + 1 + ) % list_size + # Update the index + self.primary_name_to_last_used_index[primary] = server_index + return server_index + + def reset(self) -> None: + self.primary_name_to_last_used_index.clear() diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index a36040f11b..e463439594 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -2355,9 +2355,8 @@ class TestNodesManager: Tests for the NodesManager class """ - async def test_load_balancer(self, r: RedisCluster) -> None: + async def test_get_node_from_slot(self, r: RedisCluster) -> None: n_manager = r.nodes_manager - lb = n_manager.read_load_balancer slot_1 = 1257 slot_2 = 8975 node_1 = ClusterNode(default_host, 6379, PRIMARY) @@ -2369,23 +2368,58 @@ async def test_load_balancer(self, r: RedisCluster) -> None: slot_1: [node_1, node_2, node_3], slot_2: [node_4, node_5], } - primary1_name = n_manager.slots_cache[slot_1][0].name - primary2_name = n_manager.slots_cache[slot_2][0].name - list1_size = len(n_manager.slots_cache[slot_1]) - list2_size = len(n_manager.slots_cache[slot_2]) # slot 1 - assert lb.get_server_index(primary1_name, list1_size) == 0 - assert lb.get_server_index(primary1_name, list1_size) == 1 - assert lb.get_server_index(primary1_name, list1_size) == 2 - assert lb.get_server_index(primary1_name, list1_size) == 0 + assert n_manager.get_node_from_slot(slot_1, read_from_replicas=True) == node_1 + assert n_manager.get_node_from_slot(slot_1, read_from_replicas=True) == node_2 + assert n_manager.get_node_from_slot(slot_1, read_from_replicas=True) == node_3 + assert n_manager.get_node_from_slot(slot_1, read_from_replicas=True) == node_1 # slot 2 - assert lb.get_server_index(primary2_name, list2_size) == 0 - assert lb.get_server_index(primary2_name, list2_size) == 1 - assert lb.get_server_index(primary2_name, list2_size) == 0 + assert n_manager.get_node_from_slot(slot_2, read_from_replicas=True) == node_4 + assert n_manager.get_node_from_slot(slot_2, read_from_replicas=True) == node_5 + assert n_manager.get_node_from_slot(slot_2, read_from_replicas=True) == node_4 - lb.reset() - assert lb.get_server_index(primary1_name, list1_size) == 0 - assert lb.get_server_index(primary2_name, list2_size) == 0 + n_manager.read_load_balancer.reset() + assert n_manager.get_node_from_slot(slot_1, read_from_replicas=True) == node_1 + assert n_manager.get_node_from_slot(slot_2, read_from_replicas=True) == node_4 + + async def test_get_node_from_slot_no_replicas(self, r: RedisCluster) -> None: + n_manager = r.nodes_manager + slot_1 = 1257 + slot_2 = 8975 + node_1 = ClusterNode(default_host, 6379, PRIMARY) + node_2 = ClusterNode(default_host, 6378, REPLICA) + node_3 = ClusterNode(default_host, 6377, REPLICA) + node_4 = ClusterNode(default_host, 6376, PRIMARY) + node_5 = ClusterNode(default_host, 6375, REPLICA) + n_manager.slots_cache = { + slot_1: [node_1, node_2, node_3], + slot_2: [node_4, node_5], + } + assert n_manager.get_node_from_slot(slot_1, read_from_replicas=False) == node_1 + assert n_manager.get_node_from_slot(slot_1, read_from_replicas=False) == node_1 + assert n_manager.get_node_from_slot(slot_1, read_from_replicas=False) == node_1 + + async def test_get_node_from_slot_changed_slot_replicas( + self, r: RedisCluster + ) -> None: + n_manager = r.nodes_manager + slot_1 = 1257 + node_1 = ClusterNode(default_host, 6379, PRIMARY) + node_2 = ClusterNode(default_host, 6378, REPLICA) + node_3 = ClusterNode(default_host, 6377, REPLICA) + n_manager.slots_cache = { + slot_1: [node_1, node_2, node_3], + } + # slot 1 + assert n_manager.get_node_from_slot(slot_1, read_from_replicas=True) == node_1 + assert n_manager.get_node_from_slot(slot_1, read_from_replicas=True) == node_2 + assert n_manager.get_node_from_slot(slot_1, read_from_replicas=True) == node_3 + assert n_manager.get_node_from_slot(slot_1, read_from_replicas=True) == node_1 + assert n_manager.get_node_from_slot(slot_1, read_from_replicas=True) == node_2 + + # adjust lb-size + n_manager.slots_cache[slot_1].pop() + assert n_manager.get_node_from_slot(slot_1, read_from_replicas=True) == node_1 async def test_init_slots_cache_not_all_slots_covered(self) -> None: """ diff --git a/tests/test_cluster.py b/tests/test_cluster.py index 5a32bd6a7e..1d3555f0a9 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -2501,9 +2501,8 @@ class TestNodesManager: Tests for the NodesManager class """ - def test_load_balancer(self, r): + async def test_get_node_from_slot(self, r: RedisCluster) -> None: n_manager = r.nodes_manager - lb = n_manager.read_load_balancer slot_1 = 1257 slot_2 = 8975 node_1 = ClusterNode(default_host, 6379, PRIMARY) @@ -2515,23 +2514,59 @@ def test_load_balancer(self, r): slot_1: [node_1, node_2, node_3], slot_2: [node_4, node_5], } - primary1_name = n_manager.slots_cache[slot_1][0].name - primary2_name = n_manager.slots_cache[slot_2][0].name - list1_size = len(n_manager.slots_cache[slot_1]) - list2_size = len(n_manager.slots_cache[slot_2]) # slot 1 - assert lb.get_server_index(primary1_name, list1_size) == 0 - assert lb.get_server_index(primary1_name, list1_size) == 1 - assert lb.get_server_index(primary1_name, list1_size) == 2 - assert lb.get_server_index(primary1_name, list1_size) == 0 + assert n_manager.get_node_from_slot(slot_1, read_from_replicas=True) == node_1 + assert n_manager.get_node_from_slot(slot_1, read_from_replicas=True) == node_2 + assert n_manager.get_node_from_slot(slot_1, read_from_replicas=True) == node_3 + assert n_manager.get_node_from_slot(slot_1, read_from_replicas=True) == node_1 # slot 2 - assert lb.get_server_index(primary2_name, list2_size) == 0 - assert lb.get_server_index(primary2_name, list2_size) == 1 - assert lb.get_server_index(primary2_name, list2_size) == 0 + assert n_manager.get_node_from_slot(slot_2, read_from_replicas=True) == node_4 + assert n_manager.get_node_from_slot(slot_2, read_from_replicas=True) == node_5 + assert n_manager.get_node_from_slot(slot_2, read_from_replicas=True) == node_4 - lb.reset() - assert lb.get_server_index(primary1_name, list1_size) == 0 - assert lb.get_server_index(primary2_name, list2_size) == 0 + n_manager.read_load_balancer.reset() + assert n_manager.get_node_from_slot(slot_1, read_from_replicas=True) == node_1 + assert n_manager.get_node_from_slot(slot_2, read_from_replicas=True) == node_4 + + async def test_get_node_from_slot_no_replicas(self, r: RedisCluster) -> None: + n_manager = r.nodes_manager + slot_1 = 1257 + slot_2 = 8975 + node_1 = ClusterNode(default_host, 6379, PRIMARY) + node_2 = ClusterNode(default_host, 6378, REPLICA) + node_3 = ClusterNode(default_host, 6377, REPLICA) + node_4 = ClusterNode(default_host, 6376, PRIMARY) + node_5 = ClusterNode(default_host, 6375, REPLICA) + n_manager.slots_cache = { + slot_1: [node_1, node_2, node_3], + slot_2: [node_4, node_5], + } + assert n_manager.get_node_from_slot(slot_1, read_from_replicas=False) == node_1 + assert n_manager.get_node_from_slot(slot_1, read_from_replicas=False) == node_1 + assert n_manager.get_node_from_slot(slot_1, read_from_replicas=False) == node_1 + + async def test_get_node_from_slot_changed_slot_replicas( + self, r: RedisCluster + ) -> None: + n_manager = r.nodes_manager + slot_1 = 1257 + node_1 = ClusterNode(default_host, 6379, PRIMARY) + node_2 = ClusterNode(default_host, 6378, REPLICA) + node_3 = ClusterNode(default_host, 6377, REPLICA) + n_manager.slots_cache = { + slot_1: [node_1, node_2, node_3], + } + # slot 1 + assert n_manager.get_node_from_slot(slot_1, read_from_replicas=True) == node_1 + assert n_manager.get_node_from_slot(slot_1, read_from_replicas=True) == node_2 + assert n_manager.get_node_from_slot(slot_1, read_from_replicas=True) == node_3 + assert n_manager.get_node_from_slot(slot_1, read_from_replicas=True) == node_1 + assert n_manager.get_node_from_slot(slot_1, read_from_replicas=True) == node_2 + + # adjust lb-size. This can happen during resharding, + # see https://github.com/redis/redis-py/issues/2988 + n_manager.slots_cache[slot_1].pop() + assert n_manager.get_node_from_slot(slot_1, read_from_replicas=True) == node_1 def test_init_slots_cache_not_all_slots_covered(self): """ diff --git a/tests/test_load_balancer.py b/tests/test_load_balancer.py new file mode 100644 index 0000000000..e081a8ad2f --- /dev/null +++ b/tests/test_load_balancer.py @@ -0,0 +1,165 @@ +import pytest +from redis.cluster import ClusterNode, LoadBalancer + +default_host = "127.0.0.1" +default_port = 6379 + + +@pytest.mark.onlycluster +class TestLoadBalancer: + """ + Tests for the LoadBalancer class + """ + + def test_get_server_index(self) -> None: + lb = LoadBalancer() + primary1_name = f"{default_host}:{default_port}" + primary2_name = f"{default_host}:{default_port+1}" + list1_size = 3 + list2_size = 2 + # slot 1 + assert lb.get_server_index(primary1_name, list1_size) == 0 + assert lb.get_server_index(primary1_name, list1_size) == 1 + assert lb.get_server_index(primary1_name, list1_size) == 2 + assert lb.get_server_index(primary1_name, list1_size) == 0 + # slot 2 + assert lb.get_server_index(primary2_name, list2_size) == 0 + assert lb.get_server_index(primary2_name, list2_size) == 1 + assert lb.get_server_index(primary2_name, list2_size) == 0 + + lb.reset() + assert lb.get_server_index(primary1_name, list1_size) == 0 + assert lb.get_server_index(primary2_name, list2_size) == 0 + + def test_get_server_index_changed_slot_replicas(self) -> None: + lb = LoadBalancer() + primary1_name = f"{default_host}:6379" + list1_size = 3 + # slot 1 + assert lb.get_server_index(primary1_name, list1_size) == 0 + assert lb.get_server_index(primary1_name, list1_size) == 1 + assert lb.get_server_index(primary1_name, list1_size) == 2 + assert lb.get_server_index(primary1_name, list1_size) == 0 + assert lb.get_server_index(primary1_name, list1_size) == 1 + + # adjust lb-size + list1_size = 2 + assert lb.get_server_index(primary1_name, list1_size) == 0 + assert lb.get_server_index(primary1_name, list1_size) == 1 + assert lb.get_server_index(primary1_name, list1_size) == 0 + + def test_get_node_from_slot_single_node_primary_only(self) -> None: + """ + Test that the load balancer handles a server with only a single primary node + """ + load_balancer = LoadBalancer() + slot_nodes = [ClusterNode(default_host, default_port)] + assert ( + load_balancer.get_node_from_slot(slot_nodes, read_from_replicas=False) + == slot_nodes[0] + ) + + def test_get_node_from_slot_multiple_node_primary_only(self) -> None: + """ + Test that the load balancer handles a server with multiple nodes, but only read + from primaries + """ + load_balancer = LoadBalancer() + slot_nodes = [ + ClusterNode(default_host, default_port), + ClusterNode(default_host, default_port + 1), + ] + assert ( + load_balancer.get_node_from_slot(slot_nodes, read_from_replicas=False) + == slot_nodes[0] + ) + assert ( + load_balancer.get_node_from_slot(slot_nodes, read_from_replicas=False) + == slot_nodes[0] + ) + + def test_get_node_from_slot_single_node_read_from_replicas(self) -> None: + """ + Test that the load balancer handles a server with only primary nodes, but also + try to read from replicas + """ + load_balancer = LoadBalancer() + slot_nodes = [ClusterNode(default_host, default_port)] + assert ( + load_balancer.get_node_from_slot(slot_nodes, read_from_replicas=True) + == slot_nodes[0] + ) + assert ( + load_balancer.get_node_from_slot(slot_nodes, read_from_replicas=True) + == slot_nodes[0] + ) + + def test_get_node_from_slot_multiple_node_read_from_replicas(self) -> None: + """ + Test that the load balancer handles a server with primary and replica nodes, + but also try to read from replicas + """ + load_balancer = LoadBalancer() + slot_nodes = [ + ClusterNode(default_host, default_port), + ClusterNode(default_host, default_port + 1), + ClusterNode(default_host, default_port + 2), + ] + assert ( + load_balancer.get_node_from_slot(slot_nodes, read_from_replicas=True) + == slot_nodes[0] + ) + assert ( + load_balancer.get_node_from_slot(slot_nodes, read_from_replicas=True) + == slot_nodes[1] + ) + assert ( + load_balancer.get_node_from_slot(slot_nodes, read_from_replicas=True) + == slot_nodes[2] + ) + + def test_get_node_from_slot_multiple_node_read_from_replicas_resizing(self) -> None: + """ + Test that the load balancer handles a server with primary and replica nodes, + but also try to read from replicas, and handles resharding slots + """ + load_balancer = LoadBalancer() + slot_nodes = [ + ClusterNode(default_host, default_port), + ClusterNode(default_host, default_port + 1), + ClusterNode(default_host, default_port + 2), + ] + assert ( + load_balancer.get_node_from_slot(slot_nodes, read_from_replicas=True) + == slot_nodes[0] + ) + assert ( + load_balancer.get_node_from_slot(slot_nodes, read_from_replicas=True) + == slot_nodes[1] + ) + assert ( + load_balancer.get_node_from_slot(slot_nodes, read_from_replicas=True) + == slot_nodes[2] + ) + assert ( + load_balancer.get_node_from_slot(slot_nodes, read_from_replicas=True) + == slot_nodes[0] + ) + assert ( + load_balancer.get_node_from_slot(slot_nodes, read_from_replicas=True) + == slot_nodes[1] + ) + + slot_nodes.pop() + assert ( + load_balancer.get_node_from_slot(slot_nodes, read_from_replicas=True) + == slot_nodes[0] + ) + assert ( + load_balancer.get_node_from_slot(slot_nodes, read_from_replicas=True) + == slot_nodes[1] + ) + assert ( + load_balancer.get_node_from_slot(slot_nodes, read_from_replicas=True) + == slot_nodes[0] + )