Coverage for api/v1/endpoints/backtest.py: 41.51%
106 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-13 18:58 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-13 18:58 +0000
1"""
2回测API端点
3"""
5from datetime import datetime
6from decimal import Decimal
7from typing import Any, Dict, List, Optional
9from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
10from pydantic import BaseModel
12from core.middleware.auth_middleware import get_current_user
13from core.models.user import User
14from core.trading.backtest import BacktestEngine
16router = APIRouter()
18# 全局回测引擎实例
19backtest_engines: Dict[str, BacktestEngine] = {}
22class BacktestRequest(BaseModel):
23 """回测请求"""
25 session_id: str
26 strategy_name: str
27 strategy_config: Dict[str, Any]
28 symbols: List[str]
29 start_date: datetime
30 end_date: datetime
31 initial_capital: Decimal
32 timezone: str = "UTC"
35class BacktestResponse(BaseModel):
36 """回测响应"""
38 success: bool
39 message: str
40 backtest_id: Optional[str] = None
41 performance_metrics: Optional[Dict[str, Any]] = None
44class BacktestProgressResponse(BaseModel):
45 """回测进度响应"""
47 progress: float
48 current_time: Optional[datetime] = None
49 start_date: Optional[datetime] = None
50 end_date: Optional[datetime] = None
51 is_running: bool = False
52 is_paused: bool = False
55@router.post("/start", response_model=BacktestResponse)
56async def start_backtest(
57 request: BacktestRequest,
58 background_tasks: BackgroundTasks,
59 current_user: User = Depends(get_current_user),
60):
61 """启动回测"""
62 try:
63 # 生成回测ID
64 backtest_id = (
65 f"{current_user.id}_{request.session_id}_{int(datetime.now().timestamp())}"
66 )
68 # 创建回测引擎
69 backtest_engine = BacktestEngine(
70 user_id=current_user.id,
71 session_id=request.session_id,
72 initial_capital=request.initial_capital,
73 )
75 # 设置时间范围
76 backtest_engine.set_time_range(
77 start_date=request.start_date,
78 end_date=request.end_date,
79 timezone=request.timezone,
80 )
82 # 设置回调函数
83 def on_progress(progress: float, message: str):
84 print(f"回测进度: {progress:.2%} - {message}")
86 def on_complete(result: Dict[str, Any]):
87 print(f"回测完成: {result}")
89 def on_error(error: str):
90 print(f"回测错误: {error}")
92 backtest_engine.set_callbacks(
93 on_progress=on_progress, on_complete=on_complete, on_error=on_error
94 )
96 # 存储回测引擎
97 backtest_engines[backtest_id] = backtest_engine
99 # 在后台运行回测
100 background_tasks.add_task(
101 run_backtest_task,
102 backtest_engine,
103 request.strategy_name,
104 request.strategy_config,
105 request.symbols,
106 backtest_id,
107 )
109 return BacktestResponse(
110 success=True, message="回测已启动", backtest_id=backtest_id
111 )
113 except Exception as e:
114 return BacktestResponse(success=False, message=f"启动回测失败: {str(e)}")
117@router.get("/{backtest_id}/progress", response_model=BacktestProgressResponse)
118async def get_backtest_progress(
119 backtest_id: str, current_user: User = Depends(get_current_user)
120):
121 """获取回测进度"""
122 if backtest_id not in backtest_engines:
123 raise HTTPException(status_code=404, detail="回测不存在")
125 backtest_engine = backtest_engines[backtest_id]
127 # 检查权限
128 if backtest_engine.user_id != current_user.id:
129 raise HTTPException(status_code=403, detail="无权限访问此回测")
131 progress_info = backtest_engine.get_progress()
133 return BacktestProgressResponse(
134 progress=progress_info["progress"],
135 current_time=progress_info.get("current_time"),
136 start_date=progress_info.get("start_date"),
137 end_date=progress_info.get("end_date"),
138 is_running=progress_info["is_running"],
139 is_paused=progress_info["is_paused"],
140 )
143@router.post("/{backtest_id}/pause")
144async def pause_backtest(
145 backtest_id: str, current_user: User = Depends(get_current_user)
146):
147 """暂停回测"""
148 if backtest_id not in backtest_engines:
149 raise HTTPException(status_code=404, detail="回测不存在")
151 backtest_engine = backtest_engines[backtest_id]
153 # 检查权限
154 if backtest_engine.user_id != current_user.id:
155 raise HTTPException(status_code=403, detail="无权限访问此回测")
157 backtest_engine.pause()
159 return {"message": "回测已暂停"}
162@router.post("/{backtest_id}/resume")
163async def resume_backtest(
164 backtest_id: str, current_user: User = Depends(get_current_user)
165):
166 """恢复回测"""
167 if backtest_id not in backtest_engines:
168 raise HTTPException(status_code=404, detail="回测不存在")
170 backtest_engine = backtest_engines[backtest_id]
172 # 检查权限
173 if backtest_engine.user_id != current_user.id:
174 raise HTTPException(status_code=403, detail="无权限访问此回测")
176 backtest_engine.resume()
178 return {"message": "回测已恢复"}
181@router.post("/{backtest_id}/stop")
182async def stop_backtest(
183 backtest_id: str, current_user: User = Depends(get_current_user)
184):
185 """停止回测"""
186 if backtest_id not in backtest_engines:
187 raise HTTPException(status_code=404, detail="回测不存在")
189 backtest_engine = backtest_engines[backtest_id]
191 # 检查权限
192 if backtest_engine.user_id != current_user.id:
193 raise HTTPException(status_code=403, detail="无权限访问此回测")
195 backtest_engine.stop()
197 # 清理回测引擎
198 del backtest_engines[backtest_id]
200 return {"message": "回测已停止"}
203@router.get("/{backtest_id}/result")
204async def get_backtest_result(
205 backtest_id: str, current_user: User = Depends(get_current_user)
206):
207 """获取回测结果"""
208 if backtest_id not in backtest_engines:
209 raise HTTPException(status_code=404, detail="回测不存在")
211 backtest_engine = backtest_engines[backtest_id]
213 # 检查权限
214 if backtest_engine.user_id != current_user.id:
215 raise HTTPException(status_code=403, detail="无权限访问此回测")
217 # 检查回测是否完成
218 if backtest_engine.is_running:
219 raise HTTPException(status_code=400, detail="回测尚未完成")
221 return {
222 "performance_metrics": backtest_engine.performance_metrics,
223 "trade_history": backtest_engine.trade_history,
224 "daily_returns": backtest_engine.daily_returns,
225 "start_date": backtest_engine.start_date,
226 "end_date": backtest_engine.end_date,
227 "initial_capital": backtest_engine.initial_capital,
228 }
231async def run_backtest_task(
232 backtest_engine: BacktestEngine,
233 strategy_name: str,
234 strategy_config: Dict[str, Any],
235 symbols: List[str],
236 backtest_id: str,
237):
238 """运行回测任务"""
239 try:
240 result = await backtest_engine.run_backtest(
241 strategy_name, strategy_config, symbols
242 )
243 print(f"✅ 回测完成: {backtest_id}")
245 # 这里可以将结果保存到数据库
246 # await save_backtest_result(backtest_id, result)
248 except Exception as e:
249 print(f"❌ 回测任务失败: {backtest_id} - {e}")
251 # 清理回测引擎
252 if backtest_id in backtest_engines:
253 del backtest_engines[backtest_id]
254 finally:
255 # 清理回测引擎
256 if backtest_id in backtest_engines:
257 del backtest_engines[backtest_id]