Coverage for core/trading/strategies/base_strategy.py: 39.67%

121 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-13 18:58 +0000

1""" 

2基础策略类 

3""" 

4 

5from abc import ABC, abstractmethod 

6from datetime import datetime 

7from decimal import Decimal 

8from typing import Any, Dict, List, Optional 

9 

10from core.models.trading import (MarketData, Order, OrderSide, OrderType, 

11 Portfolio, Position, RiskLimits, 

12 StrategyContext) 

13 

14 

15class BaseStrategy(ABC): 

16 """策略抽象基类""" 

17 

18 def __init__(self, name: str, config: Dict[str, Any]): 

19 self.name = name 

20 self.config = config 

21 self.is_active = False 

22 self.is_initialized = False # 策略初始化状态 

23 self.context: Optional[StrategyContext] = None 

24 self.trading_engine = None 

25 self.quote_adapter = None 

26 self.strategy_engine = None # 引用StrategyEngine,用于统一日志管理 

27 

28 # 策略状态 

29 self.positions: Dict[str, Position] = {} 

30 self.orders: List[Order] = [] 

31 self.trades: List[Any] = [] 

32 

33 # 性能指标 

34 self.total_return = Decimal("0") 

35 self.max_drawdown = Decimal("0") 

36 self.win_rate = Decimal("0") 

37 self.total_trades = 0 

38 self.winning_trades = 0 

39 

40 @abstractmethod 

41 def initialize(self, context: StrategyContext): 

42 """策略初始化""" 

43 pass 

44 

45 @abstractmethod 

46 def on_market_data(self, market_data: MarketData): 

47 """市场数据回调""" 

48 pass 

49 

50 @abstractmethod 

51 def on_trade_update(self, trade): 

52 """交易更新回调""" 

53 pass 

54 

55 # 注意:策略完全被动化,不再需要定时器回调 

56 # 所有逻辑都在 on_market_data 中处理 

57 

58 def get_historical_data( 

59 self, symbol: str, start_time: datetime, end_time: datetime 

60 ) -> List[MarketData]: 

61 """获取历史数据 - 策略需要历史数据时调用""" 

62 if not self.quote_adapter: 

63 return [] 

64 

65 try: 

66 # 通过数据适配器从数据库/缓存获取历史数据 

67 # 传递日志回调函数,确保日志通过统一日志管理 

68 def log_callback(message: str, log_type: str = "info"): 

69 self.log_message(f"【数据适配器】{message}", log_type) 

70 

71 return self.quote_adapter.get_historical_data( 

72 symbol, start_time, end_time, log_callback 

73 ) 

74 except Exception as e: 

75 self.log_message(f"❌ 获取历史数据失败: {symbol} - {e}", "error") 

76 return [] 

77 

78 def get_portfolio(self) -> Portfolio: 

79 """获取投资组合信息""" 

80 if not self.trading_engine: 

81 self.log_message(f"🔧 get_portfolio: 交易引擎未设置,返回空投资组合") 

82 return Portfolio( 

83 total_value=Decimal("0"), 

84 cash=Decimal("0"), 

85 positions=[], 

86 unrealized_pnl=Decimal("0"), 

87 realized_pnl=Decimal("0"), 

88 ) 

89 

90 try: 

91 # 获取持仓 

92 positions = self.trading_engine.get_positions() 

93 self.log_message(f"🔧 get_portfolio: 获取到 {len(positions)} 个持仓") 

94 

95 # 获取账户余额 

96 balance = self.trading_engine.get_account_balance() 

97 self.log_message( 

98 f"🔧 get_portfolio: 账户余额: 总现金=${balance.total_cash}, 可用现金=${balance.available_cash}" 

99 ) 

100 except Exception as e: 

101 self.log_message(f"❌ get_portfolio: 获取持仓或余额失败: {e}") 

102 return Portfolio( 

103 total_value=Decimal("0"), 

104 cash=Decimal("0"), 

105 positions=[], 

106 unrealized_pnl=Decimal("0"), 

107 realized_pnl=Decimal("0"), 

108 ) 

109 

110 # 计算总价值 

111 total_value = balance.total_cash 

112 unrealized_pnl = Decimal("0") 

113 

114 for position in positions: 

115 total_value += position.market_value 

116 unrealized_pnl += position.unrealized_pnl 

117 

118 portfolio = Portfolio( 

119 total_value=total_value, 

120 cash=balance.available_cash, 

121 positions=positions, 

122 unrealized_pnl=unrealized_pnl, 

123 realized_pnl=Decimal("0"), # 需要从交易历史计算 

124 ) 

125 

126 self.log_message(f"🔧 get_portfolio: 投资组合总价值=${portfolio.total_value}") 

127 return portfolio 

128 

129 def submit_order(self, order: Order) -> bool: 

130 """提交交易订单""" 

131 if not self.trading_engine: 

132 return False 

133 

