Coverage for core/services/websocket_service_factory.py: 45.34%

161 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-13 18:58 +0000

1""" 

2WebSocket服务工厂 

3为每个任务/会话创建独立的WebSocket服务实例 

4支持任何类型的ID作为key(量化交易会话ID、数据导入任务ID、API测试ID等) 

5""" 

6 

7import asyncio 

8import json 

9import logging 

10from datetime import datetime 

11from typing import Dict, Optional 

12 

13from fastapi import WebSocket 

14 

15logger = logging.getLogger(__name__) 

16 

17 

18class TaskWebSocketService: 

19 """任务专用的WebSocket服务""" 

20 

21 def __init__(self, task_id: str): 

22 self.task_id = task_id 

23 self.websocket: Optional[WebSocket] = None 

24 self.is_connected = False 

25 self.log_queue = asyncio.Queue() 

26 self.log_sequence = 0 

27 self._processing_task: Optional[asyncio.Task] = None 

28 self._start_processing() 

29 

30 def _start_processing(self): 

31 """启动日志处理任务""" 

32 if self._processing_task is None or self._processing_task.done(): 

33 self._processing_task = asyncio.create_task(self._process_log_queue()) 

34 

35 async def _process_log_queue(self): 

36 """处理日志队列""" 

37 while True: 

38 try: 

39 # 从队列获取日志消息 

40 message, log_type = await asyncio.wait_for( 

41 self.log_queue.get(), timeout=1.0 

42 ) 

43 

44 # 发送日志 

45 await self._send_log_async(message, log_type) 

46 

47 except asyncio.TimeoutError: 

48 continue 

49 except Exception as e: 

50 logger.error(f"处理日志队列失败: {e}") 

51 await asyncio.sleep(1) 

52 

53 async def _send_log_async(self, message: str, log_type: str = "log"): 

54 """异步发送日志""" 

55 if not self.is_connected or not self.websocket: 

56 return 

57 

58 try: 

59 # 添加时间戳和序号 

60 timestamp = datetime.now().isoformat(timespec="milliseconds") 

61 log_message = f"[{timestamp}] {message}" 

62 

63 # 发送到WebSocket 

64 await self.websocket.send_text( 

65 json.dumps( 

66 { 

67 "type": log_type, 

68 "message": log_message, 

69 "timestamp": timestamp, 

70 "task_id": self.task_id, 

71 "sequence": self.log_sequence, 

72 } 

73 ) 

74 ) 

75 

76 self.log_sequence += 1 

77 

78 except Exception as e: 

79 logger.error(f"WebSocket发送日志失败: {e}") 

80 self.is_connected = False 

81 

82 async def connect(self, websocket: WebSocket): 

83 """建立WebSocket连接""" 

84 await websocket.accept() 

85 self.websocket = websocket 

86 self.is_connected = True 

87 logger.info(f"WebSocket连接已建立,任务ID: {self.task_id}") 

88 

89 async def disconnect(self): 

90 """断开WebSocket连接""" 

91 self.is_connected = False 

92 if self.websocket: 

93 try: 

94 await self.websocket.close() 

95 except Exception: 

96 pass 

97 self.websocket = None 

98 

99 # 停止处理任务 

100 if self._processing_task and not self._processing_task.done(): 

101 self._processing_task.cancel() 

102 try: 

103 await self._processing_task 

104 except asyncio.CancelledError: 

105 pass 

106 

107 logger.info(f"WebSocket连接已断开,任务ID: {self.task_id}") 

108 

109 def send_log(self, message: str, log_type: str = "log"): 

110 """发送日志消息(线程安全)""" 

111 try: 

112 # 将日志消息放入队列 

113 # 检查是否有运行中的事件循环 

114 try: 

115 loop = asyncio.get_running_loop() 

116 # 如果有运行中的事件循环,创建任务 

117 asyncio.create_task(self.log_queue.put((message, log_type))) 

118 except RuntimeError: 

119 # 如果没有运行中的事件循环,创建新的 

120 loop = asyncio.new_event_loop() 

121 asyncio.set_event_loop(loop) 

122 loop.run_until_complete(self.log_queue.put((message, log_type))) 

123 loop.close() 

124 except Exception as e: 

125 logger.error(f"发送日志到队列失败: {e}") 

126 

127 def send_status(self, status: str, progress: int, data_count: int = 0): 

128 """发送状态更新""" 

129 status_message = f"状态: {status}, 进度: {progress}%, 数据量: {data_count}" 

130 self.send_log(status_message, "status") 

131 

132 def send_error(self, error_message: str): 

133 """发送错误消息""" 

134 self.send_log(f"错误: {error_message}", "error") 

135 

136 def send_success(self, success_message: str): 

137 """发送成功消息""" 

138 self.send_log(f"成功: {success_message}", "success") 

139 

140 

141class WebSocketServiceFactory: 

142 """WebSocket服务工厂 - 支持任何类型的任务ID""" 

143 

144 def __init__(self): 

145 self._services: Dict[str, TaskWebSocketService] = {} 

146 self._lock = asyncio.Lock() 

147 

