Coverage for core/repositories/stock_repository.py: 44.74%

152 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-13 18:58 +0000

1""" 

2股票数据仓储层 

3""" 

4 

5import json 

6from datetime import datetime 

7from typing import Any, Dict, List, Optional 

8 

9import redis 

10 

11from core.models.stock import (StockCodeList, StockData, StockDataFilter, 

12 StockDataResponse) 

13 

14 

15class StockRepository: 

16 """股票数据仓储类""" 

17 

18 def __init__(self, redis_client: redis.Redis): 

19 self.redis = redis_client 

20 self.stock_data_prefix = "stock_data:" 

21 self.stock_codes_key = "stock_codes" 

22 

23 def _get_stock_key(self, code: str, timestamp: int) -> str: 

24 """获取股票数据的Redis键""" 

25 # 直接使用UNIX时间戳作为键的一部分 

26 return f"{self.stock_data_prefix}{code}:{timestamp}" 

27 

28 def _get_stock_codes_key(self, code: str) -> str: 

29 """获取股票代码的Redis键""" 

30 return f"{self.stock_data_prefix}codes:{code}" 

31 

32 def _convert_symbol_format(self, symbol: str) -> str: 

33 """转换股票代码格式(与DataAdapter保持一致)""" 

34 symbol_upper = symbol.upper() 

35 

36 # 添加常见的后缀 

37 if not any(symbol_upper.endswith(suffix) for suffix in [".US", ".HK", ".CN"]): 

38 # 根据symbol特征添加后缀 

39 if symbol_upper.startswith(("0", "1", "2", "3", "6")): 

40 symbol_upper += ".CN" # 中国股票 

41 elif len(symbol_upper) == 4 and symbol_upper.isdigit(): 

42 symbol_upper += ".HK" # 港股 

43 else: 

44 symbol_upper += ".US" # 美股 

45 

46 return symbol_upper 

47 

48 def _convert_trade_session(self, trade_session: str) -> str: 

49 """转换交易时段格式""" 

50 mapping = { 

51 "PRE_MARKET": "Pre", 

52 "POST_MARKET": "Post", 

53 "INTRADAY": "Intraday", 

54 "OVERNIGHT": "Overnight", 

55 "Pre": "Pre", 

56 "Post": "Post", 

57 "Intraday": "Intraday", 

58 "Overnight": "Overnight", 

59 } 

60 return mapping.get(trade_session, "Intraday") 

61 

62 async def create_stock_data(self, stock_data: StockData) -> bool: 

63 """创建股票数据""" 

64 try: 

65 # 转换符号格式以匹配存储的格式 

66 actual_code = self._convert_symbol_format(stock_data.code) 

67 

68 # 直接使用UNIX时间戳构造数据键 

69 stock_key = f"{self.stock_data_prefix}{actual_code}:{stock_data.timestamp}" 

70 

71 # 使用hset存储数据 

72 data_dict = stock_data.model_dump() 

73 # 保持timestamp为UNIX时间戳 

74 data_dict["timestamp"] = stock_data.timestamp 

75 self.redis.hset(stock_key, mapping=data_dict) 

76 

77 # 设置过期时间(30天) 

78 self.redis.expire(stock_key, 30 * 24 * 3600) 

79 

80 # 添加到股票代码集合 

81 self.redis.sadd(self.stock_codes_key, actual_code) 

82 

83 # 添加到时间索引(使用时间戳作为score) 

84 time_index_key = f"{self.stock_data_prefix}time_index:{actual_code}" 

85 self.redis.zadd( 

86 time_index_key, {stock_data.timestamp: stock_data.timestamp} 

87 ) 

88 

89 return True 

90 except Exception as e: 

91 print(f"创建股票数据失败: {e}") 

92 return False 

93 

94 async def get_stock_data(self, code: str, timestamp: int) -> Optional[StockData]: 

95 """获取单条股票数据""" 

96 try: 

97 # 转换符号格式以匹配存储的格式 

98 actual_code = self._convert_symbol_format(code) 

99 stock_key = self._get_stock_key(actual_code, timestamp) 

100 

101 # 使用hgetall获取哈希数据 

102 data = self.redis.hgetall(stock_key) 

103 if data: 

104 # 转换数据格式以匹配StockData模型 

105 converted_data = { 

106 "code": data.get("symbol", actual_code), 

107 "open": float(data.get("open", 0)), 

108 "high": float(data.get("high", 0)), 

109 "low": float(data.get("low", 0)), 

110 "close": float(data.get("close", 0)), 

111 "volume": int(data.get("volume", 0)), 

112 "turnover": float(data.get("turnover", 0)), 

113 "timestamp": timestamp, 

114 "trade_session": self._convert_trade_session( 

115 data.get("trade_session", "Intraday") 

116 ), 

117 } 

