Coverage for core/data_source/adapters/data_adapter.py: 82.22%

270 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-13 18:58 +0000

1""" 

2数据拉取业务适配器 

3统一数据拉取逻辑,替代StockDataService 

4保留原有的复杂分页逻辑不变,只替换数据源访问层 

5""" 

6 

7import asyncio 

8import logging 

9from concurrent.futures import ThreadPoolExecutor 

10from datetime import date, datetime, timedelta 

11from decimal import Decimal 

12from typing import Any, Dict, List, Optional 

13 

14import pandas as pd 

15import pytz 

16from longport.openapi import AdjustType, Period, TradeSessions 

17 

18from core.data_source.factories.client_factory import unified_client_factory 

19from core.repositories.stock_repository import StockRepository 

20from infrastructure.database.redis_client import get_redis 

21 

22logger = logging.getLogger(__name__) 

23 

24 

25class DataAdapter: 

26 """数据拉取业务适配器 - 保留原有复杂分页逻辑,只替换数据源访问""" 

27 

28 def __init__(self, user_id: str): 

29 self.user_id = user_id 

30 redis_client = get_redis() 

31 self.stock_repo = StockRepository(redis_client) 

32 self.quote_ctx = None # 保持与原StockDataService相同的属性 

33 

34 async def fetch_stock_data( 

35 self, 

36 symbol: str, 

37 start_date: date, 

38 end_date: date, 

39 timezone: str = "America/New_York", 

40 progress_callback: Optional[callable] = None, 

41 ) -> Dict[str, Any]: 

42 """ 

43 获取股票数据(保留原有复杂分页逻辑) 

44 

45 Args: 

46 symbol: 股票代码 

47 start_date: 开始日期 

48 end_date: 结束日期 

49 timezone: 时区 

50 progress_callback: 进度回调函数 

51 

52 Returns: 

53 Dict: 包含成功状态、数据条数、错误信息等 

54 """ 

55 try: 

56 logger.info(f"开始获取股票数据: {symbol}, 用户ID: {self.user_id}") 

57 

58 # 🔄 新的数据源初始化方法(替代旧的数据源工厂) 

59 if not await self._initialize_data_source_client(): 

60 error_msg = "数据源客户端初始化失败" 

61 logger.error(error_msg) 

62 if progress_callback: 

63 progress_callback(f"{error_msg}") 

64 return {"success": False, "error": error_msg, "data_count": 0} 

65 

66 # 检查日期范围(允许起始和结束日期相同,代表获取全天数据) 

67 if start_date > end_date: 

68 error_msg = "起始日期不能晚于结束日期,请重新选择" 

69 logger.error(error_msg) 

70 if progress_callback: 

71 progress_callback(f"{error_msg}") 

72 return {"success": False, "error": error_msg, "data_count": 0} 

73 

74 # 转换股票代码格式 

75 longport_symbol = self._convert_symbol_format(symbol) 

76 

77 # 固定使用1分钟间隔 

78 longport_period = Period.Min_1 

79 

80 # 固定获取所有交易时段(盘前、盘中、盘后、夜盘) 

81 trade_sessions = TradeSessions.All 

82 

83 # 解析ISO格式的日期字符串 

84 if isinstance(start_date, str): 

85 # 解析ISO格式字符串,如 "2025-09-17T00:00:00-04:00" 

86 start_date = datetime.fromisoformat( 

87 start_date.replace("Z", "+00:00") 

88 ).date() 

89 if isinstance(end_date, str): 

90 end_date = datetime.fromisoformat( 

91 end_date.replace("Z", "+00:00") 

92 ).date() 

93 

94 # 🚀 保持原有的复杂分页逻辑开始 

95 

96 # 分批获取数据 

97 all_candlesticks = [] 

98 current_date = start_date 

99 batch_count = 0 

100 success_days = 0 

101 skip_days = 0 

102 

103 # 设置时区 

104 tz = pytz.timezone(timezone) 

105 current_time = datetime.now(tz) 

106 

107 if progress_callback: 

108 progress_callback( 

109 f"开始获取 {symbol} 数据,时间范围: {start_date}{end_date}" 

110 ) 

