Coverage for core/trading/engines/time_series_controller.py: 15.75%
146 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-13 18:58 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-13 18:58 +0000
1"""
2时序控制器 - 用于回测模式的时间序列控制
3"""
5import asyncio
6from datetime import datetime, timedelta
7from decimal import Decimal
8from typing import Any, Callable, Dict, List, Optional
10from core.data_source.adapters.data_source_adapter import \
11 QuoteDataSourceAdapter
12from core.models.trading import MarketData
13from core.trading.strategies.base_strategy import BaseStrategy
16class TimeSeriesController:
17 """时间序列控制器"""
19 def __init__(
20 self,
21 start_time: datetime,
22 end_time: datetime,
23 timezone: str,
24 strategy_engine=None,
25 ):
26 self.start_time = start_time
27 self.end_time = end_time
28 self.timezone = timezone
29 self.current_time = start_time
30 self.is_paused = False
32 # 数据源适配器
33 self.quote_data_adapter: Optional[QuoteDataSourceAdapter] = None
35 # 策略引擎引用,用于统一日志管理
36 self.strategy_engine = strategy_engine
38 # 回调函数
39 self.on_progress: Optional[Callable] = None
40 self.on_complete: Optional[Callable] = None
41 self.on_error: Optional[Callable] = None
43 def set_quote_adapter(self, quote_adapter: QuoteDataSourceAdapter):
44 """设置数据源适配器"""
45 self.quote_data_adapter = quote_adapter
47 def set_callbacks(
48 self,
49 on_progress: Optional[Callable] = None,
50 on_complete: Optional[Callable] = None,
51 on_error: Optional[Callable] = None,
52 ):
53 """设置回调函数"""
54 self.on_progress = on_progress
55 self.on_complete = on_complete
56 self.on_error = on_error
58 def log_message(self, message: str, log_type: str = "log"):
59 """统一日志接口 - 通过StrategyEngine记录日志"""
60 if self.strategy_engine:
61 self.strategy_engine.log_message(message, log_type, "时序控制器")
62 else:
63 # 如果没有StrategyEngine引用,使用print作为备用
64 print(f"📝 [时序控制器] {message}")
66 async def run_backtest(self, strategy: BaseStrategy, symbols: List[str]):
67 """运行回测"""
68 try:
69 self.log_message(f"🚀 开始回测: {self.start_time} - {self.end_time}")
70 self.log_message(f"📊 回测股票: {symbols}")
72 total_minutes = int((self.end_time - self.start_time).total_seconds() / 60)
73 current_minute = 0
75 while self.current_time <= self.end_time and not self.is_paused:
76 try:
77 # 1. 获取当前时间点的市场数据
78 market_data_list = await self._get_market_data_at_time(
79 self.current_time, symbols
80 )
82 if market_data_list:
83 # 2. 处理每个股票的市场数据
84 for market_data in market_data_list:
85 # 更新策略上下文
86 if hasattr(strategy, "context") and strategy.context:
87 strategy.context.current_time = self.current_time
89 # 通知策略 - 等待策略处理完成
90 await self._wait_for_strategy_processing(
91 strategy, market_data
92 )
94 # 3. 等待所有交易计算完成
95 await self._wait_for_trading_calculations(strategy)
97 # 4. 处理交易回调
98 await self._process_trading_callbacks(strategy)
100 # 5. 推进到下一个时间点(按分钟)
101 self.current_time += timedelta(minutes=1)
102 current_minute += 1
104 # 6. 报告进度
105 if current_minute % 60 == 0: # 每小时报告一次
106 progress = (current_minute / total_minutes) * 100
107 self.log_message(
108 f"📈 回测进度: {progress:.1f}% ({current_minute}/{total_minutes}分钟)"
109 )
111 if self.on_progress:
112 self.on_progress(progress, self.current_time)
114 # 7. 短暂休眠,避免过度占用CPU
115 await asyncio.sleep(0.01)
117 except Exception as e:
118 self.log_message(
119 f"❌ 处理时间点失败: {self.current_time} - {e}", "error"
120 )
121 if self.on_error:
122 self.on_error(e)
123 continue
125 # 回测完成
126 self.log_message(f"✅ 回测完成: {self.current_time}")
127 if self.on_complete:
128 self.on_complete()
130 except Exception as e:
131 self.log_message(f"❌ 回测失败: {e}", "error")
132 if self.on_error:
133 self.on_error(e)
135 async def _get_market_data_at_time(
136 self, current_time: datetime, symbols: List[str]
137 ) -> List[MarketData]:
138 """获取指定时间点的市场数据"""
139 market_data_list = []
141 for symbol in symbols:
142 try:
143 # 尝试从历史数据中获取真实数据
144 historical_data = await self._get_historical_data(symbol, current_time)
146 if historical_data:
147 market_data = MarketData(
148 symbol=symbol,
149 timestamp=current_time,
150 open=Decimal(str(historical_data.get("open", 150))),
151 high=Decimal(str(historical_data.get("high", 155))),
152 low=Decimal(str(historical_data.get("low", 148))),
153 close=Decimal(str(historical_data.get("close", 152))),
154 volume=int(historical_data.get("volume", 1000)),
155 )
156 print(
157 f"【时序控制器】📊 获取历史数据: {symbol} at {current_time} = ${market_data.close:.2f}"
158 )
159 else:
160 # 如果没有历史数据,使用模拟数据
161 market_data = self._generate_simulated_data(symbol, current_time)
162 print(
163 f"【时序控制器】🎭 使用模拟数据: {symbol} at {current_time} = ${market_data.close:.2f}"
164 )
166 market_data_list.append(market_data)
168 except Exception as e:
169 print(f"【时序控制器】❌ 获取市场数据失败: {symbol} - {e}")
170 continue
172 return market_data_list
174 async def _get_historical_data(
175 self, symbol: str, timestamp: datetime
176 ) -> Optional[Dict[str, Any]]:
177 """从数据库获取历史数据"""
178 try:
179 if not self.quote_data_adapter:
180 return None
182 # 使用QuoteDataSourceAdapter的抽象接口获取历史数据
183 # 注意:这里需要获取一个时间范围的数据,而不是单个时间点
184 from datetime import timedelta
186 start_time = timestamp - timedelta(minutes=1)
187 end_time = timestamp + timedelta(minutes=1)
189 # 获取历史数据
190 def log_callback(message: str, log_type: str = "info"):
191 self.log_message(f"【数据适配器】{message}", log_type)
193 historical_data = self.quote_data_adapter.get_historical_data(
194 symbol, start_time, end_time, log_callback
195 )
197 if historical_data and len(historical_data) > 0:
198 # 找到最接近目标时间的数据
199 target_data = None
200 min_diff = float("inf")
201 for data in historical_data:
202 time_diff = abs((data.timestamp - timestamp).total_seconds())
203 if time_diff < min_diff:
204 min_diff = time_diff
205 target_data = data
207 if target_data:
208 return {
209 "open": float(target_data.open),
210 "high": float(target_data.high),
211 "low": float(target_data.low),
212 "close": float(target_data.close),
213 "volume": int(target_data.volume),
214 }
216 return None
218 except Exception as e:
219 print(f"【时序控制器】❌ 获取历史数据失败: {symbol} at {timestamp} - {e}")
220 return None
222 def _generate_simulated_data(self, symbol: str, timestamp: datetime) -> MarketData:
223 """生成模拟市场数据"""
224 import random
226 # 基础价格
227 base_prices = {
228 "AAPL.US": Decimal("150.50"),
229 "MSFT.US": Decimal("300.25"),
230 "TSLA.US": Decimal("200.75"),
231 "GOOGL.US": Decimal("2500.00"),
232 "YINN.US": Decimal("25.00"),
233 "YANG.US": Decimal("15.00"),
234 }
236 base_price = base_prices.get(symbol, Decimal("100.00"))
238 # 添加时间相关的波动
239 time_factor = (timestamp.hour * 60 + timestamp.minute) / (24 * 60) # 0-1之间
240 daily_variation = Decimal(
241 str(round(random.uniform(-0.05, 0.05), 4))
242 ) # ±5% 日波动,保留4位小数
243 time_variation = Decimal(
244 str(round(random.uniform(-0.01, 0.01), 4))
245 ) # ±1% 时间波动,保留4位小数
247 current_price = base_price * (Decimal("1") + daily_variation + time_variation)
248 # 将价格四舍五入到2位小数
249 current_price = current_price.quantize(Decimal("0.01"))
251 # 计算最高价和最低价,并四舍五入到2位小数
252 high_price = (current_price * Decimal("1.02")).quantize(Decimal("0.01"))
253 low_price = (current_price * Decimal("0.98")).quantize(Decimal("0.01"))
255 return MarketData(
256 symbol=symbol,
257 timestamp=timestamp,
258 open=current_price,
259 high=high_price,
260 low=low_price,
261 close=current_price,
262 volume=random.randint(1000, 10000),
263 )
265 async def _wait_for_strategy_processing(
266 self, strategy: BaseStrategy, market_data: MarketData
267 ):
268 """等待策略处理完成"""
269 try:
270 # 通知策略处理市场数据
271 strategy.on_market_data(market_data)
273 # 等待策略完成所有分析和交易决策
274 # 这里可以添加策略处理完成的确认机制
275 await asyncio.sleep(0.01) # 短暂等待,确保策略处理完成
277 except Exception as e:
278 print(f"【时序控制器】❌ 策略处理失败: {e}")
280 async def _wait_for_trading_calculations(self, strategy: BaseStrategy):
281 """等待所有交易计算完成"""
282 try:
283 # 等待持仓更新、资金计算等完成
284 await asyncio.sleep(0.01) # 短暂等待,确保计算完成
286 except Exception as e:
287 print(f"【时序控制器】❌ 交易计算失败: {e}")
289 async def _process_trading_callbacks(self, strategy: BaseStrategy):
290 """处理交易回调"""
291 try:
292 # 处理交易结果回调
293 await asyncio.sleep(0.01) # 短暂等待,确保回调处理完成
295 except Exception as e:
296 print(f"【时序控制器】❌ 交易回调处理失败: {e}")
298 def pause(self):
299 """暂停回测"""
300 self.is_paused = True
301 print(f"【时序控制器】⏸️ 回测已暂停")
303 def resume(self):
304 """恢复回测"""
305 self.is_paused = False
306 print(f"【时序控制器】▶️ 回测已恢复")
308 def stop(self):
309 """停止回测"""
310 self.is_paused = True
311 print(f"【时序控制器】⏹️ 回测已停止")
313 def get_progress(self) -> Dict[str, Any]:
314 """获取回测进度"""
315 total_minutes = int((self.end_time - self.start_time).total_seconds() / 60)
316 current_minute = int((self.current_time - self.start_time).total_seconds() / 60)
317 progress = (current_minute / total_minutes) * 100 if total_minutes > 0 else 0
319 return {
320 "current_time": self.current_time.isoformat(),
321 "start_time": self.start_time.isoformat(),
322 "end_time": self.end_time.isoformat(),
323 "progress_percentage": progress,
324 "current_minute": current_minute,
325 "total_minutes": total_minutes,
326 "is_paused": self.is_paused,
327 }