Coverage for core/trading/backtest/backtest_engine.py: 13.17%

243 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-13 18:58 +0000

1""" 

2回测引擎 

3""" 

4 

5import asyncio 

6import time 

7from datetime import date, datetime, timedelta 

8from decimal import Decimal 

9from typing import Any, Callable, Dict, List, Optional 

10 

11from core.data_source.adapters.quote_adapter import QuoteAdapter 

12from core.models.trading import (AssetMode, MarketData, PerformanceMetrics, 

13 RiskConfig, StrategyContext, TradingMode) 

14from core.trading.engines import create_trading_engine 

15from core.trading.strategies import create_strategy 

16 

17 

18class BacktestEngine: 

19 """回测引擎""" 

20 

21 def __init__(self, user_id: str, session_id: str, initial_capital: Decimal): 

22 self.user_id = user_id 

23 self.session_id = session_id 

24 self.initial_capital = initial_capital 

25 

26 # 回测参数 

27 self.start_date: Optional[datetime] = None 

28 self.end_date: Optional[datetime] = None 

29 self.timezone: str = "UTC" 

30 

31 # 数据源 

32 self.quote_adapter = QuoteAdapter(user_id) 

33 

34 # 回测状态 

35 self.is_running = False 

36 self.current_time: Optional[datetime] = None 

37 self.is_paused = False 

38 

39 # 回测结果 

40 self.performance_metrics: Optional[PerformanceMetrics] = None 

41 self.trade_history: List[Dict[str, Any]] = [] 

42 self.daily_returns: List[Dict[str, Any]] = [] 

43 

44 # 回调函数 

45 self.on_progress: Optional[Callable] = None 

46 self.on_complete: Optional[Callable] = None 

47 self.on_error: Optional[Callable] = None 

48 

49 def set_time_range( 

50 self, start_date: datetime, end_date: datetime, timezone: str = "UTC" 

51 ): 

52 """设置回测时间范围""" 

53 self.start_date = start_date 

54 self.end_date = end_date 

55 self.timezone = timezone 

56 self.current_time = start_date 

57 

58 def set_callbacks( 

59 self, 

60 on_progress: Optional[Callable] = None, 

61 on_complete: Optional[Callable] = None, 

62 on_error: Optional[Callable] = None, 

63 ): 

64 """设置回调函数""" 

65 self.on_progress = on_progress 

66 self.on_complete = on_complete 

67 self.on_error = on_error 

68 

69 async def run_backtest( 

70 self, strategy_name: str, strategy_config: Dict[str, Any], symbols: List[str] 

71 ) -> Dict[str, Any]: 

72 """运行回测""" 

73 if self.is_running: 

74 raise RuntimeError("回测已在运行中") 

75 

76 if not self.start_date or not self.end_date: 

77 raise ValueError("请先设置回测时间范围") 

78 

79 try: 

80 self.is_running = True 

81 self.current_time = self.start_date 

82 

83 # 创建模拟交易引擎 

84 trading_engine = create_trading_engine( 

85 user_id=self.user_id, 

86 session_id=self.session_id, 

87 trading_mode=TradingMode.BACKTEST, 

88 asset_mode=AssetMode.SIMULATION, 

89 initial_capital=self.initial_capital, 

90 ) 

91 

92 if not trading_engine.initialize(): 

93 raise RuntimeError("交易引擎初始化失败") 

94 

95 if not trading_engine.start(): 

96 raise RuntimeError("交易引擎启动失败") 

97 

98 # 创建策略 

99 strategy = create_strategy(strategy_name, strategy_config) 

100 if not strategy: 

101 raise RuntimeError(f"策略创建失败: {strategy_name}") 

102 

103 # 创建策略上下文 

104 context = StrategyContext( 

105 user_id=self.user_id, 

106 trading_session_id=self.session_id, 

107 start_time=self.start_date, 

108 current_time=self.current_time, 

109 market_data_cache={}, 

110 portfolio_cache=None, 

111 config={}, 

112 ) 

113 context.trading_engine = trading_engine 

114 context.quote_adapter = self.quote_adapter 

