forked from linyqh/NarratoAI
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathstate.py
122 lines (98 loc) · 3.06 KB
/
state.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import ast
from abc import ABC, abstractmethod
from app.config import config
from app.models import const
# Base class for state management
class BaseState(ABC):
@abstractmethod
def update_task(self, task_id: str, state: int, progress: int = 0, **kwargs):
pass
@abstractmethod
def get_task(self, task_id: str):
pass
# Memory state management
class MemoryState(BaseState):
def __init__(self):
self._tasks = {}
def update_task(
self,
task_id: str,
state: int = const.TASK_STATE_PROCESSING,
progress: int = 0,
**kwargs,
):
progress = int(progress)
if progress > 100:
progress = 100
self._tasks[task_id] = {
"state": state,
"progress": progress,
**kwargs,
}
def get_task(self, task_id: str):
return self._tasks.get(task_id, None)
def delete_task(self, task_id: str):
if task_id in self._tasks:
del self._tasks[task_id]
# Redis state management
class RedisState(BaseState):
def __init__(self, host="localhost", port=6379, db=0, password=None):
import redis
self._redis = redis.StrictRedis(host=host, port=port, db=db, password=password)
def update_task(
self,
task_id: str,
state: int = const.TASK_STATE_PROCESSING,
progress: int = 0,
**kwargs,
):
progress = int(progress)
if progress > 100:
progress = 100
fields = {
"state": state,
"progress": progress,
**kwargs,
}
for field, value in fields.items():
self._redis.hset(task_id, field, str(value))
def get_task(self, task_id: str):
task_data = self._redis.hgetall(task_id)
if not task_data:
return None
task = {
key.decode("utf-8"): self._convert_to_original_type(value)
for key, value in task_data.items()
}
return task
def delete_task(self, task_id: str):
self._redis.delete(task_id)
@staticmethod
def _convert_to_original_type(value):
"""
Convert the value from byte string to its original data type.
You can extend this method to handle other data types as needed.
"""
value_str = value.decode("utf-8")
try:
# try to convert byte string array to list
return ast.literal_eval(value_str)
except (ValueError, SyntaxError):
pass
if value_str.isdigit():
return int(value_str)
# Add more conversions here if needed
return value_str
# Global state
_enable_redis = config.app.get("enable_redis", False)
_redis_host = config.app.get("redis_host", "localhost")
_redis_port = config.app.get("redis_port", 6379)
_redis_db = config.app.get("redis_db", 0)
_redis_password = config.app.get("redis_password", None)
state = (
RedisState(
host=_redis_host, port=_redis_port, db=_redis_db, password=_redis_password
)
if _enable_redis
else MemoryState()
)