Coverage for core/data_source/adapters/data_source_adapter.py: 49.24%

463 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-13 18:58 +0000

1""" 

2统一数据源适配器 

3提供统一的抽象接口,屏蔽具体券商SDK差异 

4""" 

5 

6from abc import ABC, abstractmethod 

7from datetime import date, datetime 

8from decimal import Decimal 

9from typing import Any, Dict, List, Optional 

10 

11from core.data_source.factories.client_factory import unified_client_factory 

12 

13 

14class DataSourceAdapter(ABC): 

15 """数据源适配器抽象基类""" 

16 

17 def __init__(self, user_id: str): 

18 self.user_id = user_id 

19 self._client = None 

20 

21 def _get_client(self, force_refresh: bool = False) -> Optional[Dict[str, Any]]: 

22 """获取数据源客户端""" 

23 if not self._client or force_refresh: 

24 self._client = unified_client_factory.get_client( 

25 self.user_id, force_refresh 

26 ) 

27 return self._client 

28 

29 @abstractmethod 

30 def is_available(self) -> bool: 

31 """检查适配器是否可用""" 

32 pass 

33 

34 

35class AssetDataSourceAdapter(DataSourceAdapter): 

36 """资产数据源适配器""" 

37 

38 def is_available(self) -> bool: 

39 """检查是否可用""" 

40 return self._get_client() is not None 

41 

42 def get_account_balance(self) -> Optional[Dict[str, Any]]: 

43 """获取账户余额""" 

44 client = self._get_client() 

45 if not client: 

46 return None 

47 

48 try: 

49 trade_ctx = client.get("trade_ctx") 

50 if trade_ctx: 

51 balance_list = trade_ctx.account_balance() 

52 if balance_list: 

53 balance = balance_list[0] 

54 return self._parse_account_balance(balance) 

55 except Exception as e: 

56 print(f"❌ 获取账户余额失败: {e}") 

57 

58 return None 

59 

60 def get_positions(self) -> List[Dict[str, Any]]: 

61 """获取持仓信息""" 

62 client = self._get_client() 

63 if not client: 

64 return [] 

65 

66 try: 

67 trade_ctx = client.get("trade_ctx") 

68 if trade_ctx: 

69 # 获取股票持仓 

70 stock_positions = trade_ctx.stock_positions() 

71 # print(f"🔍 股票持仓原始数据: {stock_positions}") # 已禁用日志输出 

72 

73 # 获取基金持仓 

74 fund_positions = trade_ctx.fund_positions() 

75 print(f"🔍 基金持仓原始数据: {fund_positions}") 

76 

77 # 析持仓 

78 all_positions = [] 

79 

80 # 收集symbols并获取实时价格 

81 symbols = [] 

82 if hasattr(stock_positions, "channels") and stock_positions.channels: 

83 for channel in stock_positions.channels: 

84 if hasattr(channel, "positions") and channel.positions: 

85 for position in channel.positions: 

86 if position.symbol: 

87 symbols.append(position.symbol) 

88 

89 # 获取实时价格 

90 current_prices = {} 

91 if symbols: 

92 try: 

93 quote_ctx = client.get("quote_ctx") 

94 if quote_ctx: 

95 quotes = quote_ctx.quote(symbols) 

96 for quote in quotes: 

97 current_prices[quote.symbol] = quote.last_done 

98 print(f"🔍 获取到实时价格: {len(current_prices)}") 

99 except Exception as e: 

100 print(f"❌ 获取实时价格失败: {e}") 

101 

102 # 处理股票持仓 - StockPositionsResponse对象需要通过channels访问 

103 if hasattr(stock_positions, "channels") and stock_positions.channels: 

104 for channel in stock_positions.channels: 

105 if hasattr(channel, "positions") and channel.positions: 

106 for pos in channel.positions: 

107 parsed_pos = self._parse_position_with_price( 

108 pos, "STOCK", current_prices 

109 ) 