111 progress_callback( 

112 f"开始拉取时间: {current_time.strftime('%H:%M:%S')} ({timezone})" 

113 ) 

114 progress_callback( 

115 f"调试信息: start_date={start_date}, end_date={end_date}, current_date={current_date}" 

116 ) 

117 

118 # 计算总天数和预计数据量 

119 total_days = (end_date - start_date).days + 1 

120 estimated_total = total_days * 1440 # 每天最多1440条 

121 if progress_callback: 

122 progress_callback( 

123 f"预计获取 {total_days} 天的数据,每天最多1440条(分2次获取,每次720条)" 

124 ) 

125 progress_callback(f"预计总数据量: {estimated_total}") 

126 

127 while current_date <= end_date: 

128 batch_count += 1 

129 

130 # 计算当前批次的结束日期(不包含下一天) 

131 current_end = current_date 

132 

133 # 分别处理两次获取,确保即使第一次失败也会尝试第二次 

134 morning_candlesticks = [] 

135 afternoon_candlesticks = [] 

136 

137 # 🚀 第一次获取:从用户选择的日期00:00开始,获取12小时数据 (720条记录) 

138 try: 

139 # 使用用户选择的时区,从当前日期的00:00开始 

140 morning_time = tz.localize( 

141 datetime.combine(current_date, datetime.min.time()) 

142 ) 

143 if progress_callback: 

144 progress_callback( 

145 f"📊 批次 {batch_count}/{total_days}: 获取 {current_date} 的数据" 

146 ) 

147 progress_callback( 

148 f"🕐 第1次获取 (00:00-12:00): {morning_time.isoformat()}" 

149 ) 

150 progress_callback(f"⏳ 正在调用LongPort API...") 

151 

152 morning_candlesticks = await self._get_candlesticks_async( 

153 longport_symbol, 

154 longport_period, 

155 morning_time, 

156 720, 

157 trade_sessions, 

158 forward=True, 

159 progress_callback=progress_callback, 

160 ) 

161 

162 if progress_callback: 

163 if len(morning_candlesticks) > 0: 

164 first_time = morning_candlesticks[0].timestamp 

165 last_time = morning_candlesticks[-1].timestamp 

166 # 转换为目标时区的ISO格式显示 

167 first_time_tz = first_time.astimezone(tz) 

168 last_time_tz = last_time.astimezone(tz) 

169 first_time_iso = first_time_tz.isoformat() 

170 last_time_iso = last_time_tz.isoformat() 

171 

172 progress_callback( 

173 f"✅ 第1次获取完成: {len(morning_candlesticks)} 条数据, 时间范围: {first_time_iso}{last_time_iso}, 价格: {morning_candlesticks[0].open}-{morning_candlesticks[0].close}{morning_candlesticks[-1].open}-{morning_candlesticks[-1].close}" 

174 ) 

175 else: 

176 progress_callback( 

177 f"✅ 第1次获取完成: 0 条数据 (返回空数据)" 

178 ) 

179 

180 except Exception as e: 

181 if progress_callback: 

182 progress_callback( 

183 f"❌ 第1次获取失败: {type(e).__name__}: {str(e)}" 

184 ) 

185 

186 # 🚀 第二次获取:当前日期12:00开始,获取12小时数据 (720条记录) 

187 try: 

188 # 使用用户选择的时区,从当前日期的12:00开始 

189 afternoon_time = tz.localize( 

190 datetime.combine( 

191 current_date, datetime.min.time().replace(hour=12) 

192 ) 

193 ) 

194 if progress_callback: 

195 progress_callback( 

196 f"🕐 第2次获取 (12:00-24:00): {afternoon_time.isoformat()}" 

197 ) 

198 progress_callback(f"⏳ 正在调用LongPort API...") 

199 

200 afternoon_candlesticks = await self._get_candlesticks_async( 

201 longport_symbol, 

202 longport_period, 

203 afternoon_time, 

204 720, 

205 trade_sessions, 

206 forward=True, 

207 progress_callback=progress_callback, 

208 ) 

209 

210 if progress_callback: 

