Coverage for core/trading/strategies/macd_strategy.py: 28.21%

234 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-13 18:58 +0000

1""" 

2MACD策略 

3""" 

4 

5from datetime import datetime 

6from decimal import Decimal 

7from typing import Any, Dict, List, Optional 

8 

9from core.models.trading import (MarketData, Order, OrderSide, OrderType, 

10 StrategyContext) 

11 

12from .base_strategy import BaseStrategy 

13 

14 

15class MACDStrategy(BaseStrategy): 

16 """MACD策略 - 移动平均收敛散度策略""" 

17 

18 def __init__(self, name: str, config: Dict[str, Any]): 

19 super().__init__(name, config) 

20 

21 # MACD参数 

22 self.fast_period = config.get("fast_period", 5) # 快速EMA周期 (5分钟) 

23 self.slow_period = config.get("slow_period", 13) # 慢速EMA周期 (13分钟) 

24 self.signal_period = config.get("signal_period", 4) # 信号线周期 (4分钟) 

25 self.position_size = config.get("position_size", 0.04) # 仓位大小 

26 

27 # 价格历史 - 动态管理,不预设symbols 

28 self.price_history: Dict[str, List[Decimal]] = {} 

29 self.ema_fast: Dict[str, List[Decimal]] = {} 

30 self.ema_slow: Dict[str, List[Decimal]] = {} 

31 self.macd_line: Dict[str, List[Decimal]] = {} 

32 self.signal_line: Dict[str, List[Decimal]] = {} 

33 self.histogram: Dict[str, List[Decimal]] = {} 

34 self.last_signals: Dict[str, str] = {} # "buy", "sell", "hold" 

35 

36 def initialize(self, context: StrategyContext): 

37 """策略初始化""" 

38 self.context = context 

39 self.trading_engine = context.trading_engine 

40 self.quote_adapter = context.quote_adapter 

41 

42 self.log_message(f"✅ MACD策略已初始化: {self.name}") 

43 self.log_message(f" 快速EMA周期: {self.fast_period}") 

44 self.log_message(f" 慢速EMA周期: {self.slow_period}") 

45 self.log_message(f" 信号线周期: {self.signal_period}") 

46 self.log_message(f" 仓位大小: {self.position_size}") 

47 

48 # 标记策略初始化完成 

49 self.is_initialized = True 

50 

51 def _load_historical_data(self): 

52 """加载历史数据用于技术指标计算""" 

53 try: 

54 from datetime import timedelta 

55 

56 # 计算需要的历史数据时间范围 

57 # MACD需要慢速EMA周期 + 信号线周期的历史数据 

58 days_needed = max(self.slow_period + self.signal_period, 60) # 至少60天 

59 

60 # 获取回测开始时间,避免未来函数 

61 if self.context and hasattr(self.context, "config"): 

62 config = self.context.config 

63 # 从配置中获取回测开始时间 

64 if "backtest_start_time" in config: 

65 backtest_start_time = config["backtest_start_time"] 

66 if isinstance(backtest_start_time, str): 

67 from datetime import datetime 

68 

69 backtest_start_time = datetime.fromisoformat( 

70 backtest_start_time.replace("Z", "+00:00") 

71 ) 

72 end_time = backtest_start_time 

73 start_time = end_time - timedelta(days=days_needed) 

74 self.log_message( 

75 f"📊 MACD策略加载历史数据: {days_needed}天 (回测模式,结束时间: {end_time})" 

76 ) 

77 else: 

78 # 如果没有回测时间配置,使用当前时间(实时模式) 

79 end_time = datetime.now() 

80 start_time = end_time - timedelta(days=days_needed) 

81 self.log_message( 

82 f"📊 MACD策略加载历史数据: {days_needed}天 (实时模式)" 

83 ) 

84 else: 

85 # 默认使用当前时间 

86 end_time = datetime.now() 

87 start_time = end_time - timedelta(days=days_needed) 

88 self.log_message(f"📊 MACD策略加载历史数据: {days_needed}天 (默认模式)") 

89 

90 # 从策略上下文获取可交易股票列表 

91 if not self.context or not hasattr(self.context, "trading_engine"): 

92 self.log_message(f"⚠️ 策略上下文未初始化,跳过历史数据加载") 

93 return 

94 

95 # 获取可交易股票列表 - 从策略上下文配置获取 

96 tradable_symbols = [] 

97 if self.context and hasattr(self.context, "config"): 

98 # 从配置中获取可交易股票 

99 config = self.context.config 

100 if "tradable_symbols" in config: 

101 tradable_symbols = config["tradable_symbols"] 

102 elif ( 

103 "risk_config" in config 

104 and "allowed_symbols" in config["risk_config"] 

105 ): 

106 tradable_symbols = config["risk_config"]["allowed_symbols"] 

107 

108 if not tradable_symbols: 