110 if parsed_pos: 

111 all_positions.append(parsed_pos) 

112 

113 # 处理基金持仓 - FundPositionsResponse对象也需要通过channels访问 

114 if hasattr(fund_positions, "channels") and fund_positions.channels: 

115 for channel in fund_positions.channels: 

116 if hasattr(channel, "positions") and channel.positions: 

117 for pos in channel.positions: 

118 parsed_pos = self._parse_position(pos, "FUND") 

119 if parsed_pos: 

120 all_positions.append(parsed_pos) 

121 

122 return all_positions 

123 except Exception as e: 

124 print(f"❌ 获取持仓信息失败: {e}") 

125 

126 return [] 

127 

128 def _parse_account_balance(self, balance) -> Dict[str, Any]: 

129 """解析账户余额""" 

130 # 计算总现金 

131 total_cash = Decimal("0") 

132 cash_details = [] 

133 

134 if balance.cash_infos: 

135 for cash_info in balance.cash_infos: 

136 available_cash = ( 

137 cash_info.available_cash 

138 if cash_info.available_cash is not None 

139 else Decimal("0") 

140 ) 

141 withdraw_cash = ( 

142 cash_info.withdraw_cash 

143 if cash_info.withdraw_cash is not None 

144 else Decimal("0") 

145 ) 

146 frozen_cash = ( 

147 cash_info.frozen_cash 

148 if cash_info.frozen_cash is not None 

149 else Decimal("0") 

150 ) 

151 settling_cash = ( 

152 "0.00" 

153 if cash_info.settling_cash is None 

154 else str(cash_info.settling_cash) 

155 ) 

156 

157 cash_amount = Decimal(str(available_cash)) 

158 total_cash += cash_amount 

159 

160 cash_details.append( 

161 { 

162 "currency": cash_info.currency or "USD", 

163 "available_cash": cash_amount, 

164 "withdraw_cash": Decimal(str(withdraw_cash)), 

165 "frozen_cash": Decimal(str(frozen_cash)), 

166 "settling_cash": Decimal(settling_cash), 

167 } 

168 ) 

169 

170 return { 

171 "total_cash": total_cash, 

172 "net_assets": balance.net_assets or Decimal("0"), 

173 "currency": "USD", # 默认货币 

174 "cash_details": cash_details, 

175 } 

176 

177 def _parse_position(self, position, position_type: str) -> Optional[Dict[str, Any]]: 

178 """解析持仓信息""" 

179 try: 

180 market_value = None 

181 if hasattr(position, "market_value") and position.market_value is not None: 

182 market_value = Decimal(str(position.market_value)) 

183 elif hasattr(position, "quantity") and hasattr(position, "current_price"): 

184 quantity = ( 

185 position.quantity if position.quantity is not None else Decimal("0") 

186 ) 

187 current_price = ( 

188 position.current_price 

189 if position.current_price is not None 

190 else Decimal("0") 

191 ) 

192 market_value = quantity * current_price 

193 

194 cost_price = Decimal("0") 

195 if hasattr(position, "cost_price") and position.cost_price is not None: 

196 cost_price = Decimal(str(position.cost_price)) 

197 elif hasattr(position, "cost_price") and position.cost_price is None: 

198 pass # 保持Decimal("0") 

199 

200 available_quantity = ( 

201 position.available_quantity 

202 if hasattr(position, "available_quantity") 

203 and position.available_quantity is not None 

204 else Decimal("0") 

205 ) 

206 quantity = ( 

207 position.quantity 

208 if hasattr(position, "quantity") and position.quantity is not None 

209 else Decimal("0") 

210 ) 

211 current_price = ( 

212 position.current_price 

213 if hasattr(position, "current_price") 

214 and position.current_price is not None 

215 else Decimal("0") 

216 ) 

217 