118 return StockData(**converted_data) 

119 return None 

120 except Exception as e: 

121 print(f"获取股票数据失败: {e}") 

122 return None 

123 

124 def get_stock_data_by_date(self, code: str, date) -> List: 

125 """按日期获取股票数据""" 

126 try: 

127 from datetime import datetime, time 

128 

129 from core.models.stock import StockData 

130 

131 # 转换符号格式 

132 actual_code = self._convert_symbol_format(code) 

133 

134 # 计算当天的开始和结束时间戳 

135 start_datetime = datetime.combine(date, time.min) 

136 end_datetime = datetime.combine(date, time.max) 

137 start_timestamp = int(start_datetime.timestamp()) 

138 end_timestamp = int(end_datetime.timestamp()) 

139 

140 # 获取时间范围内的数据 

141 time_index_key = f"{self.stock_data_prefix}time_index:{actual_code}" 

142 timestamps = self.redis.zrangebyscore( 

143 time_index_key, start_timestamp, end_timestamp, withscores=True 

144 ) 

145 

146 stock_data_list = [] 

147 for timestamp_score, _ in timestamps: 

148 stock_key = self._get_stock_key(actual_code, int(timestamp_score)) 

149 data = self.redis.hgetall(stock_key) 

150 if data: 

151 # 转换数据格式 

152 converted_data = { 

153 "code": data.get("symbol", actual_code), 

154 "open": float(data.get("open", 0)), 

155 "high": float(data.get("high", 0)), 

156 "low": float(data.get("low", 0)), 

157 "close": float(data.get("close", 0)), 

158 "volume": int(data.get("volume", 0)), 

159 "turnover": float(data.get("turnover", 0)), 

160 "timestamp": int(timestamp_score), 

161 "trade_session": self._convert_trade_session( 

162 data.get("trade_session", "Intraday") 

163 ), 

164 } 

165 

166 stock_data = StockData(**converted_data) 

167 stock_data_list.append(stock_data) 

168 

169 return stock_data_list 

170 

171 except Exception as e: 

172 print(f"按日期获取股票数据失败: {e}") 

173 return [] 

174 

175 async def get_stock_data_list( 

176 self, filter_params: StockDataFilter 

177 ) -> StockDataResponse: 

178 """获取股票数据列表""" 

179 try: 

180 # 构建查询条件 

181 if filter_params.code: 

182 # 查询指定股票的数据,需要转换符号格式以匹配存储的格式 

183 actual_code = self._convert_symbol_format(filter_params.code) 

184 time_index_key = f"{self.stock_data_prefix}time_index:{actual_code}" 

185 

186 # 处理时区转换 

187 start_score = filter_params.start_timestamp or 0 

188 end_score = filter_params.end_timestamp or float("inf") 

189 

190 # 获取时间范围内的数据 

191 timestamps = self.redis.zrangebyscore( 

192 time_index_key, start_score, end_score, withscores=True 

193 ) 

194 

195 stock_data_list = [] 

196 for timestamp_score, _ in timestamps: 

197 # 使用时间戳构造数据键,确保时区信息正确 

198 stock_key = self._get_stock_key(actual_code, int(timestamp_score)) 

199 # 数据是哈希类型,使用hgetall 

200 data = self.redis.hgetall(stock_key) 

201 if data: 

202 # 转换数据格式以匹配StockData模型 

203 converted_data = { 

204 "code": data.get("symbol", actual_code), # 使用转换后的符号 

205 "open": float(data.get("open", 0)), 

206 "high": float(data.get("high", 0)), 

207 "low": float(data.get("low", 0)), 

208 "close": float(data.get("close", 0)), 

209 "volume": int(data.get("volume", 0)), 

210 "turnover": float(data.get("turnover", 0)), 

211 "timestamp": int(timestamp_score), # 使用时间索引中的时间戳 

212 "trade_session": self._convert_trade_session( 

213 data.get("trade_session", "Intraday") 

214 ), 

215 } 

216 

217 stock_data = StockData(**converted_data) 

218 

219 # 应用交易时段筛选 

220 if ( 

221 filter_params.trade_session 

222 and stock_data.trade_session != filter_params.trade_session 

223 ): 

224 continue 

225 

226 stock_data_list.append(stock_data) 

227 

228 # 排序 

229 if filter_params.sort_by == "timestamp": 

230 stock_data_list.sort( 

231 key=lambda x: x.timestamp, 

232 reverse=(filter_params.sort_order == "desc"), 

233 ) 

234 

235 # 分页 

236 total = len(stock_data_list) 

237 start_idx = (filter_params.page - 1) * filter_params.page_size 

238 end_idx = start_idx + filter_params.page_size 

239 paginated_data = stock_data_list[start_idx:end_idx] 

