Skip to content

Commit 3d0f15f

Browse files
committed
🦄 refactor: RedisStreamStore
1 parent 8e47987 commit 3d0f15f

File tree

3 files changed

+248
-60
lines changed

3 files changed

+248
-60
lines changed

src/usepy_plugin_redis/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from .store import RedisStore as useRedis
2-
from .stream import RedisStreamStore as useRedisStream
2+
from .stream import RedisStreamStore as useRedisStream, RedisStreamMessage
33
from .lock import Lock as useRedisLock
44

55
useRedisStreamStore = useRedisStream
66

77
__all__ = [
88
"useRedis",
99
"useRedisStreamStore",
10+
"RedisStreamMessage",
1011
"useRedisStream",
1112
"useRedisLock"
1213
]

src/usepy_plugin_redis/stream.py

+208-42
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
import json
12
import logging
3+
import threading
24
import time
5+
from typing import Optional, List
36

47
import redis
58

@@ -8,12 +11,78 @@
811
logger = logging.getLogger(__name__)
912

1013

11-
class RedisStreamStore(RedisStore):
14+
class RedisStreamMessage:
15+
"""
16+
A message from the Redis stream.
17+
"""
18+
19+
def __init__(self, message_id, message_body, stream_name, consumer_group, consumer_name: Optional[str] = None):
20+
self.stream_name = stream_name
21+
self.consumer_group = consumer_group
22+
self.consumer_name = consumer_name
23+
self.message_id = message_id
24+
self.body = message_body
25+
26+
@staticmethod
27+
def from_xread(raw_messages, *, consumer_group, consumer_name: Optional[str] = None):
28+
"""
29+
Parse a raw message from the Redis stream xread[group] respond.
30+
"""
31+
# xread = [['test_consumer_stream', [('1693561850564-0', {'foo': 'bar'}), ('1693561905479-0', {'foo': 'bar'})]]]
32+
33+
if not consumer_group:
34+
raise ValueError("consumer_group is required")
35+
36+
result = []
37+
for stream_name, messages in raw_messages:
38+
result.extend(
39+
RedisStreamMessage(message_id, message_body, stream_name, consumer_group, consumer_name)
40+
for message_id, message_body in messages
41+
)
42+
43+
return result
44+
45+
@staticmethod
46+
def from_xclaim(raw_messages, *, stream_name, consumer_group, consumer_name: Optional[str] = None):
47+
"""
48+
Parse a raw message from the Redis stream xclaim respond.
49+
"""
50+
# xclaim = [('1693561850564-0', {'foo': 'bar'}), ('1693561905479-0', {'foo': 'bar'})]
51+
52+
if not stream_name:
53+
raise ValueError("stream_name is required")
54+
if not consumer_group:
55+
raise ValueError("consumer_group is required")
56+
57+
return [RedisStreamMessage(message_id, message_body, stream_name, consumer_group, consumer_name)
58+
for message_id, message_body in raw_messages]
59+
60+
def to_dict(self):
61+
return {
62+
"stream_name": self.stream_name,
63+
"consumer_group": self.consumer_group,
64+
"consumer_name": self.consumer_name,
65+
"message_id": self.message_id,
66+
"message_body": self.body
67+
}
68+
69+
def to_json(self):
70+
return json.dumps(self.to_dict())
71+
72+
def __str__(self):
73+
return self.to_json()
74+
75+
def __repr__(self):
76+
return f"RedisStreamMessage({self.stream_name}, {self.consumer_group}, {self.message_id}, {self.body})"
77+
1278

