Coverage for core/repositories/broker_repository.py: 18.24%

170 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-13 18:58 +0000

1""" 

2券商数据仓库 

3""" 

4 

5import uuid 

6from datetime import datetime 

7from typing import List, Optional 

8 

9from core.models.broker import (Broker, BrokerCreate, BrokerUpdate, DataSource, 

10 DataSourceCreate, DataSourceUpdate, FeeConfig) 

11from infrastructure.database.redis_client import get_redis 

12 

13 

14class BrokerRepository: 

15 """券商仓库""" 

16 

17 def __init__(self): 

18 self.redis = get_redis() 

19 self.brokers_key = "brokers" 

20 self.data_sources_key = "data_sources" 

21 

22 def _get_broker_key(self, broker_id: str) -> str: 

23 """获取券商键""" 

24 return f"broker:{broker_id}" 

25 

26 def _get_data_source_key(self, data_source_id: str) -> str: 

27 """获取数据源键""" 

28 return f"data_source:{data_source_id}" 

29 

30 def _get_broker_by_code_key(self, code: str) -> str: 

31 """获取券商代码索引键""" 

32 return f"broker_by_code:{code}" 

33 

34 # 券商管理方法 

35 def create_broker(self, broker_data: BrokerCreate) -> Broker: 

36 """创建券商""" 

37 broker_id = str(uuid.uuid4()) 

38 now = datetime.utcnow() 

39 

40 broker = Broker( 

41 id=broker_id, 

42 name=broker_data.name, 

43 code=broker_data.code, 

44 type=broker_data.type, 

45 status=broker_data.status, 

46 config=broker_data.config, 

47 created_at=now, 

48 updated_at=now, 

49 ) 

50 

51 # 存储券商数据 

52 broker_key = self._get_broker_key(broker_id) 

53 broker_dict = broker.model_dump() 

54 # 转换datetime为字符串 

55 broker_dict["created_at"] = broker_dict["created_at"].isoformat() 

56 broker_dict["updated_at"] = broker_dict["updated_at"].isoformat() 

57 # 转换config为字典 

58 broker_dict["config"] = broker.config.model_dump() 

59 

60 self.redis.setex(broker_key, 86400 * 30, broker_dict) # 30天过期 

61 

62 # 添加到券商集合 

63 self.redis.sadd(self.brokers_key, broker_id) 

64 

65 # 创建券商代码索引 

66 code_key = self._get_broker_by_code_key(broker_data.code) 

67 self.redis.setex(code_key, 86400 * 30, broker_id) 

68 

69 return broker 

70 

71 def get_broker_by_id(self, broker_id: str) -> Optional[Broker]: 

72 """根据ID获取券商""" 

73 broker_key = self._get_broker_key(broker_id) 

74 data = self.redis.get(broker_key) 

75 if data and isinstance(data, dict): 

76 # 转换字符串为datetime 

77 data["created_at"] = datetime.fromisoformat(data["created_at"]) 

78 data["updated_at"] = datetime.fromisoformat(data["updated_at"]) 

79 return Broker(**data) 

80 return None 

81 

82 def get_broker_by_code(self, code: str) -> Optional[Broker]: 

83 """根据代码获取券商""" 

84 code_key = self._get_broker_by_code_key(code) 

85 broker_id = self.redis.get(code_key) 

86 if broker_id: 

87 return self.get_broker_by_id(broker_id) 

88 return None 

89 

90 def get_all_brokers(self) -> List[Broker]: 

91 """获取所有券商""" 

92 brokers = [] 

93 broker_ids = self.redis.smembers(self.brokers_key) 

94 for broker_id in broker_ids: 

95 broker = self.get_broker_by_id(broker_id) 

96 if broker: 

97 brokers.append(broker) 

98 return brokers 

99 

100 def get_brokers_by_user(self, user_id: str) -> List[Broker]: 

101 """根据用户ID获取券商列表""" 

102 brokers = self.get_all_brokers() 

103 return [broker for broker in brokers if broker.user_id == user_id] 

104 

105 def update_broker( 

106 self, broker_id: str, broker_data: BrokerUpdate 

107 ) -> Optional[Broker]: 

108 """更新券商""" 

109 broker = self.get_broker_by_id(broker_id) 

110 if not broker: 

111 return None 

112 

113 # 更新字段 

114 update_data = broker_data.model_dump(exclude_unset=True) 

115 for field, value in update_data.items(): 

