Coverage for core/services/websocket_service_factory.py: 45.34%
161 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-13 18:58 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-13 18:58 +0000
1"""
2WebSocket服务工厂
3为每个任务/会话创建独立的WebSocket服务实例
4支持任何类型的ID作为key(量化交易会话ID、数据导入任务ID、API测试ID等)
5"""
7import asyncio
8import json
9import logging
10from datetime import datetime
11from typing import Dict, Optional
13from fastapi import WebSocket
15logger = logging.getLogger(__name__)
18class TaskWebSocketService:
19 """任务专用的WebSocket服务"""
21 def __init__(self, task_id: str):
22 self.task_id = task_id
23 self.websocket: Optional[WebSocket] = None
24 self.is_connected = False
25 self.log_queue = asyncio.Queue()
26 self.log_sequence = 0
27 self._processing_task: Optional[asyncio.Task] = None
28 self._start_processing()
30 def _start_processing(self):
31 """启动日志处理任务"""
32 if self._processing_task is None or self._processing_task.done():
33 self._processing_task = asyncio.create_task(self._process_log_queue())
35 async def _process_log_queue(self):
36 """处理日志队列"""
37 while True:
38 try:
39 # 从队列获取日志消息
40 message, log_type = await asyncio.wait_for(
41 self.log_queue.get(), timeout=1.0
42 )
44 # 发送日志
45 await self._send_log_async(message, log_type)
47 except asyncio.TimeoutError:
48 continue
49 except Exception as e:
50 logger.error(f"处理日志队列失败: {e}")
51 await asyncio.sleep(1)
53 async def _send_log_async(self, message: str, log_type: str = "log"):
54 """异步发送日志"""
55 if not self.is_connected or not self.websocket:
56 return
58 try:
59 # 添加时间戳和序号
60 timestamp = datetime.now().isoformat(timespec="milliseconds")
61 log_message = f"[{timestamp}] {message}"
63 # 发送到WebSocket
64 await self.websocket.send_text(
65 json.dumps(
66 {
67 "type": log_type,
68 "message": log_message,
69 "timestamp": timestamp,
70 "task_id": self.task_id,
71 "sequence": self.log_sequence,
72 }
73 )
74 )
76 self.log_sequence += 1
78 except Exception as e:
79 logger.error(f"WebSocket发送日志失败: {e}")
80 self.is_connected = False
82 async def connect(self, websocket: WebSocket):
83 """建立WebSocket连接"""
84 await websocket.accept()
85 self.websocket = websocket
86 self.is_connected = True
87 logger.info(f"WebSocket连接已建立,任务ID: {self.task_id}")
89 async def disconnect(self):
90 """断开WebSocket连接"""
91 self.is_connected = False
92 if self.websocket:
93 try:
94 await self.websocket.close()
95 except Exception:
96 pass
97 self.websocket = None
99 # 停止处理任务
100 if self._processing_task and not self._processing_task.done():
101 self._processing_task.cancel()
102 try:
103 await self._processing_task
104 except asyncio.CancelledError:
105 pass
107 logger.info(f"WebSocket连接已断开,任务ID: {self.task_id}")
109 def send_log(self, message: str, log_type: str = "log"):
110 """发送日志消息(线程安全)"""
111 try:
112 # 将日志消息放入队列
113 # 检查是否有运行中的事件循环
114 try:
115 loop = asyncio.get_running_loop()
116 # 如果有运行中的事件循环,创建任务
117 asyncio.create_task(self.log_queue.put((message, log_type)))
118 except RuntimeError:
119 # 如果没有运行中的事件循环,创建新的
120 loop = asyncio.new_event_loop()
121 asyncio.set_event_loop(loop)
122 loop.run_until_complete(self.log_queue.put((message, log_type)))
123 loop.close()
124 except Exception as e:
125 logger.error(f"发送日志到队列失败: {e}")
127 def send_status(self, status: str, progress: int, data_count: int = 0):
128 """发送状态更新"""
129 status_message = f"状态: {status}, 进度: {progress}%, 数据量: {data_count}"
130 self.send_log(status_message, "status")
132 def send_error(self, error_message: str):
133 """发送错误消息"""
134 self.send_log(f"错误: {error_message}", "error")
136 def send_success(self, success_message: str):
137 """发送成功消息"""
138 self.send_log(f"成功: {success_message}", "success")
141class WebSocketServiceFactory:
142 """WebSocket服务工厂 - 支持任何类型的任务ID"""
144 def __init__(self):
145 self._services: Dict[str, TaskWebSocketService] = {}
146 self._lock = asyncio.Lock()
148 async def get_or_create_service(self, task_id: str) -> TaskWebSocketService:
149 """获取或创建WebSocket服务实例"""
150 async with self._lock:
151 if task_id not in self._services:
152 self._services[task_id] = TaskWebSocketService(task_id)
153 logger.info(f"创建新的WebSocket服务实例: {task_id}")
154 return self._services[task_id]
156 async def remove_service(self, task_id: str):
157 """移除WebSocket服务实例"""
158 async with self._lock:
159 if task_id in self._services:
160 service = self._services[task_id]
161 await service.disconnect()
162 del self._services[task_id]
163 logger.info(f"移除WebSocket服务实例: {task_id}")
165 def get_service(self, task_id: str) -> Optional[TaskWebSocketService]:
166 """获取WebSocket服务实例(不创建)"""
167 return self._services.get(task_id)
169 def has_service(self, task_id: str) -> bool:
170 """检查是否存在WebSocket服务实例"""
171 return task_id in self._services
173 async def cleanup_all(self):
174 """清理所有服务实例"""
175 async with self._lock:
176 for task_id in list(self._services.keys()):
177 await self.remove_service(task_id)
179 def get_active_tasks(self) -> list:
180 """获取所有活跃的任务ID"""
181 return list(self._services.keys())
183 def get_service_count(self) -> int:
184 """获取当前服务实例数量"""
185 return len(self._services)
188# 全局WebSocket服务工厂实例
189websocket_service_factory = WebSocketServiceFactory()
192# 便捷函数
193async def get_websocket_service(task_id: str) -> TaskWebSocketService:
194 """获取WebSocket服务实例"""
195 return await websocket_service_factory.get_or_create_service(task_id)
198def send_websocket_log(task_id: str, message: str, log_type: str = "log"):
199 """发送WebSocket日志的便捷函数"""
200 service = websocket_service_factory.get_service(task_id)
201 if service:
202 service.send_log(message, log_type)
203 else:
204 # 如果服务不存在,尝试创建(用于日志缓存)
205 try:
206 import asyncio
208 # 检查是否已经有运行中的事件循环
209 try:
210 loop = asyncio.get_running_loop()
211 # 如果有运行中的事件循环,使用线程池执行
212 import concurrent.futures
214 with concurrent.futures.ThreadPoolExecutor() as executor:
215 future = executor.submit(
216 _create_service_sync, task_id, message, log_type
217 )
218 future.result(timeout=1) # 1秒超时
219 except RuntimeError:
220 # 如果没有运行中的事件循环,创建新的
221 loop = asyncio.new_event_loop()
222 asyncio.set_event_loop(loop)
223 service = loop.run_until_complete(
224 websocket_service_factory.get_or_create_service(task_id)
225 )
226 service.send_log(message, log_type)
227 loop.close()
229 except Exception as e:
230 logger.warning(f"创建WebSocket服务失败: {task_id} - {e}")
231 # 如果创建失败,至少记录日志
232 logger.info(f"[{task_id}] {message}")
235def _create_service_sync(task_id: str, message: str, log_type: str):
236 """同步创建服务的辅助函数"""
237 import asyncio
239 loop = asyncio.new_event_loop()
240 asyncio.set_event_loop(loop)
241 try:
242 service = loop.run_until_complete(
243 websocket_service_factory.get_or_create_service(task_id)
244 )
245 service.send_log(message, log_type)
246 finally:
247 loop.close()
250async def websocket_endpoint(websocket: WebSocket, task_id: str):
251 """WebSocket端点 - 支持任何类型的任务ID"""
252 logger.info(f"WebSocket端点被调用: {task_id}")
254 # 获取或创建服务实例
255 service = await get_websocket_service(task_id)
257 try:
258 # 建立连接
259 await service.connect(websocket)
261 # 保持连接活跃
262 while True:
263 try:
264 # 等待客户端发送消息
265 await asyncio.wait_for(websocket.receive_text(), timeout=60.0)
266 except asyncio.TimeoutError:
267 # 超时后发送ping消息保持连接
268 try:
269 await websocket.send_text(
270 json.dumps(
271 {
272 "type": "ping",
273 "timestamp": datetime.now().isoformat(),
274 "task_id": task_id,
275 }
276 )
277 )
278 logger.debug(f"发送ping消息到 {task_id}")
279 except Exception as e:
280 logger.error(f"发送ping失败: {e}")
281 break
282 except Exception as e:
283 logger.error(f"WebSocket错误: {e}")
284 break
286 except Exception as e:
287 logger.error(f"WebSocket端点错误: {e}")
288 finally:
289 # 断开连接
290 await service.disconnect()