211 if len(afternoon_candlesticks) > 0: 

212 first_time = afternoon_candlesticks[0].timestamp 

213 last_time = afternoon_candlesticks[-1].timestamp 

214 # 转换为目标时区的ISO格式显示 

215 first_time_tz = first_time.astimezone(tz) 

216 last_time_tz = last_time.astimezone(tz) 

217 first_time_iso = first_time_tz.isoformat() 

218 last_time_iso = last_time_tz.isoformat() 

219 

220 progress_callback( 

221 f"✅ 第2次获取完成: {len(afternoon_candlesticks)} 条数据, 时间范围: {first_time_iso}{last_time_iso}, 价格: {afternoon_candlesticks[0].open}-{afternoon_candlesticks[0].close}{afternoon_candlesticks[-1].open}-{afternoon_candlesticks[-1].close}" 

222 ) 

223 else: 

224 progress_callback( 

225 f"✅ 第2次获取完成: 0 条数据 (返回空数据)" 

226 ) 

227 

228 except Exception as e: 

229 if progress_callback: 

230 progress_callback( 

231 f"❌ 第2次获取失败: {type(e).__name__}: {str(e)}" 

232 ) 

233 

234 # 🚀 合并当天的数据 

235 daily_candlesticks = morning_candlesticks + afternoon_candlesticks 

236 

237 # 🚀 过滤掉不属于当天的数据(丢弃多余的数据) 

238 # 考虑到时区转换,先转换时间戳再过滤 

239 filtered_candlesticks = [] 

240 for candle in daily_candlesticks: 

241 # 转换时间戳到指定时区 

242 beijing_timestamp = candle.timestamp 

243 if beijing_timestamp.tzinfo is None: 

244 beijing_tz = pytz.timezone("Asia/Shanghai") 

245 beijing_timestamp = beijing_tz.localize(beijing_timestamp) 

246 

247 # 转换为指定时区 

248 target_timestamp = beijing_timestamp.astimezone(tz) 

249 candle_date = target_timestamp.date() 

250 

251 # 只保留目标日期的数据 

252 if candle_date == current_date: 

253 filtered_candlesticks.append(candle) 

254 

255 daily_candlesticks = filtered_candlesticks 

256 

257 if progress_callback: 

258 progress_callback( 

259 f"🔄 数据处理: 合并 {len(morning_candlesticks)}+{len(afternoon_candlesticks)} 条原始数据, 过滤后保留 {len(daily_candlesticks)}" 

260 ) 

261 

262 if daily_candlesticks: 

263 all_candlesticks.extend(daily_candlesticks) 

264 success_days += 1 

265 if progress_callback: 

266 # 显示当前日期 

267 current_date_str = current_date.strftime("%Y-%m-%d") 

268 progress_callback( 

269 f"📅 处理日期 {current_date_str} ({batch_count}/{total_days})" 

270 ) 

271 progress_callback( 

272 f"✅ 成功获取 {len(daily_candlesticks)} 条数据" 

273 ) 

274 # 显示当天数据的时间范围(转换为指定时区) 

275 if daily_candlesticks: 

276 first_time = daily_candlesticks[0].timestamp 

277 last_time = daily_candlesticks[-1].timestamp 

278 # 转换为指定时区显示,使用ISO格式 

279 first_time_tz = first_time.astimezone(tz) 

280 last_time_tz = last_time.astimezone(tz) 

281 progress_callback( 

282 f" 数据时间范围: {first_time_tz.isoformat()}{last_time_tz.isoformat()}" 

283 ) 

284 else: 

285 skip_days += 1 

286 if progress_callback: 

287 current_date_str = current_date.strftime("%Y-%m-%d") 

288 progress_callback( 

289 f"📅 处理日期 {current_date_str} ({batch_count}/{total_days})" 

290 ) 

291 progress_callback(f"⚠️ 无交易数据") 

292 

293 current_date += timedelta(days=1) 

294 

295 # 🚀 保持原有复杂分页逻辑结束 

296 

297 if not all_candlesticks: 

298 return { 

299 "success": False, 

300 "error": f"未获取到 {symbol} 的任何数据", 

301 "data_count": 0, 

302 } 