115 

116 # 初始化策略 

117 strategy.initialize(context) 

118 

119 # 运行回测循环 

120 await self._run_backtest_loop(trading_engine, strategy, symbols) 

121 

122 # 计算性能指标 

123 self.performance_metrics = self._calculate_performance_metrics( 

124 trading_engine 

125 ) 

126 

127 # 停止交易引擎 

128 trading_engine.stop() 

129 

130 result = { 

131 "success": True, 

132 "performance_metrics": self.performance_metrics, 

133 "trade_history": self.trade_history, 

134 "daily_returns": self.daily_returns, 

135 "start_date": self.start_date, 

136 "end_date": self.end_date, 

137 "initial_capital": self.initial_capital, 

138 } 

139 

140 if self.on_complete: 

141 self.on_complete(result) 

142 

143 return result 

144 

145 except Exception as e: 

146 error_msg = f"回测运行失败: {str(e)}" 

147 if self.on_error: 

148 self.on_error(error_msg) 

149 raise RuntimeError(error_msg) 

150 finally: 

151 self.is_running = False 

152 

153 async def _run_backtest_loop(self, trading_engine, strategy, symbols: List[str]): 

154 """运行回测循环 - 按分钟推进,确保可靠的时序控制""" 

155 total_minutes = int((self.end_date - self.start_date).total_seconds() / 60) 

156 current_minute = 0 

157 

158 while self.current_time <= self.end_date and not self.is_paused: 

159 try: 

160 # 1. 获取当前时间点的市场数据 

161 market_data_list = await self._get_market_data_at_time( 

162 self.current_time, symbols 

163 ) 

164 

165 if market_data_list: 

166 # 2. 处理每个股票的市场数据 

167 for market_data in market_data_list: 

168 # 更新策略上下文 

169 strategy.context.current_time = self.current_time 

170 

171 # 通知策略 - 等待策略处理完成 

172 await self._wait_for_strategy_processing(strategy, market_data) 

173 

174 # 通知交易引擎 - 等待交易处理完成 

175 await self._wait_for_trading_processing( 

176 trading_engine, market_data 

177 ) 

178 

179 # 3. 等待所有交易计算完成 

180 await self._wait_for_trading_calculations(trading_engine) 

181 

182 # 4. 处理交易回调 

183 await self._process_trading_callbacks(trading_engine, strategy) 

184 

185 # 5. 记录每日收益(如果是交易日结束) 

186 if self._is_trading_day_end(): 

187 await self._record_daily_returns(trading_engine) 

188 

189 # 6. 推进到下一个时间点(按分钟) 

190 self.current_time += timedelta(minutes=1) 

191 current_minute += 1 

192 

193 # 7. 报告进度 

194 if self.on_progress and current_minute % 60 == 0: # 每小时报告一次 

195 progress = current_minute / total_minutes 

196 self.on_progress( 

197 progress, f"回测进度: {current_minute}/{total_minutes} 分钟" 

198 ) 

199 

200 # 8. 短暂休眠,避免过度占用CPU 

201 await asyncio.sleep(0.001) 

202 

203 except Exception as e: 

204 print(f"❌ 回测循环错误: {e}") 

205 continue 

206 

207 async def _wait_for_strategy_processing(self, strategy, market_data: MarketData): 

208 """等待策略处理完成""" 

209 try: 

210 # 通知策略处理市场数据 

211 strategy.on_market_data(market_data) 

212 

213 # 等待策略完成所有分析和交易决策 

214 # 这里可以添加策略处理完成的确认机制 

215 await asyncio.sleep(0.01) # 短暂等待,确保策略处理完成 

216 

217 except Exception as e: 

218 print(f"❌ 策略处理失败: {e}") 

219 

220 async def _wait_for_trading_processing( 

221 self, trading_engine, market_data: MarketData 

222 ): 

223 """等待交易处理完成""" 

224 try: 

225 # 通知交易引擎处理市场数据 

226 trading_engine.process_market_data(market_data) 

227 

228 # 等待交易引擎完成订单成交判定 

