forked from OpenBMB/XAgent
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmanager.py
77 lines (57 loc) · 2.81 KB
/
manager.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from __future__ import annotations
import asyncio
import os
import threading
from typing import Dict, List
from fastapi import WebSocket, WebSocketDisconnect, status
from XAgentIO.exception import (XAgentIOWebSocketReceiveError,
XAgentIOWebSocketSendError)
from XAgentServer.envs import XAgentServerEnv
from XAgentServer.loggers.logs import Logger
from XAgentServer.response_body import WebsocketResponseBody
class Singleton(type):
_instances = {}
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super().__call__(*args, **kwargs)
return cls._instances[cls]
class WebSocketConnectionManager(metaclass=Singleton):
def __init__(self):
self.active_connections: List[Dict[str, WebSocket]] = []
self.logger = Logger(log_dir=os.path.join(XAgentServerEnv.base_dir, "logs"), log_file="websocket.log")
self.create_pong_task()
async def connect(self, websocket: WebSocket, websocket_id: str):
await websocket.accept()
self.logger.info(f"websocket {websocket_id} connected")
self.active_connections.append({websocket_id: websocket})
async def disconnect(self, websocket_id: str, websocket: WebSocket):
self.active_connections.remove({websocket_id: websocket})
self.logger.info(f"websocket {websocket_id} remove from active connections")
def is_connected(self, websocket_id: str) -> bool:
for connection in self.active_connections:
if websocket_id in connection.keys():
return True
return False
def get_connection(self, websocket_id: str) -> WebSocket:
for connection in self.active_connections:
if websocket_id in connection.keys():
return connection[websocket_id]
return None
async def broadcast_pong(self):
while True:
self.logger.info(f"pong broadcast for active connections: {len(self.active_connections)}")
for connection in self.active_connections:
for websocket_id, websocket in connection.items():
try:
await websocket.send_text(WebsocketResponseBody(status="pong", data={"type": "pong"}, message="pong").to_text())
except Exception as e:
self.logger.error(f"websocket {websocket_id} is disconnected")
self.active_connections.remove(connection)
continue
await asyncio.sleep(20)
def loop_pong(self):
asyncio.run(self.broadcast_pong())
def create_pong_task(self):
self.logger.info("Create task for pong broadcast")
pong = threading.Thread(target=self.loop_pong, daemon=True)
pong.start()