303 

304 # 转换为DataFrame并保存到数据库 

305 df = self._candlesticks_to_dataframe( 

306 all_candlesticks, symbol, longport_symbol 

307 ) 

308 

309 # 保存到数据库 

310 save_stats = await self._save_stock_data_to_redis(df) 

311 saved_count = save_stats["saved_count"] 

312 overwritten_count = save_stats["overwritten_count"] 

313 new_count = save_stats["new_count"] 

314 

315 # 计算过滤掉的数据 

316 filtered_count = len(all_candlesticks) - len(df) 

317 

318 # 计算每天平均数据量 

319 days_with_data = success_days if success_days > 0 else 1 

320 avg_daily_count = len(df) // days_with_data 

321 

322 if progress_callback: 

323 progress_callback( 

324 f"🔄 数据获取完成,共获取 {len(all_candlesticks)} 条原始数据" 

325 ) 

326 if filtered_count > 0: 

327 progress_callback(f"🔍 因时间区间过滤掉 {filtered_count} 条数据") 

328 progress_callback(f"💾 正在保存数据到数据库...") 

329 progress_callback(f"✅ 成功保存 {saved_count} 条数据到数据库") 

330 if overwritten_count > 0: 

331 progress_callback( 

332 f"📊 数据统计: 新增 {new_count} 条,覆盖 {overwritten_count} 条,过滤 {filtered_count}" 

333 ) 

334 progress_callback(f"📈 平均每天 {avg_daily_count} 条数据") 

335 

336 # 显示每天的数据统计 

337 if success_days > 0: 

338 progress_callback( 

339 f"📅 处理总结: 成功获取 {success_days} 天数据,跳过 {skip_days}" 

340 ) 

341 progress_callback( 

342 f"🎉 数据导入完成!共写入数据库 {saved_count} 条数据" 

343 ) 

344 

345 return { 

346 "success": True, 

347 "data_count": saved_count, 

348 "overwritten_count": overwritten_count, 

349 "new_count": new_count, 

350 "filtered_count": filtered_count, 

351 "avg_daily_count": avg_daily_count, 

352 "batch_count": batch_count, 

353 "success_days": success_days, 

354 "skip_days": skip_days, 

355 } 

356 

357 except Exception as e: 

358 logger.error(f"获取股票数据失败: {e}") 

359 if progress_callback: 

360 progress_callback(f"获取数据失败: {str(e)}") 

361 return {"success": False, "error": str(e), "data_count": 0} 

362 

363 async def _initialize_data_source_client(self) -> bool: 

364 """初始化数据源客户端(使用新的统一工厂)""" 

365 try: 

366 # 🔄 使用新的统一客户端工厂 

367 client = unified_client_factory.get_client(self.user_id) 

368 if not client: 

369 logger.error(f"无法获取用户 {self.user_id} 的数据源客户端") 

370 return False 

371 

372 # 获取行情上下文(保持与原代码一致) 

373 self.quote_ctx = client.get("quote_ctx") 

374 

375 logger.info("数据源客户端初始化成功") 

376 return True 

377 

378 except Exception as e: 

379 logger.error(f"数据源客户端初始化失败: {e}") 

380 return False 

381 

382 async def _get_candlesticks_async( 

383 self, 

384 symbol: str, 

385 period: Period, 

386 time: datetime, 

387 count: int, 

388 trade_sessions: TradeSessions, 

389 forward: bool = True, 

390 progress_callback=None, 

391 ) -> List: 

392 """异步获取K线数据(保持原有逻辑)""" 

393 loop = asyncio.get_event_loop() 

394 with ThreadPoolExecutor() as executor: 

395 return await loop.run_in_executor( 

396 executor, 

397 self._get_candlesticks_sync, 

398 symbol, 

399 period, 

400 time, 

401 count, 

402 trade_sessions, 

403 forward, 

404 progress_callback, 

405 ) 

406 

407 def _get_candlesticks_sync( 

408 self, 

409 symbol: str, 

410 period: Period, 

411 time: datetime, 

412 count: int, 

413 trade_sessions: TradeSessions, 

414 forward: bool = True, 

415 progress_callback=None, 

416 ) -> List: 