229 await asyncio.sleep(0.01) # 短暂等待,确保交易处理完成 

230 

231 except Exception as e: 

232 print(f"❌ 交易处理失败: {e}") 

233 

234 async def _wait_for_trading_calculations(self, trading_engine): 

235 """等待所有交易计算完成""" 

236 try: 

237 # 等待持仓更新、资金计算等完成 

238 await asyncio.sleep(0.01) # 短暂等待,确保计算完成 

239 

240 except Exception as e: 

241 print(f"❌ 交易计算失败: {e}") 

242 

243 async def _process_trading_callbacks(self, trading_engine, strategy): 

244 """处理交易回调""" 

245 try: 

246 # 处理交易完成后的回调 

247 # 例如:通知策略交易结果、更新持仓等 

248 await asyncio.sleep(0.01) # 短暂等待,确保回调完成 

249 

250 except Exception as e: 

251 print(f"❌ 交易回调处理失败: {e}") 

252 

253 def _is_trading_day_end(self) -> bool: 

254 """检查是否是交易日结束""" 

255 # 简单判断:如果是交易日的收盘时间(例如15:00) 

256 return self.current_time.hour == 15 and self.current_time.minute == 0 

257 

258 async def _get_market_data_at_time( 

259 self, current_time: datetime, symbols: List[str] 

260 ) -> List[MarketData]: 

261 """获取指定时间点的市场数据""" 

262 market_data_list = [] 

263 

264 for symbol in symbols: 

265 try: 

266 # 尝试从历史数据中获取真实数据 

267 historical_data = await self._get_historical_data(symbol, current_time) 

268 

269 if historical_data: 

270 market_data = MarketData( 

271 symbol=symbol, 

272 timestamp=current_time, 

273 open=Decimal(str(historical_data.get("open", 150))), 

274 high=Decimal(str(historical_data.get("high", 155))), 

275 low=Decimal(str(historical_data.get("low", 148))), 

276 close=Decimal(str(historical_data.get("close", 152))), 

277 volume=int(historical_data.get("volume", 1000)), 

278 ) 

279 print( 

280 f"📊 获取历史数据: {symbol} at {current_time} = ${market_data.close:.2f}" 

281 ) 

282 else: 

283 # 如果没有历史数据,使用模拟数据 

284 market_data = self._generate_simulated_data(symbol, current_time) 

285 print( 

286 f"🎭 使用模拟数据: {symbol} at {current_time} = ${market_data.close:.2f}" 

287 ) 

288 

289 market_data_list.append(market_data) 

290 

291 except Exception as e: 

292 print(f"❌ 获取市场数据失败: {symbol} - {e}") 

293 continue 

294 

295 return market_data_list 

296 

297 async def _get_historical_data( 

298 self, symbol: str, timestamp: datetime 

299 ) -> Optional[Dict[str, Any]]: 

300 """从数据库获取历史数据""" 

301 try: 

302 # 使用QuoteDataSourceAdapter的抽象接口获取历史数据 

303 # 注意:这里需要获取一个时间范围的数据,而不是单个时间点 

304 from datetime import timedelta 

305 

306 start_time = timestamp - timedelta(minutes=1) 

307 end_time = timestamp + timedelta(minutes=1) 

308 

309 # 获取历史数据 

310 def log_callback(message: str, log_type: str = "info"): 

311 self.log_message(f"【数据适配器】{message}", log_type) 

312 

313 historical_data = self.quote_data_adapter.get_historical_data( 

314 symbol, start_time, end_time, log_callback 

315 ) 

316 

317 if historical_data and len(historical_data) > 0: 

318 # 找到最接近目标时间的数据 

319 target_data = None 

320 min_diff = float("inf") 

321 for data in historical_data: 

322 time_diff = abs((data.timestamp - timestamp).total_seconds()) 

323 if time_diff < min_diff: 

324 min_diff = time_diff 

325 target_data = data 

326 

327 if target_data: 

328 return { 

329 "open": float(target_data.open), 

330 "high": float(target_data.high), 

331 "low": float(target_data.low), 

332 "close": float(target_data.close), 

333 "volume": int(target_data.volume), 

334 } 