218 return { 

219 "symbol": position.symbol, 

220 "symbol_name": self._get_symbol_name(position), 

221 "asset_type": position_type, 

222 "quantity": quantity, 

223 "available_quantity": available_quantity, 

224 "cost_price": cost_price, 

225 "current_price": current_price, 

226 "market_value": market_value, 

227 "market": str(getattr(position, "market", "UNKNOWN")), 

228 "currency": str(getattr(position, "currency", "USD")), 

229 } 

230 except Exception as e: 

231 print(f"❌ 解析持仓失败: {e}") 

232 return None 

233 

234 def _parse_position_with_price( 

235 self, position, position_type: str, current_prices: Dict[str, Any] 

236 ) -> Optional[Dict[str, Any]]: 

237 

238 try: 

239 # 安全处理 None 值 

240 quantity = ( 

241 position.quantity if position.quantity is not None else Decimal("0") 

242 ) 

243 available_quantity = ( 

244 position.available_quantity 

245 if position.available_quantity is not None 

246 else quantity 

247 ) 

248 cost_price = ( 

249 position.cost_price if position.cost_price is not None else Decimal("0") 

250 ) 

251 

252 # 获取实时价格,如果没有则使用成本价 

253 current_price = current_prices.get(position.symbol, cost_price) 

254 market_value = current_price * Decimal(str(quantity)) 

255 

256 # 解析market和currency 

257 market_raw = str(getattr(position, "market", "UNKNOWN")) 

258 currency_raw = str(getattr(position, "currency", "USD")) 

259 

260 return { 

261 "symbol": position.symbol or "", 

262 "symbol_name": position.symbol_name or "", 

263 "asset_type": position_type, 

264 "quantity": Decimal(str(quantity)), 

265 "available_quantity": Decimal(str(available_quantity)), 

266 "cost_price": Decimal(str(cost_price)), 

267 "current_price": current_price, 

268 "market": market_raw, 

269 "currency": currency_raw, 

270 "market_value": market_value, 

271 } 

272 except Exception as e: 

273 print(f"❌ 解析持仓失败: {e}") 

274 return None 

275 

276 def _get_symbol_name(self, position) -> str: 

277 """获取标的名称""" 

278 if hasattr(position, "symbol_name_cn") and position.symbol_name_cn: 

279 return position.symbol_name_cn 

280 elif hasattr(position, "symbol_name_en") and position.symbol_name_en: 

281 return position.symbol_name_en 

282 elif hasattr(position, "symbol_name") and position.symbol_name: 

283 return position.symbol_name 

284 else: 

285 return getattr(position, "symbol", "UNKNOWN") 

286 

287 

288class QuoteDataSourceAdapter(DataSourceAdapter): 

289 """行情数据源适配器""" 

290 

291 def is_available(self) -> bool: 

292 """检查是否可用""" 

293 # 在回测模式下,不初始化长桥客户端 

294 # 可以通过环境变量或配置来判断是否为回测模式 

295 import os 

296 

297 if os.getenv("TRADING_MODE") == "backtest": 

298 return False 

299 return self._get_client() is not None 

300 

301 def get_static_info(self, symbols: List[str]) -> List[Any]: 

302 """获取标的基础信息""" 

303 client = self._get_client() 

304 if not client: 

305 return [] 

306 

307 try: 

308 quote_ctx = client.get("quote_ctx") 

309 if quote_ctx: 

310 return quote_ctx.static_info(symbols) 

311 except Exception as e: 

312 print(f"❌ 获取标的基础信息失败: {e}") 

313 

314 return [] 

315 

316 def get_quote(self, symbols: List[str]) -> List[Any]: 

317 """获取实时行情""" 

318 client = self._get_client() 

319 if not client: 

320 return [] 

321 

322 try: 

323 quote_ctx = client.get("quote_ctx") 

324 if quote_ctx: 

325 return quote_ctx.quote(symbols) 

326 except Exception as e: 

327 print(f"❌ 获取实时行情失败: {e}") 

328 

329 return [] 

330 

331 def get_depth(self, symbol: str) -> Optional[Dict[str, Any]]: 

