Skip to content

Fix get_node_from_slot to handle resharding #3182

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
* Fix get_node_from_slot for cluster and asyncio cluster during reshard (#2988)
* Move doctests (doc code examples) to main branch
* Update `ResponseT` type hint
* Allow to control the minimum SSL version
Expand Down
26 changes: 11 additions & 15 deletions redis/asyncio/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
REPLICA,
SLOT_ID,
AbstractRedisCluster,
LoadBalancer,
block_pipeline_command,
get_node_name,
parse_cluster_slots,
Expand All @@ -68,6 +67,7 @@
TimeoutError,
TryAgainError,
)
from redis.load_balancer import LoadBalancer
from redis.typing import AnyKeyT, EncodableT, KeyT
from redis.utils import (
deprecated_function,
Expand Down Expand Up @@ -1244,23 +1244,19 @@ def _update_moved_slots(self) -> None:
def get_node_from_slot(
self, slot: int, read_from_replicas: bool = False
) -> "ClusterNode":
"""
Gets a node that serves this hash slot
"""
if self._moved_exception:
self._update_moved_slots()

try:
if read_from_replicas:
# get the server index in a Round-Robin manner
primary_name = self.slots_cache[slot][0].name
node_idx = self.read_load_balancer.get_server_index(
primary_name, len(self.slots_cache[slot])
)
return self.slots_cache[slot][node_idx]
return self.slots_cache[slot][0]
except (IndexError, TypeError):
raise SlotNotCoveredError(
f'Slot "{slot}" not covered by the cluster. '
f'"require_full_coverage={self.require_full_coverage}"'
)
slot_nodes = self.slots_cache.get(slot, None)
if slot_nodes is None or len(slot_nodes) == 0:
raise SlotNotCoveredError(f'Slot "{slot}" not covered by the cluster.')
return self.read_load_balancer.get_node_from_slot(
slot_nodes,
read_from_replicas,
)

def get_nodes_by_server_type(self, server_type: str) -> List["ClusterNode"]:
return [
Expand Down
58 changes: 12 additions & 46 deletions redis/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
TimeoutError,
TryAgainError,
)
from redis.load_balancer import LoadBalancer
from redis.lock import Lock
from redis.retry import Retry
from redis.utils import (
Expand Down Expand Up @@ -1319,25 +1320,6 @@ def invalidate_key_from_cache(self, key):
self.redis_connection.invalidate_key_from_cache(key)


class LoadBalancer:
"""
Round-Robin Load Balancing
"""

def __init__(self, start_index: int = 0) -> None:
self.primary_to_idx = {}
self.start_index = start_index

def get_server_index(self, primary: str, list_size: int) -> int:
server_index = self.primary_to_idx.setdefault(primary, self.start_index)
# Update the index
self.primary_to_idx[primary] = (server_index + 1) % list_size
return server_index

def reset(self) -> None:
self.primary_to_idx.clear()


class NodesManager:
def __init__(
self,
Expand Down Expand Up @@ -1426,40 +1408,24 @@ def _update_moved_slots(self):
# Reset moved_exception
self._moved_exception = None

def get_node_from_slot(self, slot, read_from_replicas=False, server_type=None):
def get_node_from_slot(
self, slot: int, read_from_replicas: bool = False
) -> "ClusterNode":
"""
Gets a node that servers this hash slot
Gets a node that serves this hash slot
"""
if self._moved_exception:
with self._lock:
if self._moved_exception:
self._update_moved_slots()

if self.slots_cache.get(slot) is None or len(self.slots_cache[slot]) == 0:
raise SlotNotCoveredError(
f'Slot "{slot}" not covered by the cluster. '
f'"require_full_coverage={self._require_full_coverage}"'
)

if read_from_replicas is True:
# get the server index in a Round-Robin manner
primary_name = self.slots_cache[slot][0].name
node_idx = self.read_load_balancer.get_server_index(
primary_name, len(self.slots_cache[slot])
)
elif (
server_type is None
or server_type == PRIMARY
or len(self.slots_cache[slot]) == 1
):
# return a primary
node_idx = 0
else:
# return a replica
# randomly choose one of the replicas
node_idx = random.randint(1, len(self.slots_cache[slot]) - 1)

return self.slots_cache[slot][node_idx]
slot_nodes = self.slots_cache.get(slot, None)
if slot_nodes is None or len(slot_nodes) == 0:
raise SlotNotCoveredError(f'Slot "{slot}" not covered by the cluster.')
return self.read_load_balancer.get_node_from_slot(
slot_nodes,
read_from_replicas,
)

def get_nodes_by_server_type(self, server_type):
"""
Expand Down
39 changes: 39 additions & 0 deletions redis/load_balancer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from typing import TYPE_CHECKING, List, Union

if TYPE_CHECKING:
from redis.asyncio.cluster import ClusterNode as AsyncioClusterNode # noqa: F401
from redis.cluster import ClusterNode


class LoadBalancer:
"""
Round-Robin Load Balancing
"""

def __init__(self) -> None:
self.primary_name_to_last_used_index: dict[str, int] = {}

def get_node_from_slot(
self,
slot_nodes: Union[List["ClusterNode"], List["AsyncioClusterNode"]],
read_from_replicas: bool,
) -> Union["ClusterNode", "AsyncioClusterNode"]:
assert len(slot_nodes) > 0
if not read_from_replicas:
return slot_nodes[0]

primary_name = slot_nodes[0].name
node_idx = self.get_server_index(primary_name, len(slot_nodes))
return slot_nodes[node_idx]

def get_server_index(self, primary: str, list_size: int) -> int:
# default to -1 if not found, so after incrementing it will be 0
server_index = (
self.primary_name_to_last_used_index.get(primary, -1) + 1
) % list_size
# Update the index
self.primary_name_to_last_used_index[primary] = server_index
return server_index

def reset(self) -> None:
self.primary_name_to_last_used_index.clear()
66 changes: 50 additions & 16 deletions tests/test_asyncio/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -2355,9 +2355,8 @@ class TestNodesManager:
Tests for the NodesManager class
"""

async def test_load_balancer(self, r: RedisCluster) -> None:
async def test_get_node_from_slot(self, r: RedisCluster) -> None:
n_manager = r.nodes_manager
lb = n_manager.read_load_balancer
slot_1 = 1257
slot_2 = 8975
node_1 = ClusterNode(default_host, 6379, PRIMARY)
Expand All @@ -2369,23 +2368,58 @@ async def test_load_balancer(self, r: RedisCluster) -> None:
slot_1: [node_1, node_2, node_3],
slot_2: [node_4, node_5],
}
primary1_name = n_manager.slots_cache[slot_1][0].name
primary2_name = n_manager.slots_cache[slot_2][0].name
list1_size = len(n_manager.slots_cache[slot_1])
list2_size = len(n_manager.slots_cache[slot_2])
# slot 1
assert lb.get_server_index(primary1_name, list1_size) == 0
assert lb.get_server_index(primary1_name, list1_size) == 1
assert lb.get_server_index(primary1_name, list1_size) == 2
assert lb.get_server_index(primary1_name, list1_size) == 0
assert n_manager.get_node_from_slot(slot_1, read_from_replicas=True) == node_1
assert n_manager.get_node_from_slot(slot_1, read_from_replicas=True) == node_2
assert n_manager.get_node_from_slot(slot_1, read_from_replicas=True) == node_3
assert n_manager.get_node_from_slot(slot_1, read_from_replicas=True) == node_1
# slot 2
assert lb.get_server_index(primary2_name, list2_size) == 0
assert lb.get_server_index(primary2_name, list2_size) == 1
assert lb.get_server_index(primary2_name, list2_size) == 0
assert n_manager.get_node_from_slot(slot_2, read_from_replicas=True) == node_4
assert n_manager.get_node_from_slot(slot_2, read_from_replicas=True) == node_5
assert n_manager.get_node_from_slot(slot_2, read_from_replicas=True) == node_4

lb.reset()
assert lb.get_server_index(primary1_name, list1_size) == 0
assert lb.get_server_index(primary2_name, list2_size) == 0
n_manager.read_load_balancer.reset()
assert n_manager.get_node_from_slot(slot_1, read_from_replicas=True) == node_1
assert n_manager.get_node_from_slot(slot_2, read_from_replicas=True) == node_4

async def test_get_node_from_slot_no_replicas(self, r: RedisCluster) -> None:
n_manager = r.nodes_manager
slot_1 = 1257
slot_2 = 8975
node_1 = ClusterNode(default_host, 6379, PRIMARY)
node_2 = ClusterNode(default_host, 6378, REPLICA)
node_3 = ClusterNode(default_host, 6377, REPLICA)
node_4 = ClusterNode(default_host, 6376, PRIMARY)
node_5 = ClusterNode(default_host, 6375, REPLICA)
n_manager.slots_cache = {
slot_1: [node_1, node_2, node_3],
slot_2: [node_4, node_5],
}
assert n_manager.get_node_from_slot(slot_1, read_from_replicas=False) == node_1
assert n_manager.get_node_from_slot(slot_1, read_from_replicas=False) == node_1
assert n_manager.get_node_from_slot(slot_1, read_from_replicas=False) == node_1

async def test_get_node_from_slot_changed_slot_replicas(
self, r: RedisCluster
) -> None:
n_manager = r.nodes_manager
slot_1 = 1257
node_1 = ClusterNode(default_host, 6379, PRIMARY)
node_2 = ClusterNode(default_host, 6378, REPLICA)
node_3 = ClusterNode(default_host, 6377, REPLICA)
n_manager.slots_cache = {
slot_1: [node_1, node_2, node_3],
}
# slot 1
assert n_manager.get_node_from_slot(slot_1, read_from_replicas=True) == node_1
assert n_manager.get_node_from_slot(slot_1, read_from_replicas=True) == node_2
assert n_manager.get_node_from_slot(slot_1, read_from_replicas=True) == node_3
assert n_manager.get_node_from_slot(slot_1, read_from_replicas=True) == node_1
assert n_manager.get_node_from_slot(slot_1, read_from_replicas=True) == node_2

# adjust lb-size
n_manager.slots_cache[slot_1].pop()
assert n_manager.get_node_from_slot(slot_1, read_from_replicas=True) == node_1

async def test_init_slots_cache_not_all_slots_covered(self) -> None:
"""
Expand Down
67 changes: 51 additions & 16 deletions tests/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -2501,9 +2501,8 @@ class TestNodesManager:
Tests for the NodesManager class
"""

def test_load_balancer(self, r):
async def test_get_node_from_slot(self, r: RedisCluster) -> None:
n_manager = r.nodes_manager
lb = n_manager.read_load_balancer
slot_1 = 1257
slot_2 = 8975
node_1 = ClusterNode(default_host, 6379, PRIMARY)
Expand All @@ -2515,23 +2514,59 @@ def test_load_balancer(self, r):
slot_1: [node_1, node_2, node_3],
slot_2: [node_4, node_5],
}
primary1_name = n_manager.slots_cache[slot_1][0].name
primary2_name = n_manager.slots_cache[slot_2][0].name
list1_size = len(n_manager.slots_cache[slot_1])
list2_size = len(n_manager.slots_cache[slot_2])
# slot 1
assert lb.get_server_index(primary1_name, list1_size) == 0
assert lb.get_server_index(primary1_name, list1_size) == 1
assert lb.get_server_index(primary1_name, list1_size) == 2
assert lb.get_server_index(primary1_name, list1_size) == 0
assert n_manager.get_node_from_slot(slot_1, read_from_replicas=True) == node_1
assert n_manager.get_node_from_slot(slot_1, read_from_replicas=True) == node_2
assert n_manager.get_node_from_slot(slot_1, read_from_replicas=True) == node_3
assert n_manager.get_node_from_slot(slot_1, read_from_replicas=True) == node_1
# slot 2
assert lb.get_server_index(primary2_name, list2_size) == 0
assert lb.get_server_index(primary2_name, list2_size) == 1
assert lb.get_server_index(primary2_name, list2_size) == 0
assert n_manager.get_node_from_slot(slot_2, read_from_replicas=True) == node_4
assert n_manager.get_node_from_slot(slot_2, read_from_replicas=True) == node_5
assert n_manager.get_node_from_slot(slot_2, read_from_replicas=True) == node_4

lb.reset()
assert lb.get_server_index(primary1_name, list1_size) == 0
assert lb.get_server_index(primary2_name, list2_size) == 0
n_manager.read_load_balancer.reset()
assert n_manager.get_node_from_slot(slot_1, read_from_replicas=True) == node_1
assert n_manager.get_node_from_slot(slot_2, read_from_replicas=True) == node_4

async def test_get_node_from_slot_no_replicas(self, r: RedisCluster) -> None:
n_manager = r.nodes_manager
slot_1 = 1257
slot_2 = 8975
node_1 = ClusterNode(default_host, 6379, PRIMARY)
node_2 = ClusterNode(default_host, 6378, REPLICA)
node_3 = ClusterNode(default_host, 6377, REPLICA)
node_4 = ClusterNode(default_host, 6376, PRIMARY)
node_5 = ClusterNode(default_host, 6375, REPLICA)
n_manager.slots_cache = {
slot_1: [node_1, node_2, node_3],
slot_2: [node_4, node_5],
}
assert n_manager.get_node_from_slot(slot_1, read_from_replicas=False) == node_1
assert n_manager.get_node_from_slot(slot_1, read_from_replicas=False) == node_1
assert n_manager.get_node_from_slot(slot_1, read_from_replicas=False) == node_1

async def test_get_node_from_slot_changed_slot_replicas(
self, r: RedisCluster
) -> None:
n_manager = r.nodes_manager
slot_1 = 1257
node_1 = ClusterNode(default_host, 6379, PRIMARY)
node_2 = ClusterNode(default_host, 6378, REPLICA)
node_3 = ClusterNode(default_host, 6377, REPLICA)
n_manager.slots_cache = {
slot_1: [node_1, node_2, node_3],
}
# slot 1
assert n_manager.get_node_from_slot(slot_1, read_from_replicas=True) == node_1
assert n_manager.get_node_from_slot(slot_1, read_from_replicas=True) == node_2
assert n_manager.get_node_from_slot(slot_1, read_from_replicas=True) == node_3
assert n_manager.get_node_from_slot(slot_1, read_from_replicas=True) == node_1
assert n_manager.get_node_from_slot(slot_1, read_from_replicas=True) == node_2

# adjust lb-size. This can happen during resharding,
# see https://github.com/redis/redis-py/issues/2988
n_manager.slots_cache[slot_1].pop()
assert n_manager.get_node_from_slot(slot_1, read_from_replicas=True) == node_1

def test_init_slots_cache_not_all_slots_covered(self):
"""
Expand Down
Loading