335 

336 return None 

337 

338 except Exception as e: 

339 print(f"❌ 获取历史数据失败: {symbol} at {timestamp} - {e}") 

340 return None 

341 

342 def _generate_simulated_data(self, symbol: str, timestamp: datetime) -> MarketData: 

343 """生成模拟市场数据""" 

344 import random 

345 

346 # 基础价格 

347 base_prices = { 

348 "AAPL.US": Decimal("150.50"), 

349 "MSFT.US": Decimal("300.25"), 

350 "TSLA.US": Decimal("200.75"), 

351 "GOOGL.US": Decimal("2500.00"), 

352 "YINN.US": Decimal("25.00"), 

353 "YANG.US": Decimal("15.00"), 

354 } 

355 

356 base_price = base_prices.get(symbol, Decimal("100.00")) 

357 

358 # 添加时间相关的波动 

359 time_factor = (timestamp.hour * 60 + timestamp.minute) / (24 * 60) # 0-1之间 

360 daily_variation = Decimal( 

361 str(round(random.uniform(-0.05, 0.05), 4)) 

362 ) # ±5% 日波动,保留4位小数 

363 time_variation = Decimal( 

364 str(round(random.uniform(-0.01, 0.01), 4)) 

365 ) # ±1% 时间波动,保留4位小数 

366 

367 current_price = base_price * (Decimal("1") + daily_variation + time_variation) 

368 # 将价格四舍五入到2位小数 

369 current_price = current_price.quantize(Decimal("0.01")) 

370 

371 # 计算最高价和最低价,并四舍五入到2位小数 

372 high_price = (current_price * Decimal("1.02")).quantize(Decimal("0.01")) 

373 low_price = (current_price * Decimal("0.98")).quantize(Decimal("0.01")) 

374 

375 return MarketData( 

376 symbol=symbol, 

377 timestamp=timestamp, 

378 open=current_price, 

379 high=high_price, 

380 low=low_price, 

381 close=current_price, 

382 volume=random.randint(1000, 10000), 

383 ) 

384 

385 async def _record_daily_returns(self, trading_engine): 

386 """记录每日收益""" 

387 try: 

388 # 获取当前账户余额 

389 balance = trading_engine.get_account_balance() 

390 

391 # 计算日收益率 

392 daily_return = ( 

393 balance.total_cash - self.initial_capital 

394 ) / self.initial_capital 

395 

396 # 记录每日收益 

397 self.daily_returns.append( 

398 { 

399 "date": self.current_time.date(), 

400 "total_value": balance.total_cash, 

401 "daily_return": daily_return, 

402 "cumulative_return": daily_return, 

403 } 

404 ) 

405 

406 except Exception as e: 

407 print(f"❌ 记录每日收益失败: {e}") 

408 

409 def _calculate_performance_metrics(self, trading_engine) -> PerformanceMetrics: 

410 """计算性能指标""" 

411 try: 

412 # 获取最终账户余额 

413 final_balance = trading_engine.get_account_balance() 

414 

415 # 计算总收益率 

416 total_return = ( 

417 final_balance.total_cash - self.initial_capital 

418 ) / self.initial_capital 

419 

420 # 计算年化收益率 

421 days = (self.end_date - self.start_date).days 

422 years = days / 365.25 

423 annualized_return = ( 

424 (1 + total_return) ** (1 / years) - 1 if years > 0 else Decimal("0") 

425 ) 

426 

427 # 计算最大回撤 

428 max_drawdown = self._calculate_max_drawdown() 

429 

430 # 计算夏普比率(简化版本) 

431 sharpe_ratio = self._calculate_sharpe_ratio() 

432 

433 # 计算胜率 

434 win_rate = self._calculate_win_rate() 

435 

436 # 统计交易次数 

437 total_trades = len(self.trade_history) 

438 winning_trades = len( 

439 [trade for trade in self.trade_history if trade.get("pnl", 0) > 0] 

440 ) 

441 losing_trades = total_trades - winning_trades 

442 