417 """同步获取K线数据(保持原有复杂逻辑)""" 

418 try: 

419 # 根据时间跨度计算offset值(1-720分钟) 

420 # 第一次获取:00:00-12:00 (720分钟) 

421 # 第二次获取:12:00-24:00 (720分钟) 

422 if forward: 

423 # 向前获取,从指定时间开始 

424 offset_minutes = 720 # 12小时 = 720分钟 

425 else: 

426 # 向后获取,从指定时间开始 

427 offset_minutes = 720 # 12小时 = 720分钟 

428 

429 # 确保offset在1-720范围内 

430 offset_minutes = max(1, min(720, offset_minutes)) 

431 

432 if progress_callback: 

433 progress_callback( 

434 f" === 调用LongPort API === symbol: {symbol}, period: {period}, time: {time}, offset: {offset_minutes}分钟, forward: {forward}, trade_sessions: {trade_sessions}, adjust_type: NoAdjust" 

435 ) 

436 

437 # 🚀 这里仍然直接调用quote_ctx,保持原有SDK调用逻辑 

438 result = self.quote_ctx.history_candlesticks_by_offset( 

439 symbol=symbol, 

440 period=period, 

441 adjust_type=AdjustType.NoAdjust, 

442 forward=forward, 

443 count=count, 

444 time=time, 

445 trade_sessions=trade_sessions, 

446 ) 

447 

448 if progress_callback: 

449 if len(result) > 0: 

450 progress_callback( 

451 f" === LongPort API返回 === 数据条数: {len(result)}, 时间范围: {result[0].timestamp}{result[-1].timestamp}, 价格范围: {result[0].open}-{result[0].close}{result[-1].open}-{result[-1].close}" 

452 ) 

453 else: 

454 progress_callback( 

455 f" === LongPort API返回 === 数据条数: {len(result)}" 

456 ) 

457 

458 return result 

459 

460 except Exception as e: 

461 if progress_callback: 

462 progress_callback(f" === LongPort API异常 ===") 

463 progress_callback(f" 异常类型: {type(e).__name__}") 

464 progress_callback(f" 异常信息: {str(e)}") 

465 logger.error(f"获取K线数据失败: {e}") 

466 return [] 

467 

468 def _candlesticks_to_dataframe( 

469 self, candlesticks: List, symbol: str, longport_symbol: str 

470 ) -> pd.DataFrame: 

471 """转换K线数据为DataFrame(保持原有逻辑)""" 

472 if not candlesticks: 

473 return pd.DataFrame() 

474 

475 data = [] 

476 for candle in candlesticks: 

477 try: 

478 # 转换时间戳 

479 timestamp = candle.timestamp 

480 if timestamp.tzinfo is None: 

481 # 如果没有时区信息,假设是北京时间 

482 beijing_tz = pytz.timezone("Asia/Shanghai") 

483 timestamp = beijing_tz.localize(timestamp) 

484 

485 # 转换为UTC时间戳 

486 utc_timestamp = timestamp.utctimetuple() 

487 

488 # 确定交易时段 

489 trade_session = self._determine_trade_session(timestamp) 

490 

491 data.append( 

492 { 

493 "symbol": longport_symbol, # 使用转换后的符号,确保数据一致性 

494 "timestamp": timestamp, # 直接使用带时区信息的timestamp 

495 "open": float(candle.open) if candle.open else 0.0, 

496 "high": float(candle.high) if candle.high else 0.0, 

497 "low": float(candle.low) if candle.low else 0.0, 

498 "close": float(candle.close) if candle.close else 0.0, 

499 "volume": int(candle.volume) if candle.volume else 0, 

500 "turnover": float(candle.turnover) if candle.turnover else 0.0, 

501 "trade_session": trade_session, 

502 } 

503 ) 

504 except Exception as e: 

505 logger.warning(f"转换K线数据失败: {e}") 

506 continue 

507 

508 df = pd.DataFrame(data) 

509 if not df.empty: 

510 df = df.sort_values("timestamp").reset_index(drop=True) 

