Coverage for core/trading/engines/time_series_controller.py: 15.75%

146 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-13 18:58 +0000

1""" 

2时序控制器 - 用于回测模式的时间序列控制 

3""" 

4 

5import asyncio 

6from datetime import datetime, timedelta 

7from decimal import Decimal 

8from typing import Any, Callable, Dict, List, Optional 

9 

10from core.data_source.adapters.data_source_adapter import \ 

11 QuoteDataSourceAdapter 

12from core.models.trading import MarketData 

13from core.trading.strategies.base_strategy import BaseStrategy 

14 

15 

16class TimeSeriesController: 

17 """时间序列控制器""" 

18 

19 def __init__( 

20 self, 

21 start_time: datetime, 

22 end_time: datetime, 

23 timezone: str, 

24 strategy_engine=None, 

25 ): 

26 self.start_time = start_time 

27 self.end_time = end_time 

28 self.timezone = timezone 

29 self.current_time = start_time 

30 self.is_paused = False 

31 

32 # 数据源适配器 

33 self.quote_data_adapter: Optional[QuoteDataSourceAdapter] = None 

34 

35 # 策略引擎引用,用于统一日志管理 

36 self.strategy_engine = strategy_engine 

37 

38 # 回调函数 

39 self.on_progress: Optional[Callable] = None 

40 self.on_complete: Optional[Callable] = None 

41 self.on_error: Optional[Callable] = None 

42 

43 def set_quote_adapter(self, quote_adapter: QuoteDataSourceAdapter): 

44 """设置数据源适配器""" 

45 self.quote_data_adapter = quote_adapter 

46 

47 def set_callbacks( 

48 self, 

49 on_progress: Optional[Callable] = None, 

50 on_complete: Optional[Callable] = None, 

51 on_error: Optional[Callable] = None, 

52 ): 

53 """设置回调函数""" 

54 self.on_progress = on_progress 

55 self.on_complete = on_complete 

56 self.on_error = on_error 

57 

58 def log_message(self, message: str, log_type: str = "log"): 

59 """统一日志接口 - 通过StrategyEngine记录日志""" 

60 if self.strategy_engine: 

61 self.strategy_engine.log_message(message, log_type, "时序控制器") 

62 else: 

63 # 如果没有StrategyEngine引用,使用print作为备用 

64 print(f"📝 [时序控制器] {message}") 

65 

66 async def run_backtest(self, strategy: BaseStrategy, symbols: List[str]): 

67 """运行回测""" 

68 try: 

69 self.log_message(f"🚀 开始回测: {self.start_time} - {self.end_time}") 

70 self.log_message(f"📊 回测股票: {symbols}") 

71 

72 total_minutes = int((self.end_time - self.start_time).total_seconds() / 60) 

73 current_minute = 0 

74 

75 while self.current_time <= self.end_time and not self.is_paused: 

76 try: 

77 # 1. 获取当前时间点的市场数据 

78 market_data_list = await self._get_market_data_at_time( 

79 self.current_time, symbols 

80 ) 

81 

82 if market_data_list: 

83 # 2. 处理每个股票的市场数据 

84 for market_data in market_data_list: 

85 # 更新策略上下文 

86 if hasattr(strategy, "context") and strategy.context: 

87 strategy.context.current_time = self.current_time 

88 

89 # 通知策略 - 等待策略处理完成 

90 await self._wait_for_strategy_processing( 

91 strategy, market_data 

92 ) 

93 

94 # 3. 等待所有交易计算完成 

95 await self._wait_for_trading_calculations(strategy) 

96 

97 # 4. 处理交易回调 

98 await self._process_trading_callbacks(strategy) 

99 

100 # 5. 推进到下一个时间点(按分钟) 

101 self.current_time += timedelta(minutes=1) 

102 current_minute += 1 

103 

104 # 6. 报告进度 

105 if current_minute % 60 == 0: # 每小时报告一次 

106 progress = (current_minute / total_minutes) * 100 

107 self.log_message( 

108 f"📈 回测进度: {progress:.1f}% ({current_minute}/{total_minutes}分钟)" 

109 ) 

110 

111 if self.on_progress: 

112 self.on_progress(progress, self.current_time) 

