Coverage for core/trading/strategies/macd_strategy.py: 28.21%
234 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-13 18:58 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-13 18:58 +0000
1"""
2MACD策略
3"""
5from datetime import datetime
6from decimal import Decimal
7from typing import Any, Dict, List, Optional
9from core.models.trading import (MarketData, Order, OrderSide, OrderType,
10 StrategyContext)
12from .base_strategy import BaseStrategy
15class MACDStrategy(BaseStrategy):
16 """MACD策略 - 移动平均收敛散度策略"""
18 def __init__(self, name: str, config: Dict[str, Any]):
19 super().__init__(name, config)
21 # MACD参数
22 self.fast_period = config.get("fast_period", 5) # 快速EMA周期 (5分钟)
23 self.slow_period = config.get("slow_period", 13) # 慢速EMA周期 (13分钟)
24 self.signal_period = config.get("signal_period", 4) # 信号线周期 (4分钟)
25 self.position_size = config.get("position_size", 0.04) # 仓位大小
27 # 价格历史 - 动态管理,不预设symbols
28 self.price_history: Dict[str, List[Decimal]] = {}
29 self.ema_fast: Dict[str, List[Decimal]] = {}
30 self.ema_slow: Dict[str, List[Decimal]] = {}
31 self.macd_line: Dict[str, List[Decimal]] = {}
32 self.signal_line: Dict[str, List[Decimal]] = {}
33 self.histogram: Dict[str, List[Decimal]] = {}
34 self.last_signals: Dict[str, str] = {} # "buy", "sell", "hold"
36 def initialize(self, context: StrategyContext):
37 """策略初始化"""
38 self.context = context
39 self.trading_engine = context.trading_engine
40 self.quote_adapter = context.quote_adapter
42 self.log_message(f"✅ MACD策略已初始化: {self.name}")
43 self.log_message(f" 快速EMA周期: {self.fast_period}")
44 self.log_message(f" 慢速EMA周期: {self.slow_period}")
45 self.log_message(f" 信号线周期: {self.signal_period}")
46 self.log_message(f" 仓位大小: {self.position_size}")
48 # 标记策略初始化完成
49 self.is_initialized = True
51 def _load_historical_data(self):
52 """加载历史数据用于技术指标计算"""
53 try:
54 from datetime import timedelta
56 # 计算需要的历史数据时间范围
57 # MACD需要慢速EMA周期 + 信号线周期的历史数据
58 days_needed = max(self.slow_period + self.signal_period, 60) # 至少60天
60 # 获取回测开始时间,避免未来函数
61 if self.context and hasattr(self.context, "config"):
62 config = self.context.config
63 # 从配置中获取回测开始时间
64 if "backtest_start_time" in config:
65 backtest_start_time = config["backtest_start_time"]
66 if isinstance(backtest_start_time, str):
67 from datetime import datetime
69 backtest_start_time = datetime.fromisoformat(
70 backtest_start_time.replace("Z", "+00:00")
71 )
72 end_time = backtest_start_time
73 start_time = end_time - timedelta(days=days_needed)
74 self.log_message(
75 f"📊 MACD策略加载历史数据: {days_needed}天 (回测模式,结束时间: {end_time})"
76 )
77 else:
78 # 如果没有回测时间配置,使用当前时间(实时模式)
79 end_time = datetime.now()
80 start_time = end_time - timedelta(days=days_needed)
81 self.log_message(
82 f"📊 MACD策略加载历史数据: {days_needed}天 (实时模式)"
83 )
84 else:
85 # 默认使用当前时间
86 end_time = datetime.now()
87 start_time = end_time - timedelta(days=days_needed)
88 self.log_message(f"📊 MACD策略加载历史数据: {days_needed}天 (默认模式)")
90 # 从策略上下文获取可交易股票列表
91 if not self.context or not hasattr(self.context, "trading_engine"):
92 self.log_message(f"⚠️ 策略上下文未初始化,跳过历史数据加载")
93 return
95 # 获取可交易股票列表 - 从策略上下文配置获取
96 tradable_symbols = []
97 if self.context and hasattr(self.context, "config"):
98 # 从配置中获取可交易股票
99 config = self.context.config
100 if "tradable_symbols" in config:
101 tradable_symbols = config["tradable_symbols"]
102 elif (
103 "risk_config" in config
104 and "allowed_symbols" in config["risk_config"]
105 ):
106 tradable_symbols = config["risk_config"]["allowed_symbols"]
108 if not tradable_symbols:
109 self.log_message(f"⚠️ 未获取到可交易股票列表,跳过历史数据加载")
110 return
112 for symbol in tradable_symbols:
113 # 获取历史数据
114 historical_data = self.get_historical_data(symbol, start_time, end_time)
116 if historical_data:
117 # 初始化价格历史
118 self.price_history[symbol] = [
119 data.close for data in historical_data
120 ]
121 self.log_message(
122 f"✅ 加载 {symbol} 历史数据: {len(historical_data)} 条"
123 )
125 # 预计算技术指标
126 self._calculate_initial_indicators(symbol)
127 else:
128 self.log_message(f"⚠️ 未找到 {symbol} 历史数据,等待实时数据")
129 self.price_history[symbol] = []
131 except Exception as e:
132 self.log_message(f"❌ 加载历史数据失败: {e}")
134 def _calculate_initial_indicators(self, symbol: str):
135 """计算初始技术指标"""
136 try:
137 if len(self.price_history[symbol]) < self.slow_period + self.signal_period:
138 return
140 # 计算EMA
141 prices = self.price_history[symbol]
142 self.ema_fast[symbol] = self._calculate_ema(prices, self.fast_period)
143 self.ema_slow[symbol] = self._calculate_ema(prices, self.slow_period)
145 # 计算MACD
146 self._calculate_macd(symbol)
148 self.log_message(f"✅ 预计算 {symbol} 技术指标完成")
150 except Exception as e:
151 self.log_message(f"❌ 计算初始技术指标失败: {symbol} - {e}")
153 def on_market_data(self, market_data: MarketData):
154 """市场数据回调"""
155 symbol = market_data.symbol
157 # 动态初始化股票数据
158 if symbol not in self.price_history:
159 self.price_history[symbol] = []
160 self.ema_fast[symbol] = []
161 self.ema_slow[symbol] = []
162 self.macd_line[symbol] = []
163 self.signal_line[symbol] = []
164 self.histogram[symbol] = []
165 self.last_signals[symbol] = "hold"
166 self.log_message(f"🆕 初始化股票数据: {symbol}")
168 # 更新价格历史
169 self._update_price_history(market_data)
171 # 检查是否足够的历史数据
172 current_history_length = len(self.price_history[market_data.symbol])
173 required_length = self.slow_period + self.signal_period
175 self.log_message(
176 f"📊 MACD数据检查: {market_data.symbol} 当前历史数据: {current_history_length}条, 需要: {required_length}条"
177 )
179 if current_history_length < required_length:
180 self.log_message(f"⏳ MACD数据不足,跳过信号生成: {market_data.symbol}")
181 return
183 # 计算MACD
184 self.log_message(f"🔢 开始计算MACD指标: {market_data.symbol}")
185 macd_data = self._calculate_macd(market_data.symbol)
186 if macd_data is None:
187 self.log_message(f"❌ MACD计算失败: {market_data.symbol}")
188 return
190 self.log_message(
191 f"📈 MACD计算结果: {market_data.symbol} MACD={macd_data['macd']:.4f}, Signal={macd_data['signal']:.4f}, Histogram={macd_data['histogram']:.4f}"
192 )
194 # 生成交易信号
195 signal = self._generate_signal(market_data.symbol, macd_data)
196 self.log_message(f"🎯 生成交易信号: {market_data.symbol} = {signal}")
198 # 执行交易
199 if signal != "hold":
200 self._execute_signal(market_data.symbol, signal, market_data.close)
202 # 检查止损(策略完全被动化,在每次市场数据更新时检查)
203 self._check_stop_loss()
205 def on_trade_update(self, trade):
206 """交易更新回调"""
207 self.log_message(
208 f"📊 MACD策略交易更新: {trade.symbol} {trade.side} {trade.quantity} @ {trade.price}"
209 )
210 self.trades.append(trade)
211 self.total_trades += 1
213 # 更新胜率统计
214 if trade.side == OrderSide.SELL:
215 # 简化:假设卖出盈利
216 self.winning_trades += 1
218 # 注意:策略完全被动化,不再需要定时器回调
219 # 止损检查在 on_market_data 中处理
221 def _update_price_history(self, market_data: MarketData):
222 """更新价格历史"""
223 symbol = market_data.symbol
224 if symbol not in self.price_history:
225 self.price_history[symbol] = []
227 # 检查是否已经存在这个时间点的数据(避免重复添加)
228 # 由于我们使用分钟级数据,简单检查最后一个价格是否相同
229 if (
230 self.price_history[symbol]
231 and self.price_history[symbol][-1] == market_data.close
232 ):
233 # 价格相同,可能是重复数据,不添加
234 return
236 self.price_history[symbol].append(market_data.close)
238 # 保持历史数据长度
239 max_history = self.slow_period * 3
240 if len(self.price_history[symbol]) > max_history:
241 self.price_history[symbol] = self.price_history[symbol][-max_history:]
243 def _calculate_ema(self, prices: List[Decimal], period: int) -> List[Decimal]:
244 """计算指数移动平均线"""
245 if len(prices) < period:
246 return []
248 ema_values = []
249 multiplier = Decimal("2") / Decimal(str(period + 1))
251 # 第一个EMA值使用SMA
252 sma = sum(prices[:period]) / Decimal(str(period))
253 ema_values.append(sma)
255 # 计算后续EMA值
256 for i in range(period, len(prices)):
257 ema = (prices[i] * multiplier) + (
258 ema_values[-1] * (Decimal("1") - multiplier)
259 )
260 ema_values.append(ema)
262 return ema_values
264 def _calculate_macd(self, symbol: str) -> Optional[Dict[str, Decimal]]:
265 """计算MACD指标"""
266 if symbol not in self.price_history:
267 return None
269 prices = self.price_history[symbol]
270 if len(prices) < self.slow_period + self.signal_period:
271 return None
273 # 计算EMA
274 ema_fast_values = self._calculate_ema(prices, self.fast_period)
275 ema_slow_values = self._calculate_ema(prices, self.slow_period)
277 # EMA值数量检查:EMA值数量 = len(prices) - period + 1
278 # 我们需要至少signal_period个EMA值来计算信号线
279 if (
280 len(ema_fast_values) < self.signal_period
281 or len(ema_slow_values) < self.signal_period
282 ):
283 self.log_message(
284 f"❌ EMA值不足: 快速EMA={len(ema_fast_values)}条, 慢速EMA={len(ema_slow_values)}条, 需要={self.signal_period}条"
285 )
286 return None
288 # 计算MACD线
289 macd_line_values = []
290 min_length = min(len(ema_fast_values), len(ema_slow_values))
292 for i in range(min_length):
293 macd = ema_fast_values[i] - ema_slow_values[i]
294 macd_line_values.append(macd)
296 if len(macd_line_values) < self.signal_period:
297 self.log_message(
298 f"❌ MACD线值不足: {len(macd_line_values)}条, 需要={self.signal_period}条"
299 )
300 return None
302 # 计算信号线
303 signal_line_values = self._calculate_ema(macd_line_values, self.signal_period)
305 if not signal_line_values:
306 return None
308 # 计算柱状图
309 histogram_values = []
310 for i in range(len(signal_line_values)):
311 hist = macd_line_values[i + self.signal_period - 1] - signal_line_values[i]
312 histogram_values.append(hist)
314 # 更新历史数据
315 self.ema_fast[symbol] = ema_fast_values
316 self.ema_slow[symbol] = ema_slow_values
317 self.macd_line[symbol] = macd_line_values
318 self.signal_line[symbol] = signal_line_values
319 self.histogram[symbol] = histogram_values
321 # 返回最新值
322 return {
323 "macd": macd_line_values[-1],
324 "signal": signal_line_values[-1],
325 "histogram": histogram_values[-1] if histogram_values else Decimal("0"),
326 }
328 def _generate_signal(self, symbol: str, macd_data: Dict[str, Decimal]) -> str:
329 """生成交易信号"""
330 last_signal = self.last_signals.get(symbol, "hold")
332 macd = macd_data["macd"]
333 signal = macd_data["signal"]
334 histogram = macd_data["histogram"]
336 # MACD金叉:MACD线上穿信号线
337 if macd > signal and histogram > 0 and last_signal != "buy":
338 self.last_signals[symbol] = "buy"
339 return "buy"
340 # MACD死叉:MACD线下穿信号线
341 elif macd < signal and histogram < 0 and last_signal != "sell":
342 self.last_signals[symbol] = "sell"
343 return "sell"
344 else:
345 return "hold"
347 def _execute_signal(self, symbol: str, signal: str, current_price: Decimal):
348 """执行交易信号"""
349 self.log_message(f"🔧 开始执行交易信号: {symbol} {signal} @ {current_price}")
351 if not self.trading_engine:
352 self.log_message(f"❌ 交易引擎未设置,无法执行信号: {symbol} {signal}")
353 return
355 portfolio = self.get_portfolio()
356 self.log_message(
357 f"🔧 获取投资组合信息: 总价值=${portfolio.total_value}, 现金=${portfolio.cash}"
358 )
360 if signal == "buy":
361 # 买入逻辑
362 # 计算买入数量
363 position_size = self.calculate_position_size(
364 symbol, current_price, self.position_size
365 )
366 self.log_message(f"🔧 计算买入数量: {position_size}")
368 # 使用实际买入数量检查风险限制
369 if position_size > 0:
370 risk_check = self.check_risk_limits(
371 symbol, position_size, current_price
372 )
373 if not risk_check["valid"]:
374 error_details = ", ".join(risk_check["errors"])
375 self.log_message(
376 f"❌ MACD买入信号被风险控制阻止: {error_details}", "warn"
377 )
378 return
380 # 创建订单
381 order = Order(
382 symbol=symbol,
383 side=OrderSide.BUY,
384 order_type=OrderType.MARKET,
385 quantity=position_size,
386 price=None,
387 session_id=self.context.trading_session_id,
388 )
390 success = self.submit_order(order)
391 if success:
392 self.log_message(
393 f"🟢 MACD买入信号执行: {symbol} {position_size} @ {current_price}",
394 "info",
395 )
396 else:
397 self.log_message(f"❌ MACD买入订单失败: {symbol}", "error")
399 elif signal == "sell":
400 # 卖出逻辑
401 # 查找当前持仓
402 current_position = None
403 for position in portfolio.positions:
404 if position.symbol == symbol:
405 current_position = position
406 break
408 self.log_message(
409 f"🔧 查找当前持仓: {symbol} = {current_position.quantity if current_position else 0}"
410 )
412 if current_position and current_position.quantity > 0:
413 # 卖出全部持仓
414 order = Order(
415 symbol=symbol,
416 side=OrderSide.SELL,
417 order_type=OrderType.MARKET,
418 quantity=current_position.quantity,
419 price=None,
420 session_id=self.context.trading_session_id,
421 )
423 success = self.submit_order(order)
424 if success:
425 self.log_message(
426 f"🔴 MACD卖出信号执行: {symbol} {current_position.quantity} @ {current_price}",
427 "info",
428 )
429 else:
430 self.log_message(f"❌ MACD卖出订单失败: {symbol}", "error")
432 def _check_stop_loss(self):
433 """检查止损"""
434 portfolio = self.get_portfolio()
435 risk_limits = self.get_risk_limits()
437 for position in portfolio.positions:
438 if position.unrealized_pnl < 0:
439 # 计算亏损比例
440 loss_ratio = abs(position.unrealized_pnl) / (
441 position.avg_price * position.quantity
442 )
444 if loss_ratio >= risk_limits.stop_loss_ratio:
445 # 触发止损
446 order = Order(
447 symbol=position.symbol,
448 side=OrderSide.SELL,
449 order_type=OrderType.MARKET,
450 quantity=position.quantity,
451 price=None,
452 session_id=self.context.trading_session_id,
453 )
455 success = self.submit_order(order)
456 if success:
457 self.log_message(
458 f"🛑 MACD止损触发: {position.symbol} {position.quantity}",
459 "warn",
460 )
462 def get_macd_data(self, symbol: str) -> Optional[Dict[str, Decimal]]:
463 """获取当前MACD数据"""
464 if (
465 symbol in self.macd_line
466 and self.macd_line[symbol]
467 and symbol in self.signal_line
468 and self.signal_line[symbol]
469 and symbol in self.histogram
470 and self.histogram[symbol]
471 ):
473 return {
474 "macd": self.macd_line[symbol][-1],
475 "signal": self.signal_line[symbol][-1],
476 "histogram": self.histogram[symbol][-1],
477 }
478 return None
480 def get_macd_history(self, symbol: str) -> Dict[str, List[Decimal]]:
481 """获取MACD历史数据"""
482 return {
483 "macd": self.macd_line.get(symbol, []),
484 "signal": self.signal_line.get(symbol, []),
485 "histogram": self.histogram.get(symbol, []),
486 }