Coverage for core/repositories/stock_repository.py: 44.74%
152 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-13 18:58 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-13 18:58 +0000
1"""
2股票数据仓储层
3"""
5import json
6from datetime import datetime
7from typing import Any, Dict, List, Optional
9import redis
11from core.models.stock import (StockCodeList, StockData, StockDataFilter,
12 StockDataResponse)
15class StockRepository:
16 """股票数据仓储类"""
18 def __init__(self, redis_client: redis.Redis):
19 self.redis = redis_client
20 self.stock_data_prefix = "stock_data:"
21 self.stock_codes_key = "stock_codes"
23 def _get_stock_key(self, code: str, timestamp: int) -> str:
24 """获取股票数据的Redis键"""
25 # 直接使用UNIX时间戳作为键的一部分
26 return f"{self.stock_data_prefix}{code}:{timestamp}"
28 def _get_stock_codes_key(self, code: str) -> str:
29 """获取股票代码的Redis键"""
30 return f"{self.stock_data_prefix}codes:{code}"
32 def _convert_symbol_format(self, symbol: str) -> str:
33 """转换股票代码格式(与DataAdapter保持一致)"""
34 symbol_upper = symbol.upper()
36 # 添加常见的后缀
37 if not any(symbol_upper.endswith(suffix) for suffix in [".US", ".HK", ".CN"]):
38 # 根据symbol特征添加后缀
39 if symbol_upper.startswith(("0", "1", "2", "3", "6")):
40 symbol_upper += ".CN" # 中国股票
41 elif len(symbol_upper) == 4 and symbol_upper.isdigit():
42 symbol_upper += ".HK" # 港股
43 else:
44 symbol_upper += ".US" # 美股
46 return symbol_upper
48 def _convert_trade_session(self, trade_session: str) -> str:
49 """转换交易时段格式"""
50 mapping = {
51 "PRE_MARKET": "Pre",
52 "POST_MARKET": "Post",
53 "INTRADAY": "Intraday",
54 "OVERNIGHT": "Overnight",
55 "Pre": "Pre",
56 "Post": "Post",
57 "Intraday": "Intraday",
58 "Overnight": "Overnight",
59 }
60 return mapping.get(trade_session, "Intraday")
62 async def create_stock_data(self, stock_data: StockData) -> bool:
63 """创建股票数据"""
64 try:
65 # 转换符号格式以匹配存储的格式
66 actual_code = self._convert_symbol_format(stock_data.code)
68 # 直接使用UNIX时间戳构造数据键
69 stock_key = f"{self.stock_data_prefix}{actual_code}:{stock_data.timestamp}"
71 # 使用hset存储数据
72 data_dict = stock_data.model_dump()
73 # 保持timestamp为UNIX时间戳
74 data_dict["timestamp"] = stock_data.timestamp
75 self.redis.hset(stock_key, mapping=data_dict)
77 # 设置过期时间(30天)
78 self.redis.expire(stock_key, 30 * 24 * 3600)
80 # 添加到股票代码集合
81 self.redis.sadd(self.stock_codes_key, actual_code)
83 # 添加到时间索引(使用时间戳作为score)
84 time_index_key = f"{self.stock_data_prefix}time_index:{actual_code}"
85 self.redis.zadd(
86 time_index_key, {stock_data.timestamp: stock_data.timestamp}
87 )
89 return True
90 except Exception as e:
91 print(f"创建股票数据失败: {e}")
92 return False
94 async def get_stock_data(self, code: str, timestamp: int) -> Optional[StockData]:
95 """获取单条股票数据"""
96 try:
97 # 转换符号格式以匹配存储的格式
98 actual_code = self._convert_symbol_format(code)
99 stock_key = self._get_stock_key(actual_code, timestamp)
101 # 使用hgetall获取哈希数据
102 data = self.redis.hgetall(stock_key)
103 if data:
104 # 转换数据格式以匹配StockData模型
105 converted_data = {
106 "code": data.get("symbol", actual_code),
107 "open": float(data.get("open", 0)),
108 "high": float(data.get("high", 0)),
109 "low": float(data.get("low", 0)),
110 "close": float(data.get("close", 0)),
111 "volume": int(data.get("volume", 0)),
112 "turnover": float(data.get("turnover", 0)),
113 "timestamp": timestamp,
114 "trade_session": self._convert_trade_session(
115 data.get("trade_session", "Intraday")
116 ),
117 }
118 return StockData(**converted_data)
119 return None
120 except Exception as e:
121 print(f"获取股票数据失败: {e}")
122 return None
124 def get_stock_data_by_date(self, code: str, date) -> List:
125 """按日期获取股票数据"""
126 try:
127 from datetime import datetime, time
129 from core.models.stock import StockData
131 # 转换符号格式
132 actual_code = self._convert_symbol_format(code)
134 # 计算当天的开始和结束时间戳
135 start_datetime = datetime.combine(date, time.min)
136 end_datetime = datetime.combine(date, time.max)
137 start_timestamp = int(start_datetime.timestamp())
138 end_timestamp = int(end_datetime.timestamp())
140 # 获取时间范围内的数据
141 time_index_key = f"{self.stock_data_prefix}time_index:{actual_code}"
142 timestamps = self.redis.zrangebyscore(
143 time_index_key, start_timestamp, end_timestamp, withscores=True
144 )
146 stock_data_list = []
147 for timestamp_score, _ in timestamps:
148 stock_key = self._get_stock_key(actual_code, int(timestamp_score))
149 data = self.redis.hgetall(stock_key)
150 if data:
151 # 转换数据格式
152 converted_data = {
153 "code": data.get("symbol", actual_code),
154 "open": float(data.get("open", 0)),
155 "high": float(data.get("high", 0)),
156 "low": float(data.get("low", 0)),
157 "close": float(data.get("close", 0)),
158 "volume": int(data.get("volume", 0)),
159 "turnover": float(data.get("turnover", 0)),
160 "timestamp": int(timestamp_score),
161 "trade_session": self._convert_trade_session(
162 data.get("trade_session", "Intraday")
163 ),
164 }
166 stock_data = StockData(**converted_data)
167 stock_data_list.append(stock_data)
169 return stock_data_list
171 except Exception as e:
172 print(f"按日期获取股票数据失败: {e}")
173 return []
175 async def get_stock_data_list(
176 self, filter_params: StockDataFilter
177 ) -> StockDataResponse:
178 """获取股票数据列表"""
179 try:
180 # 构建查询条件
181 if filter_params.code:
182 # 查询指定股票的数据,需要转换符号格式以匹配存储的格式
183 actual_code = self._convert_symbol_format(filter_params.code)
184 time_index_key = f"{self.stock_data_prefix}time_index:{actual_code}"
186 # 处理时区转换
187 start_score = filter_params.start_timestamp or 0
188 end_score = filter_params.end_timestamp or float("inf")
190 # 获取时间范围内的数据
191 timestamps = self.redis.zrangebyscore(
192 time_index_key, start_score, end_score, withscores=True
193 )
195 stock_data_list = []
196 for timestamp_score, _ in timestamps:
197 # 使用时间戳构造数据键,确保时区信息正确
198 stock_key = self._get_stock_key(actual_code, int(timestamp_score))
199 # 数据是哈希类型,使用hgetall
200 data = self.redis.hgetall(stock_key)
201 if data:
202 # 转换数据格式以匹配StockData模型
203 converted_data = {
204 "code": data.get("symbol", actual_code), # 使用转换后的符号
205 "open": float(data.get("open", 0)),
206 "high": float(data.get("high", 0)),
207 "low": float(data.get("low", 0)),
208 "close": float(data.get("close", 0)),
209 "volume": int(data.get("volume", 0)),
210 "turnover": float(data.get("turnover", 0)),
211 "timestamp": int(timestamp_score), # 使用时间索引中的时间戳
212 "trade_session": self._convert_trade_session(
213 data.get("trade_session", "Intraday")
214 ),
215 }
217 stock_data = StockData(**converted_data)
219 # 应用交易时段筛选
220 if (
221 filter_params.trade_session
222 and stock_data.trade_session != filter_params.trade_session
223 ):
224 continue
226 stock_data_list.append(stock_data)
228 # 排序
229 if filter_params.sort_by == "timestamp":
230 stock_data_list.sort(
231 key=lambda x: x.timestamp,
232 reverse=(filter_params.sort_order == "desc"),
233 )
235 # 分页
236 total = len(stock_data_list)
237 start_idx = (filter_params.page - 1) * filter_params.page_size
238 end_idx = start_idx + filter_params.page_size
239 paginated_data = stock_data_list[start_idx:end_idx]
241 return StockDataResponse(
242 data=paginated_data,
243 total=total,
244 page=filter_params.page,
245 page_size=filter_params.page_size,
246 total_pages=(total + filter_params.page_size - 1)
247 // filter_params.page_size,
248 )
249 else:
250 # 查询所有股票的数据
251 all_codes = self.redis.smembers(self.stock_codes_key)
252 all_data = []
254 for code in all_codes:
255 code_str = code.decode("utf-8") if isinstance(code, bytes) else code
256 time_index_key = f"{self.stock_data_prefix}time_index:{code_str}"
257 start_score = filter_params.start_timestamp or 0
258 end_score = filter_params.end_timestamp or float("inf")
260 timestamps = self.redis.zrangebyscore(
261 time_index_key, start_score, end_score, withscores=True
262 )
264 for timestamp, _ in timestamps:
265 # 转换符号格式以匹配存储的格式
266 actual_code = self._convert_symbol_format(code_str)
267 stock_key = self._get_stock_key(actual_code, int(timestamp))
268 data = self.redis.hgetall(stock_key)
269 if data:
270 # 转换数据格式以匹配StockData模型
271 converted_data = {
272 "code": data.get("symbol", actual_code),
273 "open": float(data.get("open", 0)),
274 "high": float(data.get("high", 0)),
275 "low": float(data.get("low", 0)),
276 "close": float(data.get("close", 0)),
277 "volume": int(data.get("volume", 0)),
278 "turnover": float(data.get("turnover", 0)),
279 "timestamp": int(timestamp),
280 "trade_session": self._convert_trade_session(
281 data.get("trade_session", "Intraday")
282 ),
283 }
284 stock_data = StockData(**converted_data)
286 # 应用交易时段筛选
287 if (
288 filter_params.trade_session
289 and stock_data.trade_session
290 != filter_params.trade_session
291 ):
292 continue
294 all_data.append(stock_data)
296 # 排序
297 if filter_params.sort_by == "timestamp":
298 all_data.sort(
299 key=lambda x: x.timestamp,
300 reverse=(filter_params.sort_order == "desc"),
301 )
303 # 分页
304 total = len(all_data)
305 start_idx = (filter_params.page - 1) * filter_params.page_size
306 end_idx = start_idx + filter_params.page_size
307 paginated_data = all_data[start_idx:end_idx]
309 return StockDataResponse(
310 data=paginated_data,
311 total=total,
312 page=filter_params.page,
313 page_size=filter_params.page_size,
314 total_pages=(total + filter_params.page_size - 1)
315 // filter_params.page_size,
316 )
318 except Exception as e:
319 print(f"获取股票数据列表失败: {e}")
320 return StockDataResponse(
321 data=[],
322 total=0,
323 page=filter_params.page,
324 page_size=filter_params.page_size,
325 total_pages=0,
326 )
328 async def get_stock_codes(self) -> StockCodeList:
329 """获取所有股票代码"""
330 try:
331 codes = self.redis.smembers(self.stock_codes_key)
332 code_list = [
333 code.decode("utf-8") if isinstance(code, bytes) else code
334 for code in codes
335 ]
336 return StockCodeList(codes=sorted(code_list))
337 except Exception as e:
338 print(f"获取股票代码失败: {e}")
339 return StockCodeList(codes=[])
341 async def delete_stock_data(
342 self, code: str, start_timestamp: int, end_timestamp: int
343 ) -> bool:
344 """删除指定时间范围内的股票数据"""
345 try:
346 # 转换符号格式以匹配存储的格式
347 actual_code = self._convert_symbol_format(code)
348 time_index_key = f"{self.stock_data_prefix}time_index:{actual_code}"
350 # 获取时间范围内的数据
351 timestamps = self.redis.zrangebyscore(
352 time_index_key, start_timestamp, end_timestamp, withscores=True
353 )
355 deleted_count = 0
356 for timestamp_score, _ in timestamps:
357 # 使用时间戳构造数据键
358 stock_key = self._get_stock_key(actual_code, int(timestamp_score))
359 if self.redis.delete(stock_key):
360 deleted_count += 1
361 # 从时间索引中删除
362 self.redis.zrem(time_index_key, timestamp_score)
364 return deleted_count > 0
365 except Exception as e:
366 print(f"删除股票数据失败: {e}")
367 return False