Coverage for core/trading/backtest/backtest_engine.py: 13.17%
243 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-13 18:58 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-13 18:58 +0000
1"""
2回测引擎
3"""
5import asyncio
6import time
7from datetime import date, datetime, timedelta
8from decimal import Decimal
9from typing import Any, Callable, Dict, List, Optional
11from core.data_source.adapters.quote_adapter import QuoteAdapter
12from core.models.trading import (AssetMode, MarketData, PerformanceMetrics,
13 RiskConfig, StrategyContext, TradingMode)
14from core.trading.engines import create_trading_engine
15from core.trading.strategies import create_strategy
18class BacktestEngine:
19 """回测引擎"""
21 def __init__(self, user_id: str, session_id: str, initial_capital: Decimal):
22 self.user_id = user_id
23 self.session_id = session_id
24 self.initial_capital = initial_capital
26 # 回测参数
27 self.start_date: Optional[datetime] = None
28 self.end_date: Optional[datetime] = None
29 self.timezone: str = "UTC"
31 # 数据源
32 self.quote_adapter = QuoteAdapter(user_id)
34 # 回测状态
35 self.is_running = False
36 self.current_time: Optional[datetime] = None
37 self.is_paused = False
39 # 回测结果
40 self.performance_metrics: Optional[PerformanceMetrics] = None
41 self.trade_history: List[Dict[str, Any]] = []
42 self.daily_returns: List[Dict[str, Any]] = []
44 # 回调函数
45 self.on_progress: Optional[Callable] = None
46 self.on_complete: Optional[Callable] = None
47 self.on_error: Optional[Callable] = None
49 def set_time_range(
50 self, start_date: datetime, end_date: datetime, timezone: str = "UTC"
51 ):
52 """设置回测时间范围"""
53 self.start_date = start_date
54 self.end_date = end_date
55 self.timezone = timezone
56 self.current_time = start_date
58 def set_callbacks(
59 self,
60 on_progress: Optional[Callable] = None,
61 on_complete: Optional[Callable] = None,
62 on_error: Optional[Callable] = None,
63 ):
64 """设置回调函数"""
65 self.on_progress = on_progress
66 self.on_complete = on_complete
67 self.on_error = on_error
69 async def run_backtest(
70 self, strategy_name: str, strategy_config: Dict[str, Any], symbols: List[str]
71 ) -> Dict[str, Any]:
72 """运行回测"""
73 if self.is_running:
74 raise RuntimeError("回测已在运行中")
76 if not self.start_date or not self.end_date:
77 raise ValueError("请先设置回测时间范围")
79 try:
80 self.is_running = True
81 self.current_time = self.start_date
83 # 创建模拟交易引擎
84 trading_engine = create_trading_engine(
85 user_id=self.user_id,
86 session_id=self.session_id,
87 trading_mode=TradingMode.BACKTEST,
88 asset_mode=AssetMode.SIMULATION,
89 initial_capital=self.initial_capital,
90 )
92 if not trading_engine.initialize():
93 raise RuntimeError("交易引擎初始化失败")
95 if not trading_engine.start():
96 raise RuntimeError("交易引擎启动失败")
98 # 创建策略
99 strategy = create_strategy(strategy_name, strategy_config)
100 if not strategy:
101 raise RuntimeError(f"策略创建失败: {strategy_name}")
103 # 创建策略上下文
104 context = StrategyContext(
105 user_id=self.user_id,
106 trading_session_id=self.session_id,
107 start_time=self.start_date,
108 current_time=self.current_time,
109 market_data_cache={},
110 portfolio_cache=None,
111 config={},
112 )
113 context.trading_engine = trading_engine
114 context.quote_adapter = self.quote_adapter
116 # 初始化策略
117 strategy.initialize(context)
119 # 运行回测循环
120 await self._run_backtest_loop(trading_engine, strategy, symbols)
122 # 计算性能指标
123 self.performance_metrics = self._calculate_performance_metrics(
124 trading_engine
125 )
127 # 停止交易引擎
128 trading_engine.stop()
130 result = {
131 "success": True,
132 "performance_metrics": self.performance_metrics,
133 "trade_history": self.trade_history,
134 "daily_returns": self.daily_returns,
135 "start_date": self.start_date,
136 "end_date": self.end_date,
137 "initial_capital": self.initial_capital,
138 }
140 if self.on_complete:
141 self.on_complete(result)
143 return result
145 except Exception as e:
146 error_msg = f"回测运行失败: {str(e)}"
147 if self.on_error:
148 self.on_error(error_msg)
149 raise RuntimeError(error_msg)
150 finally:
151 self.is_running = False
153 async def _run_backtest_loop(self, trading_engine, strategy, symbols: List[str]):
154 """运行回测循环 - 按分钟推进,确保可靠的时序控制"""
155 total_minutes = int((self.end_date - self.start_date).total_seconds() / 60)
156 current_minute = 0
158 while self.current_time <= self.end_date and not self.is_paused:
159 try:
160 # 1. 获取当前时间点的市场数据
161 market_data_list = await self._get_market_data_at_time(
162 self.current_time, symbols
163 )
165 if market_data_list:
166 # 2. 处理每个股票的市场数据
167 for market_data in market_data_list:
168 # 更新策略上下文
169 strategy.context.current_time = self.current_time
171 # 通知策略 - 等待策略处理完成
172 await self._wait_for_strategy_processing(strategy, market_data)
174 # 通知交易引擎 - 等待交易处理完成
175 await self._wait_for_trading_processing(
176 trading_engine, market_data
177 )
179 # 3. 等待所有交易计算完成
180 await self._wait_for_trading_calculations(trading_engine)
182 # 4. 处理交易回调
183 await self._process_trading_callbacks(trading_engine, strategy)
185 # 5. 记录每日收益(如果是交易日结束)
186 if self._is_trading_day_end():
187 await self._record_daily_returns(trading_engine)
189 # 6. 推进到下一个时间点(按分钟)
190 self.current_time += timedelta(minutes=1)
191 current_minute += 1
193 # 7. 报告进度
194 if self.on_progress and current_minute % 60 == 0: # 每小时报告一次
195 progress = current_minute / total_minutes
196 self.on_progress(
197 progress, f"回测进度: {current_minute}/{total_minutes} 分钟"
198 )
200 # 8. 短暂休眠,避免过度占用CPU
201 await asyncio.sleep(0.001)
203 except Exception as e:
204 print(f"❌ 回测循环错误: {e}")
205 continue
207 async def _wait_for_strategy_processing(self, strategy, market_data: MarketData):
208 """等待策略处理完成"""
209 try:
210 # 通知策略处理市场数据
211 strategy.on_market_data(market_data)
213 # 等待策略完成所有分析和交易决策
214 # 这里可以添加策略处理完成的确认机制
215 await asyncio.sleep(0.01) # 短暂等待,确保策略处理完成
217 except Exception as e:
218 print(f"❌ 策略处理失败: {e}")
220 async def _wait_for_trading_processing(
221 self, trading_engine, market_data: MarketData
222 ):
223 """等待交易处理完成"""
224 try:
225 # 通知交易引擎处理市场数据
226 trading_engine.process_market_data(market_data)
228 # 等待交易引擎完成订单成交判定
229 await asyncio.sleep(0.01) # 短暂等待,确保交易处理完成
231 except Exception as e:
232 print(f"❌ 交易处理失败: {e}")
234 async def _wait_for_trading_calculations(self, trading_engine):
235 """等待所有交易计算完成"""
236 try:
237 # 等待持仓更新、资金计算等完成
238 await asyncio.sleep(0.01) # 短暂等待,确保计算完成
240 except Exception as e:
241 print(f"❌ 交易计算失败: {e}")
243 async def _process_trading_callbacks(self, trading_engine, strategy):
244 """处理交易回调"""
245 try:
246 # 处理交易完成后的回调
247 # 例如:通知策略交易结果、更新持仓等
248 await asyncio.sleep(0.01) # 短暂等待,确保回调完成
250 except Exception as e:
251 print(f"❌ 交易回调处理失败: {e}")
253 def _is_trading_day_end(self) -> bool:
254 """检查是否是交易日结束"""
255 # 简单判断:如果是交易日的收盘时间(例如15:00)
256 return self.current_time.hour == 15 and self.current_time.minute == 0
258 async def _get_market_data_at_time(
259 self, current_time: datetime, symbols: List[str]
260 ) -> List[MarketData]:
261 """获取指定时间点的市场数据"""
262 market_data_list = []
264 for symbol in symbols:
265 try:
266 # 尝试从历史数据中获取真实数据
267 historical_data = await self._get_historical_data(symbol, current_time)
269 if historical_data:
270 market_data = MarketData(
271 symbol=symbol,
272 timestamp=current_time,
273 open=Decimal(str(historical_data.get("open", 150))),
274 high=Decimal(str(historical_data.get("high", 155))),
275 low=Decimal(str(historical_data.get("low", 148))),
276 close=Decimal(str(historical_data.get("close", 152))),
277 volume=int(historical_data.get("volume", 1000)),
278 )
279 print(
280 f"📊 获取历史数据: {symbol} at {current_time} = ${market_data.close:.2f}"
281 )
282 else:
283 # 如果没有历史数据,使用模拟数据
284 market_data = self._generate_simulated_data(symbol, current_time)
285 print(
286 f"🎭 使用模拟数据: {symbol} at {current_time} = ${market_data.close:.2f}"
287 )
289 market_data_list.append(market_data)
291 except Exception as e:
292 print(f"❌ 获取市场数据失败: {symbol} - {e}")
293 continue
295 return market_data_list
297 async def _get_historical_data(
298 self, symbol: str, timestamp: datetime
299 ) -> Optional[Dict[str, Any]]:
300 """从数据库获取历史数据"""
301 try:
302 # 使用QuoteDataSourceAdapter的抽象接口获取历史数据
303 # 注意:这里需要获取一个时间范围的数据,而不是单个时间点
304 from datetime import timedelta
306 start_time = timestamp - timedelta(minutes=1)
307 end_time = timestamp + timedelta(minutes=1)
309 # 获取历史数据
310 def log_callback(message: str, log_type: str = "info"):
311 self.log_message(f"【数据适配器】{message}", log_type)
313 historical_data = self.quote_data_adapter.get_historical_data(
314 symbol, start_time, end_time, log_callback
315 )
317 if historical_data and len(historical_data) > 0:
318 # 找到最接近目标时间的数据
319 target_data = None
320 min_diff = float("inf")
321 for data in historical_data:
322 time_diff = abs((data.timestamp - timestamp).total_seconds())
323 if time_diff < min_diff:
324 min_diff = time_diff
325 target_data = data
327 if target_data:
328 return {
329 "open": float(target_data.open),
330 "high": float(target_data.high),
331 "low": float(target_data.low),
332 "close": float(target_data.close),
333 "volume": int(target_data.volume),
334 }
336 return None
338 except Exception as e:
339 print(f"❌ 获取历史数据失败: {symbol} at {timestamp} - {e}")
340 return None
342 def _generate_simulated_data(self, symbol: str, timestamp: datetime) -> MarketData:
343 """生成模拟市场数据"""
344 import random
346 # 基础价格
347 base_prices = {
348 "AAPL.US": Decimal("150.50"),
349 "MSFT.US": Decimal("300.25"),
350 "TSLA.US": Decimal("200.75"),
351 "GOOGL.US": Decimal("2500.00"),
352 "YINN.US": Decimal("25.00"),
353 "YANG.US": Decimal("15.00"),
354 }
356 base_price = base_prices.get(symbol, Decimal("100.00"))
358 # 添加时间相关的波动
359 time_factor = (timestamp.hour * 60 + timestamp.minute) / (24 * 60) # 0-1之间
360 daily_variation = Decimal(
361 str(round(random.uniform(-0.05, 0.05), 4))
362 ) # ±5% 日波动,保留4位小数
363 time_variation = Decimal(
364 str(round(random.uniform(-0.01, 0.01), 4))
365 ) # ±1% 时间波动,保留4位小数
367 current_price = base_price * (Decimal("1") + daily_variation + time_variation)
368 # 将价格四舍五入到2位小数
369 current_price = current_price.quantize(Decimal("0.01"))
371 # 计算最高价和最低价,并四舍五入到2位小数
372 high_price = (current_price * Decimal("1.02")).quantize(Decimal("0.01"))
373 low_price = (current_price * Decimal("0.98")).quantize(Decimal("0.01"))
375 return MarketData(
376 symbol=symbol,
377 timestamp=timestamp,
378 open=current_price,
379 high=high_price,
380 low=low_price,
381 close=current_price,
382 volume=random.randint(1000, 10000),
383 )
385 async def _record_daily_returns(self, trading_engine):
386 """记录每日收益"""
387 try:
388 # 获取当前账户余额
389 balance = trading_engine.get_account_balance()
391 # 计算日收益率
392 daily_return = (
393 balance.total_cash - self.initial_capital
394 ) / self.initial_capital
396 # 记录每日收益
397 self.daily_returns.append(
398 {
399 "date": self.current_time.date(),
400 "total_value": balance.total_cash,
401 "daily_return": daily_return,
402 "cumulative_return": daily_return,
403 }
404 )
406 except Exception as e:
407 print(f"❌ 记录每日收益失败: {e}")
409 def _calculate_performance_metrics(self, trading_engine) -> PerformanceMetrics:
410 """计算性能指标"""
411 try:
412 # 获取最终账户余额
413 final_balance = trading_engine.get_account_balance()
415 # 计算总收益率
416 total_return = (
417 final_balance.total_cash - self.initial_capital
418 ) / self.initial_capital
420 # 计算年化收益率
421 days = (self.end_date - self.start_date).days
422 years = days / 365.25
423 annualized_return = (
424 (1 + total_return) ** (1 / years) - 1 if years > 0 else Decimal("0")
425 )
427 # 计算最大回撤
428 max_drawdown = self._calculate_max_drawdown()
430 # 计算夏普比率(简化版本)
431 sharpe_ratio = self._calculate_sharpe_ratio()
433 # 计算胜率
434 win_rate = self._calculate_win_rate()
436 # 统计交易次数
437 total_trades = len(self.trade_history)
438 winning_trades = len(
439 [trade for trade in self.trade_history if trade.get("pnl", 0) > 0]
440 )
441 losing_trades = total_trades - winning_trades
443 return PerformanceMetrics(
444 total_return=total_return,
445 annualized_return=annualized_return,
446 max_drawdown=max_drawdown,
447 sharpe_ratio=sharpe_ratio,
448 win_rate=win_rate,
449 total_trades=total_trades,
450 winning_trades=winning_trades,
451 losing_trades=losing_trades,
452 )
454 except Exception as e:
455 print(f"❌ 计算性能指标失败: {e}")
456 return PerformanceMetrics(
457 total_return=Decimal("0"),
458 annualized_return=Decimal("0"),
459 max_drawdown=Decimal("0"),
460 sharpe_ratio=Decimal("0"),
461 win_rate=Decimal("0"),
462 total_trades=0,
463 winning_trades=0,
464 losing_trades=0,
465 )
467 def _calculate_max_drawdown(self) -> Decimal:
468 """计算最大回撤"""
469 if not self.daily_returns:
470 return Decimal("0")
472 peak = self.initial_capital
473 max_drawdown = Decimal("0")
475 for daily_data in self.daily_returns:
476 current_value = daily_data["total_value"]
478 if current_value > peak:
479 peak = current_value
481 drawdown = (peak - current_value) / peak
482 if drawdown > max_drawdown:
483 max_drawdown = drawdown
485 return max_drawdown
487 def _calculate_sharpe_ratio(self) -> Decimal:
488 """计算夏普比率(简化版本)"""
489 if not self.daily_returns or len(self.daily_returns) < 2:
490 return Decimal("0")
492 # 计算日收益率
493 daily_returns = []
494 for i in range(1, len(self.daily_returns)):
495 prev_value = self.daily_returns[i - 1]["total_value"]
496 curr_value = self.daily_returns[i]["total_value"]
497 daily_return = (curr_value - prev_value) / prev_value
498 daily_returns.append(daily_return)
500 if not daily_returns:
501 return Decimal("0")
503 # 计算平均收益率和标准差
504 avg_return = sum(daily_returns) / len(daily_returns)
506 variance = sum((r - avg_return) ** 2 for r in daily_returns) / len(
507 daily_returns
508 )
509 std_dev = Decimal(str(variance**0.5))
511 # 计算夏普比率(假设无风险利率为0)
512 if std_dev == 0:
513 return Decimal("0")
515 sharpe_ratio = avg_return / std_dev
516 return sharpe_ratio
518 def _calculate_win_rate(self) -> Decimal:
519 """计算胜率"""
520 if not self.trade_history:
521 return Decimal("0")
523 winning_trades = len(
524 [trade for trade in self.trade_history if trade.get("pnl", 0) > 0]
525 )
526 return Decimal(str(winning_trades)) / Decimal(str(len(self.trade_history)))
528 def pause(self):
529 """暂停回测"""
530 self.is_paused = True
531 print("⏸️ 回测已暂停")
533 def resume(self):
534 """恢复回测"""
535 self.is_paused = False
536 print("▶️ 回测已恢复")
538 def stop(self):
539 """停止回测"""
540 self.is_running = False
541 self.is_paused = False
542 print("🛑 回测已停止")
544 def get_progress(self) -> Dict[str, Any]:
545 """获取回测进度"""
546 if not self.start_date or not self.end_date or not self.current_time:
547 return {"progress": 0, "status": "not_started"}
549 total_days = (self.end_date - self.start_date).days
550 current_days = (self.current_time - self.start_date).days
551 progress = current_days / total_days if total_days > 0 else 0
553 return {
554 "progress": progress,
555 "current_time": self.current_time,
556 "start_date": self.start_date,
557 "end_date": self.end_date,
558 "is_running": self.is_running,
559 "is_paused": self.is_paused,
560 }