332 """获取盘口信息""" 

333 client = self._get_client() 

334 if not client: 

335 return None 

336 

337 try: 

338 quote_ctx = client.get("quote_ctx") 

339 if quote_ctx: 

340 depth = quote_ctx.depth(symbol) 

341 return self._serialize_depth(depth) 

342 except Exception as e: 

343 print(f"❌ 获取盘口信息失败: {e}") 

344 

345 return None 

346 

347 def get_trades(self, symbol: str, count: int) -> List[Any]: 

348 """获取成交明细""" 

349 client = self._get_client() 

350 if not client: 

351 return [] 

352 

353 try: 

354 quote_ctx = client.get("quote_ctx") 

355 if quote_ctx: 

356 return quote_ctx.trades(symbol, count) 

357 except Exception as e: 

358 print(f"❌ 获取成交明细失败: {e}") 

359 

360 return [] 

361 

362 def get_candlesticks( 

363 self, 

364 symbol: str, 

365 period: str, 

366 count: int, 

367 adjust_type: str, 

368 trade_sessions: str = "Intraday", 

369 ) -> List[Any]: 

370 """获取K线数据""" 

371 client = self._get_client() 

372 if not client: 

373 return [] 

374 

375 try: 

376 quote_ctx = client.get("quote_ctx") 

377 if quote_ctx: 

378 # 转换枚举值 

379 period_enum = self._get_period_enum(period) 

380 adjust_enum = self._get_adjust_type_enum(adjust_type) 

381 sessions_enum = self._get_trade_sessions_enum(trade_sessions) 

382 

383 result = quote_ctx.candlesticks( 

384 symbol, period_enum, count, adjust_enum, sessions_enum 

385 ) 

386 

387 # 检查是否是协程对象 

388 if hasattr(result, "__await__"): 

389 print(f"【数据适配器】⚠️ K线数据获取返回协程对象,跳过: {symbol}") 

390 return [] 

391 

392 return result 

393 except Exception as e: 

394 print(f"❌ 获取K线数据失败: {e}") 

395 

396 return [] 

397 

398 def get_historical_data( 

399 self, symbol: str, start_time: datetime, end_time: datetime, log_callback=None 

400 ) -> List: 

401 """获取历史数据 - 从数据库获取分钟级数据,不初始化长桥客户端""" 

402 try: 

403 from decimal import Decimal 

404 

405 from core.models.trading import MarketData 

406 from core.repositories.stock_repository import StockRepository 

407 from infrastructure.database.redis_client import get_redis 

408 

409 log_message = f"📊 从数据库获取分钟级历史数据: {symbol} from {start_time} to {end_time}" 

410 if log_callback: 

411 log_callback(log_message, "info") 

412 # 移除fallback的print语句,确保所有日志都通过统一日志管理 

413 

414 # 直接从数据库获取历史数据,不初始化长桥客户端 

415 redis_client = get_redis() 

416 stock_repo = StockRepository(redis_client) 

417 

418 # 一次性获取时间范围内的所有分钟级数据 

419 log_message = ( 

420 f"📊 一次性获取时间范围内的所有分钟级数据: {start_time}{end_time}" 

421 ) 

422 if log_callback: 

423 log_callback(log_message, "info") 

424 

425 # 计算时间戳范围 

426 start_timestamp = int(start_time.timestamp()) 

427 end_timestamp = int(end_time.timestamp()) 

428 

429 # 直接从Redis获取时间范围内的所有数据 

430 time_index_key = f"{stock_repo.stock_data_prefix}time_index:{symbol}" 

431 timestamps = stock_repo.redis.zrangebyscore( 

432 time_index_key, start_timestamp, end_timestamp, withscores=True 

433 ) 

434 

435 historical_data = [] 

436 for timestamp_score, _ in timestamps: 

437 stock_key = stock_repo._get_stock_key(symbol, int(timestamp_score)) 

