Coverage for api/v1/endpoints/backtest.py: 41.51%

106 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-13 18:58 +0000

1""" 

2回测API端点 

3""" 

4 

5from datetime import datetime 

6from decimal import Decimal 

7from typing import Any, Dict, List, Optional 

8 

9from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException 

10from pydantic import BaseModel 

11 

12from core.middleware.auth_middleware import get_current_user 

13from core.models.user import User 

14from core.trading.backtest import BacktestEngine 

15 

16router = APIRouter() 

17 

18# 全局回测引擎实例 

19backtest_engines: Dict[str, BacktestEngine] = {} 

20 

21 

22class BacktestRequest(BaseModel): 

23 """回测请求""" 

24 

25 session_id: str 

26 strategy_name: str 

27 strategy_config: Dict[str, Any] 

28 symbols: List[str] 

29 start_date: datetime 

30 end_date: datetime 

31 initial_capital: Decimal 

32 timezone: str = "UTC" 

33 

34 

35class BacktestResponse(BaseModel): 

36 """回测响应""" 

37 

38 success: bool 

39 message: str 

40 backtest_id: Optional[str] = None 

41 performance_metrics: Optional[Dict[str, Any]] = None 

42 

43 

44class BacktestProgressResponse(BaseModel): 

45 """回测进度响应""" 

46 

47 progress: float 

48 current_time: Optional[datetime] = None 

49 start_date: Optional[datetime] = None 

50 end_date: Optional[datetime] = None 

51 is_running: bool = False 

52 is_paused: bool = False 

53 

54 

55@router.post("/start", response_model=BacktestResponse) 

56async def start_backtest( 

57 request: BacktestRequest, 

58 background_tasks: BackgroundTasks, 

59 current_user: User = Depends(get_current_user), 

60): 

61 """启动回测""" 

62 try: 

63 # 生成回测ID 

64 backtest_id = ( 

65 f"{current_user.id}_{request.session_id}_{int(datetime.now().timestamp())}" 

66 ) 

67 

68 # 创建回测引擎 

69 backtest_engine = BacktestEngine( 

70 user_id=current_user.id, 

71 session_id=request.session_id, 

72 initial_capital=request.initial_capital, 

73 ) 

74 

75 # 设置时间范围 

76 backtest_engine.set_time_range( 

77 start_date=request.start_date, 

78 end_date=request.end_date, 

79 timezone=request.timezone, 

80 ) 

81 

82 # 设置回调函数 

83 def on_progress(progress: float, message: str): 

84 print(f"回测进度: {progress:.2%} - {message}") 

85 

86 def on_complete(result: Dict[str, Any]): 

87 print(f"回测完成: {result}") 

88 

89 def on_error(error: str): 

90 print(f"回测错误: {error}") 

91 

92 backtest_engine.set_callbacks( 

93 on_progress=on_progress, on_complete=on_complete, on_error=on_error 

94 ) 

95 

96 # 存储回测引擎 

97 backtest_engines[backtest_id] = backtest_engine 

98 

99 # 在后台运行回测 

100 background_tasks.add_task( 

101 run_backtest_task, 

102 backtest_engine, 

103 request.strategy_name, 

104 request.strategy_config, 

105 request.symbols, 

106 backtest_id, 

107 ) 

108 

109 return BacktestResponse( 

110 success=True, message="回测已启动", backtest_id=backtest_id 

111 ) 

112 

113 except Exception as e: 

114 return BacktestResponse(success=False, message=f"启动回测失败: {str(e)}") 

115 

116 

117@router.get("/{backtest_id}/progress", response_model=BacktestProgressResponse) 

118async def get_backtest_progress( 

119 backtest_id: str, current_user: User = Depends(get_current_user) 

120): 

121 """获取回测进度""" 

122 if backtest_id not in backtest_engines: 

123 raise HTTPException(status_code=404, detail="回测不存在") 

124 

125 backtest_engine = backtest_engines[backtest_id] 

126 

127 # 检查权限 

128 if backtest_engine.user_id != current_user.id: 

129 raise HTTPException(status_code=403, detail="无权限访问此回测") 

130 

131 progress_info = backtest_engine.get_progress() 

132 