134 result = self.trading_engine.submit_order(order) 

135 # 修复:使用正确的枚举值进行比较 

136 from core.models.trading import OrderStatus 

137 

138 success = result.status in [OrderStatus.FILLED, OrderStatus.PENDING] 

139 

140 if success: 

141 self.orders.append(order) 

142 

143 return success 

144 

145 def get_risk_limits(self) -> RiskLimits: 

146 """获取风险限制""" 

147 if not self.context: 

148 return RiskLimits( 

149 max_position_ratio=0.1, 

150 stop_loss_ratio=0.05, 

151 max_drawdown=0.15, 

152 allowed_symbols=[], 

153 ) 

154 

155 # 从上下文获取风险配置 

156 risk_config = self.context.config.get("risk_config", {}) 

157 return RiskLimits( 

158 max_position_ratio=risk_config.get("max_position_ratio", 0.1), 

159 stop_loss_ratio=risk_config.get("stop_loss_ratio", 0.05), 

160 max_drawdown=risk_config.get("max_drawdown", 0.15), 

161 allowed_symbols=risk_config.get("allowed_symbols", []), 

162 ) 

163 

164 def calculate_position_size( 

165 self, symbol: str, price: Decimal, risk_ratio: float = 0.1 

166 ) -> Decimal: 

167 """计算建议持仓数量""" 

168 portfolio = self.get_portfolio() 

169 risk_limits = self.get_risk_limits() 

170 

171 # 计算最大可投入资金 

172 max_investment = portfolio.total_value * Decimal( 

173 str(min(risk_ratio, risk_limits.max_position_ratio)) 

174 ) 

175 

176 # 计算持仓数量 

177 position_size = max_investment / price 

178 

179 return position_size 

180 

181 def check_risk_limits( 

182 self, symbol: str, quantity: Decimal, price: Decimal 

183 ) -> Dict[str, Any]: 

184 """检查风险限制""" 

185 errors = [] 

186 warnings = [] 

187 

188 portfolio = self.get_portfolio() 

189 risk_limits = self.get_risk_limits() 

190 

191 # 检查允许交易的股票代码 

192 if risk_limits.allowed_symbols and symbol not in risk_limits.allowed_symbols: 

193 errors.append(f"股票 {symbol} 不在允许交易列表中") 

194 

195 # 检查单只股票最大持仓比例 

196 position_value = quantity * price 

197 position_ratio = ( 

198 position_value / portfolio.total_value if portfolio.total_value > 0 else 0 

199 ) 

200 

201 if position_ratio > risk_limits.max_position_ratio: 

202 errors.append( 

203 f"持仓比例 {position_ratio:.2%} 超过最大限制 {risk_limits.max_position_ratio:.2%}" 

204 ) 

205 

206 # 检查资金充足性 

207 if position_value > portfolio.cash: 

208 errors.append("资金不足") 

209 

210 return {"valid": len(errors) == 0, "errors": errors, "warnings": warnings} 

211 

212 def update_performance_metrics(self): 

213 """更新性能指标""" 

214 # 这里可以实现更复杂的性能指标计算 

215 # 简化版本 

216 if self.total_trades > 0: 

217 self.win_rate = Decimal(str(self.winning_trades)) / Decimal( 

218 str(self.total_trades) 

219 ) 

220 

221 def get_performance_summary(self) -> Dict[str, Any]: 

222 """获取性能摘要""" 

223 return { 

224 "strategy_name": self.name, 

225 "total_return": self.total_return, 

226 "max_drawdown": self.max_drawdown, 

227 "win_rate": self.win_rate, 

228 "total_trades": self.total_trades, 

229 "winning_trades": self.winning_trades, 

230 "losing_trades": self.total_trades - self.winning_trades, 

231 "is_active": self.is_active, 

232 } 

233 

234 def stop(self): 

235 """停止策略""" 

236 self.is_active = False 

237 self.log_message(f"🛑 策略已停止: {self.name}") 

238 

239 def start(self): 

240 """启动策略""" 

241 self.is_active = True 

242 self.log_message(f"▶️ 策略已启动: {self.name}") 

243 

244 def reset(self): 

245 """重置策略状态""" 

246 self.positions.clear() 

247 self.orders.clear() 

248 self.trades.clear() 

249 self.total_return = Decimal("0") 

250 self.max_drawdown = Decimal("0") 

251 self.win_rate = Decimal("0") 

252 self.total_trades = 0 

253 self.winning_trades = 0 

254 self.log_message(f"🔄 策略已重置: {self.name}") 

255 

256 def log_message(self, message: str, log_type: str = "log"): 

257 """统一日志接口 - 通过StrategyEngine记录日志""" 

258 if self.strategy_engine: 

259 self.strategy_engine.log_message(message, log_type, f"策略_{self.name}") 

260 else: 

261 # 如果没有StrategyEngine引用,使用print作为备用 

262 print(f"📝 [{self.name}] {message}")