438 data = stock_repo.redis.hgetall(stock_key) 

439 if data: 

440 # 转换数据格式 

441 converted_data = { 

442 "code": data.get("symbol", symbol), 

443 "open": float(data.get("open", 0)), 

444 "high": float(data.get("high", 0)), 

445 "low": float(data.get("low", 0)), 

446 "close": float(data.get("close", 0)), 

447 "volume": int(data.get("volume", 0)), 

448 "timestamp": int(timestamp_score), 

449 } 

450 historical_data.append(converted_data) 

451 

452 log_message = f"📊 一次性获取到 {len(historical_data)} 条分钟级数据" 

453 if log_callback: 

454 log_callback(log_message, "info") 

455 

456 if not historical_data: 

457 print(f"【数据适配器】⚠️ 数据库中没有找到 {symbol} 的历史数据") 

458 return [] 

459 

460 # 转换为MarketData格式 

461 market_data_list = [] 

462 for data in historical_data: 

463 # 将时间戳转换为datetime对象 

464 if isinstance(data.timestamp, int): 

465 data_timestamp = datetime.fromtimestamp(data.timestamp) 

466 else: 

467 data_timestamp = data.timestamp 

468 

469 # 检查时间范围 

470 if start_time <= data_timestamp <= end_time: 

471 market_data = MarketData( 

472 symbol=symbol, 

473 timestamp=data_timestamp, 

474 open=Decimal(str(data.open)), 

475 high=Decimal(str(data.high)), 

476 low=Decimal(str(data.low)), 

477 close=Decimal(str(data.close)), 

478 volume=int(data.volume), 

479 ) 

480 market_data_list.append(market_data) 

481 

482 # 按时间排序 

483 market_data_list.sort(key=lambda x: x.timestamp) 

484 

485 log_message = f"✅ 从数据库获取到 {len(market_data_list)} 条分钟级历史数据" 

486 if log_callback: 

487 log_callback(log_message, "info") 

488 # 移除fallback的print语句,确保所有日志都通过统一日志管理 

489 return market_data_list 

490 

491 except Exception as e: 

492 log_message = f"❌ 获取历史数据失败: {symbol} - {e}" 

493 if log_callback: 

494 log_callback(log_message, "error") 

495 # 移除fallback的print语句,确保所有日志都通过统一日志管理 

496 return [] 

497 

498 def get_trading_days( 

499 self, market: str, begin: date, end: date 

500 ) -> Optional[Dict[str, Any]]: 

501 """获取交易日期""" 

502 client = self._get_client() 

503 if not client: 

504 return None 

505 

506 try: 

507 quote_ctx = client.get("quote_ctx") 

508 if quote_ctx: 

509 market_enum = self._get_market_enum(market) 

510 return quote_ctx.trading_days(market_enum, begin, end) 

511 except Exception as e: 

512 print(f"❌ 获取交易日期失败: {e}") 

513 

514 return None 

515 

516 def get_trading_session(self) -> List[Any]: 

517 """获取交易时段""" 

518 client = self._get_client() 

519 if not client: 

520 return [] 

521 

522 try: 

523 quote_ctx = client.get("quote_ctx") 

524 if quote_ctx: 

525 return quote_ctx.trading_session() 

526 except Exception as e: 

527 print(f"❌ 获取交易时段失败: {e}") 

528 

529 return [] 

530 

531 def get_calc_indexes(self, symbols: List[str], indexes: List[str]) -> List[Any]: 

532 """获取计算指标""" 

533 client = self._get_client() 

534 if not client: 

535 return [] 

536 

537 try: 

538 quote_ctx = client.get("quote_ctx") 

539 if quote_ctx: 

540 # 转换指标枚举 

541 calc_indexes = [] 

542 for idx in indexes: 

543 try: 

544 calc_idx = self._get_calc_index_enum(idx) 

545 calc_indexes.append(calc_idx) 

546 except Exception as e: 

547 print(f"❌ 转换指标失败: {idx} - {e}") 