113 

114 # 7. 短暂休眠,避免过度占用CPU 

115 await asyncio.sleep(0.01) 

116 

117 except Exception as e: 

118 self.log_message( 

119 f"❌ 处理时间点失败: {self.current_time} - {e}", "error" 

120 ) 

121 if self.on_error: 

122 self.on_error(e) 

123 continue 

124 

125 # 回测完成 

126 self.log_message(f"✅ 回测完成: {self.current_time}") 

127 if self.on_complete: 

128 self.on_complete() 

129 

130 except Exception as e: 

131 self.log_message(f"❌ 回测失败: {e}", "error") 

132 if self.on_error: 

133 self.on_error(e) 

134 

135 async def _get_market_data_at_time( 

136 self, current_time: datetime, symbols: List[str] 

137 ) -> List[MarketData]: 

138 """获取指定时间点的市场数据""" 

139 market_data_list = [] 

140 

141 for symbol in symbols: 

142 try: 

143 # 尝试从历史数据中获取真实数据 

144 historical_data = await self._get_historical_data(symbol, current_time) 

145 

146 if historical_data: 

147 market_data = MarketData( 

148 symbol=symbol, 

149 timestamp=current_time, 

150 open=Decimal(str(historical_data.get("open", 150))), 

151 high=Decimal(str(historical_data.get("high", 155))), 

152 low=Decimal(str(historical_data.get("low", 148))), 

153 close=Decimal(str(historical_data.get("close", 152))), 

154 volume=int(historical_data.get("volume", 1000)), 

155 ) 

156 print( 

157 f"【时序控制器】📊 获取历史数据: {symbol} at {current_time} = ${market_data.close:.2f}" 

158 ) 

159 else: 

160 # 如果没有历史数据,使用模拟数据 

161 market_data = self._generate_simulated_data(symbol, current_time) 

162 print( 

163 f"【时序控制器】🎭 使用模拟数据: {symbol} at {current_time} = ${market_data.close:.2f}" 

164 ) 

165 

166 market_data_list.append(market_data) 

167 

168 except Exception as e: 

169 print(f"【时序控制器】❌ 获取市场数据失败: {symbol} - {e}") 

170 continue 

171 

172 return market_data_list 

173 

174 async def _get_historical_data( 

175 self, symbol: str, timestamp: datetime 

176 ) -> Optional[Dict[str, Any]]: 

177 """从数据库获取历史数据""" 

178 try: 

179 if not self.quote_data_adapter: 

180 return None 

181 

182 # 使用QuoteDataSourceAdapter的抽象接口获取历史数据 

183 # 注意:这里需要获取一个时间范围的数据,而不是单个时间点 

184 from datetime import timedelta 

185 

186 start_time = timestamp - timedelta(minutes=1) 

187 end_time = timestamp + timedelta(minutes=1) 

188 

189 # 获取历史数据 

190 def log_callback(message: str, log_type: str = "info"): 

191 self.log_message(f"【数据适配器】{message}", log_type) 

192 

193 historical_data = self.quote_data_adapter.get_historical_data( 

194 symbol, start_time, end_time, log_callback 

195 ) 

196 

197 if historical_data and len(historical_data) > 0: 

198 # 找到最接近目标时间的数据 

199 target_data = None 

200 min_diff = float("inf") 

201 for data in historical_data: 

202 time_diff = abs((data.timestamp - timestamp).total_seconds()) 

203 if time_diff < min_diff: 

204 min_diff = time_diff 

205 target_data = data 

206 

207 if target_data: 

208 return { 

209 "open": float(target_data.open), 

210 "high": float(target_data.high), 

211 "low": float(target_data.low), 

212 "close": float(target_data.close), 

213 "volume": int(target_data.volume), 

214 } 

215 

216 return None 

217 

218 except Exception as e: 

219 print(f"【时序控制器】❌ 获取历史数据失败: {symbol} at {timestamp} - {e}") 

220 return None 

221 

222 def _generate_simulated_data(self, symbol: str, timestamp: datetime) -> MarketData: 

223 """生成模拟市场数据""" 

224 import random 

225 

226 # 基础价格 