148 async def get_or_create_service(self, task_id: str) -> TaskWebSocketService: 

149 """获取或创建WebSocket服务实例""" 

150 async with self._lock: 

151 if task_id not in self._services: 

152 self._services[task_id] = TaskWebSocketService(task_id) 

153 logger.info(f"创建新的WebSocket服务实例: {task_id}") 

154 return self._services[task_id] 

155 

156 async def remove_service(self, task_id: str): 

157 """移除WebSocket服务实例""" 

158 async with self._lock: 

159 if task_id in self._services: 

160 service = self._services[task_id] 

161 await service.disconnect() 

162 del self._services[task_id] 

163 logger.info(f"移除WebSocket服务实例: {task_id}") 

164 

165 def get_service(self, task_id: str) -> Optional[TaskWebSocketService]: 

166 """获取WebSocket服务实例(不创建)""" 

167 return self._services.get(task_id) 

168 

169 def has_service(self, task_id: str) -> bool: 

170 """检查是否存在WebSocket服务实例""" 

171 return task_id in self._services 

172 

173 async def cleanup_all(self): 

174 """清理所有服务实例""" 

175 async with self._lock: 

176 for task_id in list(self._services.keys()): 

177 await self.remove_service(task_id) 

178 

179 def get_active_tasks(self) -> list: 

180 """获取所有活跃的任务ID""" 

181 return list(self._services.keys()) 

182 

183 def get_service_count(self) -> int: 

184 """获取当前服务实例数量""" 

185 return len(self._services) 

186 

187 

188# 全局WebSocket服务工厂实例 

189websocket_service_factory = WebSocketServiceFactory() 

190 

191 

192# 便捷函数 

193async def get_websocket_service(task_id: str) -> TaskWebSocketService: 

194 """获取WebSocket服务实例""" 

195 return await websocket_service_factory.get_or_create_service(task_id) 

196 

197 

198def send_websocket_log(task_id: str, message: str, log_type: str = "log"): 

199 """发送WebSocket日志的便捷函数""" 

200 service = websocket_service_factory.get_service(task_id) 

201 if service: 

202 service.send_log(message, log_type) 

203 else: 

204 # 如果服务不存在,尝试创建(用于日志缓存) 

205 try: 

206 import asyncio 

207 

208 # 检查是否已经有运行中的事件循环 

209 try: 

210 loop = asyncio.get_running_loop() 

211 # 如果有运行中的事件循环,使用线程池执行 

212 import concurrent.futures 

213 

214 with concurrent.futures.ThreadPoolExecutor() as executor: 

215 future = executor.submit( 

216 _create_service_sync, task_id, message, log_type 

217 ) 

218 future.result(timeout=1) # 1秒超时 

219 except RuntimeError: 

220 # 如果没有运行中的事件循环,创建新的 

221 loop = asyncio.new_event_loop() 

222 asyncio.set_event_loop(loop) 

223 service = loop.run_until_complete( 

224 websocket_service_factory.get_or_create_service(task_id) 

225 ) 

226 service.send_log(message, log_type) 

227 loop.close() 

228 

229 except Exception as e: 

230 logger.warning(f"创建WebSocket服务失败: {task_id} - {e}") 

231 # 如果创建失败,至少记录日志 

232 logger.info(f"[{task_id}] {message}") 

233 

234 

235def _create_service_sync(task_id: str, message: str, log_type: str): 

236 """同步创建服务的辅助函数""" 

237 import asyncio 

238 

239 loop = asyncio.new_event_loop() 

240 asyncio.set_event_loop(loop) 

241 try: 

242 service = loop.run_until_complete( 

243 websocket_service_factory.get_or_create_service(task_id) 

244 ) 

245 service.send_log(message, log_type) 

246 finally: 

247 loop.close() 

248 

249 

250async def websocket_endpoint(websocket: WebSocket, task_id: str): 

251 """WebSocket端点 - 支持任何类型的任务ID""" 

252 logger.info(f"WebSocket端点被调用: {task_id}") 

253 

254 # 获取或创建服务实例 

255 service = await get_websocket_service(task_id) 

256 

257 try: 

258 # 建立连接 

259 await service.connect(websocket) 

260 

261 # 保持连接活跃 

262 while True: 

263 try: 

264 # 等待客户端发送消息 

265 await asyncio.wait_for(websocket.receive_text(), timeout=60.0) 

266 except asyncio.TimeoutError: 

267 # 超时后发送ping消息保持连接 

268 try: 

269 await websocket.send_text( 

270 json.dumps( 

271 { 

272 "type": "ping", 

273 "timestamp": datetime.now().isoformat(), 

274 "task_id": task_id, 

275 } 

276 ) 

277 ) 

278 logger.debug(f"发送ping消息到 {task_id}") 

279 except Exception as e: 

280 logger.error(f"发送ping失败: {e}") 

281 break 

282 except Exception as e: 

283 logger.error(f"WebSocket错误: {e}") 

284 break 

285 

286 except Exception as e: 

287 logger.error(f"WebSocket端点错误: {e}") 

288 finally: 

289 # 断开连接 

290 await service.disconnect()