116 setattr(broker, field, value) 

117 

118 broker.updated_at = datetime.utcnow() 

119 

120 # 保存更新后的券商 

121 broker_key = self._get_broker_key(broker_id) 

122 broker_dict = broker.model_dump() 

123 # 转换datetime为字符串 

124 broker_dict["created_at"] = broker_dict["created_at"].isoformat() 

125 broker_dict["updated_at"] = broker_dict["updated_at"].isoformat() 

126 # 转换config为字典 

127 broker_dict["config"] = broker.config.model_dump() 

128 

129 self.redis.setex(broker_key, 86400 * 30, broker_dict) 

130 

131 return broker 

132 

133 def delete_broker(self, broker_id: str) -> bool: 

134 """删除券商""" 

135 broker = self.get_broker_by_id(broker_id) 

136 if not broker: 

137 return False 

138 

139 # 删除券商数据 

140 broker_key = self._get_broker_key(broker_id) 

141 self.redis.delete(broker_key) 

142 

143 # 从券商集合中移除 

144 self.redis.srem(self.brokers_key, broker_id) 

145 

146 # 删除券商代码索引 

147 code_key = self._get_broker_by_code_key(broker.code) 

148 self.redis.delete(code_key) 

149 

150 # 删除相关数据源 

151 self._delete_broker_data_sources(broker_id) 

152 

153 return True 

154 

155 def _delete_broker_data_sources(self, broker_id: str): 

156 """删除券商的所有数据源""" 

157 data_sources = self.get_data_sources_by_broker(broker_id) 

158 for data_source in data_sources: 

159 self.delete_data_source(data_source.id) 

160 

161 # 数据源管理方法 

162 def create_data_source(self, data_source_data: DataSourceCreate) -> DataSource: 

163 """创建数据源""" 

164 data_source_id = str(uuid.uuid4()) 

165 now = datetime.utcnow() 

166 

167 data_source = DataSource( 

168 id=data_source_id, 

169 name=data_source_data.name, 

170 type=data_source_data.type, 

171 broker_id=data_source_data.broker_id, 

172 status=data_source_data.status, 

173 priority=data_source_data.priority, 

174 created_at=now, 

175 updated_at=now, 

176 ) 

177 

178 # 存储数据源数据 

179 data_source_key = self._get_data_source_key(data_source_id) 

180 data_source_dict = data_source.model_dump() 

181 # 转换datetime为字符串 

182 data_source_dict["created_at"] = data_source_dict["created_at"].isoformat() 

183 data_source_dict["updated_at"] = data_source_dict["updated_at"].isoformat() 

184 

185 self.redis.setex(data_source_key, 86400 * 30, data_source_dict) # 30天过期 

186 

187 # 添加到数据源集合 

188 self.redis.sadd(self.data_sources_key, data_source_id) 

189 

190 return data_source 

191 

192 def get_data_source_by_id(self, data_source_id: str) -> Optional[DataSource]: 

193 """根据ID获取数据源""" 

194 data_source_key = self._get_data_source_key(data_source_id) 

195 data = self.redis.get(data_source_key) 

196 if data and isinstance(data, dict): 

197 # 转换字符串为datetime 

198 data["created_at"] = datetime.fromisoformat(data["created_at"]) 

199 data["updated_at"] = datetime.fromisoformat(data["updated_at"]) 

200 return DataSource(**data) 

201 return None 

202 

203 def get_all_data_sources(self) -> List[DataSource]: 

204 """获取所有数据源""" 

205 data_sources = [] 

206 data_source_ids = self.redis.smembers(self.data_sources_key) 

207 for data_source_id in data_source_ids: 

208 data_source = self.get_data_source_by_id(data_source_id) 

209 if data_source: 

210 data_sources.append(data_source) 

211 return data_sources 

212 

213 def get_data_sources_by_broker(self, broker_id: str) -> List[DataSource]: 

214 """根据券商ID获取数据源""" 

215 data_sources = self.get_all_data_sources() 

216 return [ds for ds in data_sources if ds.broker_id == broker_id] 

217 

218 def update_data_source( 

219 self, data_source_id: str, data_source_data: DataSourceUpdate 

220 ) -> Optional[DataSource]: 

221 """更新数据源""" 

222 data_source = self.get_data_source_by_id(data_source_id) 

223 if not data_source: 

224 return None 

225 

226 # 更新字段 

227 update_data = data_source_data.model_dump(exclude_unset=True) 

