Coverage for core/data_source/factories/client_factory.py: 67.63%

139 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-13 18:58 +0000

1""" 

2统一数据源客户端工厂 

3负责所有券商客户端的统一管理和缓存 

4""" 

5 

6import threading 

7import time 

8from abc import ABC, abstractmethod 

9from typing import Any, Dict, Optional 

10 

11from core.data_source.factories.config_factory import (ConfigFactory, 

12 DataSourceType, 

13 UnifiedDataSourceConfig, 

14 unified_config_factory) 

15 

16 

17class DataSourceClientManager(ABC): 

18 """数据源客户端管理器抽象基类""" 

19 

20 @abstractmethod 

21 def create_client(self, config: Dict[str, Any]) -> Optional[Dict[str, Any]]: 

22 """创建客户端""" 

23 pass 

24 

25 @abstractmethod 

26 def validate_client(self, client: Dict[str, Any]) -> bool: 

27 """验证客户端是否有效""" 

28 pass 

29 

30 @abstractmethod 

31 def close_client(self, client: Dict[str, Any]): 

32 """关闭客户端""" 

33 pass 

34 

35 

36class LongPortClientManager(DataSourceClientManager): 

37 """长桥客户端管理器""" 

38 

39 def create_client(self, config: Dict[str, Any]) -> Optional[Dict[str, Any]]: 

40 """创建长桥客户端""" 

41 try: 

42 from longport.openapi import (Config, Language, QuoteContext, 

43 TradeContext) 

44 

45 # 转换语言设置 

46 language_str = config.get("language", "zh-CN") 

47 language_map = { 

48 "zh-CN": Language.ZH_CN, 

49 "zh-HK": Language.ZH_HK, 

50 "en": Language.EN, 

51 } 

52 language = language_map.get(language_str, Language.EN) 

53 

54 # 创建配置 

55 longport_config = Config( 

56 app_key=config["app_key"], 

57 app_secret=config["app_secret"], 

58 access_token=config["access_token"], 

59 language=language, 

60 enable_overnight=config.get("enable_overnight", False), 

61 ) 

62 

63 # 创建上下文 

64 quote_ctx = QuoteContext(longport_config) 

65 trade_ctx = TradeContext(longport_config) 

66 

67 return { 

68 "quote_ctx": quote_ctx, 

69 "trade_ctx": trade_ctx, 

70 "config": longport_config, 

71 "created_at": time.time(), 

72 "type": "longport", 

73 } 

74 

75 except Exception as e: 

76 print(f"❌ 创建长桥客户端失败: {e}") 

77 return None 

78 

79 def validate_client(self, client: Dict[str, Any]) -> bool: 

80 """验证长桥客户端是否有效""" 

81 try: 

82 # 尝试简单的健康检查 

83 quote_ctx = client.get("quote_ctx") 

84 if quote_ctx: 

85 # 这里可以添加更复杂的验证逻辑 

86 return True 

87 return False 

88 except Exception: 

89 return False 

90 

91 def close_client(self, client: Dict[str, Any]): 

92 """关闭长桥客户端""" 

93 # 长桥SDK会自动管理连接 

94 print(f"🔌 长桥客户端已关闭") 

95 

96 

97class FutuClientManager(DataSourceClientManager): 

98 """富途客户端管理器(预留)""" 

99 

100 def create_client(self, config: Dict[str, Any]) -> Optional[Dict[str, Any]]: 

101 print("🚧 富途客户端创建功能待实现") 

102 return None 

103 

104 def validate_client(self, client: Dict[str, Any]) -> bool: 

105 return False 

106 

107 def close_client(self, client: Dict[str, Any]): 

108 pass 

109 

110 

111class TigerClientManager(DataSourceClientManager): 

112 """老虎客户端管理器(预留)""" 

113 

114 def create_client(self, config: Dict[str, Any]) -> Optional[Dict[str, Any]]: 

115 print("🚧 老虎客户端创建功能待实现") 

116 return None 

117 

118 def validate_client(self, client: Dict[str, Any]) -> bool: 

119 return False 

120 

121 def close_client(self, client: Dict[str, Any]): 

122 pass 

123 

124 

125class ClientFactory: 

126 """统一数据源客户端工厂 - 单例模式""" 

127 

128 _instance = None 

129 _lock = threading.Lock() 

130 

131 def __new__(cls): 

132 if cls._instance is None: 

133 with cls._lock: 

134 if cls._instance is None: 

135 cls._instance = super().__new__(cls) 

136 cls._instance._initialized = False 

137 return cls._instance 

138 

139 def __init__(self): 

140 if not self._initialized: 

141 self._config_factory = unified_config_factory 

142 

143 # 客户端管理器:DataSourceType -> DataSourceClientManager 

144 self._client_managers: Dict[DataSourceType, DataSourceClientManager] = { 

145 DataSourceType.LONGPORT: LongPortClientManager(), 

146 DataSourceType.FUTU: FutuClientManager(), 

147 DataSourceType.TIGER: TigerClientManager(), 

148 } 

149 

150 # 客户端缓存:config_hash -> client 

151 self._client_cache: Dict[str, Dict[str, Any]] = {} 

152 

153 # 缓存元数据:config_hash -> {'last_used': timestamp, 'user_id': str} 

154 self._cache_metadata: Dict[str, Dict[str, Any]] = {} 

155 

