Coverage for core/trading/strategies/rsi_strategy.py: 18.52%

135 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-13 18:58 +0000

1""" 

2RSI策略 

3""" 

4 

5from datetime import datetime 

6from decimal import Decimal 

7from typing import Any, Dict, List, Optional 

8 

9from core.models.trading import (MarketData, Order, OrderSide, OrderType, 

10 StrategyContext) 

11 

12from .base_strategy import BaseStrategy 

13 

14 

15class RSIStrategy(BaseStrategy): 

16 """RSI策略 - 相对强弱指标策略""" 

17 

18 def __init__(self, name: str, config: Dict[str, Any]): 

19 super().__init__(name, config) 

20 

21 # RSI参数 

22 self.period = config.get("period", 14) # RSI周期 

23 self.overbought = config.get("overbought", 70) # 超买阈值 

24 self.oversold = config.get("oversold", 30) # 超卖阈值 

25 self.position_size = config.get("position_size", 0.04) # 仓位大小 

26 

27 # 价格历史 - 动态管理 

28 self.price_history: Dict[str, List[Decimal]] = {} 

29 self.rsi_values: Dict[str, List[Decimal]] = {} 

30 self.last_signals: Dict[str, str] = {} # "buy", "sell", "hold" 

31 

32 def initialize(self, context: StrategyContext): 

33 """策略初始化""" 

34 self.context = context 

35 self.trading_engine = context.trading_engine 

36 self.quote_adapter = context.quote_adapter 

37 

38 self.log_message(f"✅ RSI策略已初始化: {self.name}") 

39 self.log_message(f" RSI周期: {self.period}") 

40 self.log_message(f" 超买阈值: {self.overbought}") 

41 self.log_message(f" 超卖阈值: {self.oversold}") 

42 self.log_message(f" 仓位大小: {self.position_size}") 

43 

44 def on_market_data(self, market_data: MarketData): 

45 """市场数据回调""" 

46 symbol = market_data.symbol 

47 

48 # 动态初始化股票数据 

49 if symbol not in self.price_history: 

50 self.price_history[symbol] = [] 

51 self.rsi_values[symbol] = [] 

52 self.last_signals[symbol] = "hold" 

53 self.log_message(f"🆕 初始化股票数据: {symbol}") 

54 

55 # 更新价格历史 

56 self._update_price_history(market_data) 

57 

58 # 检查是否足够的历史数据 

59 if len(self.price_history[market_data.symbol]) < self.period + 1: 

60 return 

61 

62 # 计算RSI 

63 rsi = self._calculate_rsi(market_data.symbol) 

64 if rsi is None: 

65 return 

66 

67 # 更新RSI历史 

68 self.rsi_values[market_data.symbol].append(rsi) 

69 if len(self.rsi_values[market_data.symbol]) > 100: # 保持最近100个值 

70 self.rsi_values[market_data.symbol] = self.rsi_values[market_data.symbol][ 

71 -100: 

72 ] 

73 

74 # 生成交易信号 

75 signal = self._generate_signal(market_data.symbol, rsi) 

76 

77 # 执行交易 

78 if signal != "hold": 

79 self._execute_signal(market_data.symbol, signal, market_data.close) 

80 

81 # 检查止损(策略完全被动化,在每次市场数据更新时检查) 

82 self._check_stop_loss() 

83 

84 def on_trade_update(self, trade): 

85 """交易更新回调""" 

86 self.log_message( 

87 f"📊 RSI策略交易更新: {trade.symbol} {trade.side} {trade.quantity} @ {trade.price}" 

88 ) 

89 self.trades.append(trade) 

90 self.total_trades += 1 

91 

92 # 更新胜率统计 

93 if trade.side == OrderSide.SELL: 

94 # 简化:假设卖出盈利 

95 self.winning_trades += 1 

96 

97 # 注意:策略完全被动化,不再需要定时器回调 

98 # 止损检查在 on_market_data 中处理 

99 

100 def _update_price_history(self, market_data: MarketData): 

101 """更新价格历史""" 

102 symbol = market_data.symbol 

103 if symbol not in self.price_history: 

104 self.price_history[symbol] = [] 

105 

106 self.price_history[symbol].append(market_data.close) 

107 

108 # 保持历史数据长度 

109 max_history = self.period * 3 

110 if len(self.price_history[symbol]) > max_history: 

111 self.price_history[symbol] = self.price_history[symbol][-max_history:] 

112 

113 def _calculate_rsi(self, symbol: str) -> Optional[Decimal]: 

114 """计算RSI指标""" 

115 if symbol not in self.price_history: 

116 return None 

117 

118 prices = self.price_history[symbol] 

119 if len(prices) < self.period + 1: 

120 return None 

121 

122 # 计算价格变化 

123 price_changes = [] 

124 for i in range(1, len(prices)): 

125 change = prices[i] - prices[i - 1] 

126 price_changes.append(change) 

127 

128 if len(price_changes) < self.period: 

129 return None 

130 

131 # 计算最近period期的平均涨幅和跌幅 