109 self.log_message(f"⚠️ 未获取到可交易股票列表,跳过历史数据加载") 

110 return 

111 

112 for symbol in tradable_symbols: 

113 # 获取历史数据 

114 historical_data = self.get_historical_data(symbol, start_time, end_time) 

115 

116 if historical_data: 

117 # 初始化价格历史 

118 self.price_history[symbol] = [ 

119 data.close for data in historical_data 

120 ] 

121 self.log_message( 

122 f"✅ 加载 {symbol} 历史数据: {len(historical_data)}" 

123 ) 

124 

125 # 预计算技术指标 

126 self._calculate_initial_indicators(symbol) 

127 else: 

128 self.log_message(f"⚠️ 未找到 {symbol} 历史数据,等待实时数据") 

129 self.price_history[symbol] = [] 

130 

131 except Exception as e: 

132 self.log_message(f"❌ 加载历史数据失败: {e}") 

133 

134 def _calculate_initial_indicators(self, symbol: str): 

135 """计算初始技术指标""" 

136 try: 

137 if len(self.price_history[symbol]) < self.slow_period + self.signal_period: 

138 return 

139 

140 # 计算EMA 

141 prices = self.price_history[symbol] 

142 self.ema_fast[symbol] = self._calculate_ema(prices, self.fast_period) 

143 self.ema_slow[symbol] = self._calculate_ema(prices, self.slow_period) 

144 

145 # 计算MACD 

146 self._calculate_macd(symbol) 

147 

148 self.log_message(f"✅ 预计算 {symbol} 技术指标完成") 

149 

150 except Exception as e: 

151 self.log_message(f"❌ 计算初始技术指标失败: {symbol} - {e}") 

152 

153 def on_market_data(self, market_data: MarketData): 

154 """市场数据回调""" 

155 symbol = market_data.symbol 

156 

157 # 动态初始化股票数据 

158 if symbol not in self.price_history: 

159 self.price_history[symbol] = [] 

160 self.ema_fast[symbol] = [] 

161 self.ema_slow[symbol] = [] 

162 self.macd_line[symbol] = [] 

163 self.signal_line[symbol] = [] 

164 self.histogram[symbol] = [] 

165 self.last_signals[symbol] = "hold" 

166 self.log_message(f"🆕 初始化股票数据: {symbol}") 

167 

168 # 更新价格历史 

169 self._update_price_history(market_data) 

170 

171 # 检查是否足够的历史数据 

172 current_history_length = len(self.price_history[market_data.symbol]) 

173 required_length = self.slow_period + self.signal_period 

174 

175 self.log_message( 

176 f"📊 MACD数据检查: {market_data.symbol} 当前历史数据: {current_history_length}条, 需要: {required_length}" 

177 ) 

178 

179 if current_history_length < required_length: 

180 self.log_message(f"⏳ MACD数据不足,跳过信号生成: {market_data.symbol}") 

181 return 

182 

183 # 计算MACD 

184 self.log_message(f"🔢 开始计算MACD指标: {market_data.symbol}") 

185 macd_data = self._calculate_macd(market_data.symbol) 

186 if macd_data is None: 

187 self.log_message(f"❌ MACD计算失败: {market_data.symbol}") 

188 return 

189 

190 self.log_message( 

191 f"📈 MACD计算结果: {market_data.symbol} MACD={macd_data['macd']:.4f}, Signal={macd_data['signal']:.4f}, Histogram={macd_data['histogram']:.4f}" 

192 ) 

193 

194 # 生成交易信号 

195 signal = self._generate_signal(market_data.symbol, macd_data) 

196 self.log_message(f"🎯 生成交易信号: {market_data.symbol} = {signal}") 

197 

198 # 执行交易 

199 if signal != "hold": 

200 self._execute_signal(market_data.symbol, signal, market_data.close) 

201 

202 # 检查止损(策略完全被动化,在每次市场数据更新时检查) 

203 self._check_stop_loss() 

204 

205 def on_trade_update(self, trade): 

206 """交易更新回调""" 

207 self.log_message( 

208 f"📊 MACD策略交易更新: {trade.symbol} {trade.side} {trade.quantity} @ {trade.price}" 

209 ) 

210 self.trades.append(trade) 

211 self.total_trades += 1 

212 

213 # 更新胜率统计 

214 if trade.side == OrderSide.SELL: 

215 # 简化:假设卖出盈利 

216 self.winning_trades += 1 

217 

218 # 注意:策略完全被动化,不再需要定时器回调 

219 # 止损检查在 on_market_data 中处理 

220 

221 def _update_price_history(self, market_data: MarketData): 

222 """更新价格历史""" 

223 symbol = market_data.symbol 

224 if symbol not in self.price_history: 

225 self.price_history[symbol] = [] 

226 