240 

241 return StockDataResponse( 

242 data=paginated_data, 

243 total=total, 

244 page=filter_params.page, 

245 page_size=filter_params.page_size, 

246 total_pages=(total + filter_params.page_size - 1) 

247 // filter_params.page_size, 

248 ) 

249 else: 

250 # 查询所有股票的数据 

251 all_codes = self.redis.smembers(self.stock_codes_key) 

252 all_data = [] 

253 

254 for code in all_codes: 

255 code_str = code.decode("utf-8") if isinstance(code, bytes) else code 

256 time_index_key = f"{self.stock_data_prefix}time_index:{code_str}" 

257 start_score = filter_params.start_timestamp or 0 

258 end_score = filter_params.end_timestamp or float("inf") 

259 

260 timestamps = self.redis.zrangebyscore( 

261 time_index_key, start_score, end_score, withscores=True 

262 ) 

263 

264 for timestamp, _ in timestamps: 

265 # 转换符号格式以匹配存储的格式 

266 actual_code = self._convert_symbol_format(code_str) 

267 stock_key = self._get_stock_key(actual_code, int(timestamp)) 

268 data = self.redis.hgetall(stock_key) 

269 if data: 

270 # 转换数据格式以匹配StockData模型 

271 converted_data = { 

272 "code": data.get("symbol", actual_code), 

273 "open": float(data.get("open", 0)), 

274 "high": float(data.get("high", 0)), 

275 "low": float(data.get("low", 0)), 

276 "close": float(data.get("close", 0)), 

277 "volume": int(data.get("volume", 0)), 

278 "turnover": float(data.get("turnover", 0)), 

279 "timestamp": int(timestamp), 

280 "trade_session": self._convert_trade_session( 

281 data.get("trade_session", "Intraday") 

282 ), 

283 } 

284 stock_data = StockData(**converted_data) 

285 

286 # 应用交易时段筛选 

287 if ( 

288 filter_params.trade_session 

289 and stock_data.trade_session 

290 != filter_params.trade_session 

291 ): 

292 continue 

293 

294 all_data.append(stock_data) 

295 

296 # 排序 

297 if filter_params.sort_by == "timestamp": 

298 all_data.sort( 

299 key=lambda x: x.timestamp, 

300 reverse=(filter_params.sort_order == "desc"), 

301 ) 

302 

303 # 分页 

304 total = len(all_data) 

305 start_idx = (filter_params.page - 1) * filter_params.page_size 

306 end_idx = start_idx + filter_params.page_size 

307 paginated_data = all_data[start_idx:end_idx] 

308 

309 return StockDataResponse( 

310 data=paginated_data, 

311 total=total, 

312 page=filter_params.page, 

313 page_size=filter_params.page_size, 

314 total_pages=(total + filter_params.page_size - 1) 

315 // filter_params.page_size, 

316 ) 

317 

318 except Exception as e: 

319 print(f"获取股票数据列表失败: {e}") 

320 return StockDataResponse( 

321 data=[], 

322 total=0, 

323 page=filter_params.page, 

324 page_size=filter_params.page_size, 

325 total_pages=0, 

326 ) 

327 

328 async def get_stock_codes(self) -> StockCodeList: 

329 """获取所有股票代码""" 

330 try: 

331 codes = self.redis.smembers(self.stock_codes_key) 

332 code_list = [ 

333 code.decode("utf-8") if isinstance(code, bytes) else code 

334 for code in codes 

335 ] 

336 return StockCodeList(codes=sorted(code_list)) 

337 except Exception as e: 

338 print(f"获取股票代码失败: {e}") 

339 return StockCodeList(codes=[]) 

340 

341 async def delete_stock_data( 

342 self, code: str, start_timestamp: int, end_timestamp: int 

343 ) -> bool: 

344 """删除指定时间范围内的股票数据""" 

345 try: 

346 # 转换符号格式以匹配存储的格式 

347 actual_code = self._convert_symbol_format(code) 

348 time_index_key = f"{self.stock_data_prefix}time_index:{actual_code}" 

349 

350 # 获取时间范围内的数据 

351 timestamps = self.redis.zrangebyscore( 

352 time_index_key, start_timestamp, end_timestamp, withscores=True 

353 ) 

354 

355 deleted_count = 0 

356 for timestamp_score, _ in timestamps: 

357 # 使用时间戳构造数据键 

358 stock_key = self._get_stock_key(actual_code, int(timestamp_score)) 

359 if self.redis.delete(stock_key): 

360 deleted_count += 1 

361 # 从时间索引中删除 

362 self.redis.zrem(time_index_key, timestamp_score) 

363 

364 return deleted_count > 0 

365 except Exception as e: 

366 print(f"删除股票数据失败: {e}") 

367 return False