79+
class RedisStreamStore(RedisStore):
80+
1381
def __init__(self, stream_name, *args, **kwargs):
1482
super().__init__(*args, **kwargs)
1583
self.stream_name = stream_name
16-
84+
self.state = threading.local()
85+
1786
def send(self, message: dict):
1887
"""
1988
Send a message to the Redis stream.
@@ -23,58 +92,155 @@ def send(self, message: dict):
2392
except redis.RedisError as e:
2493
logger.error(f"Failed to send message: {e}")
2594
raise
26-
95+
2796
def _create_group(self, consumer_group, start_id="0-0"):
2897
try:
2998
self.connection.xgroup_create(self.stream_name, consumer_group, id=start_id, mkstream=True)
3099
except redis.exceptions.ResponseError as e:
31100
if "already exists" not in str(e):
32101
raise e
33-
34-
def _process_pending_messages(self, consumer_name, consumer_group, callback):
102+
103+
def _is_need_xclaim(self):
104+
"""判断是否需要 xclaim"""
105+
return time.time() * 1000 - self.state.xclaim_last_time > self.state.xclaim_interval
106+
107+
def _xautoclaim(
108+
self,
109+
consumer_name: str,
110+
consumer_group: str,
111+
count: int,
112+
min_idle_time: int
113+
) -> Optional[List[RedisStreamMessage]]:
35114
"""
36-
Process pending messages for the given consumer.
115+
Claim messages from the Redis stream.
37116
"""
38-
claim_start_id = "0-0"
39-
while True:
40-
try:
41-
pending_info = self.connection.xpending(self.stream_name, consumer_group)
42-
if pending_info.get('pending', 0) > 0:
43-
start_id = pending_info['min']
44-
end_id = pending_info['max']
45-
# pending_messages = self.connection.xpending_range(self.stream_name, consumer_group, start_id,
46-
# end_id, count=10)
47-
48-
claim_start_id, pending_messages, _ = self.connection.xautoclaim(
49-
self.stream_name, consumer_group, consumer_name,
50-
min_idle_time=0,
51-
start_id=claim_start_id,
52-
count=1
53-
)
54-
print(pending_messages)
55-
for message in pending_messages:
56-
callback([message], consumer_name, consumer_group)
57-
else:
58-
break
59-
60-
except redis.RedisError as e:
61-
logger.error(f"Error processing pending messages: {e}")
62-
time.sleep(self.RECONNECTION_DELAY)
63-
64-
def start_consuming(self, consumer_name, consumer_group, callback, prefetch=1):
117+
# TODO: 异常需要重试
118+
try:
119+
_, pending_messages, _ = self.connection.xautoclaim(
120+
self.stream_name, consumer_group, consumer_name,
121+
min_idle_time=min_idle_time,
122+
start_id="0-0",
123+
count=count
124+
)
125+
logger.debug(f"xautoclaim: {pending_messages=}")
126+
if pending_messages:
127+
return RedisStreamMessage.from_xclaim(
128+
raw_messages=pending_messages,
129+
stream_name=self.stream_name,
130+
consumer_group=consumer_group,
131+
consumer_name=consumer_name
132+
)
133+
except redis.RedisError as e:
134+
logger.error(f"Error claiming messages: {e}")
135+
time.sleep(self.RECONNECTION_DELAY)
136+
137+
def _xreadgroup(
138+
self,
139+
consumer_name: str,
140+
consumer_group: str,
141+
count: int,
142+
block: Optional[int] = None
143+
) -> Optional[List[RedisStreamMessage]]:
144+
"""
145+
Read messages from the Redis stream.
146+
"""
147+
try:
148+
raw_messages = self.connection.xreadgroup(
149+
groupname=consumer_group,
150+
consumername=consumer_name,
151+
streams={self.stream_name: ">"},
152+
count=count,
153+
block=block,
154+
)
155+
logger.debug(f"xreadgroup: {raw_messages=}")
156+
if raw_messages:
157+
return RedisStreamMessage.from_xread(
158+
raw_messages=raw_messages,
159+
consumer_group=consumer_group,
160+
consumer_name=consumer_name,
161+
)
162+
except redis.RedisError as e:
163+
logger.error(f"Error reading messages: {e}")
164+
time.sleep(self.RECONNECTION_DELAY)
165+
166+
def consume(
167+
self,
168+
consumer_group: str,
169+
consumer_name: str,
170+
prefetch: int,
171+
claim_min_idle_time: int = 3600000,
172+
force_claim: bool = False,
173+
block: Optional[int] = None
174+
) -> Optional[List[RedisStreamMessage]]:
175+
"""
176+
Consume messages from the Redis stream(order: xclaim -> xreadgroup).
177+
178+
:param consumer_name: 消费者名称
179+
:param consumer_group: 消费组名称
180+
:param prefetch: 消费数量
181+
:param block: xreadgroup 阻塞时间,单位毫秒,默认为 None,不阻塞
182+
:param claim_min_idle_time: xclaim 最小空闲的时间(即消息消费的超时时间),单位毫秒,默认为 1 小时
183+
:param force_claim: 是否强制 xclaim,True=忽略上次 xclaim 的时间和 xclaim 间隔限制,False=按周期执行 xclaim
184+
185+
:return: 消费的消息列表,消费失败时则返回 []
186+
"""
187+
# 获取 当前消费者尚未ACK 的消息,最多获取 prefetch 个
188+
pending_messages = self.connection.xpending_range(
189+
name=self.stream_name,
190+
groupname=consumer_group,
191+
min='-',
192+
max='+',
193+
count=prefetch,
194+
idle=0,
195+
consumername=consumer_name,
196+
)
197+
# 计算需获取的消息数量
198+
need_count = prefetch - len(pending_messages) if pending_messages else prefetch
199+
if need_count <= 0:
200+
return []
201+
202+
result = []
203+
# 先尝试 xclaim
204+
if force_claim or self._is_need_xclaim():
205+
pending_messages = self._xautoclaim(consumer_name, consumer_group, need_count, claim_min_idle_time)
206+
if pending_messages:
207+
result.extend(pending_messages)
208+
# 更新还需获取的消息数量
209+
need_count = need_count - len(result)
210+
211+
self.state.xclaim_last_time = time.time() * 1000
212+
213+
# 然后 xreadgroup
214+
if need_count > 0:
215+
messages = self._xreadgroup(consumer_name, consumer_group, need_count, block)
216+
if messages:
217+
result.extend(messages)
218+
219+
return result
220+
221+
def start_consuming(self, consumer_group, consumer_name, callback, prefetch=1, timeout=3600000, **kwargs):
65222
"""
66223
Start consuming messages from the Redis stream.
224+
225+
:param consumer_name: 消费者名称
226+
:param consumer_group: 消费组名称
227+
:param callback: 消费回调函数
228+
:param prefetch: 消费并发数量
229+
:param timeout: 消费超时时间(即 xclaim 最小空闲的时间),默认为 1 小时
67230
"""
68-
self._create_group(consumer_group)
69-
70-
self._process_pending_messages(consumer_name, consumer_group, callback)
71-
72-
while True:
231+
# 初始化工作
232+
self._create_group(consumer_group, kwargs.get('group_start_id', '0-0'))
233+
block = kwargs.get('xread_block', None)
234+
235+
# xclaim 时必要的参数,得线程隔离
236+
self.state.xclaim_interval = kwargs.get('xclaim_interval', 5 * 60 * 1000) # xclaim 间隔时间,单位毫秒,默认 5 分钟
237+
self.state.xclaim_last_time = 0
238+
239+
while not self._shutdown:
73240
try:
74-
messages = self.connection.xreadgroup(consumer_group, consumer_name, {self.stream_name: '>'},
75-
count=prefetch)
76-
for _, message in messages:
77-
callback(message, consumer_name, consumer_group)
241+
messages = self.consume(consumer_group, consumer_name, prefetch, timeout, block=block)
242+
for message in messages:
243+
callback(message)
78244
except redis.RedisError as e:
79245
logger.error(f"Error consuming messages: {e}")
80246
time.sleep(self.RECONNECTION_DELAY)

tests/test_stream.py

+38-17
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import threading
2+
import time
23

34
import pytest
45

5-
from usepy_plugin_redis import useRedisStreamStore
6+
from usepy_plugin_redis import useRedisStreamStore, RedisStreamMessage
67

78

89
@pytest.fixture
@@ -28,21 +29,41 @@ def test_consumer(redis):
2829
redis.stream_name = stream
2930
send_message = {'foo': 'bar'}
3031
redis.send(send_message)
31-
32-
def callback(message, consumer_group, *args, **kwargs):
33-
assert isinstance(message, list)
34-
assert len(message) == 1
35-
first_message = message[0]
36-
message_id, message = first_message
37-
print("message_id", message_id)
38-
_message = {
39-
k.decode() if isinstance(k, bytes) else k: v.decode() if isinstance(v, bytes) else v
40-
for k, v in message.items()
41-
}
42-
assert _message == send_message
43-
redis.connection.xack(stream, consumer_group, message_id)
32+
33+
def callback(message: RedisStreamMessage, *args, **kwargs):
34+
# print("message", message.to_dict())
35+
assert message.body == send_message
36+
redis.connection.xack(message.stream_name, message.consumer_group, message.message_id)
4437
redis.shutdown()
38+
39+
# threading.Thread(target=redis.start_consuming, args=("consumer_group1", "consumer_name", callback,)).start()
40+
threading.Thread(target=redis.start_consuming, args=("consumer_group2", "consumer_name", callback,),
41+
kwargs={'prefetch': 2}).start()
42+
# threading.Thread(target=redis.start_consuming, args=("consumer_group3", "consumer_name", callback,)).start()
43+
4544

46-
threading.Thread(target=redis.start_consuming, args=("consumer_name", "consumer_group1", callback,)).start()
47-
threading.Thread(target=redis.start_consuming, args=("consumer_name", "consumer_group2", callback,)).start()
48-
threading.Thread(target=redis.start_consuming, args=("consumer_name", "consumer_group3", callback,)).start()
45+
def test_consumer_xclaim(redis):
46+
stream = 'test_consumer_xclaim_stream'
47+
redis.connection.delete(stream)
48+
49+
redis.stream_name = stream
50+
send_message = {'foo': 'bar'}
51+
redis.send(send_message)
52+
53+
def callback(message: RedisStreamMessage, *args, **kwargs):
54+
# print("message", message.to_dict())
55+
assert message.body == send_message
56+
redis.connection.xack(message.stream_name, message.consumer_group, message.message_id)
57+
redis.shutdown()
58+
59+
redis._create_group("consumer_group")
60+
redis.consume("consumer_group", "consumer1", prefetch=1, claim_min_idle_time=0, force_claim=True)
61+
time.sleep(1)
62+
job = threading.Thread(target=redis.start_consuming,
63+
args=("consumer_group", "consumer2", callback,),
64+
kwargs={'prefetch': 1, 'xclaim_interval': 5000, 'timeout': 500, 'force_claim': True}
65+
)
66+
job.start()
67+
job.join()
68+
69+

0 commit comments

Comments
 (0)