156 self._max_clients = 100 # 最大客户端数量 

157 self._cache_ttl = 1800 # 30分钟缓存有效期 

158 

159 self._initialized = True 

160 

161 def get_client( 

162 self, user_id: str, force_refresh: bool = False 

163 ) -> Optional[Dict[str, Any]]: 

164 """ 

165 获取数据源客户端 

166 

167 Args: 

168 user_id: 用户ID 

169 force_refresh: 是否强制刷新 

170 

171 Returns: 

172 客户端实例,失败返回 None 

173 """ 

174 # 获取用户配置 

175 config = self._config_factory.get_data_source_config(user_id, force_refresh) 

176 if not config: 

177 print(f"❌ 无法获取用户 {user_id} 的数据源配置") 

178 return None 

179 

180 config_hash = config.config_hash 

181 current_time = time.time() 

182 

183 # 检查客户端缓存 

184 if not force_refresh and config_hash in self._client_cache: 

185 cached_client = self._client_cache[config_hash] 

186 cached_meta = self._cache_metadata[config_hash] 

187 

188 # 检查缓存是否过期 

189 if current_time - cached_meta["last_used"] < self._cache_ttl: 

190 # 验证客户端是否有效 

191 manager = self._client_managers.get(config.type) 

192 if manager and manager.validate_client(cached_client): 

193 cached_meta["last_used"] = current_time 

194 print(f"♻️ 复用数据源客户端,用户: {user_id}, 类型: {config.type}") 

195 return cached_client 

196 else: 

197 # 客户端无效,移除缓存 

198 self._remove_client_from_cache(config_hash) 

199 print(f"⚠️ 缓存客户端无效,用户: {user_id}") 

200 

201 # 创建新客户端 

202 manager = self._client_managers.get(config.type) 

203 if not manager: 

204 print(f"❌ 不支持的数据源类型: {config.type}") 

205 return None 

206 

207 # 强制执行缓存限制 

208 self._enforce_cache_limit() 

209 

210 # 创建新客户端 

211 print(f"🆕 创建新的数据源客户端,用户: {user_id}, 类型: {config.type}") 

212 client = manager.create_client(config.config) 

213 

214 if client: 

215 # 缓存客户端 

216 client["user_id"] = user_id 

217 client["config_hash"] = config_hash 

218 

219 self._client_cache[config_hash] = client 

220 self._cache_metadata[config_hash] = { 

221 "last_used": current_time, 

222 "user_id": user_id, 

223 } 

224 

225 print(f"✅ 数据源客户端创建成功并缓存,类型: {config.type}") 

226 return client 

227 

228 print(f"❌ 数据源客户端创建失败,用户: {user_id}") 

229 return None 

230 

231 def _remove_client_from_cache(self, config_hash: str): 

232 """从缓存中移除客户端""" 

233 if config_hash in self._client_cache: 

234 client = self._client_cache[config_hash] 

235 

236 # 如果客户端有类型信息,关闭客户端 

237 client_type = client.get("type") 

238 if client_type == "longport": 

239 manager = self._client_managers[DataSourceType.LONGPORT] 

240 manager.close_client(client) 

241 

242 del self._client_cache[config_hash] 

243 if config_hash in self._cache_metadata: 

244 del self._cache_metadata[config_hash] 

245 

246 def _enforce_cache_limit(self): 

247 """强制执行缓存限制(LRU淘汰)""" 

248 if len(self._client_cache) >= self._max_clients: 

249 # 找到最久未使用的客户端 

250 oldest_hash = min( 

251 self._cache_metadata.keys(), 

252 key=lambda k: self._cache_metadata[k]["last_used"], 

253 ) 

254 

255 user_id = self._cache_metadata[oldest_hash]["user_id"] 

256 print(f"🗑️ LRU淘汰数据源客户端,用户: {user_id}") 

257 

258 self._remove_client_from_cache(oldest_hash) 

259 

260 def refresh_client(self, user_id: str) -> Optional[Dict[str, Any]]: 

261 """强制刷新客户端""" 

262 return self.get_client(user_id, force_refresh=True) 

263 

264 def clear_cache(self, user_id: Optional[str] = None): 

265 """清理客户端缓存""" 

266 if user_id: 

267 # 清理特定用户的缓存 

268 keys_to_remove = [ 

269 hash_key 

270 for hash_key, meta in self._cache_metadata.items() 

271 if meta.get("user_id") == user_id 

272 ] 

273 

274 for key in keys_to_remove: 

275 self._remove_client_from_cache(key) 

276 

277 if keys_to_remove: 

278 print(f"🗑️ 已清理用户 {user_id} 的客户端缓存") 

279 else: 

280 # 清理所有缓存 

281 for config_hash in list(self._client_cache.keys()): 

282 self._remove_client_from_cache(config_hash) 

283 print("🗑️ 已清理所有客户端缓存") 

284 

285 def get_cache_info(self) -> Dict[str, Any]: 

286 """获取缓存信息(用于调试)""" 

287 return { 

288 "cached_clients": len(self._client_cache), 

289 "max_clients": self._max_clients, 

290 "cache_usage_percent": (len(self._client_cache) / self._max_clients) * 100, 

291 "supported_types": [t.value for t in self._client_managers.keys()], 

292 } 

293 

294 

295# 全局单例实例 

296unified_client_factory = ClientFactory()