Coverage for core/repositories/broker_repository.py: 18.24%
170 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-13 18:58 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-13 18:58 +0000
1"""
2券商数据仓库
3"""
5import uuid
6from datetime import datetime
7from typing import List, Optional
9from core.models.broker import (Broker, BrokerCreate, BrokerUpdate, DataSource,
10 DataSourceCreate, DataSourceUpdate, FeeConfig)
11from infrastructure.database.redis_client import get_redis
14class BrokerRepository:
15 """券商仓库"""
17 def __init__(self):
18 self.redis = get_redis()
19 self.brokers_key = "brokers"
20 self.data_sources_key = "data_sources"
22 def _get_broker_key(self, broker_id: str) -> str:
23 """获取券商键"""
24 return f"broker:{broker_id}"
26 def _get_data_source_key(self, data_source_id: str) -> str:
27 """获取数据源键"""
28 return f"data_source:{data_source_id}"
30 def _get_broker_by_code_key(self, code: str) -> str:
31 """获取券商代码索引键"""
32 return f"broker_by_code:{code}"
34 # 券商管理方法
35 def create_broker(self, broker_data: BrokerCreate) -> Broker:
36 """创建券商"""
37 broker_id = str(uuid.uuid4())
38 now = datetime.utcnow()
40 broker = Broker(
41 id=broker_id,
42 name=broker_data.name,
43 code=broker_data.code,
44 type=broker_data.type,
45 status=broker_data.status,
46 config=broker_data.config,
47 created_at=now,
48 updated_at=now,
49 )
51 # 存储券商数据
52 broker_key = self._get_broker_key(broker_id)
53 broker_dict = broker.model_dump()
54 # 转换datetime为字符串
55 broker_dict["created_at"] = broker_dict["created_at"].isoformat()
56 broker_dict["updated_at"] = broker_dict["updated_at"].isoformat()
57 # 转换config为字典
58 broker_dict["config"] = broker.config.model_dump()
60 self.redis.setex(broker_key, 86400 * 30, broker_dict) # 30天过期
62 # 添加到券商集合
63 self.redis.sadd(self.brokers_key, broker_id)
65 # 创建券商代码索引
66 code_key = self._get_broker_by_code_key(broker_data.code)
67 self.redis.setex(code_key, 86400 * 30, broker_id)
69 return broker
71 def get_broker_by_id(self, broker_id: str) -> Optional[Broker]:
72 """根据ID获取券商"""
73 broker_key = self._get_broker_key(broker_id)
74 data = self.redis.get(broker_key)
75 if data and isinstance(data, dict):
76 # 转换字符串为datetime
77 data["created_at"] = datetime.fromisoformat(data["created_at"])
78 data["updated_at"] = datetime.fromisoformat(data["updated_at"])
79 return Broker(**data)
80 return None
82 def get_broker_by_code(self, code: str) -> Optional[Broker]:
83 """根据代码获取券商"""
84 code_key = self._get_broker_by_code_key(code)
85 broker_id = self.redis.get(code_key)
86 if broker_id:
87 return self.get_broker_by_id(broker_id)
88 return None
90 def get_all_brokers(self) -> List[Broker]:
91 """获取所有券商"""
92 brokers = []
93 broker_ids = self.redis.smembers(self.brokers_key)
94 for broker_id in broker_ids:
95 broker = self.get_broker_by_id(broker_id)
96 if broker:
97 brokers.append(broker)
98 return brokers
100 def get_brokers_by_user(self, user_id: str) -> List[Broker]:
101 """根据用户ID获取券商列表"""
102 brokers = self.get_all_brokers()
103 return [broker for broker in brokers if broker.user_id == user_id]
105 def update_broker(
106 self, broker_id: str, broker_data: BrokerUpdate
107 ) -> Optional[Broker]:
108 """更新券商"""
109 broker = self.get_broker_by_id(broker_id)
110 if not broker:
111 return None
113 # 更新字段
114 update_data = broker_data.model_dump(exclude_unset=True)
115 for field, value in update_data.items():
116 setattr(broker, field, value)
118 broker.updated_at = datetime.utcnow()
120 # 保存更新后的券商
121 broker_key = self._get_broker_key(broker_id)
122 broker_dict = broker.model_dump()
123 # 转换datetime为字符串
124 broker_dict["created_at"] = broker_dict["created_at"].isoformat()
125 broker_dict["updated_at"] = broker_dict["updated_at"].isoformat()
126 # 转换config为字典
127 broker_dict["config"] = broker.config.model_dump()
129 self.redis.setex(broker_key, 86400 * 30, broker_dict)
131 return broker
133 def delete_broker(self, broker_id: str) -> bool:
134 """删除券商"""
135 broker = self.get_broker_by_id(broker_id)
136 if not broker:
137 return False
139 # 删除券商数据
140 broker_key = self._get_broker_key(broker_id)
141 self.redis.delete(broker_key)
143 # 从券商集合中移除
144 self.redis.srem(self.brokers_key, broker_id)
146 # 删除券商代码索引
147 code_key = self._get_broker_by_code_key(broker.code)
148 self.redis.delete(code_key)
150 # 删除相关数据源
151 self._delete_broker_data_sources(broker_id)
153 return True
155 def _delete_broker_data_sources(self, broker_id: str):
156 """删除券商的所有数据源"""
157 data_sources = self.get_data_sources_by_broker(broker_id)
158 for data_source in data_sources:
159 self.delete_data_source(data_source.id)
161 # 数据源管理方法
162 def create_data_source(self, data_source_data: DataSourceCreate) -> DataSource:
163 """创建数据源"""
164 data_source_id = str(uuid.uuid4())
165 now = datetime.utcnow()
167 data_source = DataSource(
168 id=data_source_id,
169 name=data_source_data.name,
170 type=data_source_data.type,
171 broker_id=data_source_data.broker_id,
172 status=data_source_data.status,
173 priority=data_source_data.priority,
174 created_at=now,
175 updated_at=now,
176 )
178 # 存储数据源数据
179 data_source_key = self._get_data_source_key(data_source_id)
180 data_source_dict = data_source.model_dump()
181 # 转换datetime为字符串
182 data_source_dict["created_at"] = data_source_dict["created_at"].isoformat()
183 data_source_dict["updated_at"] = data_source_dict["updated_at"].isoformat()
185 self.redis.setex(data_source_key, 86400 * 30, data_source_dict) # 30天过期
187 # 添加到数据源集合
188 self.redis.sadd(self.data_sources_key, data_source_id)
190 return data_source
192 def get_data_source_by_id(self, data_source_id: str) -> Optional[DataSource]:
193 """根据ID获取数据源"""
194 data_source_key = self._get_data_source_key(data_source_id)
195 data = self.redis.get(data_source_key)
196 if data and isinstance(data, dict):
197 # 转换字符串为datetime
198 data["created_at"] = datetime.fromisoformat(data["created_at"])
199 data["updated_at"] = datetime.fromisoformat(data["updated_at"])
200 return DataSource(**data)
201 return None
203 def get_all_data_sources(self) -> List[DataSource]:
204 """获取所有数据源"""
205 data_sources = []
206 data_source_ids = self.redis.smembers(self.data_sources_key)
207 for data_source_id in data_source_ids:
208 data_source = self.get_data_source_by_id(data_source_id)
209 if data_source:
210 data_sources.append(data_source)
211 return data_sources
213 def get_data_sources_by_broker(self, broker_id: str) -> List[DataSource]:
214 """根据券商ID获取数据源"""
215 data_sources = self.get_all_data_sources()
216 return [ds for ds in data_sources if ds.broker_id == broker_id]
218 def update_data_source(
219 self, data_source_id: str, data_source_data: DataSourceUpdate
220 ) -> Optional[DataSource]:
221 """更新数据源"""
222 data_source = self.get_data_source_by_id(data_source_id)
223 if not data_source:
224 return None
226 # 更新字段
227 update_data = data_source_data.model_dump(exclude_unset=True)
228 for field, value in update_data.items():
229 setattr(data_source, field, value)
231 data_source.updated_at = datetime.utcnow()
233 # 保存更新后的数据源
234 data_source_key = self._get_data_source_key(data_source_id)
235 data_source_dict = data_source.model_dump()
236 # 转换datetime为字符串
237 data_source_dict["created_at"] = data_source_dict["created_at"].isoformat()
238 data_source_dict["updated_at"] = data_source_dict["updated_at"].isoformat()
240 self.redis.setex(data_source_key, 86400 * 30, data_source_dict)
242 return data_source
244 def delete_data_source(self, data_source_id: str) -> bool:
245 """删除数据源"""
246 data_source = self.get_data_source_by_id(data_source_id)
247 if not data_source:
248 return False
250 # 删除数据源数据
251 data_source_key = self._get_data_source_key(data_source_id)
252 self.redis.delete(data_source_key)
254 # 从数据源集合中移除
255 self.redis.srem(self.data_sources_key, data_source_id)
257 return True
259 # 费用配置管理方法
260 def get_fee_config(self, broker_id: str) -> Optional[FeeConfig]:
261 """
262 获取券商的费用配置
264 Args:
265 broker_id: 券商ID
267 Returns:
268 FeeConfig: 费用配置对象,如果不存在返回None
269 """
270 broker = self.get_broker_by_id(broker_id)
271 if not broker or not broker.config:
272 return None
274 # 从broker.config中获取fee_config
275 if hasattr(broker.config, "fee_config") and broker.config.fee_config:
276 return broker.config.fee_config
278 # 如果没有费用配置,返回默认配置
279 return FeeConfig()
281 def update_fee_config(
282 self, broker_id: str, fee_config: FeeConfig
283 ) -> Optional[Broker]:
284 """
285 更新券商的费用配置
287 Args:
288 broker_id: 券商ID
289 fee_config: 费用配置对象
291 Returns:
292 Broker: 更新后的券商对象,如果不存在返回None
293 """
294 broker = self.get_broker_by_id(broker_id)
295 if not broker:
296 return None
298 # 更新费用配置
299 broker.config.fee_config = fee_config
300 broker.updated_at = datetime.utcnow()
302 # 保存更新后的券商
303 broker_key = self._get_broker_key(broker_id)
304 broker_dict = broker.model_dump()
305 # 转换datetime为字符串
306 broker_dict["created_at"] = broker_dict["created_at"].isoformat()
307 broker_dict["updated_at"] = broker_dict["updated_at"].isoformat()
308 # 转换config为字典
309 broker_dict["config"] = broker.config.model_dump()
311 self.redis.setex(broker_key, 86400 * 30, broker_dict)
313 # 缓存费用配置(用于快速查询)
314 fee_config_key = f"broker:fee_config:{broker_id}"
315 fee_config_dict = fee_config.model_dump()
316 self.redis.setex(fee_config_key, 3600, fee_config_dict) # 1小时过期
318 return broker
320 def cache_fee_config_for_session(self, session_id: str, fee_config: FeeConfig):
321 """
322 缓存交易会话的费用配置
324 Args:
325 session_id: 交易会话ID
326 fee_config: 费用配置对象
327 """
328 session_fee_key = f"trading:session:fee_config:{session_id}"
329 fee_config_dict = fee_config.model_dump()
330 self.redis.setex(session_fee_key, 86400, fee_config_dict) # 24小时过期
332 def get_fee_config_from_session_cache(self, session_id: str) -> Optional[FeeConfig]:
333 """
334 从缓存中获取交易会话的费用配置
336 Args:
337 session_id: 交易会话ID
339 Returns:
340 FeeConfig: 费用配置对象,如果不存在返回None
341 """
342 session_fee_key = f"trading:session:fee_config:{session_id}"
343 data = self.redis.get(session_fee_key)
344 if data and isinstance(data, dict):
345 return FeeConfig(**data)
346 return None