228 for field, value in update_data.items(): 

229 setattr(data_source, field, value) 

230 

231 data_source.updated_at = datetime.utcnow() 

232 

233 # 保存更新后的数据源 

234 data_source_key = self._get_data_source_key(data_source_id) 

235 data_source_dict = data_source.model_dump() 

236 # 转换datetime为字符串 

237 data_source_dict["created_at"] = data_source_dict["created_at"].isoformat() 

238 data_source_dict["updated_at"] = data_source_dict["updated_at"].isoformat() 

239 

240 self.redis.setex(data_source_key, 86400 * 30, data_source_dict) 

241 

242 return data_source 

243 

244 def delete_data_source(self, data_source_id: str) -> bool: 

245 """删除数据源""" 

246 data_source = self.get_data_source_by_id(data_source_id) 

247 if not data_source: 

248 return False 

249 

250 # 删除数据源数据 

251 data_source_key = self._get_data_source_key(data_source_id) 

252 self.redis.delete(data_source_key) 

253 

254 # 从数据源集合中移除 

255 self.redis.srem(self.data_sources_key, data_source_id) 

256 

257 return True 

258 

259 # 费用配置管理方法 

260 def get_fee_config(self, broker_id: str) -> Optional[FeeConfig]: 

261 """ 

262 获取券商的费用配置 

263 

264 Args: 

265 broker_id: 券商ID 

266 

267 Returns: 

268 FeeConfig: 费用配置对象,如果不存在返回None 

269 """ 

270 broker = self.get_broker_by_id(broker_id) 

271 if not broker or not broker.config: 

272 return None 

273 

274 # 从broker.config中获取fee_config 

275 if hasattr(broker.config, "fee_config") and broker.config.fee_config: 

276 return broker.config.fee_config 

277 

278 # 如果没有费用配置,返回默认配置 

279 return FeeConfig() 

280 

281 def update_fee_config( 

282 self, broker_id: str, fee_config: FeeConfig 

283 ) -> Optional[Broker]: 

284 """ 

285 更新券商的费用配置 

286 

287 Args: 

288 broker_id: 券商ID 

289 fee_config: 费用配置对象 

290 

291 Returns: 

292 Broker: 更新后的券商对象,如果不存在返回None 

293 """ 

294 broker = self.get_broker_by_id(broker_id) 

295 if not broker: 

296 return None 

297 

298 # 更新费用配置 

299 broker.config.fee_config = fee_config 

300 broker.updated_at = datetime.utcnow() 

301 

302 # 保存更新后的券商 

303 broker_key = self._get_broker_key(broker_id) 

304 broker_dict = broker.model_dump() 

305 # 转换datetime为字符串 

306 broker_dict["created_at"] = broker_dict["created_at"].isoformat() 

307 broker_dict["updated_at"] = broker_dict["updated_at"].isoformat() 

308 # 转换config为字典 

309 broker_dict["config"] = broker.config.model_dump() 

310 

311 self.redis.setex(broker_key, 86400 * 30, broker_dict) 

312 

313 # 缓存费用配置(用于快速查询) 

314 fee_config_key = f"broker:fee_config:{broker_id}" 

315 fee_config_dict = fee_config.model_dump() 

316 self.redis.setex(fee_config_key, 3600, fee_config_dict) # 1小时过期 

317 

318 return broker 

319 

320 def cache_fee_config_for_session(self, session_id: str, fee_config: FeeConfig): 

321 """ 

322 缓存交易会话的费用配置 

323 

324 Args: 

325 session_id: 交易会话ID 

326 fee_config: 费用配置对象 

327 """ 

328 session_fee_key = f"trading:session:fee_config:{session_id}" 

329 fee_config_dict = fee_config.model_dump() 

330 self.redis.setex(session_fee_key, 86400, fee_config_dict) # 24小时过期 

331 

332 def get_fee_config_from_session_cache(self, session_id: str) -> Optional[FeeConfig]: 

333 """ 

334 从缓存中获取交易会话的费用配置 

335 

336 Args: 

337 session_id: 交易会话ID 

338 

339 Returns: 

340 FeeConfig: 费用配置对象,如果不存在返回None 

341 """ 

342 session_fee_key = f"trading:session:fee_config:{session_id}" 

343 data = self.redis.get(session_fee_key) 

344 if data and isinstance(data, dict): 

345 return FeeConfig(**data) 

346 return None