133 return BacktestProgressResponse( 

134 progress=progress_info["progress"], 

135 current_time=progress_info.get("current_time"), 

136 start_date=progress_info.get("start_date"), 

137 end_date=progress_info.get("end_date"), 

138 is_running=progress_info["is_running"], 

139 is_paused=progress_info["is_paused"], 

140 ) 

141 

142 

143@router.post("/{backtest_id}/pause") 

144async def pause_backtest( 

145 backtest_id: str, current_user: User = Depends(get_current_user) 

146): 

147 """暂停回测""" 

148 if backtest_id not in backtest_engines: 

149 raise HTTPException(status_code=404, detail="回测不存在") 

150 

151 backtest_engine = backtest_engines[backtest_id] 

152 

153 # 检查权限 

154 if backtest_engine.user_id != current_user.id: 

155 raise HTTPException(status_code=403, detail="无权限访问此回测") 

156 

157 backtest_engine.pause() 

158 

159 return {"message": "回测已暂停"} 

160 

161 

162@router.post("/{backtest_id}/resume") 

163async def resume_backtest( 

164 backtest_id: str, current_user: User = Depends(get_current_user) 

165): 

166 """恢复回测""" 

167 if backtest_id not in backtest_engines: 

168 raise HTTPException(status_code=404, detail="回测不存在") 

169 

170 backtest_engine = backtest_engines[backtest_id] 

171 

172 # 检查权限 

173 if backtest_engine.user_id != current_user.id: 

174 raise HTTPException(status_code=403, detail="无权限访问此回测") 

175 

176 backtest_engine.resume() 

177 

178 return {"message": "回测已恢复"} 

179 

180 

181@router.post("/{backtest_id}/stop") 

182async def stop_backtest( 

183 backtest_id: str, current_user: User = Depends(get_current_user) 

184): 

185 """停止回测""" 

186 if backtest_id not in backtest_engines: 

187 raise HTTPException(status_code=404, detail="回测不存在") 

188 

189 backtest_engine = backtest_engines[backtest_id] 

190 

191 # 检查权限 

192 if backtest_engine.user_id != current_user.id: 

193 raise HTTPException(status_code=403, detail="无权限访问此回测") 

194 

195 backtest_engine.stop() 

196 

197 # 清理回测引擎 

198 del backtest_engines[backtest_id] 

199 

200 return {"message": "回测已停止"} 

201 

202 

203@router.get("/{backtest_id}/result") 

204async def get_backtest_result( 

205 backtest_id: str, current_user: User = Depends(get_current_user) 

206): 

207 """获取回测结果""" 

208 if backtest_id not in backtest_engines: 

209 raise HTTPException(status_code=404, detail="回测不存在") 

210 

211 backtest_engine = backtest_engines[backtest_id] 

212 

213 # 检查权限 

214 if backtest_engine.user_id != current_user.id: 

215 raise HTTPException(status_code=403, detail="无权限访问此回测") 

216 

217 # 检查回测是否完成 

218 if backtest_engine.is_running: 

219 raise HTTPException(status_code=400, detail="回测尚未完成") 

220 

221 return { 

222 "performance_metrics": backtest_engine.performance_metrics, 

223 "trade_history": backtest_engine.trade_history, 

224 "daily_returns": backtest_engine.daily_returns, 

225 "start_date": backtest_engine.start_date, 

226 "end_date": backtest_engine.end_date, 

227 "initial_capital": backtest_engine.initial_capital, 

228 } 

229 

230 

231async def run_backtest_task( 

232 backtest_engine: BacktestEngine, 

233 strategy_name: str, 

234 strategy_config: Dict[str, Any], 

235 symbols: List[str], 

236 backtest_id: str, 

237): 

238 """运行回测任务""" 

239 try: 

240 result = await backtest_engine.run_backtest( 

241 strategy_name, strategy_config, symbols 

242 ) 

243 print(f"✅ 回测完成: {backtest_id}") 

244 

245 # 这里可以将结果保存到数据库 

246 # await save_backtest_result(backtest_id, result) 

247 

248 except Exception as e: 

249 print(f"❌ 回测任务失败: {backtest_id} - {e}") 

250 

251 # 清理回测引擎 

252 if backtest_id in backtest_engines: 

253 del backtest_engines[backtest_id] 

254 finally: 

255 # 清理回测引擎 

256 if backtest_id in backtest_engines: 

257 del backtest_engines[backtest_id]