548 # 跳过无效的指标 

549 continue 

550 

551 if not calc_indexes: 

552 print("❌ 没有有效的计算指标可查询") 

553 return [] 

554 

555 result = quote_ctx.calc_indexes(symbols, calc_indexes) 

556 return result 

557 except Exception as e: 

558 print(f"❌ 获取计算指标失败: {e}") 

559 import traceback 

560 

561 traceback.print_exc() 

562 

563 return [] 

564 

565 def subscribe( 

566 self, symbols: List[str], sub_types: List[str], is_first_push: bool = False 

567 ) -> bool: 

568 """订阅行情数据""" 

569 client = self._get_client() 

570 if not client: 

571 return False 

572 

573 try: 

574 quote_ctx = client.get("quote_ctx") 

575 if quote_ctx: 

576 sub_type_enums = [self._get_sub_type_enum(st) for st in sub_types] 

577 quote_ctx.subscribe(symbols, sub_type_enums, is_first_push) 

578 return True 

579 except Exception as e: 

580 print(f"❌ 订阅行情数据失败: {e}") 

581 

582 return False 

583 

584 def unsubscribe(self, symbols: List[str], sub_types: List[str]) -> bool: 

585 """取消订阅""" 

586 client = self._get_client() 

587 if not client: 

588 return False 

589 

590 try: 

591 quote_ctx = client.get("quote_ctx") 

592 if quote_ctx: 

593 sub_type_enums = [self._get_sub_type_enum(st) for st in sub_types] 

594 quote_ctx.unsubscribe(symbols, sub_type_enums) 

595 return True 

596 except Exception as e: 

597 print(f"❌ 取消订阅失败: {e}") 

598 

599 return False 

600 

601 def get_realtime_quote(self, symbols: List[str]) -> List[Any]: 

602 """获取实时报价""" 

603 client = self._get_client() 

604 if not client: 

605 print(f"❌ 获取实时报价失败: 客户端不可用") 

606 return [] 

607 

608 try: 

609 quote_ctx = client.get("quote_ctx") 

610 if quote_ctx: 

611 print(f"🔍 获取实时报价: {symbols}") 

612 # LongPort的realtime_quote可能需要回调,尝试使用quote方法 

613 try: 

614 result = quote_ctx.quote(symbols) 

615 print( 

616 f"🔍 使用quote方法获取结果: {len(result) if result else 0}" 

617 ) 

618 return result 

619 except Exception as quote_error: 

620 print(f"❌ quote方法失败: {quote_error}") 

621 # 尝试使用realtime_quote 

622 try: 

623 result = quote_ctx.realtime_quote(symbols) 

624 print( 

625 f"🔍 使用realtime_quote方法获取结果: {len(result) if result else 0}" 

626 ) 

627 return result 

628 except Exception as realtime_error: 

629 print(f"❌ realtime_quote方法失败: {realtime_error}") 

630 return [] 

631 else: 

632 print(f"❌ 获取实时报价失败: quote_ctx不可用") 

633 except Exception as e: 

634 print(f"❌ 获取实时报价失败: {e}") 

635 

636 return [] 

637 

638 def get_realtime_depth(self, symbol: str) -> Optional[Dict[str, Any]]: 

639 """获取实时盘口""" 

640 client = self._get_client() 

641 if not client: 

642 return None 

643 

644 try: 

645 quote_ctx = client.get("quote_ctx") 

646 if quote_ctx: 

647 depth = quote_ctx.realtime_depth(symbol) 

648 return self._serialize_depth(depth) 

649 except Exception as e: 

650 print(f"❌ 获取实时盘口失败: {e}") 

651 

652 return None 

653 

654 def get_realtime_trades(self, symbol: str, count: int) -> List[Any]: 

655 """获取实时成交明细""" 

656 client = self._get_client() 

657 if not client: 

658 return [] 

659 

660 try: 

661 quote_ctx = client.get("quote_ctx") 