227 # 检查是否已经存在这个时间点的数据(避免重复添加) 

228 # 由于我们使用分钟级数据,简单检查最后一个价格是否相同 

229 if ( 

230 self.price_history[symbol] 

231 and self.price_history[symbol][-1] == market_data.close 

232 ): 

233 # 价格相同,可能是重复数据,不添加 

234 return 

235 

236 self.price_history[symbol].append(market_data.close) 

237 

238 # 保持历史数据长度 

239 max_history = self.slow_period * 3 

240 if len(self.price_history[symbol]) > max_history: 

241 self.price_history[symbol] = self.price_history[symbol][-max_history:] 

242 

243 def _calculate_ema(self, prices: List[Decimal], period: int) -> List[Decimal]: 

244 """计算指数移动平均线""" 

245 if len(prices) < period: 

246 return [] 

247 

248 ema_values = [] 

249 multiplier = Decimal("2") / Decimal(str(period + 1)) 

250 

251 # 第一个EMA值使用SMA 

252 sma = sum(prices[:period]) / Decimal(str(period)) 

253 ema_values.append(sma) 

254 

255 # 计算后续EMA值 

256 for i in range(period, len(prices)): 

257 ema = (prices[i] * multiplier) + ( 

258 ema_values[-1] * (Decimal("1") - multiplier) 

259 ) 

260 ema_values.append(ema) 

261 

262 return ema_values 

263 

264 def _calculate_macd(self, symbol: str) -> Optional[Dict[str, Decimal]]: 

265 """计算MACD指标""" 

266 if symbol not in self.price_history: 

267 return None 

268 

269 prices = self.price_history[symbol] 

270 if len(prices) < self.slow_period + self.signal_period: 

271 return None 

272 

273 # 计算EMA 

274 ema_fast_values = self._calculate_ema(prices, self.fast_period) 

275 ema_slow_values = self._calculate_ema(prices, self.slow_period) 

276 

277 # EMA值数量检查:EMA值数量 = len(prices) - period + 1 

278 # 我们需要至少signal_period个EMA值来计算信号线 

279 if ( 

280 len(ema_fast_values) < self.signal_period 

281 or len(ema_slow_values) < self.signal_period 

282 ): 

283 self.log_message( 

284 f"❌ EMA值不足: 快速EMA={len(ema_fast_values)}条, 慢速EMA={len(ema_slow_values)}条, 需要={self.signal_period}" 

285 ) 

286 return None 

287 

288 # 计算MACD线 

289 macd_line_values = [] 

290 min_length = min(len(ema_fast_values), len(ema_slow_values)) 

291 

292 for i in range(min_length): 

293 macd = ema_fast_values[i] - ema_slow_values[i] 

294 macd_line_values.append(macd) 

295 

296 if len(macd_line_values) < self.signal_period: 

297 self.log_message( 

298 f"❌ MACD线值不足: {len(macd_line_values)}条, 需要={self.signal_period}" 

299 ) 

300 return None 

301 

302 # 计算信号线 

303 signal_line_values = self._calculate_ema(macd_line_values, self.signal_period) 

304 

305 if not signal_line_values: 

306 return None 

307 

308 # 计算柱状图 

309 histogram_values = [] 

310 for i in range(len(signal_line_values)): 

311 hist = macd_line_values[i + self.signal_period - 1] - signal_line_values[i] 

312 histogram_values.append(hist) 

313 

314 # 更新历史数据 

315 self.ema_fast[symbol] = ema_fast_values 

316 self.ema_slow[symbol] = ema_slow_values 

317 self.macd_line[symbol] = macd_line_values 

318 self.signal_line[symbol] = signal_line_values 

319 self.histogram[symbol] = histogram_values 

320 

321 # 返回最新值 

322 return { 

323 "macd": macd_line_values[-1], 

324 "signal": signal_line_values[-1], 

325 "histogram": histogram_values[-1] if histogram_values else Decimal("0"), 

326 } 

327 

328 def _generate_signal(self, symbol: str, macd_data: Dict[str, Decimal]) -> str: 

329 """生成交易信号""" 

330 last_signal = self.last_signals.get(symbol, "hold") 

331 

332 macd = macd_data["macd"] 

333 signal = macd_data["signal"] 

334 histogram = macd_data["histogram"] 

335 

336 # MACD金叉:MACD线上穿信号线 

337 if macd > signal and histogram > 0 and last_signal != "buy": 

338 self.last_signals[symbol] = "buy" 

339 return "buy" 

340 # MACD死叉:MACD线下穿信号线 

341 elif macd < signal and histogram < 0 and last_signal != "sell": 

342 self.last_signals[symbol] = "sell" 

343 return "sell" 

344 else: 

345 return "hold" 

346 

347 def _execute_signal(self, symbol: str, signal: str, current_price: Decimal): 

348 """执行交易信号""" 