511 

512 return df 

513 

514 def _determine_trade_session(self, timestamp: datetime) -> str: 

515 """确定交易时段(保持原有逻辑)""" 

516 try: 

517 hour = timestamp.hour 

518 

519 if hour < 9 or hour >= 16: 

520 return "PRE_MARKET" # 盘前/盘后 

521 elif hour >= 9 and hour < 12: 

522 return "MORNING" # 上午 

523 elif hour >= 12 and hour < 13: 

524 return "LUNCH" # 午休 

525 else: 

526 return "AFTERNOON" # 下午 

527 

528 except Exception: 

529 return "UNKNOWN" 

530 

531 async def _save_stock_data_to_redis(self, df: pd.DataFrame) -> Dict[str, int]: 

532 """保存股票数据到Redis(保持原有逻辑)""" 

533 try: 

534 redis_client = get_redis() 

535 saved_count = 0 

536 overwritten_count = 0 

537 new_count = 0 

538 

539 for _, row in df.iterrows(): 

540 try: 

541 # 将时间戳转换为UNIX时间戳 

542 timestamp_int = int(row["timestamp"].timestamp()) 

543 

544 # 创建唯一键,使用UNIX时间戳 

545 key = f"stock_data:{row['symbol']}:{timestamp_int}" 

546 

547 # 检查是否已存在 

548 if redis_client.exists(key): 

549 overwritten_count += 1 

550 else: 

551 new_count += 1 

552 

553 # 保存数据 

554 data_dict = row.to_dict() 

555 data_dict["timestamp"] = timestamp_int # 使用UNIX时间戳 

556 redis_client.hset(key, mapping=data_dict) 

557 

558 # 设置过期时间(30天) 

559 redis_client.expire(key, 30 * 24 * 3600) 

560 

561 # 添加到时间索引(用于查询) 

562 time_index_key = f"stock_data:time_index:{row['symbol']}" 

563 # 使用UNIX时间戳作为score 

564 redis_client.zadd(time_index_key, {timestamp_int: timestamp_int}) 

565 

566 # 添加到股票代码集合 

567 redis_client.sadd("stock_codes", row["symbol"]) 

568 

569 saved_count += 1 

570 

571 except Exception as e: 

572 logger.warning(f"保存单条数据失败: {e}") 

573 continue 

574 

575 return { 

576 "saved_count": saved_count, 

577 "overwritten_count": overwritten_count, 

578 "new_count": new_count, 

579 } 

580 

581 except Exception as e: 

582 logger.error(f"保存数据到Redis失败: {e}") 

583 return {"saved_count": 0, "overwritten_count": 0, "new_count": 0} 

584 

585 def _convert_symbol_format(self, symbol: str) -> str: 

586 """转换股票代码格式(保持原有逻辑)""" 

587 # 简化的格式转换,根据实际需求扩展 

588 symbol_upper = symbol.upper() 

589 

590 # 添加常见的后缀 

591 if not any(symbol_upper.endswith(suffix) for suffix in [".US", ".HK", ".CN"]): 

592 # 根据symbol特征添加后缀 

593 if symbol_upper.startswith(("0", "1", "2", "3", "6")): 

594 symbol_upper += ".CN" # 中国股票 

595 elif len(symbol_upper) == 4 and symbol_upper.isdigit(): 

596 symbol_upper += ".HK" # 港股 

597 else: 

598 symbol_upper += ".US" # 美股 

599 

600 return symbol_upper 

601 

602 def get_daily_summary(self, symbol: str, date_range: str) -> Dict[str, Any]: 

603 """获取数据拉取摘要""" 

604 return { 

605 "user_id": self.user_id, 

606 "symbol": symbol, 

607 "date_range": date_range, 

608 "available": self.quote_ctx is not None, 

609 "capabilities": { 

610 "complex_pagination": True, # 保留复杂分页 

611 "morning_afternoon_split": True, # 上午下午分离 

612 "timezone_conversion": True, # 时区转换 

613 "trade_session_detection": True, # 交易时段检测 

614 "data_validation": True, # 数据验证 

615 }, 

616 }