Coverage for core/trading/strategies/rsi_strategy.py: 18.52%
135 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-13 18:58 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-13 18:58 +0000
1"""
2RSI策略
3"""
5from datetime import datetime
6from decimal import Decimal
7from typing import Any, Dict, List, Optional
9from core.models.trading import (MarketData, Order, OrderSide, OrderType,
10 StrategyContext)
12from .base_strategy import BaseStrategy
15class RSIStrategy(BaseStrategy):
16 """RSI策略 - 相对强弱指标策略"""
18 def __init__(self, name: str, config: Dict[str, Any]):
19 super().__init__(name, config)
21 # RSI参数
22 self.period = config.get("period", 14) # RSI周期
23 self.overbought = config.get("overbought", 70) # 超买阈值
24 self.oversold = config.get("oversold", 30) # 超卖阈值
25 self.position_size = config.get("position_size", 0.04) # 仓位大小
27 # 价格历史 - 动态管理
28 self.price_history: Dict[str, List[Decimal]] = {}
29 self.rsi_values: Dict[str, List[Decimal]] = {}
30 self.last_signals: Dict[str, str] = {} # "buy", "sell", "hold"
32 def initialize(self, context: StrategyContext):
33 """策略初始化"""
34 self.context = context
35 self.trading_engine = context.trading_engine
36 self.quote_adapter = context.quote_adapter
38 self.log_message(f"✅ RSI策略已初始化: {self.name}")
39 self.log_message(f" RSI周期: {self.period}")
40 self.log_message(f" 超买阈值: {self.overbought}")
41 self.log_message(f" 超卖阈值: {self.oversold}")
42 self.log_message(f" 仓位大小: {self.position_size}")
44 def on_market_data(self, market_data: MarketData):
45 """市场数据回调"""
46 symbol = market_data.symbol
48 # 动态初始化股票数据
49 if symbol not in self.price_history:
50 self.price_history[symbol] = []
51 self.rsi_values[symbol] = []
52 self.last_signals[symbol] = "hold"
53 self.log_message(f"🆕 初始化股票数据: {symbol}")
55 # 更新价格历史
56 self._update_price_history(market_data)
58 # 检查是否足够的历史数据
59 if len(self.price_history[market_data.symbol]) < self.period + 1:
60 return
62 # 计算RSI
63 rsi = self._calculate_rsi(market_data.symbol)
64 if rsi is None:
65 return
67 # 更新RSI历史
68 self.rsi_values[market_data.symbol].append(rsi)
69 if len(self.rsi_values[market_data.symbol]) > 100: # 保持最近100个值
70 self.rsi_values[market_data.symbol] = self.rsi_values[market_data.symbol][
71 -100:
72 ]
74 # 生成交易信号
75 signal = self._generate_signal(market_data.symbol, rsi)
77 # 执行交易
78 if signal != "hold":
79 self._execute_signal(market_data.symbol, signal, market_data.close)
81 # 检查止损(策略完全被动化,在每次市场数据更新时检查)
82 self._check_stop_loss()
84 def on_trade_update(self, trade):
85 """交易更新回调"""
86 self.log_message(
87 f"📊 RSI策略交易更新: {trade.symbol} {trade.side} {trade.quantity} @ {trade.price}"
88 )
89 self.trades.append(trade)
90 self.total_trades += 1
92 # 更新胜率统计
93 if trade.side == OrderSide.SELL:
94 # 简化:假设卖出盈利
95 self.winning_trades += 1
97 # 注意:策略完全被动化,不再需要定时器回调
98 # 止损检查在 on_market_data 中处理
100 def _update_price_history(self, market_data: MarketData):
101 """更新价格历史"""
102 symbol = market_data.symbol
103 if symbol not in self.price_history:
104 self.price_history[symbol] = []
106 self.price_history[symbol].append(market_data.close)
108 # 保持历史数据长度
109 max_history = self.period * 3
110 if len(self.price_history[symbol]) > max_history:
111 self.price_history[symbol] = self.price_history[symbol][-max_history:]
113 def _calculate_rsi(self, symbol: str) -> Optional[Decimal]:
114 """计算RSI指标"""
115 if symbol not in self.price_history:
116 return None
118 prices = self.price_history[symbol]
119 if len(prices) < self.period + 1:
120 return None
122 # 计算价格变化
123 price_changes = []
124 for i in range(1, len(prices)):
125 change = prices[i] - prices[i - 1]
126 price_changes.append(change)
128 if len(price_changes) < self.period:
129 return None
131 # 计算最近period期的平均涨幅和跌幅
132 recent_changes = price_changes[-self.period :]
134 gains = [change for change in recent_changes if change > 0]
135 losses = [-change for change in recent_changes if change < 0]
137 avg_gain = sum(gains) / len(gains) if gains else Decimal("0")
138 avg_loss = sum(losses) / len(losses) if losses else Decimal("0")
140 # 计算RSI
141 if avg_loss == 0:
142 return Decimal("100")
144 rs = avg_gain / avg_loss
145 rsi = Decimal("100") - (Decimal("100") / (Decimal("1") + rs))
147 return rsi
149 def _generate_signal(self, symbol: str, rsi: Decimal) -> str:
150 """生成交易信号"""
151 last_signal = self.last_signals.get(symbol, "hold")
153 if rsi <= self.oversold and last_signal != "buy":
154 # RSI超卖,买入信号
155 self.last_signals[symbol] = "buy"
156 return "buy"
157 elif rsi >= self.overbought and last_signal != "sell":
158 # RSI超买,卖出信号
159 self.last_signals[symbol] = "sell"
160 return "sell"
161 else:
162 return "hold"
164 def _execute_signal(self, symbol: str, signal: str, current_price: Decimal):
165 """执行交易信号"""
166 if not self.trading_engine:
167 return
169 portfolio = self.get_portfolio()
171 if signal == "buy":
172 # 买入逻辑
173 # 检查风险限制
174 risk_check = self.check_risk_limits(symbol, Decimal("1"), current_price)
175 if not risk_check["valid"]:
176 error_details = ", ".join(risk_check["errors"])
177 self.log_message(
178 f"❌ RSI买入信号被风险控制阻止: {error_details}", "warn"
179 )
180 return
182 # 计算买入数量
183 position_size = self.calculate_position_size(
184 symbol, current_price, self.position_size
185 )
187 if position_size > 0:
188 order = Order(
189 symbol=symbol,
190 side=OrderSide.BUY,
191 order_type=OrderType.MARKET,
192 quantity=position_size,
193 price=None,
194 session_id=self.context.trading_session_id,
195 )
197 success = self.submit_order(order)
198 if success:
199 self.log_message(
200 f"🟢 RSI买入信号执行: {symbol} {position_size} @ {current_price}",
201 "info",
202 )
203 else:
204 self.log_message(f"❌ RSI买入订单失败: {symbol}", "error")
206 elif signal == "sell":
207 # 卖出逻辑
208 # 查找当前持仓
209 current_position = None
210 for position in portfolio.positions:
211 if position.symbol == symbol:
212 current_position = position
213 break
215 if current_position and current_position.quantity > 0:
216 # 卖出全部持仓
217 order = Order(
218 symbol=symbol,
219 side=OrderSide.SELL,
220 order_type=OrderType.MARKET,
221 quantity=current_position.quantity,
222 price=None,
223 session_id=self.context.trading_session_id,
224 )
226 success = self.submit_order(order)
227 if success:
228 self.log_message(
229 f"🔴 RSI卖出信号执行: {symbol} {current_position.quantity} @ {current_price}",
230 "info",
231 )
232 else:
233 self.log_message(f"❌ RSI卖出订单失败: {symbol}", "error")
235 def _check_stop_loss(self):
236 """检查止损"""
237 portfolio = self.get_portfolio()
238 risk_limits = self.get_risk_limits()
240 for position in portfolio.positions:
241 if position.unrealized_pnl < 0:
242 # 计算亏损比例
243 loss_ratio = abs(position.unrealized_pnl) / (
244 position.avg_price * position.quantity
245 )
247 if loss_ratio >= risk_limits.stop_loss_ratio:
248 # 触发止损
249 order = Order(
250 symbol=position.symbol,
251 side=OrderSide.SELL,
252 order_type=OrderType.MARKET,
253 quantity=position.quantity,
254 price=None,
255 session_id=self.context.trading_session_id,
256 )
258 success = self.submit_order(order)
259 if success:
260 self.log_message(
261 f"🛑 RSI止损触发: {position.symbol} {position.quantity}",
262 "warn",
263 )
265 def get_rsi_value(self, symbol: str) -> Optional[Decimal]:
266 """获取当前RSI值"""
267 if symbol in self.rsi_values and self.rsi_values[symbol]:
268 return self.rsi_values[symbol][-1]
269 return None
271 def get_rsi_history(self, symbol: str) -> List[Decimal]:
272 """获取RSI历史值"""
273 return self.rsi_values.get(symbol, [])