Skip to content

Commit 3a097c5

Browse files
committed
Added info-server.
Signed-off-by: Pavel Kirilin <[email protected]>
1 parent 6bd1550 commit 3a097c5

File tree

4 files changed

+236
-0
lines changed

4 files changed

+236
-0
lines changed

taskiq/cli/info_server/__init__.py

Whitespace-only changes.

taskiq/cli/info_server/client.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import asyncio
2+
import ipaddress
3+
import json
4+
import socket
5+
from typing import Any, Dict
6+
import contextlib
7+
8+
9+
def sync_send_request(host: str, port: int, data: Dict[str, Any]) -> str:
10+
with contextlib.suppress(ValueError):
11+
host = ipaddress.ip_address(host)
12+
if isinstance(host, ipaddress.IPv6Address):
13+
addr_family = socket.AF_INET6
14+
else:
15+
addr_family = socket.AF_INET
16+
info = socket.getaddrinfo(
17+
host,
18+
port,
19+
family=addr_family,
20+
type=socket.SOCK_STREAM,
21+
)
22+
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
23+
sock.connect((host, port))
24+
encoded = json.dumps(data).encode("utf-8")
25+
sock.sendall(len(encoded).to_bytes(10, "big") + encoded)
26+
27+
28+
class TaskiqServerClient:
29+
def __init__(self, host: str, port: int) -> None:
30+
self.host = host
31+
self.port = port
32+
33+
async def read_response(self, reader: asyncio.StreamReader) -> Dict[str, Any]:
34+
response = await reader.read(10)
35+
if response == b"":
36+
raise ConnectionError("Connection closed")
37+
body_len = int.from_bytes(response, "big")
38+
buffer = b""
39+
while len(buffer) < body_len:
40+
buffer += await reader.read(1024)
41+
return json.loads(buffer[:body_len])
42+
43+
async def send_request(self, data: Dict[str, Any]) -> str:
44+
reader, writer = await asyncio.open_connection(self.host, self.port)
45+
body = json.dumps(data)
46+
body_len = len(body)
47+
writer.write(body_len.to_bytes(10, "big") + body.encode("utf-8"))
48+
await writer.drain()
49+
return await self.read_response(reader)

taskiq/cli/info_server/models.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import enum
2+
3+
4+
@enum.unique
5+
class WorkerState(int, enum.Enum):
6+
"""Worker state enumeration."""
7+
8+
READY = enum.auto()
9+
IDLE = enum.auto()
10+
BUSY = enum.auto()
11+
STOPPING = enum.auto()
12+
STOPPED = enum.auto()