132 recent_changes = price_changes[-self.period :] 

133 

134 gains = [change for change in recent_changes if change > 0] 

135 losses = [-change for change in recent_changes if change < 0] 

136 

137 avg_gain = sum(gains) / len(gains) if gains else Decimal("0") 

138 avg_loss = sum(losses) / len(losses) if losses else Decimal("0") 

139 

140 # 计算RSI 

141 if avg_loss == 0: 

142 return Decimal("100") 

143 

144 rs = avg_gain / avg_loss 

145 rsi = Decimal("100") - (Decimal("100") / (Decimal("1") + rs)) 

146 

147 return rsi 

148 

149 def _generate_signal(self, symbol: str, rsi: Decimal) -> str: 

150 """生成交易信号""" 

151 last_signal = self.last_signals.get(symbol, "hold") 

152 

153 if rsi <= self.oversold and last_signal != "buy": 

154 # RSI超卖,买入信号 

155 self.last_signals[symbol] = "buy" 

156 return "buy" 

157 elif rsi >= self.overbought and last_signal != "sell": 

158 # RSI超买,卖出信号 

159 self.last_signals[symbol] = "sell" 

160 return "sell" 

161 else: 

162 return "hold" 

163 

164 def _execute_signal(self, symbol: str, signal: str, current_price: Decimal): 

165 """执行交易信号""" 

166 if not self.trading_engine: 

167 return 

168 

169 portfolio = self.get_portfolio() 

170 

171 if signal == "buy": 

172 # 买入逻辑 

173 # 检查风险限制 

174 risk_check = self.check_risk_limits(symbol, Decimal("1"), current_price) 

175 if not risk_check["valid"]: 

176 error_details = ", ".join(risk_check["errors"]) 

177 self.log_message( 

178 f"❌ RSI买入信号被风险控制阻止: {error_details}", "warn" 

179 ) 

180 return 

181 

182 # 计算买入数量 

183 position_size = self.calculate_position_size( 

184 symbol, current_price, self.position_size 

185 ) 

186 

187 if position_size > 0: 

188 order = Order( 

189 symbol=symbol, 

190 side=OrderSide.BUY, 

191 order_type=OrderType.MARKET, 

192 quantity=position_size, 

193 price=None, 

194 session_id=self.context.trading_session_id, 

195 ) 

196 

197 success = self.submit_order(order) 

198 if success: 

199 self.log_message( 

200 f"🟢 RSI买入信号执行: {symbol} {position_size} @ {current_price}", 

201 "info", 

202 ) 

203 else: 

204 self.log_message(f"❌ RSI买入订单失败: {symbol}", "error") 

205 

206 elif signal == "sell": 

207 # 卖出逻辑 

208 # 查找当前持仓 

209 current_position = None 

210 for position in portfolio.positions: 

211 if position.symbol == symbol: 

212 current_position = position 

213 break 

214 

215 if current_position and current_position.quantity > 0: 

216 # 卖出全部持仓 

217 order = Order( 

218 symbol=symbol, 

219 side=OrderSide.SELL, 

220 order_type=OrderType.MARKET, 

221 quantity=current_position.quantity, 

222 price=None, 

223 session_id=self.context.trading_session_id, 

224 ) 

225 

226 success = self.submit_order(order) 

227 if success: 

228 self.log_message( 

229 f"🔴 RSI卖出信号执行: {symbol} {current_position.quantity} @ {current_price}", 

230 "info", 

231 ) 

232 else: 

233 self.log_message(f"❌ RSI卖出订单失败: {symbol}", "error") 

234 

235 def _check_stop_loss(self): 

236 """检查止损""" 

237 portfolio = self.get_portfolio() 

238 risk_limits = self.get_risk_limits() 

239 

240 for position in portfolio.positions: 

241 if position.unrealized_pnl < 0: 

242 # 计算亏损比例 

243 loss_ratio = abs(position.unrealized_pnl) / ( 

244 position.avg_price * position.quantity 

245 ) 

246 

247 if loss_ratio >= risk_limits.stop_loss_ratio: 

248 # 触发止损 

249 order = Order( 

250 symbol=position.symbol, 

251 side=OrderSide.SELL, 

252 order_type=OrderType.MARKET, 

253 quantity=position.quantity, 

254 price=None, 

255 session_id=self.context.trading_session_id, 

256 ) 

257 

258 success = self.submit_order(order) 

259 if success: 

260 self.log_message( 

261 f"🛑 RSI止损触发: {position.symbol} {position.quantity}", 

262 "warn", 

263 ) 

264 

265 def get_rsi_value(self, symbol: str) -> Optional[Decimal]: 

266 """获取当前RSI值""" 

267 if symbol in self.rsi_values and self.rsi_values[symbol]: 

268 return self.rsi_values[symbol][-1] 

269 return None 

270 

271 def get_rsi_history(self, symbol: str) -> List[Decimal]: 

272 """获取RSI历史值""" 

273 return self.rsi_values.get(symbol, [])