349 self.log_message(f"🔧 开始执行交易信号: {symbol} {signal} @ {current_price}") 

350 

351 if not self.trading_engine: 

352 self.log_message(f"❌ 交易引擎未设置,无法执行信号: {symbol} {signal}") 

353 return 

354 

355 portfolio = self.get_portfolio() 

356 self.log_message( 

357 f"🔧 获取投资组合信息: 总价值=${portfolio.total_value}, 现金=${portfolio.cash}" 

358 ) 

359 

360 if signal == "buy": 

361 # 买入逻辑 

362 # 计算买入数量 

363 position_size = self.calculate_position_size( 

364 symbol, current_price, self.position_size 

365 ) 

366 self.log_message(f"🔧 计算买入数量: {position_size}") 

367 

368 # 使用实际买入数量检查风险限制 

369 if position_size > 0: 

370 risk_check = self.check_risk_limits( 

371 symbol, position_size, current_price 

372 ) 

373 if not risk_check["valid"]: 

374 error_details = ", ".join(risk_check["errors"]) 

375 self.log_message( 

376 f"❌ MACD买入信号被风险控制阻止: {error_details}", "warn" 

377 ) 

378 return 

379 

380 # 创建订单 

381 order = Order( 

382 symbol=symbol, 

383 side=OrderSide.BUY, 

384 order_type=OrderType.MARKET, 

385 quantity=position_size, 

386 price=None, 

387 session_id=self.context.trading_session_id, 

388 ) 

389 

390 success = self.submit_order(order) 

391 if success: 

392 self.log_message( 

393 f"🟢 MACD买入信号执行: {symbol} {position_size} @ {current_price}", 

394 "info", 

395 ) 

396 else: 

397 self.log_message(f"❌ MACD买入订单失败: {symbol}", "error") 

398 

399 elif signal == "sell": 

400 # 卖出逻辑 

401 # 查找当前持仓 

402 current_position = None 

403 for position in portfolio.positions: 

404 if position.symbol == symbol: 

405 current_position = position 

406 break 

407 

408 self.log_message( 

409 f"🔧 查找当前持仓: {symbol} = {current_position.quantity if current_position else 0}" 

410 ) 

411 

412 if current_position and current_position.quantity > 0: 

413 # 卖出全部持仓 

414 order = Order( 

415 symbol=symbol, 

416 side=OrderSide.SELL, 

417 order_type=OrderType.MARKET, 

418 quantity=current_position.quantity, 

419 price=None, 

420 session_id=self.context.trading_session_id, 

421 ) 

422 

423 success = self.submit_order(order) 

424 if success: 

425 self.log_message( 

426 f"🔴 MACD卖出信号执行: {symbol} {current_position.quantity} @ {current_price}", 

427 "info", 

428 ) 

429 else: 

430 self.log_message(f"❌ MACD卖出订单失败: {symbol}", "error") 

431 

432 def _check_stop_loss(self): 

433 """检查止损""" 

434 portfolio = self.get_portfolio() 

435 risk_limits = self.get_risk_limits() 

436 

437 for position in portfolio.positions: 

438 if position.unrealized_pnl < 0: 

439 # 计算亏损比例 

440 loss_ratio = abs(position.unrealized_pnl) / ( 

441 position.avg_price * position.quantity 

442 ) 

443 

444 if loss_ratio >= risk_limits.stop_loss_ratio: 

445 # 触发止损 

446 order = Order( 

447 symbol=position.symbol, 

448 side=OrderSide.SELL, 

449 order_type=OrderType.MARKET, 

450 quantity=position.quantity, 

451 price=None, 

452 session_id=self.context.trading_session_id, 

453 ) 

454 

455 success = self.submit_order(order) 

456 if success: 

457 self.log_message( 

458 f"🛑 MACD止损触发: {position.symbol} {position.quantity}", 

459 "warn", 

460 ) 

461 

462 def get_macd_data(self, symbol: str) -> Optional[Dict[str, Decimal]]: 

463 """获取当前MACD数据""" 

464 if ( 

465 symbol in self.macd_line 

466 and self.macd_line[symbol] 

467 and symbol in self.signal_line 

468 and self.signal_line[symbol] 

469 and symbol in self.histogram 

470 and self.histogram[symbol] 

471 ): 

472 

473 return { 

474 "macd": self.macd_line[symbol][-1], 

475 "signal": self.signal_line[symbol][-1], 

476 "histogram": self.histogram[symbol][-1], 

477 } 

478 return None 

479 

480 def get_macd_history(self, symbol: str) -> Dict[str, List[Decimal]]: 

481 """获取MACD历史数据""" 

482 return { 

483 "macd": self.macd_line.get(symbol, []), 

484 "signal": self.signal_line.get(symbol, []), 

485 "histogram": self.histogram.get(symbol, []), 

486 }