Coverage for core/trading/engines/simulation_engine.py: 71.91%
388 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-13 18:58 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-13 18:58 +0000
1"""
2模拟交易引擎
3"""
5import json
6import uuid
7from datetime import datetime, timedelta
8from decimal import Decimal
9from typing import Any, Dict, List, Optional
11from core.data_source.adapters.quote_adapter import QuoteAdapter
12from core.models.broker import FeeConfig, TradingFee
13from core.models.trading import (AccountBalance, MarketData, Order,
14 OrderResult, OrderSide, OrderStatus,
15 OrderType, Position, Trade)
16from core.repositories.trading_repository import TradingRepository
17from core.trading.utils.fee_calculator import FeeCalculator
20class SimulationEngine:
21 """模拟交易引擎"""
23 def __init__(
24 self,
25 user_id: str,
26 session_id: str,
27 initial_capital: Decimal,
28 strategy_engine=None,
29 currency: str = "USD",
30 fee_config: Optional[FeeConfig] = None,
31 ):
32 self.user_id = user_id
33 self.session_id = session_id
34 self.initial_capital = initial_capital
35 self.current_capital = initial_capital
36 self.currency = currency # 支持多币种
37 self.positions: Dict[str, Position] = {}
38 self.orders: Dict[str, Order] = {}
39 self.trades: List[Trade] = []
40 self.trading_repo = TradingRepository()
41 self.quote_adapter = QuoteAdapter(user_id)
42 self.strategy_engine = strategy_engine # 引用StrategyEngine,用于统一日志管理
44 # Redis客户端(用于订单持久化)
45 from infrastructure.database.redis_client import get_redis
47 self.redis_client = get_redis()
49 # 费用计算器(使用传入的费用配置,如果没有则使用默认配置)
50 if fee_config is None:
51 fee_config = FeeConfig() # 使用默认的长桥费率
52 self.fee_calculator = FeeCalculator(fee_config)
54 # 模拟交易参数(已废弃,使用费用配置替代)
55 self.commission_rate = Decimal("0.001") # 0.1% 手续费(已废弃)
56 self.min_commission = Decimal("1.0") # 最小手续费(已废弃)
57 self.slippage_rate = Decimal("0.0005") # 0.05% 滑点
59 # 市场数据缓存
60 self.market_data_cache: Dict[str, MarketData] = {}
61 self.last_update_time = datetime.now()
63 # 订单时间窗口缓存(用于精确的成交判定)
64 self.order_time_windows: Dict[str, List[MarketData]] = {}
66 def log_message(self, message: str, log_type: str = "log"):
67 """统一日志接口 - 通过StrategyEngine记录日志"""
68 if self.strategy_engine:
69 self.strategy_engine.log_message(message, log_type, "模拟交易模块")
70 else:
71 # 如果没有StrategyEngine引用,使用print作为备用
72 print(f"📝 [模拟交易模块] {message}")
74 def initialize(self) -> bool:
75 """初始化模拟交易引擎"""
76 try:
77 # 初始化账户余额
78 self._initialize_account_balance()
80 # 加载现有持仓(如果有)
81 self._load_existing_positions()
83 return True
84 except Exception as e:
85 self.log_message(f"❌ 模拟交易引擎初始化失败: {e}", "error")
86 return False
88 def submit_order(self, order: Order) -> OrderResult:
89 """提交模拟订单"""
90 try:
91 # 验证订单
92 validation_result = self._validate_order(order)
93 if not validation_result["valid"]:
94 return OrderResult(
95 order_id="",
96 status=OrderStatus.REJECTED,
97 message=f"订单验证失败: {', '.join(validation_result['errors'])}",
98 )
100 # 生成订单ID
101 order_id = str(uuid.uuid4())
102 order.order_id = order_id
104 # 存储订单到内存
105 self.orders[order_id] = order
107 # 保存订单到Redis数据库
108 self._save_order_to_redis(order)
110 # 处理订单
111 if order.order_type == OrderType.MARKET:
112 # 市价单立即成交
113 return self._process_market_order(order)
114 elif order.order_type == OrderType.LIMIT:
115 # 限价单等待成交
116 return self._process_limit_order(order)
117 else:
118 return OrderResult(
119 order_id=order_id,
120 status=OrderStatus.REJECTED,
121 message="不支持的订单类型",
122 )
124 except Exception as e:
125 self.log_message(f"❌ 提交订单失败: {e}", "error")
126 return OrderResult(
127 order_id="",
128 status=OrderStatus.REJECTED,
129 message=f"提交订单失败: {str(e)}",
130 )
132 def process_market_data(self, market_data: MarketData):
133 """处理市场数据,检查订单成交"""
134 try:
135 # 更新市场数据缓存
136 self.market_data_cache[market_data.symbol] = market_data
137 self.last_update_time = datetime.now()
139 # 更新订单时间窗口缓存
140 self._update_order_time_windows(market_data)
142 # 检查限价单成交(基于时间窗口)
143 self._check_limit_order_execution_with_time_window(market_data)
145 # 更新持仓市值
146 self._update_positions_market_value()
148 except Exception as e:
149 self.log_message(f"❌ 处理市场数据失败: {e}", "error")
151 def get_account_balance(self) -> AccountBalance:
152 """获取模拟账户余额"""
153 # 计算当前总资产
154 total_value = self.current_capital
155 for position in self.positions.values():
156 total_value += position.market_value
158 return AccountBalance(
159 total_cash=total_value,
160 available_cash=self.current_capital,
161 frozen_cash=Decimal("0"),
162 currency=self.currency,
163 )
165 def get_positions(self) -> List[Position]:
166 """获取模拟持仓"""
167 return list(self.positions.values())
169 def get_pending_orders(self) -> List[Order]:
170 """获取待处理订单"""
171 return [
172 order
173 for order in self.orders.values()
174 if order.status == OrderStatus.PENDING
175 ]
177 def get_trade_history(self) -> List[Trade]:
178 """获取成交历史"""
179 return self.trades.copy()
181 def cancel_order(self, order_id: str) -> bool:
182 """取消订单"""
183 if order_id not in self.orders:
184 return False
186 order = self.orders[order_id]
187 if order.status != OrderStatus.PENDING:
188 return False
190 order.status = OrderStatus.CANCELLED
191 return True
193 def get_performance_metrics(self) -> Dict[str, Any]:
194 """获取性能指标"""
195 try:
196 # 计算总收益率
197 total_return = (
198 self.get_account_balance().total_cash - self.initial_capital
199 ) / self.initial_capital
201 # 计算胜率
202 winning_trades = len(
203 [trade for trade in self.trades if trade.side == OrderSide.SELL]
204 )
205 total_trades = len(self.trades)
206 win_rate = (
207 Decimal(str(winning_trades)) / Decimal(str(total_trades))
208 if total_trades > 0
209 else Decimal("0")
210 )
212 # 计算最大回撤(简化计算)
213 max_drawdown = self._calculate_max_drawdown()
215 return {
216 "total_return": total_return,
217 "win_rate": win_rate,
218 "total_trades": total_trades,
219 "max_drawdown": max_drawdown,
220 "current_capital": self.current_capital,
221 "total_assets": self.get_account_balance().total_cash,
222 "positions_count": len(self.positions),
223 }
224 except Exception as e:
225 self.log_message(f"❌ 计算性能指标失败: {e}", "error")
226 return {}
228 # ===== 私有方法 =====
230 def _save_order_to_redis(self, order: Order):
231 """保存订单到Redis数据库"""
232 try:
233 # 构建订单数据(过滤掉None值)
234 order_data = {
235 "id": order.order_id,
236 "session_id": order.session_id,
237 "symbol": order.symbol,
238 "side": order.side.value,
239 "order_type": order.order_type.value,
240 "quantity": str(order.quantity),
241 "status": order.status.value,
242 "submitted_at": datetime.now().isoformat(),
243 "commission": "0", # 初始手续费为0,成交时会更新
244 }
246 # 只有当price不为None时才添加
247 if order.price is not None:
248 order_data["price"] = str(order.price)
250 # 保存订单详情(使用与TradingRepository一致的键格式)
251 order_key = f"trading_order:{order.order_id}"
252 self.redis_client.client.hset(order_key, mapping=order_data)
254 # 将会话订单ID添加到会话订单集合
255 session_orders_key = f"session_orders:{order.session_id}"
256 self.redis_client.client.sadd(session_orders_key, order.order_id)
258 except Exception as e:
259 self.log_message(f"❌ 保存订单到Redis失败: {e}", "error")
261 def _update_order_in_redis(self, order: Order):
262 """更新Redis中的订单状态"""
263 try:
264 order_key = f"trading_order:{order.order_id}"
266 # 更新订单状态
267 self.redis_client.client.hset(order_key, "status", order.status.value)
269 # 如果订单已成交,更新成交信息
270 if order.status == OrderStatus.FILLED:
271 # 从trades列表中查找对应的成交记录
272 for trade in self.trades:
273 if trade.order_id == order.order_id:
274 self.redis_client.client.hset(
275 order_key, "filled_at", trade.timestamp.isoformat()
276 )
277 self.redis_client.client.hset(
278 order_key, "filled_price", str(trade.price)
279 )
280 self.redis_client.client.hset(
281 order_key, "filled_quantity", str(trade.quantity)
282 )
283 self.redis_client.client.hset(
284 order_key, "commission", str(trade.commission)
285 )
286 # 保存交易费用信息
287 if trade.trading_fee:
288 import json
289 from decimal import Decimal
291 # 自定义JSON编码器处理Decimal类型
292 class DecimalEncoder(json.JSONEncoder):
293 def default(self, obj):
294 if isinstance(obj, Decimal):
295 return float(obj)
296 return super().default(obj)
298 self.redis_client.client.hset(
299 order_key,
300 "trading_fee",
301 json.dumps(trade.trading_fee, cls=DecimalEncoder),
302 )
303 break
305 # 订单成交后,更新持仓数据
306 self._save_positions_to_redis()
308 except Exception as e:
309 self.log_message(f"❌ 更新订单到Redis失败: {e}", "error")
311 def _save_positions_to_redis(self):
312 """保存持仓数据到Redis"""
313 try:
314 import uuid
315 from datetime import datetime
317 # 为每个持仓创建历史记录
318 for symbol, position in self.positions.items():
319 if position.quantity > 0: # 只保存有持仓的记录
320 position_id = str(uuid.uuid4())
322 # 保存持仓详情
323 position_data = {
324 "id": position_id,
325 "session_id": self.session_id,
326 "symbol": position.symbol,
327 "quantity": str(position.quantity),
328 "avg_price": str(position.avg_price),
329 "current_price": str(position.current_price),
330 "market_value": str(position.market_value),
331 "unrealized_pnl": str(position.unrealized_pnl),
332 "realized_pnl": str(position.realized_pnl),
333 "timestamp": datetime.now().isoformat(),
334 }
336 position_key = f"position_history:{position_id}"
337 self.redis_client.client.hset(position_key, mapping=position_data)
339 # 添加到会话持仓集合
340 session_positions_key = f"session_positions:{self.session_id}"
341 self.redis_client.client.sadd(session_positions_key, position_id)
343 except Exception as e:
344 self.log_message(f"❌ 保存持仓到Redis失败: {e}", "error")
346 def _initialize_account_balance(self):
347 """初始化账户余额"""
348 self.current_capital = self.initial_capital
350 def _load_existing_positions(self):
351 """加载现有持仓"""
352 try:
353 # 从数据库加载最新持仓
354 latest_positions = self.trading_repo.get_latest_positions(self.session_id)
356 for symbol, position_data in latest_positions.items():
357 position = Position(
358 symbol=symbol,
359 quantity=position_data.quantity,
360 avg_price=position_data.avg_price,
361 current_price=position_data.avg_price, # 初始使用成本价
362 market_value=position_data.market_value,
363 unrealized_pnl=position_data.unrealized_pnl,
364 realized_pnl=position_data.realized_pnl,
365 )
366 self.positions[symbol] = position
368 except Exception as e:
369 self.log_message(f"❌ 加载现有持仓失败: {e}", "error")
371 def _validate_order(self, order: Order) -> Dict[str, Any]:
372 """验证订单"""
373 errors = []
374 warnings = []
376 # 检查资金充足性(买入订单)
377 if order.side == OrderSide.BUY:
378 if order.order_type == OrderType.MARKET:
379 # 市价单需要估算价格
380 estimated_price = self._get_estimated_price(order.symbol)
381 if estimated_price:
382 required_amount = order.quantity * estimated_price
383 if required_amount > self.current_capital:
384 error_msg = f"资金不足: 需要 ${required_amount:.2f}, 可用 ${self.current_capital:.2f}"
385 errors.append("资金不足")
386 self.log_message(f"❌ 订单验证失败: {error_msg}", "warn")
387 else:
388 errors.append("无法获取股票价格")
389 self.log_message(
390 f"❌ 订单验证失败: 无法获取股票价格 {order.symbol}", "error"
391 )
392 elif order.order_type == OrderType.LIMIT and order.price:
393 required_amount = order.quantity * order.price
394 if required_amount > self.current_capital:
395 error_msg = f"资金不足: 需要 ${required_amount:.2f}, 可用 ${self.current_capital:.2f}"
396 errors.append("资金不足")
397 self.log_message(f"❌ 订单验证失败: {error_msg}", "warn")
399 # 检查持仓充足性(卖出订单)
400 elif order.side == OrderSide.SELL:
401 if order.symbol not in self.positions:
402 errors.append("没有该股票的持仓")
403 self.log_message(f"❌ 订单验证失败: 没有 {order.symbol} 的持仓", "warn")
404 else:
405 current_position = self.positions[order.symbol]
406 if order.quantity > current_position.quantity:
407 error_msg = f"持仓数量不足: 需要 {order.quantity}, 持有 {current_position.quantity}"
408 errors.append("持仓数量不足")
409 self.log_message(f"❌ 订单验证失败: {error_msg}", "warn")
411 return {"valid": len(errors) == 0, "errors": errors, "warnings": warnings}
413 def _process_market_order(self, order: Order) -> OrderResult:
414 """处理市价单"""
415 try:
416 # 获取当前市场价格
417 current_price = self._get_current_price(order.symbol)
418 if not current_price:
419 return OrderResult(
420 order_id=order.order_id,
421 status=OrderStatus.REJECTED,
422 message="无法获取市场价格",
423 )
425 # 应用滑点
426 execution_price = self._apply_slippage(current_price, order.side)
428 # 计算交易费用(使用新的费用计算器)
429 trading_fee = self._calculate_trading_fee(
430 order.quantity, execution_price, order.side, order.symbol
431 )
433 # 计算旧的手续费(兼容性)
434 commission = self._calculate_commission(order.quantity, execution_price)
436 # 执行交易
437 trade = Trade(
438 order_id=order.order_id,
439 symbol=order.symbol,
440 side=order.side,
441 quantity=order.quantity,
442 price=execution_price,
443 timestamp=datetime.now(),
444 commission=commission,
445 trading_fee=trading_fee.model_dump() if trading_fee else None,
446 )
448 # 更新持仓和资金(包含费用扣除)
449 self._update_position_and_capital(trade)
451 # 更新订单的费用信息
452 order.trading_fee = trading_fee.model_dump() if trading_fee else None
454 # 记录成交
455 self.trades.append(trade)
456 order.status = OrderStatus.FILLED
458 # 更新Redis中的订单状态
459 self._update_order_in_redis(order)
461 # 记录费用信息到日志
462 if trading_fee:
463 self.log_message(
464 f"📊 订单 {order.order_id[:8]} 成交费用: {trading_fee.total_fee} {trading_fee.currency} "
465 f"(平台费: {trading_fee.platform_fee}, 活动费: {trading_fee.activity_fee}, "
466 f"交收费: {trading_fee.clearing_fee}, 审计费: {trading_fee.audit_fee})",
467 "log",
468 )
470 return OrderResult(
471 order_id=order.order_id,
472 status=OrderStatus.FILLED,
473 message="订单已成交",
474 filled_price=execution_price,
475 filled_quantity=order.quantity,
476 trading_fee=trading_fee.model_dump() if trading_fee else None,
477 )
479 except Exception as e:
480 self.log_message(f"❌ 处理市价单失败: {e}", "error")
481 return OrderResult(
482 order_id=order.order_id,
483 status=OrderStatus.REJECTED,
484 message=f"处理市价单失败: {str(e)}",
485 )
487 def _process_limit_order(self, order: Order) -> OrderResult:
488 """处理限价单"""
489 try:
490 # 检查是否可以立即成交
491 current_price = self._get_current_price(order.symbol)
492 if current_price and self._can_execute_limit_order(order, current_price):
493 # 立即成交
494 return self._execute_limit_order(order, current_price)
495 else:
496 # 等待成交
497 order.status = OrderStatus.PENDING
498 return OrderResult(
499 order_id=order.order_id,
500 status=OrderStatus.PENDING,
501 message="订单已提交,等待成交",
502 )
504 except Exception as e:
505 self.log_message(f"❌ 处理限价单失败: {e}", "error")
506 return OrderResult(
507 order_id=order.order_id,
508 status=OrderStatus.REJECTED,
509 message=f"处理限价单失败: {str(e)}",
510 )
512 def _update_order_time_windows(self, market_data: MarketData):
513 """更新订单时间窗口缓存"""
514 symbol = market_data.symbol
516 # 为每个股票维护时间窗口数据
517 if symbol not in self.order_time_windows:
518 self.order_time_windows[symbol] = []
520 # 添加新的市场数据
521 self.order_time_windows[symbol].append(market_data)
523 # 保持时间窗口大小(例如:保留最近60分钟的数据)
524 max_window_size = 60
525 if len(self.order_time_windows[symbol]) > max_window_size:
526 self.order_time_windows[symbol] = self.order_time_windows[symbol][
527 -max_window_size:
528 ]
530 def _check_limit_order_execution_with_time_window(self, market_data: MarketData):
531 """基于时间窗口检查限价单成交"""
532 pending_orders = [
533 order
534 for order in self.orders.values()
535 if order.status == OrderStatus.PENDING
536 and order.symbol == market_data.symbol
537 ]
539 for order in pending_orders:
540 # 使用时间窗口数据进行精确的成交判定
541 execution_price = self._check_order_execution_with_time_window(
542 order, market_data
543 )
544 if execution_price:
545 self._execute_limit_order(order, execution_price)
547 def _check_order_execution_with_time_window(
548 self, order: Order, market_data: MarketData
549 ) -> Optional[Decimal]:
550 """基于时间窗口检查订单成交"""
551 if not order.price:
552 return None
554 symbol = order.symbol
556 # 获取该股票的时间窗口数据
557 time_window_data = self.order_time_windows.get(symbol, [])
559 if not time_window_data:
560 # 如果没有时间窗口数据,使用当前K线数据
561 return self._check_order_execution_with_candlestick(order, market_data)
563 # 在时间窗口内查找是否满足成交条件
564 if order.side == OrderSide.BUY:
565 # 买入限价单:检查时间窗口内是否有价格触及限价
566 min_price = min([data.low for data in time_window_data])
567 if min_price <= order.price:
568 # 使用当前收盘价作为成交价
569 return market_data.close
570 else:
571 # 卖出限价单:检查时间窗口内是否有价格触及限价
572 max_price = max([data.high for data in time_window_data])
573 if max_price >= order.price:
574 # 使用当前收盘价作为成交价
575 return market_data.close
577 return None
579 def _check_limit_order_execution(self, market_data: MarketData):
580 """检查限价单成交 - 基于K线数据和时间窗口的精确判定"""
581 pending_orders = [
582 order
583 for order in self.orders.values()
584 if order.status == OrderStatus.PENDING
585 and order.symbol == market_data.symbol
586 ]
588 for order in pending_orders:
589 # 使用K线数据进行精确的成交判定
590 execution_price = self._check_order_execution_with_candlestick(
591 order, market_data
592 )
593 if execution_price:
594 self._execute_limit_order(order, execution_price)
596 def _check_order_execution_with_candlestick(
597 self, order: Order, market_data: MarketData
598 ) -> Optional[Decimal]:
599 """基于K线数据检查订单成交 - 考虑时间窗口内的价格变化"""
600 if not order.price:
601 return None
603 # 获取K线数据(开盘价、最高价、最低价、收盘价)
604 open_price = market_data.open
605 high_price = market_data.high
606 low_price = market_data.low
607 close_price = market_data.close
609 if order.side == OrderSide.BUY:
610 # 买入限价单:在挂单当前分钟或未来时间结束的任何时刻,价格触及或低于限价
611 # 检查最低价是否触及限价
612 if low_price <= order.price:
613 # 成交价格:使用收盘价作为成交价(更符合当前市场价格概念)
614 return close_price
615 else:
616 # 卖出限价单:在挂单当前分钟或未来时间结束的任何时刻,价格触及或高于限价
617 # 检查最高价是否触及限价
618 if high_price >= order.price:
619 # 成交价格:使用收盘价作为成交价(更符合当前市场价格概念)
620 return close_price
622 return None
624 def _can_execute_limit_order(self, order: Order, current_price: Decimal) -> bool:
625 """检查限价单是否可以成交(简化版本,用于兼容性)"""
626 if not order.price:
627 return False
629 if order.side == OrderSide.BUY:
630 # 买入限价单:当前价格 <= 限价
631 return current_price <= order.price
632 else:
633 # 卖出限价单:当前价格 >= 限价
634 return current_price >= order.price
636 def _execute_limit_order(
637 self, order: Order, execution_price: Decimal
638 ) -> OrderResult:
639 """执行限价单"""
640 try:
641 # 计算交易费用(使用新的费用计算器)
642 trading_fee = self._calculate_trading_fee(
643 order.quantity, execution_price, order.side, order.symbol
644 )
646 # 计算旧的手续费(兼容性)
647 commission = self._calculate_commission(order.quantity, execution_price)
649 # 执行交易
650 trade = Trade(
651 order_id=order.order_id,
652 symbol=order.symbol,
653 side=order.side,
654 quantity=order.quantity,
655 price=execution_price,
656 timestamp=datetime.now(),
657 commission=commission,
658 trading_fee=trading_fee.model_dump() if trading_fee else None,
659 )
661 # 更新持仓和资金(包含费用扣除)
662 self._update_position_and_capital(trade)
664 # 更新订单的费用信息
665 order.trading_fee = trading_fee.model_dump() if trading_fee else None
667 # 记录成交
668 self.trades.append(trade)
669 order.status = OrderStatus.FILLED
671 # 更新Redis中的订单状态
672 self._update_order_in_redis(order)
674 # 记录费用信息到日志
675 if trading_fee:
676 self.log_message(
677 f"📊 订单 {order.order_id[:8]} 成交费用: {trading_fee.total_fee} {trading_fee.currency} "
678 f"(平台费: {trading_fee.platform_fee}, 活动费: {trading_fee.activity_fee}, "
679 f"交收费: {trading_fee.clearing_fee}, 审计费: {trading_fee.audit_fee})",
680 "log",
681 )
683 return OrderResult(
684 order_id=order.order_id,
685 status=OrderStatus.FILLED,
686 message="限价单已成交",
687 filled_price=execution_price,
688 filled_quantity=order.quantity,
689 trading_fee=trading_fee.model_dump() if trading_fee else None,
690 )
692 except Exception as e:
693 self.log_message(f"❌ 执行限价单失败: {e}", "error")
694 return OrderResult(
695 order_id=order.order_id,
696 status=OrderStatus.REJECTED,
697 message=f"执行限价单失败: {str(e)}",
698 )
700 def _update_position_and_capital(self, trade: Trade):
701 """更新持仓和资金"""
702 symbol = trade.symbol
703 quantity = trade.quantity
704 price = trade.price
705 commission = trade.commission
707 # 获取交易费用(优先使用新的费用结构)
708 total_fee = Decimal("0")
709 if trade.trading_fee and isinstance(trade.trading_fee, dict):
710 total_fee = Decimal(str(trade.trading_fee.get("total_fee", "0")))
711 else:
712 # 兼容旧的手续费字段
713 total_fee = commission
715 if trade.side == OrderSide.BUY:
716 # 买入:减少现金(交易金额 + 费用),增加持仓
717 total_cost = quantity * price + total_fee
718 self.current_capital -= total_cost
720 self.log_message(
721 f"💰 买入扣款: 交易金额 {quantity * price:.2f} + 费用 {total_fee:.4f} = {total_cost:.2f} {self.currency}",
722 "log",
723 )
725 if symbol in self.positions:
726 # 更新现有持仓
727 position = self.positions[symbol]
728 total_quantity = position.quantity + quantity
729 total_cost_basis = (
730 position.avg_price * position.quantity + quantity * price
731 )
732 new_avg_price = total_cost_basis / total_quantity
734 position.quantity = total_quantity
735 position.avg_price = new_avg_price
736 position.current_price = price
737 position.market_value = total_quantity * price
738 position.unrealized_pnl = position.market_value - (
739 new_avg_price * total_quantity
740 )
741 else:
742 # 新建持仓
743 position = Position(
744 symbol=symbol,
745 quantity=quantity,
746 avg_price=price,
747 current_price=price,
748 market_value=quantity * price,
749 unrealized_pnl=Decimal("0"),
750 realized_pnl=Decimal("0"),
751 )
752 self.positions[symbol] = position
754 else:
755 # 卖出:增加现金(交易金额 - 费用),减少持仓
756 if symbol not in self.positions:
757 raise ValueError("没有该股票的持仓")
759 position = self.positions[symbol]
760 if quantity > position.quantity:
761 raise ValueError("持仓数量不足")
763 # 计算已实现盈亏(卖出价格 - 成本价格)* 数量 - 费用
764 realized_pnl = (price - position.avg_price) * quantity - total_fee
765 position.realized_pnl += realized_pnl
767 # 更新资金:增加交易金额,扣除费用
768 cash_received = quantity * price - total_fee
769 self.current_capital += cash_received
771 self.log_message(
772 f"💰 卖出入账: 交易金额 {quantity * price:.2f} - 费用 {total_fee:.4f} = {cash_received:.2f} {self.currency}",
773 "log",
774 )
776 # 更新持仓
777 position.quantity -= quantity
778 if position.quantity == 0:
779 # 完全卖出,删除持仓
780 del self.positions[symbol]
781 else:
782 # 部分卖出,更新市值
783 position.current_price = price
784 position.market_value = position.quantity * price
785 position.unrealized_pnl = position.market_value - (
786 position.avg_price * position.quantity
787 )
789 def _get_current_price(self, symbol: str) -> Optional[Decimal]:
790 """获取当前价格"""
791 # 优先从缓存获取
792 if symbol in self.market_data_cache:
793 return self.market_data_cache[symbol].close
795 # 如果缓存中没有,尝试从数据源获取(实时交易模式)
796 try:
797 quotes = self.quote_adapter.get_quote([symbol])
798 if quotes and len(quotes) > 0:
799 quote = quotes[0]
800 if hasattr(quote, "last_done") and quote.last_done:
801 return Decimal(str(quote.last_done))
802 else:
803 self.log_message(f"⚠️ 无法获取 {symbol} 的价格数据", "warn")
804 else:
805 self.log_message(f"⚠️ 未获取到 {symbol} 的报价数据", "warn")
806 except Exception as e:
807 self.log_message(f"❌ 获取价格失败: {symbol} - {e}", "error")
809 return None
811 def _get_estimated_price(self, symbol: str) -> Optional[Decimal]:
812 """获取估算价格(用于市价单)"""
813 return self._get_current_price(symbol)
815 def _apply_slippage(self, price: Decimal, side: OrderSide) -> Decimal:
816 """应用滑点"""
817 if side == OrderSide.BUY:
818 # 买入时价格向上滑点
819 return price * (1 + self.slippage_rate)
820 else:
821 # 卖出时价格向下滑点
822 return price * (1 - self.slippage_rate)
824 def _calculate_commission(self, quantity: Decimal, price: Decimal) -> Decimal:
825 """计算手续费(已废弃,保留用于兼容性)"""
826 commission = quantity * price * self.commission_rate
827 return max(commission, self.min_commission)
829 def _calculate_trading_fee(
830 self, quantity: Decimal, price: Decimal, side: OrderSide, symbol: str = None
831 ) -> Optional[TradingFee]:
832 """计算交易费用
834 Args:
835 quantity: 交易数量
836 price: 交易价格
837 side: 交易方向(买入/卖出)
838 symbol: 股票代码(用于确定市场类型)
840 Returns:
841 TradingFee对象,包含详细的费用分解
842 """
843 try:
844 # 根据股票代码确定市场类型和货币类型
845 market, currency = self._get_market_and_currency_from_symbol(symbol)
847 # 根据市场类型选择对应的计算方法
848 if market == "US":
849 return self.fee_calculator.calculate_us_fees(
850 quantity, price, side.value
851 )
852 elif market == "HK":
853 return self.fee_calculator.calculate_hk_fees(
854 quantity, price, side.value
855 )
856 else:
857 # 默认使用美股费率
858 self.log_message(f"⚠️ 未知市场类型 {market},使用美股费率", "warning")
859 return self.fee_calculator.calculate_us_fees(
860 quantity, price, side.value
861 )
862 except Exception as e:
863 self.log_message(f"❌ 计算交易费用失败: {e}", "error")
864 return None
866 def _get_market_and_currency_from_symbol(self, symbol: str) -> tuple[str, str]:
867 """根据股票代码确定市场类型和货币类型
869 Args:
870 symbol: 股票代码(如 AAPL.US, 700.HK)
872 Returns:
873 tuple: (市场类型, 货币类型)
874 """
875 if not symbol or "." not in symbol:
876 # 如果没有股票代码或格式不正确,使用会话级别的货币
877 return ("US" if self.currency == "USD" else "HK", self.currency)
879 # 从股票代码中提取市场后缀
880 market_suffix = symbol.split(".")[-1].upper()
882 # 市场类型到货币类型的映射
883 market_currency_mapping = {
884 "US": ("US", "USD"),
885 "HK": ("HK", "HKD"),
886 "SZ": ("CN", "CNY"), # 深圳
887 "SH": ("CN", "CNY"), # 上海
888 "SG": ("SG", "SGD"), # 新加坡
889 }
891 market, currency = market_currency_mapping.get(market_suffix, ("US", "USD"))
892 return market, currency
894 def _update_positions_market_value(self):
895 """更新持仓市值"""
896 for symbol, position in self.positions.items():
897 current_price = self._get_current_price(symbol)
898 if current_price:
899 position.current_price = current_price
900 position.market_value = position.quantity * current_price
901 position.unrealized_pnl = position.market_value - (
902 position.avg_price * position.quantity
903 )
905 def _calculate_max_drawdown(self) -> Decimal:
906 """计算最大回撤(简化版本)"""
907 if not self.trades:
908 return Decimal("0")
910 # 简化计算:基于当前总资产和初始资金
911 current_total = self.get_account_balance().total_cash
912 if current_total < self.initial_capital:
913 return (self.initial_capital - current_total) / self.initial_capital
914 else:
915 return Decimal("0")