Coverage for core/trading/strategies/base_strategy.py: 39.67%
121 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-13 18:58 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-13 18:58 +0000
1"""
2基础策略类
3"""
5from abc import ABC, abstractmethod
6from datetime import datetime
7from decimal import Decimal
8from typing import Any, Dict, List, Optional
10from core.models.trading import (MarketData, Order, OrderSide, OrderType,
11 Portfolio, Position, RiskLimits,
12 StrategyContext)
15class BaseStrategy(ABC):
16 """策略抽象基类"""
18 def __init__(self, name: str, config: Dict[str, Any]):
19 self.name = name
20 self.config = config
21 self.is_active = False
22 self.is_initialized = False # 策略初始化状态
23 self.context: Optional[StrategyContext] = None
24 self.trading_engine = None
25 self.quote_adapter = None
26 self.strategy_engine = None # 引用StrategyEngine,用于统一日志管理
28 # 策略状态
29 self.positions: Dict[str, Position] = {}
30 self.orders: List[Order] = []
31 self.trades: List[Any] = []
33 # 性能指标
34 self.total_return = Decimal("0")
35 self.max_drawdown = Decimal("0")
36 self.win_rate = Decimal("0")
37 self.total_trades = 0
38 self.winning_trades = 0
40 @abstractmethod
41 def initialize(self, context: StrategyContext):
42 """策略初始化"""
43 pass
45 @abstractmethod
46 def on_market_data(self, market_data: MarketData):
47 """市场数据回调"""
48 pass
50 @abstractmethod
51 def on_trade_update(self, trade):
52 """交易更新回调"""
53 pass
55 # 注意:策略完全被动化,不再需要定时器回调
56 # 所有逻辑都在 on_market_data 中处理
58 def get_historical_data(
59 self, symbol: str, start_time: datetime, end_time: datetime
60 ) -> List[MarketData]:
61 """获取历史数据 - 策略需要历史数据时调用"""
62 if not self.quote_adapter:
63 return []
65 try:
66 # 通过数据适配器从数据库/缓存获取历史数据
67 # 传递日志回调函数,确保日志通过统一日志管理
68 def log_callback(message: str, log_type: str = "info"):
69 self.log_message(f"【数据适配器】{message}", log_type)
71 return self.quote_adapter.get_historical_data(
72 symbol, start_time, end_time, log_callback
73 )
74 except Exception as e:
75 self.log_message(f"❌ 获取历史数据失败: {symbol} - {e}", "error")
76 return []
78 def get_portfolio(self) -> Portfolio:
79 """获取投资组合信息"""
80 if not self.trading_engine:
81 self.log_message(f"🔧 get_portfolio: 交易引擎未设置,返回空投资组合")
82 return Portfolio(
83 total_value=Decimal("0"),
84 cash=Decimal("0"),
85 positions=[],
86 unrealized_pnl=Decimal("0"),
87 realized_pnl=Decimal("0"),
88 )
90 try:
91 # 获取持仓
92 positions = self.trading_engine.get_positions()
93 self.log_message(f"🔧 get_portfolio: 获取到 {len(positions)} 个持仓")
95 # 获取账户余额
96 balance = self.trading_engine.get_account_balance()
97 self.log_message(
98 f"🔧 get_portfolio: 账户余额: 总现金=${balance.total_cash}, 可用现金=${balance.available_cash}"
99 )
100 except Exception as e:
101 self.log_message(f"❌ get_portfolio: 获取持仓或余额失败: {e}")
102 return Portfolio(
103 total_value=Decimal("0"),
104 cash=Decimal("0"),
105 positions=[],
106 unrealized_pnl=Decimal("0"),
107 realized_pnl=Decimal("0"),
108 )
110 # 计算总价值
111 total_value = balance.total_cash
112 unrealized_pnl = Decimal("0")
114 for position in positions:
115 total_value += position.market_value
116 unrealized_pnl += position.unrealized_pnl
118 portfolio = Portfolio(
119 total_value=total_value,
120 cash=balance.available_cash,
121 positions=positions,
122 unrealized_pnl=unrealized_pnl,
123 realized_pnl=Decimal("0"), # 需要从交易历史计算
124 )
126 self.log_message(f"🔧 get_portfolio: 投资组合总价值=${portfolio.total_value}")
127 return portfolio
129 def submit_order(self, order: Order) -> bool:
130 """提交交易订单"""
131 if not self.trading_engine:
132 return False
134 result = self.trading_engine.submit_order(order)
135 # 修复:使用正确的枚举值进行比较
136 from core.models.trading import OrderStatus
138 success = result.status in [OrderStatus.FILLED, OrderStatus.PENDING]
140 if success:
141 self.orders.append(order)
143 return success
145 def get_risk_limits(self) -> RiskLimits:
146 """获取风险限制"""
147 if not self.context:
148 return RiskLimits(
149 max_position_ratio=0.1,
150 stop_loss_ratio=0.05,
151 max_drawdown=0.15,
152 allowed_symbols=[],
153 )
155 # 从上下文获取风险配置
156 risk_config = self.context.config.get("risk_config", {})
157 return RiskLimits(
158 max_position_ratio=risk_config.get("max_position_ratio", 0.1),
159 stop_loss_ratio=risk_config.get("stop_loss_ratio", 0.05),
160 max_drawdown=risk_config.get("max_drawdown", 0.15),
161 allowed_symbols=risk_config.get("allowed_symbols", []),
162 )
164 def calculate_position_size(
165 self, symbol: str, price: Decimal, risk_ratio: float = 0.1
166 ) -> Decimal:
167 """计算建议持仓数量"""
168 portfolio = self.get_portfolio()
169 risk_limits = self.get_risk_limits()
171 # 计算最大可投入资金
172 max_investment = portfolio.total_value * Decimal(
173 str(min(risk_ratio, risk_limits.max_position_ratio))
174 )
176 # 计算持仓数量
177 position_size = max_investment / price
179 return position_size
181 def check_risk_limits(
182 self, symbol: str, quantity: Decimal, price: Decimal
183 ) -> Dict[str, Any]:
184 """检查风险限制"""
185 errors = []
186 warnings = []
188 portfolio = self.get_portfolio()
189 risk_limits = self.get_risk_limits()
191 # 检查允许交易的股票代码
192 if risk_limits.allowed_symbols and symbol not in risk_limits.allowed_symbols:
193 errors.append(f"股票 {symbol} 不在允许交易列表中")
195 # 检查单只股票最大持仓比例
196 position_value = quantity * price
197 position_ratio = (
198 position_value / portfolio.total_value if portfolio.total_value > 0 else 0
199 )
201 if position_ratio > risk_limits.max_position_ratio:
202 errors.append(
203 f"持仓比例 {position_ratio:.2%} 超过最大限制 {risk_limits.max_position_ratio:.2%}"
204 )
206 # 检查资金充足性
207 if position_value > portfolio.cash:
208 errors.append("资金不足")
210 return {"valid": len(errors) == 0, "errors": errors, "warnings": warnings}
212 def update_performance_metrics(self):
213 """更新性能指标"""
214 # 这里可以实现更复杂的性能指标计算
215 # 简化版本
216 if self.total_trades > 0:
217 self.win_rate = Decimal(str(self.winning_trades)) / Decimal(
218 str(self.total_trades)
219 )
221 def get_performance_summary(self) -> Dict[str, Any]:
222 """获取性能摘要"""
223 return {
224 "strategy_name": self.name,
225 "total_return": self.total_return,
226 "max_drawdown": self.max_drawdown,
227 "win_rate": self.win_rate,
228 "total_trades": self.total_trades,
229 "winning_trades": self.winning_trades,
230 "losing_trades": self.total_trades - self.winning_trades,
231 "is_active": self.is_active,
232 }
234 def stop(self):
235 """停止策略"""
236 self.is_active = False
237 self.log_message(f"🛑 策略已停止: {self.name}")
239 def start(self):
240 """启动策略"""
241 self.is_active = True
242 self.log_message(f"▶️ 策略已启动: {self.name}")
244 def reset(self):
245 """重置策略状态"""
246 self.positions.clear()
247 self.orders.clear()
248 self.trades.clear()
249 self.total_return = Decimal("0")
250 self.max_drawdown = Decimal("0")
251 self.win_rate = Decimal("0")
252 self.total_trades = 0
253 self.winning_trades = 0
254 self.log_message(f"🔄 策略已重置: {self.name}")
256 def log_message(self, message: str, log_type: str = "log"):
257 """统一日志接口 - 通过StrategyEngine记录日志"""
258 if self.strategy_engine:
259 self.strategy_engine.log_message(message, log_type, f"策略_{self.name}")
260 else:
261 # 如果没有StrategyEngine引用,使用print作为备用
262 print(f"📝 [{self.name}] {message}")