662 if quote_ctx: 

663 # 实时成交明细需要设置回调函数,这里返回历史成交作为替代 

664 # 或者使用 trades 方法获取最近的成交记录 

665 result = quote_ctx.trades(symbol, count) 

666 

667 if not result: 

668 print(f"⚠️ trades 返回空结果,{symbol} 可能没有成交数据") 

669 

670 return result 

671 except Exception as e: 

672 print(f"❌ 获取实时成交明细失败: {e}") 

673 import traceback 

674 

675 traceback.print_exc() 

676 

677 return [] 

678 

679 def get_subscription_summary(self) -> Dict[str, Any]: 

680 """获取订阅概览""" 

681 client = self._get_client() 

682 if not client: 

683 return {} 

684 

685 try: 

686 quote_ctx = client.get("quote_ctx") 

687 if quote_ctx: 

688 subscriptions = quote_ctx.subscriptions() 

689 return { 

690 "subscriptions": subscriptions, 

691 "count": len(subscriptions) if subscriptions else 0, 

692 } 

693 except Exception as e: 

694 print(f"❌ 获取订阅概览失败: {e}") 

695 

696 return {} 

697 

698 # 枚举转换辅助方法 

699 def _get_period_enum(self, period: str): 

700 from longport.openapi import Period 

701 

702 # 映射字符串到Period枚举 

703 period_mapping = { 

704 "1m": "Min_1", 

705 "5m": "Min_5", 

706 "15m": "Min_15", 

707 "30m": "Min_30", 

708 "1h": "Min_60", 

709 "1d": "Day_1", 

710 "1w": "Week_1", 

711 "1M": "Month_1", 

712 } 

713 

714 # 如果传入的是映射后的值,直接使用 

715 if hasattr(Period, period): 

716 return getattr(Period, period) 

717 

718 # 如果传入的是字符串格式,先映射 

719 if period in period_mapping: 

720 enum_name = period_mapping[period] 

721 return getattr(Period, enum_name) 

722 

723 # 默认返回Min_1 

724 return Period.Min_1 

725 

726 def _get_adjust_type_enum(self, adjust_type: str): 

727 from longport.openapi import AdjustType 

728 

729 return getattr(AdjustType, adjust_type) 

730 

731 def _get_trade_sessions_enum(self, trade_sessions: str): 

732 from longport.openapi import TradeSessions 

733 

734 return getattr(TradeSessions, trade_sessions) 

735 

736 def _get_market_enum(self, market: str): 

737 from longport.openapi import Market 

738 

739 return getattr(Market, market) 

740 

741 def _get_calc_index_enum(self, calc_index: str): 

742 from longport.openapi import CalcIndex 

743 

744 # 尝试直接匹配 

745 if hasattr(CalcIndex, calc_index): 

746 return getattr(CalcIndex, calc_index) 

747 

748 # 尝试匹配常用指标名称 

749 name_mapping = { 

750 "LastDone": "LastDone", 

751 "ChangeRate": "ChangeRate", 

752 "ChangeValue": "ChangeVal", 

753 "ChangeVal": "ChangeVal", 

754 "Volume": "Volume", 

755 "Turnover": "Turnover", 

756 "PeTTMRatio": "PeTtmRatio", 

757 "PbRatio": "PbRatio", 

758 "DividendRatioTTM": "DividendRatioTtm", 

759 "Amplitude": "Amplitude", 

760 "VolumeRatio": "VolumeRatio", 

761 "FiveDayChangeRate": "FiveDayChangeRate", 

762 "TenDayChangeRate": "TenDayChangeRate", 

763 "YtdChangeRate": "YtdChangeRate", 

764 "TurnoverRate": "TurnoverRate", 

765 "TotalMarketValue": "TotalMarketValue", 

766 "CapitalFlow": "CapitalFlow", 

767 } 

768 

769 if calc_index in name_mapping: 

770 enum_name = name_mapping[calc_index] 

