Coverage for core/trading/strategies/ma_crossover_strategy.py: 26.67%

105 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-13 18:58 +0000

1""" 

2移动平均线交叉策略 

3""" 

4 

5from decimal import Decimal 

6from typing import Any, Dict, List, Optional 

7 

8import numpy as np 

9import pandas as pd 

10 

11from core.models.trading import (MarketData, Order, OrderSide, OrderType, 

12 StrategyContext) 

13 

14from .base_strategy import BaseStrategy 

15 

16 

17class MovingAverageCrossoverStrategy(BaseStrategy): 

18 """移动平均线交叉策略""" 

19 

20 def __init__(self, name: str, config: Dict[str, Any]): 

21 super().__init__(name, config) 

22 self.description = "基于短期和长期移动平均线交叉的交易策略" 

23 

24 # 策略参数 

25 self.short_period = config.get("short_period", 5) 

26 self.long_period = config.get("long_period", 20) 

27 self.position_size = config.get("position_size", 0.04) 

28 

29 # 内部状态 - 动态管理 

30 self.price_history: Dict[str, List[float]] = {} 

31 self.positions: Dict[str, Decimal] = {} 

32 self.last_signal: Dict[str, str] = {} 

33 

34 self.log_message(f"{self.name} 初始化完成") 

35 self.log_message(f" 短期均线周期: {self.short_period}") 

36 self.log_message(f" 长期均线周期: {self.long_period}") 

37 self.log_message(f" 仓位大小: {self.position_size}") 

38 self.log_message(f" 策略描述: {self.description}") 

39 

40 def on_market_data(self, market_data: MarketData) -> List[Order]: 

41 """处理市场数据""" 

42 symbol = market_data.symbol 

43 

44 # 动态初始化股票数据 

45 if symbol not in self.price_history: 

46 self.price_history[symbol] = [] 

47 self.positions[symbol] = Decimal("0") 

48 self.last_signal[symbol] = "hold" 

49 self.log_message(f"🆕 初始化股票数据: {symbol}") 

50 

51 orders = [] 

52 

53 # 更新价格历史 

54 if market_data.symbol not in self.price_history: 

55 self.price_history[market_data.symbol] = [] 

56 

57 self.price_history[market_data.symbol].append(float(market_data.close)) 

58 

59 # 保持足够的历史数据 

60 max_history = max(self.short_period, self.long_period) + 10 

61 if len(self.price_history[market_data.symbol]) > max_history: 

62 self.price_history[market_data.symbol] = self.price_history[ 

63 market_data.symbol 

64 ][-max_history:] 

65 

66 # 计算移动平均线 

67 if len(self.price_history[market_data.symbol]) >= self.long_period: 

68 signal = self._calculate_signal(market_data.symbol, market_data.close) 

69 

70 if signal and signal != self.last_signal.get(market_data.symbol): 

71 order = self._create_order( 

72 market_data.symbol, signal, market_data.close 

73 ) 

74 if order: 

75 orders.append(order) 

76 self.last_signal[market_data.symbol] = signal 

77 

78 return orders 

79 

80 def _calculate_signal(self, symbol: str, current_price: float) -> Optional[str]: 

81 """计算交易信号""" 

82 try: 

83 prices = self.price_history[symbol] 

84 

85 # 计算移动平均线 

86 short_ma = np.mean(prices[-self.short_period :]) 

87 long_ma = np.mean(prices[-self.long_period :]) 

88 

89 # 计算前一个周期的移动平均线 

90 if len(prices) > self.long_period + 1: 

91 prev_short_ma = np.mean(prices[-self.short_period - 1 : -1]) 

92 prev_long_ma = np.mean(prices[-self.long_period - 1 : -1]) 

93 

94 # 金叉:短期均线上穿长期均线 

95 if prev_short_ma <= prev_long_ma and short_ma > long_ma: 

96 return "buy" 

97 

98 # 死叉:短期均线下穿长期均线 

99 elif prev_short_ma >= prev_long_ma and short_ma < long_ma: 

100 return "sell" 

101 