taskiq/cli/info_server/server.py

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
import ipaddress
2+
import json
3+
import logging
4+
import os
5+
import socket
6+
import threading
7+
from dataclasses import dataclass
8+
from multiprocessing.pool import ThreadPool
9+
import time
10+
from typing import Any, Dict, List, Optional
11+
12+
from pydantic import BaseModel, Field
13+
14+
from taskiq.cli.info_server.models import WorkerState
15+
16+
logger = logging.getLogger("taskiq.worker.info_server")
17+
18+
19+
class ServerState(BaseModel):
20+
"""State of the taskiq server."""
21+
22+
workers_state: List[WorkerState] = Field(default_factory=list)
23+
workers_count: int
24+
active_tasks: Dict[str, List[Dict[str, Any]]] = Field(default_factory=dict)
25+
26+
27+
class TaskiqInfoServer(threading.Thread):
28+
def __init__(
29+
self,
30+
host: str,
31+
port: int,
32+
ready_event: Optional[threading.Event] = None,
33+
workers_count: int = 0,
34+
) -> None:
35+
super().__init__()
36+
try:
37+
addr = ipaddress.ip_address(host)
38+
except ValueError:
39+
addr = host
40+
if isinstance(addr, ipaddress.IPv6Address):
41+
addr_family = socket.AF_INET6
42+
else:
43+
addr_family = socket.AF_INET
44+
info = socket.getaddrinfo(
45+
host,
46+
port,
47+
family=addr_family,
48+
type=socket.SOCK_STREAM,
49+
)
50+
self.addr_family, self.sock_kind, self.sock_proto, _, self.bind_info = info[0]
51+
self.stop_event = threading.Event()
52+
self.state = {}
53+
self.ready_event = ready_event
54+
self.state = ServerState(workers_count=workers_count)
55+
self.methods = {
56+
"update_state": self.update_state,
57+
}
58+
59+
def wait_started(self) -> None:
60+
if self.ready_event is None:
61+
return
62+
while not self.ready_event.is_set():
63+
time.sleep(0.1)
64+
if not self.is_alive():
65+
raise RuntimeError("Failed to start server")
66+
67+
def wait_workers(self, timeout: Optional[float] = None) -> None:
68+
start = time.monotonic()
69+
while True:
70+
for state in self.state.workers_state:
71+
if state == WorkerState.READY:
72+
break
73+
74+
if self.state.workers_count == len(self.state.workers_state):
75+
break
76+
77+
if timeout is not None and time.monotonic() - start > timeout:
78+
raise TimeoutError("Failed to start workers")
79+
print(self.state)
80+
time.sleep(0.1)
81+
82+
def kill(self) -> None:
83+
self.stop_event.set()
84+
85+
def run(self) -> None:
86+
server = socket.socket(
87+
self.addr_family,
88+
self.sock_kind,
89+
self.sock_proto,
90+
)
91+
if os.name != "nt":
92+
server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
93+
server.bind(self.bind_info)
94+
server.settimeout(1)
95+
server.listen()
96+
if self.ready_event is not None:
97+
self.ready_event.set()
98+
while True:
99+
try:
100+
client, addr = server.accept()
101+
logger.info(f"Accepted connection from {addr[0]}:{addr[1]}")
102+
threading.Thread(target=self._handle_client, args=(client,)).start()
103+
except TimeoutError:
104+
if self.stop_event.is_set():
105+
break
106+
except Exception as exc:
107+
logger.warning(
108+
"Exception found when processing request %s",
109+
exc,
110+
exc_info=True,
111+
)
112+
server.close()
113+
114+
def _receive_request(self, client: socket.socket) -> Dict[str, Any]:
115+
body_len = client.recv(10)
116+
body_len = int.from_bytes(body_len, "big")
117+
buffer = b""
118+
while len(buffer) < body_len:
119+
buffer += client.recv(1024)
120+
buffer = buffer[:body_len]
121+
return json.loads(buffer)
122+
123+
def _send_response(self, client: socket.socket, data: Dict[str, Any]) -> None:
124+
encoded = json.dumps(data).encode("utf-8")
125+
body_len = len(encoded)
126+
client.sendall(body_len.to_bytes(10, "big") + encoded)
127+
128+
def _handle_client(self, client: socket.socket) -> None:
129+
empty_response = {"status": "ok", "data": {}}
130+
try:
131+
request = self._receive_request(client)
132+
except ValueError as exc:
133+
self._send_response(
134+
client,
135+
{"status": "error", "description": str(exc)},
136+
)
137+
return
138+
if "method" not in request and "params" not in request:
139+
response = {
140+
"status": "error",
141+
"description": "Invalid request",
142+
}
143+
elif request["method"] not in self.methods:
144+
response = {
145+
"status": "error",
146+
"description": "Unknown method",
147+
}
148+
else:
149+
try:
150+
response = self.methods[request["method"]](request) or empty_response
151+
except Exception as exc:
152+
logger.warning(
153+
"Exception found when processing request %s",
154+
exc,
155+
exc_info=True,
156+
)
157+
response = {
158+
"status": "error",
159+
"description": str(exc),
160+
}
161+
self._send_response(client, response)
162+
163+
def update_state(self, request: Dict[str, Any]) -> None:
164+
worker_id = request["params"]["worker_id"]
165+
state = request["params"]["state"]
166+
self.state.workers_state[worker_id] = WorkerState(state)
167+
168+
169+
if __name__ == "__main__":
170+
logging.basicConfig(level=logging.INFO)
171+
ev = threading.Event()
172+
server = TaskiqInfoServer("127.0.0.1", 2332, ev)
173+
server.start()
174+
server.wait_started()
175+
server.join()

0 commit comments

Comments
 (0)