443 return PerformanceMetrics( 

444 total_return=total_return, 

445 annualized_return=annualized_return, 

446 max_drawdown=max_drawdown, 

447 sharpe_ratio=sharpe_ratio, 

448 win_rate=win_rate, 

449 total_trades=total_trades, 

450 winning_trades=winning_trades, 

451 losing_trades=losing_trades, 

452 ) 

453 

454 except Exception as e: 

455 print(f"❌ 计算性能指标失败: {e}") 

456 return PerformanceMetrics( 

457 total_return=Decimal("0"), 

458 annualized_return=Decimal("0"), 

459 max_drawdown=Decimal("0"), 

460 sharpe_ratio=Decimal("0"), 

461 win_rate=Decimal("0"), 

462 total_trades=0, 

463 winning_trades=0, 

464 losing_trades=0, 

465 ) 

466 

467 def _calculate_max_drawdown(self) -> Decimal: 

468 """计算最大回撤""" 

469 if not self.daily_returns: 

470 return Decimal("0") 

471 

472 peak = self.initial_capital 

473 max_drawdown = Decimal("0") 

474 

475 for daily_data in self.daily_returns: 

476 current_value = daily_data["total_value"] 

477 

478 if current_value > peak: 

479 peak = current_value 

480 

481 drawdown = (peak - current_value) / peak 

482 if drawdown > max_drawdown: 

483 max_drawdown = drawdown 

484 

485 return max_drawdown 

486 

487 def _calculate_sharpe_ratio(self) -> Decimal: 

488 """计算夏普比率(简化版本)""" 

489 if not self.daily_returns or len(self.daily_returns) < 2: 

490 return Decimal("0") 

491 

492 # 计算日收益率 

493 daily_returns = [] 

494 for i in range(1, len(self.daily_returns)): 

495 prev_value = self.daily_returns[i - 1]["total_value"] 

496 curr_value = self.daily_returns[i]["total_value"] 

497 daily_return = (curr_value - prev_value) / prev_value 

498 daily_returns.append(daily_return) 

499 

500 if not daily_returns: 

501 return Decimal("0") 

502 

503 # 计算平均收益率和标准差 

504 avg_return = sum(daily_returns) / len(daily_returns) 

505 

506 variance = sum((r - avg_return) ** 2 for r in daily_returns) / len( 

507 daily_returns 

508 ) 

509 std_dev = Decimal(str(variance**0.5)) 

510 

511 # 计算夏普比率(假设无风险利率为0) 

512 if std_dev == 0: 

513 return Decimal("0") 

514 

515 sharpe_ratio = avg_return / std_dev 

516 return sharpe_ratio 

517 

518 def _calculate_win_rate(self) -> Decimal: 

519 """计算胜率""" 

520 if not self.trade_history: 

521 return Decimal("0") 

522 

523 winning_trades = len( 

524 [trade for trade in self.trade_history if trade.get("pnl", 0) > 0] 

525 ) 

526 return Decimal(str(winning_trades)) / Decimal(str(len(self.trade_history))) 

527 

528 def pause(self): 

529 """暂停回测""" 

530 self.is_paused = True 

531 print("⏸️ 回测已暂停") 

532 

533 def resume(self): 

534 """恢复回测""" 

535 self.is_paused = False 

536 print("▶️ 回测已恢复") 

537 

538 def stop(self): 

539 """停止回测""" 

540 self.is_running = False 

541 self.is_paused = False 

542 print("🛑 回测已停止") 

543 

544 def get_progress(self) -> Dict[str, Any]: 

545 """获取回测进度""" 

546 if not self.start_date or not self.end_date or not self.current_time: 

547 return {"progress": 0, "status": "not_started"} 

548 

549 total_days = (self.end_date - self.start_date).days 

550 current_days = (self.current_time - self.start_date).days 

551 progress = current_days / total_days if total_days > 0 else 0 

552 

553 return { 

554 "progress": progress, 

555 "current_time": self.current_time, 

556 "start_date": self.start_date, 

557 "end_date": self.end_date, 

558 "is_running": self.is_running, 

559 "is_paused": self.is_paused, 

560 }