Coverage for core/data_source/adapters/data_source_adapter.py: 49.24%
463 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-13 18:58 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-13 18:58 +0000
1"""
2统一数据源适配器
3提供统一的抽象接口,屏蔽具体券商SDK差异
4"""
6from abc import ABC, abstractmethod
7from datetime import date, datetime
8from decimal import Decimal
9from typing import Any, Dict, List, Optional
11from core.data_source.factories.client_factory import unified_client_factory
14class DataSourceAdapter(ABC):
15 """数据源适配器抽象基类"""
17 def __init__(self, user_id: str):
18 self.user_id = user_id
19 self._client = None
21 def _get_client(self, force_refresh: bool = False) -> Optional[Dict[str, Any]]:
22 """获取数据源客户端"""
23 if not self._client or force_refresh:
24 self._client = unified_client_factory.get_client(
25 self.user_id, force_refresh
26 )
27 return self._client
29 @abstractmethod
30 def is_available(self) -> bool:
31 """检查适配器是否可用"""
32 pass
35class AssetDataSourceAdapter(DataSourceAdapter):
36 """资产数据源适配器"""
38 def is_available(self) -> bool:
39 """检查是否可用"""
40 return self._get_client() is not None
42 def get_account_balance(self) -> Optional[Dict[str, Any]]:
43 """获取账户余额"""
44 client = self._get_client()
45 if not client:
46 return None
48 try:
49 trade_ctx = client.get("trade_ctx")
50 if trade_ctx:
51 balance_list = trade_ctx.account_balance()
52 if balance_list:
53 balance = balance_list[0]
54 return self._parse_account_balance(balance)
55 except Exception as e:
56 print(f"❌ 获取账户余额失败: {e}")
58 return None
60 def get_positions(self) -> List[Dict[str, Any]]:
61 """获取持仓信息"""
62 client = self._get_client()
63 if not client:
64 return []
66 try:
67 trade_ctx = client.get("trade_ctx")
68 if trade_ctx:
69 # 获取股票持仓
70 stock_positions = trade_ctx.stock_positions()
71 # print(f"🔍 股票持仓原始数据: {stock_positions}") # 已禁用日志输出
73 # 获取基金持仓
74 fund_positions = trade_ctx.fund_positions()
75 print(f"🔍 基金持仓原始数据: {fund_positions}")
77 # 析持仓
78 all_positions = []
80 # 收集symbols并获取实时价格
81 symbols = []
82 if hasattr(stock_positions, "channels") and stock_positions.channels:
83 for channel in stock_positions.channels:
84 if hasattr(channel, "positions") and channel.positions:
85 for position in channel.positions:
86 if position.symbol:
87 symbols.append(position.symbol)
89 # 获取实时价格
90 current_prices = {}
91 if symbols:
92 try:
93 quote_ctx = client.get("quote_ctx")
94 if quote_ctx:
95 quotes = quote_ctx.quote(symbols)
96 for quote in quotes:
97 current_prices[quote.symbol] = quote.last_done
98 print(f"🔍 获取到实时价格: {len(current_prices)} 个")
99 except Exception as e:
100 print(f"❌ 获取实时价格失败: {e}")
102 # 处理股票持仓 - StockPositionsResponse对象需要通过channels访问
103 if hasattr(stock_positions, "channels") and stock_positions.channels:
104 for channel in stock_positions.channels:
105 if hasattr(channel, "positions") and channel.positions:
106 for pos in channel.positions:
107 parsed_pos = self._parse_position_with_price(
108 pos, "STOCK", current_prices
109 )
110 if parsed_pos:
111 all_positions.append(parsed_pos)
113 # 处理基金持仓 - FundPositionsResponse对象也需要通过channels访问
114 if hasattr(fund_positions, "channels") and fund_positions.channels:
115 for channel in fund_positions.channels:
116 if hasattr(channel, "positions") and channel.positions:
117 for pos in channel.positions:
118 parsed_pos = self._parse_position(pos, "FUND")
119 if parsed_pos:
120 all_positions.append(parsed_pos)
122 return all_positions
123 except Exception as e:
124 print(f"❌ 获取持仓信息失败: {e}")
126 return []
128 def _parse_account_balance(self, balance) -> Dict[str, Any]:
129 """解析账户余额"""
130 # 计算总现金
131 total_cash = Decimal("0")
132 cash_details = []
134 if balance.cash_infos:
135 for cash_info in balance.cash_infos:
136 available_cash = (
137 cash_info.available_cash
138 if cash_info.available_cash is not None
139 else Decimal("0")
140 )
141 withdraw_cash = (
142 cash_info.withdraw_cash
143 if cash_info.withdraw_cash is not None
144 else Decimal("0")
145 )
146 frozen_cash = (
147 cash_info.frozen_cash
148 if cash_info.frozen_cash is not None
149 else Decimal("0")
150 )
151 settling_cash = (
152 "0.00"
153 if cash_info.settling_cash is None
154 else str(cash_info.settling_cash)
155 )
157 cash_amount = Decimal(str(available_cash))
158 total_cash += cash_amount
160 cash_details.append(
161 {
162 "currency": cash_info.currency or "USD",
163 "available_cash": cash_amount,
164 "withdraw_cash": Decimal(str(withdraw_cash)),
165 "frozen_cash": Decimal(str(frozen_cash)),
166 "settling_cash": Decimal(settling_cash),
167 }
168 )
170 return {
171 "total_cash": total_cash,
172 "net_assets": balance.net_assets or Decimal("0"),
173 "currency": "USD", # 默认货币
174 "cash_details": cash_details,
175 }
177 def _parse_position(self, position, position_type: str) -> Optional[Dict[str, Any]]:
178 """解析持仓信息"""
179 try:
180 market_value = None
181 if hasattr(position, "market_value") and position.market_value is not None:
182 market_value = Decimal(str(position.market_value))
183 elif hasattr(position, "quantity") and hasattr(position, "current_price"):
184 quantity = (
185 position.quantity if position.quantity is not None else Decimal("0")
186 )
187 current_price = (
188 position.current_price
189 if position.current_price is not None
190 else Decimal("0")
191 )
192 market_value = quantity * current_price
194 cost_price = Decimal("0")
195 if hasattr(position, "cost_price") and position.cost_price is not None:
196 cost_price = Decimal(str(position.cost_price))
197 elif hasattr(position, "cost_price") and position.cost_price is None:
198 pass # 保持Decimal("0")
200 available_quantity = (
201 position.available_quantity
202 if hasattr(position, "available_quantity")
203 and position.available_quantity is not None
204 else Decimal("0")
205 )
206 quantity = (
207 position.quantity
208 if hasattr(position, "quantity") and position.quantity is not None
209 else Decimal("0")
210 )
211 current_price = (
212 position.current_price
213 if hasattr(position, "current_price")
214 and position.current_price is not None
215 else Decimal("0")
216 )
218 return {
219 "symbol": position.symbol,
220 "symbol_name": self._get_symbol_name(position),
221 "asset_type": position_type,
222 "quantity": quantity,
223 "available_quantity": available_quantity,
224 "cost_price": cost_price,
225 "current_price": current_price,
226 "market_value": market_value,
227 "market": str(getattr(position, "market", "UNKNOWN")),
228 "currency": str(getattr(position, "currency", "USD")),
229 }
230 except Exception as e:
231 print(f"❌ 解析持仓失败: {e}")
232 return None
234 def _parse_position_with_price(
235 self, position, position_type: str, current_prices: Dict[str, Any]
236 ) -> Optional[Dict[str, Any]]:
238 try:
239 # 安全处理 None 值
240 quantity = (
241 position.quantity if position.quantity is not None else Decimal("0")
242 )
243 available_quantity = (
244 position.available_quantity
245 if position.available_quantity is not None
246 else quantity
247 )
248 cost_price = (
249 position.cost_price if position.cost_price is not None else Decimal("0")
250 )
252 # 获取实时价格,如果没有则使用成本价
253 current_price = current_prices.get(position.symbol, cost_price)
254 market_value = current_price * Decimal(str(quantity))
256 # 解析market和currency
257 market_raw = str(getattr(position, "market", "UNKNOWN"))
258 currency_raw = str(getattr(position, "currency", "USD"))
260 return {
261 "symbol": position.symbol or "",
262 "symbol_name": position.symbol_name or "",
263 "asset_type": position_type,
264 "quantity": Decimal(str(quantity)),
265 "available_quantity": Decimal(str(available_quantity)),
266 "cost_price": Decimal(str(cost_price)),
267 "current_price": current_price,
268 "market": market_raw,
269 "currency": currency_raw,
270 "market_value": market_value,
271 }
272 except Exception as e:
273 print(f"❌ 解析持仓失败: {e}")
274 return None
276 def _get_symbol_name(self, position) -> str:
277 """获取标的名称"""
278 if hasattr(position, "symbol_name_cn") and position.symbol_name_cn:
279 return position.symbol_name_cn
280 elif hasattr(position, "symbol_name_en") and position.symbol_name_en:
281 return position.symbol_name_en
282 elif hasattr(position, "symbol_name") and position.symbol_name:
283 return position.symbol_name
284 else:
285 return getattr(position, "symbol", "UNKNOWN")
288class QuoteDataSourceAdapter(DataSourceAdapter):
289 """行情数据源适配器"""
291 def is_available(self) -> bool:
292 """检查是否可用"""
293 # 在回测模式下,不初始化长桥客户端
294 # 可以通过环境变量或配置来判断是否为回测模式
295 import os
297 if os.getenv("TRADING_MODE") == "backtest":
298 return False
299 return self._get_client() is not None
301 def get_static_info(self, symbols: List[str]) -> List[Any]:
302 """获取标的基础信息"""
303 client = self._get_client()
304 if not client:
305 return []
307 try:
308 quote_ctx = client.get("quote_ctx")
309 if quote_ctx:
310 return quote_ctx.static_info(symbols)
311 except Exception as e:
312 print(f"❌ 获取标的基础信息失败: {e}")
314 return []
316 def get_quote(self, symbols: List[str]) -> List[Any]:
317 """获取实时行情"""
318 client = self._get_client()
319 if not client:
320 return []
322 try:
323 quote_ctx = client.get("quote_ctx")
324 if quote_ctx:
325 return quote_ctx.quote(symbols)
326 except Exception as e:
327 print(f"❌ 获取实时行情失败: {e}")
329 return []
331 def get_depth(self, symbol: str) -> Optional[Dict[str, Any]]:
332 """获取盘口信息"""
333 client = self._get_client()
334 if not client:
335 return None
337 try:
338 quote_ctx = client.get("quote_ctx")
339 if quote_ctx:
340 depth = quote_ctx.depth(symbol)
341 return self._serialize_depth(depth)
342 except Exception as e:
343 print(f"❌ 获取盘口信息失败: {e}")
345 return None
347 def get_trades(self, symbol: str, count: int) -> List[Any]:
348 """获取成交明细"""
349 client = self._get_client()
350 if not client:
351 return []
353 try:
354 quote_ctx = client.get("quote_ctx")
355 if quote_ctx:
356 return quote_ctx.trades(symbol, count)
357 except Exception as e:
358 print(f"❌ 获取成交明细失败: {e}")
360 return []
362 def get_candlesticks(
363 self,
364 symbol: str,
365 period: str,
366 count: int,
367 adjust_type: str,
368 trade_sessions: str = "Intraday",
369 ) -> List[Any]:
370 """获取K线数据"""
371 client = self._get_client()
372 if not client:
373 return []
375 try:
376 quote_ctx = client.get("quote_ctx")
377 if quote_ctx:
378 # 转换枚举值
379 period_enum = self._get_period_enum(period)
380 adjust_enum = self._get_adjust_type_enum(adjust_type)
381 sessions_enum = self._get_trade_sessions_enum(trade_sessions)
383 result = quote_ctx.candlesticks(
384 symbol, period_enum, count, adjust_enum, sessions_enum
385 )
387 # 检查是否是协程对象
388 if hasattr(result, "__await__"):
389 print(f"【数据适配器】⚠️ K线数据获取返回协程对象,跳过: {symbol}")
390 return []
392 return result
393 except Exception as e:
394 print(f"❌ 获取K线数据失败: {e}")
396 return []
398 def get_historical_data(
399 self, symbol: str, start_time: datetime, end_time: datetime, log_callback=None
400 ) -> List:
401 """获取历史数据 - 从数据库获取分钟级数据,不初始化长桥客户端"""
402 try:
403 from decimal import Decimal
405 from core.models.trading import MarketData
406 from core.repositories.stock_repository import StockRepository
407 from infrastructure.database.redis_client import get_redis
409 log_message = f"📊 从数据库获取分钟级历史数据: {symbol} from {start_time} to {end_time}"
410 if log_callback:
411 log_callback(log_message, "info")
412 # 移除fallback的print语句,确保所有日志都通过统一日志管理
414 # 直接从数据库获取历史数据,不初始化长桥客户端
415 redis_client = get_redis()
416 stock_repo = StockRepository(redis_client)
418 # 一次性获取时间范围内的所有分钟级数据
419 log_message = (
420 f"📊 一次性获取时间范围内的所有分钟级数据: {start_time} 到 {end_time}"
421 )
422 if log_callback:
423 log_callback(log_message, "info")
425 # 计算时间戳范围
426 start_timestamp = int(start_time.timestamp())
427 end_timestamp = int(end_time.timestamp())
429 # 直接从Redis获取时间范围内的所有数据
430 time_index_key = f"{stock_repo.stock_data_prefix}time_index:{symbol}"
431 timestamps = stock_repo.redis.zrangebyscore(
432 time_index_key, start_timestamp, end_timestamp, withscores=True
433 )
435 historical_data = []
436 for timestamp_score, _ in timestamps:
437 stock_key = stock_repo._get_stock_key(symbol, int(timestamp_score))
438 data = stock_repo.redis.hgetall(stock_key)
439 if data:
440 # 转换数据格式
441 converted_data = {
442 "code": data.get("symbol", symbol),
443 "open": float(data.get("open", 0)),
444 "high": float(data.get("high", 0)),
445 "low": float(data.get("low", 0)),
446 "close": float(data.get("close", 0)),
447 "volume": int(data.get("volume", 0)),
448 "timestamp": int(timestamp_score),
449 }
450 historical_data.append(converted_data)
452 log_message = f"📊 一次性获取到 {len(historical_data)} 条分钟级数据"
453 if log_callback:
454 log_callback(log_message, "info")
456 if not historical_data:
457 print(f"【数据适配器】⚠️ 数据库中没有找到 {symbol} 的历史数据")
458 return []
460 # 转换为MarketData格式
461 market_data_list = []
462 for data in historical_data:
463 # 将时间戳转换为datetime对象
464 if isinstance(data.timestamp, int):
465 data_timestamp = datetime.fromtimestamp(data.timestamp)
466 else:
467 data_timestamp = data.timestamp
469 # 检查时间范围
470 if start_time <= data_timestamp <= end_time:
471 market_data = MarketData(
472 symbol=symbol,
473 timestamp=data_timestamp,
474 open=Decimal(str(data.open)),
475 high=Decimal(str(data.high)),
476 low=Decimal(str(data.low)),
477 close=Decimal(str(data.close)),
478 volume=int(data.volume),
479 )
480 market_data_list.append(market_data)
482 # 按时间排序
483 market_data_list.sort(key=lambda x: x.timestamp)
485 log_message = f"✅ 从数据库获取到 {len(market_data_list)} 条分钟级历史数据"
486 if log_callback:
487 log_callback(log_message, "info")
488 # 移除fallback的print语句,确保所有日志都通过统一日志管理
489 return market_data_list
491 except Exception as e:
492 log_message = f"❌ 获取历史数据失败: {symbol} - {e}"
493 if log_callback:
494 log_callback(log_message, "error")
495 # 移除fallback的print语句,确保所有日志都通过统一日志管理
496 return []
498 def get_trading_days(
499 self, market: str, begin: date, end: date
500 ) -> Optional[Dict[str, Any]]:
501 """获取交易日期"""
502 client = self._get_client()
503 if not client:
504 return None
506 try:
507 quote_ctx = client.get("quote_ctx")
508 if quote_ctx:
509 market_enum = self._get_market_enum(market)
510 return quote_ctx.trading_days(market_enum, begin, end)
511 except Exception as e:
512 print(f"❌ 获取交易日期失败: {e}")
514 return None
516 def get_trading_session(self) -> List[Any]:
517 """获取交易时段"""
518 client = self._get_client()
519 if not client:
520 return []
522 try:
523 quote_ctx = client.get("quote_ctx")
524 if quote_ctx:
525 return quote_ctx.trading_session()
526 except Exception as e:
527 print(f"❌ 获取交易时段失败: {e}")
529 return []
531 def get_calc_indexes(self, symbols: List[str], indexes: List[str]) -> List[Any]:
532 """获取计算指标"""
533 client = self._get_client()
534 if not client:
535 return []
537 try:
538 quote_ctx = client.get("quote_ctx")
539 if quote_ctx:
540 # 转换指标枚举
541 calc_indexes = []
542 for idx in indexes:
543 try:
544 calc_idx = self._get_calc_index_enum(idx)
545 calc_indexes.append(calc_idx)
546 except Exception as e:
547 print(f"❌ 转换指标失败: {idx} - {e}")
548 # 跳过无效的指标
549 continue
551 if not calc_indexes:
552 print("❌ 没有有效的计算指标可查询")
553 return []
555 result = quote_ctx.calc_indexes(symbols, calc_indexes)
556 return result
557 except Exception as e:
558 print(f"❌ 获取计算指标失败: {e}")
559 import traceback
561 traceback.print_exc()
563 return []
565 def subscribe(
566 self, symbols: List[str], sub_types: List[str], is_first_push: bool = False
567 ) -> bool:
568 """订阅行情数据"""
569 client = self._get_client()
570 if not client:
571 return False
573 try:
574 quote_ctx = client.get("quote_ctx")
575 if quote_ctx:
576 sub_type_enums = [self._get_sub_type_enum(st) for st in sub_types]
577 quote_ctx.subscribe(symbols, sub_type_enums, is_first_push)
578 return True
579 except Exception as e:
580 print(f"❌ 订阅行情数据失败: {e}")
582 return False
584 def unsubscribe(self, symbols: List[str], sub_types: List[str]) -> bool:
585 """取消订阅"""
586 client = self._get_client()
587 if not client:
588 return False
590 try:
591 quote_ctx = client.get("quote_ctx")
592 if quote_ctx:
593 sub_type_enums = [self._get_sub_type_enum(st) for st in sub_types]
594 quote_ctx.unsubscribe(symbols, sub_type_enums)
595 return True
596 except Exception as e:
597 print(f"❌ 取消订阅失败: {e}")
599 return False
601 def get_realtime_quote(self, symbols: List[str]) -> List[Any]:
602 """获取实时报价"""
603 client = self._get_client()
604 if not client:
605 print(f"❌ 获取实时报价失败: 客户端不可用")
606 return []
608 try:
609 quote_ctx = client.get("quote_ctx")
610 if quote_ctx:
611 print(f"🔍 获取实时报价: {symbols}")
612 # LongPort的realtime_quote可能需要回调,尝试使用quote方法
613 try:
614 result = quote_ctx.quote(symbols)
615 print(
616 f"🔍 使用quote方法获取结果: {len(result) if result else 0} 条"
617 )
618 return result
619 except Exception as quote_error:
620 print(f"❌ quote方法失败: {quote_error}")
621 # 尝试使用realtime_quote
622 try:
623 result = quote_ctx.realtime_quote(symbols)
624 print(
625 f"🔍 使用realtime_quote方法获取结果: {len(result) if result else 0} 条"
626 )
627 return result
628 except Exception as realtime_error:
629 print(f"❌ realtime_quote方法失败: {realtime_error}")
630 return []
631 else:
632 print(f"❌ 获取实时报价失败: quote_ctx不可用")
633 except Exception as e:
634 print(f"❌ 获取实时报价失败: {e}")
636 return []
638 def get_realtime_depth(self, symbol: str) -> Optional[Dict[str, Any]]:
639 """获取实时盘口"""
640 client = self._get_client()
641 if not client:
642 return None
644 try:
645 quote_ctx = client.get("quote_ctx")
646 if quote_ctx:
647 depth = quote_ctx.realtime_depth(symbol)
648 return self._serialize_depth(depth)
649 except Exception as e:
650 print(f"❌ 获取实时盘口失败: {e}")
652 return None
654 def get_realtime_trades(self, symbol: str, count: int) -> List[Any]:
655 """获取实时成交明细"""
656 client = self._get_client()
657 if not client:
658 return []
660 try:
661 quote_ctx = client.get("quote_ctx")
662 if quote_ctx:
663 # 实时成交明细需要设置回调函数,这里返回历史成交作为替代
664 # 或者使用 trades 方法获取最近的成交记录
665 result = quote_ctx.trades(symbol, count)
667 if not result:
668 print(f"⚠️ trades 返回空结果,{symbol} 可能没有成交数据")
670 return result
671 except Exception as e:
672 print(f"❌ 获取实时成交明细失败: {e}")
673 import traceback
675 traceback.print_exc()
677 return []
679 def get_subscription_summary(self) -> Dict[str, Any]:
680 """获取订阅概览"""
681 client = self._get_client()
682 if not client:
683 return {}
685 try:
686 quote_ctx = client.get("quote_ctx")
687 if quote_ctx:
688 subscriptions = quote_ctx.subscriptions()
689 return {
690 "subscriptions": subscriptions,
691 "count": len(subscriptions) if subscriptions else 0,
692 }
693 except Exception as e:
694 print(f"❌ 获取订阅概览失败: {e}")
696 return {}
698 # 枚举转换辅助方法
699 def _get_period_enum(self, period: str):
700 from longport.openapi import Period
702 # 映射字符串到Period枚举
703 period_mapping = {
704 "1m": "Min_1",
705 "5m": "Min_5",
706 "15m": "Min_15",
707 "30m": "Min_30",
708 "1h": "Min_60",
709 "1d": "Day_1",
710 "1w": "Week_1",
711 "1M": "Month_1",
712 }
714 # 如果传入的是映射后的值,直接使用
715 if hasattr(Period, period):
716 return getattr(Period, period)
718 # 如果传入的是字符串格式,先映射
719 if period in period_mapping:
720 enum_name = period_mapping[period]
721 return getattr(Period, enum_name)
723 # 默认返回Min_1
724 return Period.Min_1
726 def _get_adjust_type_enum(self, adjust_type: str):
727 from longport.openapi import AdjustType
729 return getattr(AdjustType, adjust_type)
731 def _get_trade_sessions_enum(self, trade_sessions: str):
732 from longport.openapi import TradeSessions
734 return getattr(TradeSessions, trade_sessions)
736 def _get_market_enum(self, market: str):
737 from longport.openapi import Market
739 return getattr(Market, market)
741 def _get_calc_index_enum(self, calc_index: str):
742 from longport.openapi import CalcIndex
744 # 尝试直接匹配
745 if hasattr(CalcIndex, calc_index):
746 return getattr(CalcIndex, calc_index)
748 # 尝试匹配常用指标名称
749 name_mapping = {
750 "LastDone": "LastDone",
751 "ChangeRate": "ChangeRate",
752 "ChangeValue": "ChangeVal",
753 "ChangeVal": "ChangeVal",
754 "Volume": "Volume",
755 "Turnover": "Turnover",
756 "PeTTMRatio": "PeTtmRatio",
757 "PbRatio": "PbRatio",
758 "DividendRatioTTM": "DividendRatioTtm",
759 "Amplitude": "Amplitude",
760 "VolumeRatio": "VolumeRatio",
761 "FiveDayChangeRate": "FiveDayChangeRate",
762 "TenDayChangeRate": "TenDayChangeRate",
763 "YtdChangeRate": "YtdChangeRate",
764 "TurnoverRate": "TurnoverRate",
765 "TotalMarketValue": "TotalMarketValue",
766 "CapitalFlow": "CapitalFlow",
767 }
769 if calc_index in name_mapping:
770 enum_name = name_mapping[calc_index]
771 if hasattr(CalcIndex, enum_name):
772 return getattr(CalcIndex, enum_name)
774 # 如果都找不到,抛出更明确的错误
775 available_enums = [attr for attr in dir(CalcIndex) if not attr.startswith("_")]
776 raise AttributeError(
777 f"'{calc_index}' not found in CalcIndex. Available enums: {available_enums}"
778 )
780 def _get_sub_type_enum(self, sub_type: str):
781 from longport.openapi import SubType
783 # 处理常见的命名差异
784 sub_type_mapping = {
785 "Trades": "Trade", # Trades -> Trade
786 "Trade": "Trade",
787 "Quote": "Quote",
788 "Depth": "Depth",
789 "Brokers": "Brokers",
790 }
792 mapped_type = sub_type_mapping.get(sub_type, sub_type)
793 return getattr(SubType, mapped_type)
795 def _serialize_depth(self, depth) -> Optional[Dict[str, Any]]:
796 """序列化盘口数据"""
797 if not depth:
798 return None
800 try:
801 return {
802 "symbol": getattr(depth, "symbol", ""),
803 "asks": getattr(depth, "asks", []),
804 "bids": getattr(depth, "bids", []),
805 "timestamp": getattr(depth, "timestamp", None),
806 }
807 except Exception as e:
808 print(f"❌ 序列化盘口数据失败: {e}")
809 return None
812class DataImportDataSourceAdapter(DataSourceAdapter):
813 """数据导入数据源适配器"""
815 def __init__(self, user_id: str):
816 super().__init__(user_id)
817 self._quote_adapter = QuoteDataSourceAdapter(user_id)
819 def is_available(self) -> bool:
820 """检查是否可用"""
821 return self._quote_adapter.is_available()
823 def fetch_stock_data(
824 self, symbol: str, start_date: date, end_date: date
825 ) -> Dict[str, Any]:
826 """获取股票历史数据"""
827 if not self.is_available():
828 return {"success": False, "error": "数据源不可用", "data_count": 0}
830 try:
831 # 转换为长桥标准格式
832 longport_symbol = self._convert_symbol_format(symbol)
834 # 获取K线数据(使用1分钟间隔)
835 candlesticks = self._quote_adapter.get_candlesticks(
836 longport_symbol,
837 Period.Min_1,
838 1000, # 最大1000条
839 AdjustType.NoAdjust,
840 TradeSessions.Intraday,
841 )
843 return {
844 "success": True,
845 "data": candlesticks,
846 "data_count": len(candlesticks),
847 "symbol": longport_symbol,
848 }
850 except Exception as e:
851 return {
852 "success": False,
853 "error": f"获取股票数据失败: {e}",
854 "data_count": 0,
855 }
857 def _convert_symbol_format(self, symbol: str) -> str:
858 """转换股票代码格式"""
859 # TODO: 实现不同券商的代码格式转换
860 return symbol
863# 工厂方法
864def create_data_source_adapter(adapter_type: str, user_id: str):
865 """创建数据源适配器"""
866 if adapter_type == "asset":
867 return AssetDataSourceAdapter(user_id)
868 elif adapter_type == "quote":
869 return QuoteDataSourceAdapter(user_id)
870 elif adapter_type == "data_import":
871 return DataImportDataSourceAdapter(user_id)
872 elif adapter_type == "trade":
873 from core.data_source.adapters.trade_adapter import \
874 TradeDataSourceAdapter
876 return TradeDataSourceAdapter(user_id)
877 else:
878 raise ValueError(f"不支持的适配器类型: {adapter_type}")