227 base_prices = { 

228 "AAPL.US": Decimal("150.50"), 

229 "MSFT.US": Decimal("300.25"), 

230 "TSLA.US": Decimal("200.75"), 

231 "GOOGL.US": Decimal("2500.00"), 

232 "YINN.US": Decimal("25.00"), 

233 "YANG.US": Decimal("15.00"), 

234 } 

235 

236 base_price = base_prices.get(symbol, Decimal("100.00")) 

237 

238 # 添加时间相关的波动 

239 time_factor = (timestamp.hour * 60 + timestamp.minute) / (24 * 60) # 0-1之间 

240 daily_variation = Decimal( 

241 str(round(random.uniform(-0.05, 0.05), 4)) 

242 ) # ±5% 日波动,保留4位小数 

243 time_variation = Decimal( 

244 str(round(random.uniform(-0.01, 0.01), 4)) 

245 ) # ±1% 时间波动,保留4位小数 

246 

247 current_price = base_price * (Decimal("1") + daily_variation + time_variation) 

248 # 将价格四舍五入到2位小数 

249 current_price = current_price.quantize(Decimal("0.01")) 

250 

251 # 计算最高价和最低价,并四舍五入到2位小数 

252 high_price = (current_price * Decimal("1.02")).quantize(Decimal("0.01")) 

253 low_price = (current_price * Decimal("0.98")).quantize(Decimal("0.01")) 

254 

255 return MarketData( 

256 symbol=symbol, 

257 timestamp=timestamp, 

258 open=current_price, 

259 high=high_price, 

260 low=low_price, 

261 close=current_price, 

262 volume=random.randint(1000, 10000), 

263 ) 

264 

265 async def _wait_for_strategy_processing( 

266 self, strategy: BaseStrategy, market_data: MarketData 

267 ): 

268 """等待策略处理完成""" 

269 try: 

270 # 通知策略处理市场数据 

271 strategy.on_market_data(market_data) 

272 

273 # 等待策略完成所有分析和交易决策 

274 # 这里可以添加策略处理完成的确认机制 

275 await asyncio.sleep(0.01) # 短暂等待,确保策略处理完成 

276 

277 except Exception as e: 

278 print(f"【时序控制器】❌ 策略处理失败: {e}") 

279 

280 async def _wait_for_trading_calculations(self, strategy: BaseStrategy): 

281 """等待所有交易计算完成""" 

282 try: 

283 # 等待持仓更新、资金计算等完成 

284 await asyncio.sleep(0.01) # 短暂等待,确保计算完成 

285 

286 except Exception as e: 

287 print(f"【时序控制器】❌ 交易计算失败: {e}") 

288 

289 async def _process_trading_callbacks(self, strategy: BaseStrategy): 

290 """处理交易回调""" 

291 try: 

292 # 处理交易结果回调 

293 await asyncio.sleep(0.01) # 短暂等待,确保回调处理完成 

294 

295 except Exception as e: 

296 print(f"【时序控制器】❌ 交易回调处理失败: {e}") 

297 

298 def pause(self): 

299 """暂停回测""" 

300 self.is_paused = True 

301 print(f"【时序控制器】⏸️ 回测已暂停") 

302 

303 def resume(self): 

304 """恢复回测""" 

305 self.is_paused = False 

306 print(f"【时序控制器】▶️ 回测已恢复") 

307 

308 def stop(self): 

309 """停止回测""" 

310 self.is_paused = True 

311 print(f"【时序控制器】⏹️ 回测已停止") 

312 

313 def get_progress(self) -> Dict[str, Any]: 

314 """获取回测进度""" 

315 total_minutes = int((self.end_time - self.start_time).total_seconds() / 60) 

316 current_minute = int((self.current_time - self.start_time).total_seconds() / 60) 

317 progress = (current_minute / total_minutes) * 100 if total_minutes > 0 else 0 

318 

319 return { 

320 "current_time": self.current_time.isoformat(), 

321 "start_time": self.start_time.isoformat(), 

322 "end_time": self.end_time.isoformat(), 

323 "progress_percentage": progress, 

324 "current_minute": current_minute, 

325 "total_minutes": total_minutes, 

326 "is_paused": self.is_paused, 

327 }