Coverage for core/trading/strategies/ma_crossover_strategy.py: 26.67%
105 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-13 18:58 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-13 18:58 +0000
1"""
2移动平均线交叉策略
3"""
5from decimal import Decimal
6from typing import Any, Dict, List, Optional
8import numpy as np
9import pandas as pd
11from core.models.trading import (MarketData, Order, OrderSide, OrderType,
12 StrategyContext)
14from .base_strategy import BaseStrategy
17class MovingAverageCrossoverStrategy(BaseStrategy):
18 """移动平均线交叉策略"""
20 def __init__(self, name: str, config: Dict[str, Any]):
21 super().__init__(name, config)
22 self.description = "基于短期和长期移动平均线交叉的交易策略"
24 # 策略参数
25 self.short_period = config.get("short_period", 5)
26 self.long_period = config.get("long_period", 20)
27 self.position_size = config.get("position_size", 0.04)
29 # 内部状态 - 动态管理
30 self.price_history: Dict[str, List[float]] = {}
31 self.positions: Dict[str, Decimal] = {}
32 self.last_signal: Dict[str, str] = {}
34 self.log_message(f"✅ {self.name} 初始化完成")
35 self.log_message(f" 短期均线周期: {self.short_period}")
36 self.log_message(f" 长期均线周期: {self.long_period}")
37 self.log_message(f" 仓位大小: {self.position_size}")
38 self.log_message(f" 策略描述: {self.description}")
40 def on_market_data(self, market_data: MarketData) -> List[Order]:
41 """处理市场数据"""
42 symbol = market_data.symbol
44 # 动态初始化股票数据
45 if symbol not in self.price_history:
46 self.price_history[symbol] = []
47 self.positions[symbol] = Decimal("0")
48 self.last_signal[symbol] = "hold"
49 self.log_message(f"🆕 初始化股票数据: {symbol}")
51 orders = []
53 # 更新价格历史
54 if market_data.symbol not in self.price_history:
55 self.price_history[market_data.symbol] = []
57 self.price_history[market_data.symbol].append(float(market_data.close))
59 # 保持足够的历史数据
60 max_history = max(self.short_period, self.long_period) + 10
61 if len(self.price_history[market_data.symbol]) > max_history:
62 self.price_history[market_data.symbol] = self.price_history[
63 market_data.symbol
64 ][-max_history:]
66 # 计算移动平均线
67 if len(self.price_history[market_data.symbol]) >= self.long_period:
68 signal = self._calculate_signal(market_data.symbol, market_data.close)
70 if signal and signal != self.last_signal.get(market_data.symbol):
71 order = self._create_order(
72 market_data.symbol, signal, market_data.close
73 )
74 if order:
75 orders.append(order)
76 self.last_signal[market_data.symbol] = signal
78 return orders
80 def _calculate_signal(self, symbol: str, current_price: float) -> Optional[str]:
81 """计算交易信号"""
82 try:
83 prices = self.price_history[symbol]
85 # 计算移动平均线
86 short_ma = np.mean(prices[-self.short_period :])
87 long_ma = np.mean(prices[-self.long_period :])
89 # 计算前一个周期的移动平均线
90 if len(prices) > self.long_period + 1:
91 prev_short_ma = np.mean(prices[-self.short_period - 1 : -1])
92 prev_long_ma = np.mean(prices[-self.long_period - 1 : -1])
94 # 金叉:短期均线上穿长期均线
95 if prev_short_ma <= prev_long_ma and short_ma > long_ma:
96 return "buy"
98 # 死叉:短期均线下穿长期均线
99 elif prev_short_ma >= prev_long_ma and short_ma < long_ma:
100 return "sell"
102 return None
104 except Exception as e:
105 self.log_message(f"❌ 计算信号失败 {symbol}: {e}", "error")
106 return None
108 def _create_order(
109 self, symbol: str, signal: str, current_price: float
110 ) -> Optional[Order]:
111 """创建订单"""
112 try:
113 # 获取当前持仓
114 current_position = self.positions.get(symbol, Decimal("0"))
116 if signal == "buy" and current_position <= 0:
117 # 买入信号
118 quantity = Decimal(str(int(self.position_size * 1000))) # 简化计算
120 order = Order(
121 symbol=symbol,
122 side=OrderSide.BUY,
123 order_type=OrderType.MARKET,
124 quantity=quantity,
125 price=None,
126 session_id=self.session_id,
127 )
129 # 更新持仓
130 self.positions[symbol] = current_position + quantity
132 self.log_message(f"📈 买入信号: {symbol} 数量: {quantity}", "info")
133 return order
135 elif signal == "sell" and current_position > 0:
136 # 卖出信号
137 quantity = current_position
139 order = Order(
140 symbol=symbol,
141 side=OrderSide.SELL,
142 order_type=OrderType.MARKET,
143 quantity=quantity,
144 price=None,
145 session_id=self.session_id,
146 )
148 # 更新持仓
149 self.positions[symbol] = Decimal("0")
151 self.log_message(f"📉 卖出信号: {symbol} 数量: {quantity}", "info")
152 return order
154 return None
156 except Exception as e:
157 self.log_message(f"❌ 创建订单失败 {symbol}: {e}", "error")
158 return None
160 def on_trade_update(self, trade_data: Dict[str, Any]) -> None:
161 """处理交易更新"""
162 try:
163 symbol = trade_data.get("symbol")
164 side = trade_data.get("side")
165 quantity = Decimal(str(trade_data.get("quantity", 0)))
166 status = trade_data.get("status")
168 if status == "filled" and symbol in self.symbols:
169 if side == "buy":
170 self.positions[symbol] = (
171 self.positions.get(symbol, Decimal("0")) + quantity
172 )
173 elif side == "sell":
174 self.positions[symbol] = (
175 self.positions.get(symbol, Decimal("0")) - quantity
176 )
178 self.log_message(
179 f"📊 交易更新: {symbol} {side} {quantity} 持仓: {self.positions[symbol]}",
180 "info",
181 )
183 except Exception as e:
184 self.log_message(f"❌ 处理交易更新失败: {e}", "error")
186 # 注意:策略完全被动化,不再需要定时器回调
187 # 所有逻辑都在 on_market_data 中处理
189 def get_strategy_info(self) -> Dict[str, Any]:
190 """获取策略信息"""
191 return {
192 "name": self.name,
193 "description": self.description,
194 "parameters": {
195 "short_period": self.short_period,
196 "long_period": self.long_period,
197 "position_size": self.position_size,
198 "symbols": self.symbols,
199 },
200 "positions": {symbol: float(pos) for symbol, pos in self.positions.items()},
201 "price_history_length": {
202 symbol: len(prices) for symbol, prices in self.price_history.items()
203 },
204 }
206 def initialize(self, context: StrategyContext) -> None:
207 """初始化策略"""
208 self.context = context
209 self.price_history.clear()
210 self.positions.clear()
211 self.last_signal.clear()
212 self.log_message(f"🔄 {self.name} 已初始化")
214 def reset(self) -> None:
215 """重置策略状态"""
216 self.price_history.clear()
217 self.positions.clear()
218 self.last_signal.clear()
219 self.log_message(f"🔄 {self.name} 状态已重置")