102 return None 

103 

104 except Exception as e: 

105 self.log_message(f"❌ 计算信号失败 {symbol}: {e}", "error") 

106 return None 

107 

108 def _create_order( 

109 self, symbol: str, signal: str, current_price: float 

110 ) -> Optional[Order]: 

111 """创建订单""" 

112 try: 

113 # 获取当前持仓 

114 current_position = self.positions.get(symbol, Decimal("0")) 

115 

116 if signal == "buy" and current_position <= 0: 

117 # 买入信号 

118 quantity = Decimal(str(int(self.position_size * 1000))) # 简化计算 

119 

120 order = Order( 

121 symbol=symbol, 

122 side=OrderSide.BUY, 

123 order_type=OrderType.MARKET, 

124 quantity=quantity, 

125 price=None, 

126 session_id=self.session_id, 

127 ) 

128 

129 # 更新持仓 

130 self.positions[symbol] = current_position + quantity 

131 

132 self.log_message(f"📈 买入信号: {symbol} 数量: {quantity}", "info") 

133 return order 

134 

135 elif signal == "sell" and current_position > 0: 

136 # 卖出信号 

137 quantity = current_position 

138 

139 order = Order( 

140 symbol=symbol, 

141 side=OrderSide.SELL, 

142 order_type=OrderType.MARKET, 

143 quantity=quantity, 

144 price=None, 

145 session_id=self.session_id, 

146 ) 

147 

148 # 更新持仓 

149 self.positions[symbol] = Decimal("0") 

150 

151 self.log_message(f"📉 卖出信号: {symbol} 数量: {quantity}", "info") 

152 return order 

153 

154 return None 

155 

156 except Exception as e: 

157 self.log_message(f"❌ 创建订单失败 {symbol}: {e}", "error") 

158 return None 

159 

160 def on_trade_update(self, trade_data: Dict[str, Any]) -> None: 

161 """处理交易更新""" 

162 try: 

163 symbol = trade_data.get("symbol") 

164 side = trade_data.get("side") 

165 quantity = Decimal(str(trade_data.get("quantity", 0))) 

166 status = trade_data.get("status") 

167 

168 if status == "filled" and symbol in self.symbols: 

169 if side == "buy": 

170 self.positions[symbol] = ( 

171 self.positions.get(symbol, Decimal("0")) + quantity 

172 ) 

173 elif side == "sell": 

174 self.positions[symbol] = ( 

175 self.positions.get(symbol, Decimal("0")) - quantity 

176 ) 

177 

178 self.log_message( 

179 f"📊 交易更新: {symbol} {side} {quantity} 持仓: {self.positions[symbol]}", 

180 "info", 

181 ) 

182 

183 except Exception as e: 

184 self.log_message(f"❌ 处理交易更新失败: {e}", "error") 

185 

186 # 注意:策略完全被动化,不再需要定时器回调 

187 # 所有逻辑都在 on_market_data 中处理 

188 

189 def get_strategy_info(self) -> Dict[str, Any]: 

190 """获取策略信息""" 

191 return { 

192 "name": self.name, 

193 "description": self.description, 

194 "parameters": { 

195 "short_period": self.short_period, 

196 "long_period": self.long_period, 

197 "position_size": self.position_size, 

198 "symbols": self.symbols, 

199 }, 

200 "positions": {symbol: float(pos) for symbol, pos in self.positions.items()}, 

201 "price_history_length": { 

202 symbol: len(prices) for symbol, prices in self.price_history.items() 

203 }, 

204 } 

205 

206 def initialize(self, context: StrategyContext) -> None: 

207 """初始化策略""" 

208 self.context = context 

209 self.price_history.clear() 

210 self.positions.clear() 

211 self.last_signal.clear() 

212 self.log_message(f"🔄 {self.name} 已初始化") 

213 

214 def reset(self) -> None: 

215 """重置策略状态""" 

216 self.price_history.clear() 

217 self.positions.clear() 

218 self.last_signal.clear() 

219 self.log_message(f"🔄 {self.name} 状态已重置")