771 if hasattr(CalcIndex, enum_name): 

772 return getattr(CalcIndex, enum_name) 

773 

774 # 如果都找不到,抛出更明确的错误 

775 available_enums = [attr for attr in dir(CalcIndex) if not attr.startswith("_")] 

776 raise AttributeError( 

777 f"'{calc_index}' not found in CalcIndex. Available enums: {available_enums}" 

778 ) 

779 

780 def _get_sub_type_enum(self, sub_type: str): 

781 from longport.openapi import SubType 

782 

783 # 处理常见的命名差异 

784 sub_type_mapping = { 

785 "Trades": "Trade", # Trades -> Trade 

786 "Trade": "Trade", 

787 "Quote": "Quote", 

788 "Depth": "Depth", 

789 "Brokers": "Brokers", 

790 } 

791 

792 mapped_type = sub_type_mapping.get(sub_type, sub_type) 

793 return getattr(SubType, mapped_type) 

794 

795 def _serialize_depth(self, depth) -> Optional[Dict[str, Any]]: 

796 """序列化盘口数据""" 

797 if not depth: 

798 return None 

799 

800 try: 

801 return { 

802 "symbol": getattr(depth, "symbol", ""), 

803 "asks": getattr(depth, "asks", []), 

804 "bids": getattr(depth, "bids", []), 

805 "timestamp": getattr(depth, "timestamp", None), 

806 } 

807 except Exception as e: 

808 print(f"❌ 序列化盘口数据失败: {e}") 

809 return None 

810 

811 

812class DataImportDataSourceAdapter(DataSourceAdapter): 

813 """数据导入数据源适配器""" 

814 

815 def __init__(self, user_id: str): 

816 super().__init__(user_id) 

817 self._quote_adapter = QuoteDataSourceAdapter(user_id) 

818 

819 def is_available(self) -> bool: 

820 """检查是否可用""" 

821 return self._quote_adapter.is_available() 

822 

823 def fetch_stock_data( 

824 self, symbol: str, start_date: date, end_date: date 

825 ) -> Dict[str, Any]: 

826 """获取股票历史数据""" 

827 if not self.is_available(): 

828 return {"success": False, "error": "数据源不可用", "data_count": 0} 

829 

830 try: 

831 # 转换为长桥标准格式 

832 longport_symbol = self._convert_symbol_format(symbol) 

833 

834 # 获取K线数据(使用1分钟间隔) 

835 candlesticks = self._quote_adapter.get_candlesticks( 

836 longport_symbol, 

837 Period.Min_1, 

838 1000, # 最大1000条 

839 AdjustType.NoAdjust, 

840 TradeSessions.Intraday, 

841 ) 

842 

843 return { 

844 "success": True, 

845 "data": candlesticks, 

846 "data_count": len(candlesticks), 

847 "symbol": longport_symbol, 

848 } 

849 

850 except Exception as e: 

851 return { 

852 "success": False, 

853 "error": f"获取股票数据失败: {e}", 

854 "data_count": 0, 

855 } 

856 

857 def _convert_symbol_format(self, symbol: str) -> str: 

858 """转换股票代码格式""" 

859 # TODO: 实现不同券商的代码格式转换 

860 return symbol 

861 

862 

863# 工厂方法 

864def create_data_source_adapter(adapter_type: str, user_id: str): 

865 """创建数据源适配器""" 

866 if adapter_type == "asset": 

867 return AssetDataSourceAdapter(user_id) 

868 elif adapter_type == "quote": 

869 return QuoteDataSourceAdapter(user_id) 

870 elif adapter_type == "data_import": 

871 return DataImportDataSourceAdapter(user_id) 

872 elif adapter_type == "trade": 

873 from core.data_source.adapters.trade_adapter import \ 

874 TradeDataSourceAdapter 

875 

876 return TradeDataSourceAdapter(user_id) 

877 else: 

878 raise ValueError(f"不支持的适配器类型: {adapter_type}")