Coverage for api/v1/endpoints/trade_test.py: 70.30%
330 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-13 18:58 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-13 18:58 +0000
1"""
2交易测试API端点
3提供长桥OpenAPI交易功能的测试接口
4"""
6from datetime import date, datetime
7from decimal import Decimal
8from typing import Any, Dict, List, Optional, Union
10from fastapi import APIRouter, Depends, HTTPException
11from pydantic import BaseModel, Field, field_validator
13from core.data_source.adapters.trade_adapter import TradeDataSourceAdapter
14from core.middleware.auth_middleware import get_current_user
15from core.models.user import User
17router = APIRouter(prefix="/api-test/trade", tags=["交易测试"])
20# 请求模型定义
21class SubmitOrderRequest(BaseModel):
22 """委托下单请求"""
24 symbol: str = Field(..., description="股票代码", example="YINN.US")
25 order_type: str = Field(..., description="订单类型", example="LO")
26 side: str = Field(..., description="买卖方向", example="Buy")
27 submitted_quantity: Union[float, int, str] = Field(
28 ..., description="委托数量", example=100
29 )
30 time_in_force: str = Field(..., description="订单有效期", example="Day")
31 submitted_price: Optional[Union[float, int, str]] = Field(
32 None, description="委托价格", example=50.0
33 )
34 trigger_price: Optional[Union[float, int, str]] = Field(
35 None, description="触发价格", example=45.0
36 )
37 limit_offset: Optional[Union[float, int, str]] = Field(
38 None, description="限价偏移", example=0.1
39 )
40 trailing_amount: Optional[Union[float, int, str]] = Field(
41 None, description="追踪金额", example=0.5
42 )
43 trailing_percent: Optional[Union[float, int, str]] = Field(
44 None, description="追踪百分比", example=0.01
45 )
46 expire_date: Optional[Union[date, str]] = Field(
47 None, description="过期日期", example=date(2025, 1, 1)
48 )
49 outside_rth: Optional[str] = Field(
50 None, description="盘前盘后交易", example="RTH_ONLY"
51 )
52 remark: Optional[str] = Field(None, description="备注", example="测试订单")
54 @field_validator(
55 "submitted_quantity",
56 "submitted_price",
57 "trigger_price",
58 "limit_offset",
59 "trailing_amount",
60 "trailing_percent",
61 )
62 @classmethod
63 def validate_decimal_fields(cls, v):
64 if v is not None:
65 return Decimal(str(v))
66 return v
68 @field_validator("expire_date", mode="before")
69 @classmethod
70 def validate_expire_date(cls, v):
71 if v is not None:
72 if isinstance(v, str):
73 # 处理ISO格式的日期时间字符串
74 if "T" in v:
75 # 提取日期部分
76 date_part = v.split("T")[0]
77 return date.fromisoformat(date_part)
78 else:
79 return date.fromisoformat(v)
80 elif hasattr(v, "date"):
81 # 处理datetime对象
82 return v.date()
83 return v
86class ReplaceOrderRequest(BaseModel):
87 """改单请求"""
89 order_id: str = Field(..., description="订单ID", example="709043056541253632")
90 quantity: Union[float, int, str] = Field(..., description="修改数量", example=150)
91 price: Optional[Union[float, int, str]] = Field(
92 None, description="修改价格", example=55.0
93 )
94 trigger_price: Optional[Union[float, int, str]] = Field(
95 None, description="修改触发价格", example=50.0
96 )
97 remark: Optional[str] = Field(None, description="备注", example="修改订单")
99 @field_validator("quantity", "price", "trigger_price")
100 @classmethod
101 def validate_decimal_fields(cls, v):
102 if v is not None:
103 return Decimal(str(v))
104 return v
107class CancelOrderRequest(BaseModel):
108 """撤单请求"""
110 order_id: str = Field(..., description="订单ID", example="709043056541253632")
113class GetOrdersRequest(BaseModel):
114 """获取订单请求"""
116 symbol: Optional[str] = Field(None, description="股票代码", example="YINN.US")
117 status: Optional[Union[str, List[str]]] = Field(
118 None, description="订单状态", example=["New", "Filled"]
119 )
120 side: Optional[str] = Field(None, description="买卖方向", example="Buy")
121 market: Optional[str] = Field(None, description="市场", example="US")
122 order_id: Optional[str] = Field(
123 None, description="订单ID", example="709043056541253632"
124 )
126 @field_validator("status")
127 @classmethod
128 def validate_status(cls, v):
129 if isinstance(v, str):
130 return [s.strip() for s in v.split(",") if s.strip()]
131 return v
134class GetHistoryOrdersRequest(BaseModel):
135 """获取历史订单请求"""
137 symbol: Optional[str] = Field(None, description="股票代码", example="YINN.US")
138 status: Optional[Union[str, List[str]]] = Field(
139 None, description="订单状态", example=["Filled", "Canceled"]
140 )
141 side: Optional[str] = Field(None, description="买卖方向", example="Buy")
142 market: Optional[str] = Field(None, description="市场", example="US")
143 start_at: Optional[Union[date, datetime]] = Field(None, description="开始时间")
144 end_at: Optional[Union[date, datetime]] = Field(None, description="结束时间")
146 @field_validator("status")
147 @classmethod
148 def validate_status(cls, v):
149 if isinstance(v, str):
150 return [s.strip() for s in v.split(",") if s.strip()]
151 return v
153 @field_validator("start_at", "end_at")
154 @classmethod
155 def validate_datetime(cls, v):
156 if v is not None:
157 if isinstance(v, str):
158 # 处理ISO格式的日期时间字符串
159 try:
160 # 移除Z后缀并添加时区信息
161 if v.endswith("Z"):
162 v = v[:-1] + "+00:00"
163 return datetime.fromisoformat(v)
164 except ValueError:
165 # 如果解析失败,尝试只解析日期部分
166 try:
167 date_part = v.split("T")[0]
168 return datetime.combine(
169 date.fromisoformat(date_part), datetime.min.time()
170 )
171 except ValueError:
172 raise ValueError(f"Invalid date format: {v}")
173 elif isinstance(v, datetime):
174 return v
175 elif isinstance(v, date):
176 return datetime.combine(v, datetime.min.time())
177 return v
180class GetExecutionsRequest(BaseModel):
181 """获取成交明细请求"""
183 symbol: Optional[str] = Field(None, description="股票代码", example="YINN.US")
184 order_id: Optional[str] = Field(
185 None, description="订单ID", example="709043056541253632"
186 )
189class GetHistoryExecutionsRequest(BaseModel):
190 """获取历史成交明细请求"""
192 symbol: Optional[str] = Field(None, description="股票代码", example="YINN.US")
193 start_at: Optional[Union[date, datetime]] = Field(None, description="开始时间")
194 end_at: Optional[Union[date, datetime]] = Field(None, description="结束时间")
196 @field_validator("start_at", "end_at")
197 @classmethod
198 def validate_datetime(cls, v):
199 if v is not None:
200 if isinstance(v, str):
201 # 处理ISO格式的日期时间字符串
202 try:
203 # 移除Z后缀并添加时区信息
204 if v.endswith("Z"):
205 v = v[:-1] + "+00:00"
206 return datetime.fromisoformat(v)
207 except ValueError:
208 # 如果解析失败,尝试只解析日期部分
209 try:
210 date_part = v.split("T")[0]
211 return datetime.combine(
212 date.fromisoformat(date_part), datetime.min.time()
213 )
214 except ValueError:
215 raise ValueError(f"Invalid date format: {v}")
216 elif isinstance(v, datetime):
217 return v
218 elif isinstance(v, date):
219 return datetime.combine(v, datetime.min.time())
220 return v
223class GetOrderDetailRequest(BaseModel):
224 """获取订单详情请求"""
226 order_id: str = Field(..., description="订单ID", example="709043056541253632")
229class GetAccountBalanceRequest(BaseModel):
230 """获取账户余额请求"""
232 currency: Optional[str] = Field(None, description="货币", example="USD")
235class EstimateMaxPurchaseQuantityRequest(BaseModel):
236 """估算最大购买数量请求"""
238 symbol: str = Field(..., description="股票代码", example="YINN.US")
239 order_type: str = Field(..., description="订单类型", example="LO")
240 side: str = Field(..., description="买卖方向", example="Buy")
241 price: Optional[Union[float, int, str]] = Field(
242 None, description="估算价格", example=50.0
243 )
244 currency: Optional[str] = Field(None, description="货币", example="USD")
245 order_id: Optional[str] = Field(
246 None, description="订单ID(修改订单时使用)", example="709043056541253632"
247 )
248 fractional_shares: bool = Field(False, description="是否支持碎股", example=False)
250 @field_validator("price")
251 @classmethod
252 def validate_price(cls, v):
253 if v is not None:
254 return Decimal(str(v))
255 return v
258# 响应模型
259class TradeTestResponse(BaseModel):
260 """交易测试响应"""
262 success: bool = Field(..., description="是否成功")
263 message: str = Field(..., description="人类可读的结果描述")
264 data: Optional[Any] = Field(None, description="原始数据")
265 error: Optional[str] = Field(None, description="错误信息")
268def format_trade_response(
269 success: bool, message: str, data: Any = None, error: str = None
270) -> TradeTestResponse:
271 """格式化交易测试响应"""
272 return TradeTestResponse(success=success, message=message, data=data, error=error)
275# 交易测试API端点
278@router.get("/supported-order-types/{market}", response_model=TradeTestResponse)
279async def get_supported_order_types(
280 market: str, current_user: User = Depends(get_current_user)
281):
282 """获取指定市场支持的订单类型"""
283 try:
284 adapter = TradeDataSourceAdapter(current_user.id)
285 order_types = adapter.get_supported_order_types(market)
287 message = f"市场 {market} 支持的订单类型"
288 return format_trade_response(True, message, order_types)
289 except Exception as e:
290 return format_trade_response(False, f"获取订单类型失败: {e}", error=str(e))
293@router.get("/order-type-fields/{order_type}", response_model=TradeTestResponse)
294async def get_order_type_required_fields(
295 order_type: str, current_user: User = Depends(get_current_user)
296):
297 """获取订单类型必填字段"""
298 try:
299 adapter = TradeDataSourceAdapter(current_user.id)
300 required_fields = adapter.get_order_type_required_fields(order_type)
302 message = f"订单类型 {order_type} 的必填字段"
303 return format_trade_response(True, message, required_fields)
304 except Exception as e:
305 return format_trade_response(False, f"获取必填字段失败: {e}", error=str(e))
308@router.post("/submit-order", response_model=TradeTestResponse)
309async def test_submit_order(
310 request: SubmitOrderRequest, current_user: User = Depends(get_current_user)
311):
312 """测试委托下单"""
313 try:
314 adapter = TradeDataSourceAdapter(current_user.id)
315 result = adapter.submit_order(
316 symbol=request.symbol,
317 order_type=request.order_type,
318 side=request.side,
319 submitted_quantity=request.submitted_quantity,
320 time_in_force=request.time_in_force,
321 submitted_price=request.submitted_price,
322 trigger_price=request.trigger_price,
323 limit_offset=request.limit_offset,
324 trailing_amount=request.trailing_amount,
325 trailing_percent=request.trailing_percent,
326 expire_date=request.expire_date,
327 outside_rth=request.outside_rth,
328 remark=request.remark,
329 )
331 if result:
332 order_id = result.get("order_id", "N/A")
333 message = f"委托下单成功\n• 订单ID: {order_id}\n• 标的: {request.symbol}\n• 方向: {request.side}\n• 数量: {request.submitted_quantity}\n• 价格: {request.submitted_price or '市价'}"
334 else:
335 message = "委托下单失败"
337 return format_trade_response(result is not None, message, result)
339 except Exception as e:
340 return format_trade_response(False, f"委托下单失败", error=str(e))
343@router.post("/replace-order", response_model=TradeTestResponse)
344async def test_replace_order(
345 request: ReplaceOrderRequest, current_user: User = Depends(get_current_user)
346):
347 """测试改单"""
348 try:
349 adapter = TradeDataSourceAdapter(current_user.id)
350 success = adapter.replace_order(
351 order_id=request.order_id,
352 quantity=request.quantity,
353 price=request.price,
354 trigger_price=request.trigger_price,
355 remark=request.remark,
356 )
358 message = f"改单{'成功' if success else '失败'}\n• 订单ID: {request.order_id}\n• 新数量: {request.quantity}\n• 新价格: {request.price or '不变'}"
360 return format_trade_response(
361 success, message, {"order_id": request.order_id, "success": success}
362 )
364 except Exception as e:
365 return format_trade_response(False, f"改单失败", error=str(e))
368@router.post("/cancel-order", response_model=TradeTestResponse)
369async def test_cancel_order(
370 request: CancelOrderRequest, current_user: User = Depends(get_current_user)
371):
372 """测试撤单"""
373 try:
374 adapter = TradeDataSourceAdapter(current_user.id)
375 success = adapter.cancel_order(request.order_id)
377 message = f"撤单{'成功' if success else '失败'}\n• 订单ID: {request.order_id}"
379 return format_trade_response(
380 success, message, {"order_id": request.order_id, "success": success}
381 )
383 except Exception as e:
384 return format_trade_response(False, f"撤单失败", error=str(e))
387@router.post("/today-orders", response_model=TradeTestResponse)
388async def test_today_orders(
389 request: GetOrdersRequest, current_user: User = Depends(get_current_user)
390):
391 """测试获取当日订单"""
392 try:
393 adapter = TradeDataSourceAdapter(current_user.id)
394 orders = adapter.get_today_orders(
395 symbol=request.symbol,
396 status=request.status,
397 side=request.side,
398 market=request.market,
399 order_id=request.order_id,
400 )
402 if orders:
403 message = f"成功获取 {len(orders)} 条当日订单"
404 for order in orders[:3]: # 显示前3个
405 order_id = order.get("order_id", "N/A")
406 symbol = order.get("symbol", "N/A")
407 status = order.get("status", "N/A")
408 side = order.get("side", "N/A")
409 quantity = order.get("quantity", 0)
410 message += f"\n• {order_id}: {symbol} {side} {quantity}股 ({status})"
411 if len(orders) > 3:
412 message += f"\n... 还有 {len(orders) - 3} 条订单"
413 else:
414 message = "未获取到当日订单"
416 return format_trade_response(len(orders) > 0, message, orders)
418 except Exception as e:
419 return format_trade_response(False, f"获取当日订单失败", error=str(e))
422@router.post("/history-orders", response_model=TradeTestResponse)
423async def test_history_orders(
424 request: GetHistoryOrdersRequest, current_user: User = Depends(get_current_user)
425):
426 """测试获取历史订单"""
427 try:
428 adapter = TradeDataSourceAdapter(current_user.id)
429 orders = adapter.get_history_orders(
430 symbol=request.symbol,
431 status=request.status,
432 side=request.side,
433 market=request.market,
434 start_at=request.start_at,
435 end_at=request.end_at,
436 )
438 if orders:
439 message = f"成功获取 {len(orders)} 条历史订单"
440 for order in orders[:3]: # 显示前3个
441 order_id = order.get("order_id", "N/A")
442 symbol = order.get("symbol", "N/A")
443 status = order.get("status", "N/A")
444 side = order.get("side", "N/A")
445 quantity = order.get("quantity", 0)
446 submitted_at = order.get("submitted_at", "N/A")
447 message += f"\n• {order_id}: {symbol} {side} {quantity}股 ({status}) - {submitted_at}"
448 if len(orders) > 3:
449 message += f"\n... 还有 {len(orders) - 3} 条订单"
450 else:
451 message = "未获取到历史订单"
453 return format_trade_response(len(orders) > 0, message, orders)
455 except Exception as e:
456 return format_trade_response(False, f"获取历史订单失败", error=str(e))
459@router.post("/today-executions", response_model=TradeTestResponse)
460async def test_today_executions(
461 request: GetExecutionsRequest, current_user: User = Depends(get_current_user)
462):
463 """测试获取当日成交明细"""
464 try:
465 adapter = TradeDataSourceAdapter(current_user.id)
466 executions = adapter.get_today_executions(
467 symbol=request.symbol, order_id=request.order_id
468 )
470 if executions:
471 message = f"成功获取 {len(executions)} 条当日成交明细"
472 for execution in executions[:3]: # 显示前3个
473 trade_id = execution.get("trade_id", "N/A")
474 symbol = execution.get("symbol", "N/A")
475 quantity = execution.get("quantity", 0)
476 price = execution.get("price", 0)
477 trade_done_at = execution.get("trade_done_at", "N/A")
478 message += (
479 f"\n• {trade_id}: {symbol} {quantity}股 @${price} ({trade_done_at})"
480 )
481 if len(executions) > 3:
482 message += f"\n... 还有 {len(executions) - 3} 条成交"
483 else:
484 message = "未获取到当日成交明细"
486 return format_trade_response(len(executions) > 0, message, executions)
488 except Exception as e:
489 return format_trade_response(False, f"获取当日成交明细失败", error=str(e))
492@router.post("/history-executions", response_model=TradeTestResponse)
493async def test_history_executions(
494 request: GetHistoryExecutionsRequest, current_user: User = Depends(get_current_user)
495):
496 """测试获取历史成交明细"""
497 try:
498 adapter = TradeDataSourceAdapter(current_user.id)
499 executions = adapter.get_history_executions(
500 symbol=request.symbol, start_at=request.start_at, end_at=request.end_at
501 )
503 if executions:
504 message = f"成功获取 {len(executions)} 条历史成交明细"
505 for execution in executions[:3]: # 显示前3个
506 trade_id = execution.get("trade_id", "N/A")
507 symbol = execution.get("symbol", "N/A")
508 quantity = execution.get("quantity", 0)
509 price = execution.get("price", 0)
510 trade_done_at = execution.get("trade_done_at", "N/A")
511 message += (
512 f"\n• {trade_id}: {symbol} {quantity}股 @${price} ({trade_done_at})"
513 )
514 if len(executions) > 3:
515 message += f"\n... 还有 {len(executions) - 3} 条成交"
516 else:
517 message = "未获取到历史成交明细"
519 return format_trade_response(len(executions) > 0, message, executions)
521 except Exception as e:
522 return format_trade_response(False, f"获取历史成交明细失败", error=str(e))
525@router.post("/order-detail", response_model=TradeTestResponse)
526async def test_order_detail(
527 request: GetOrderDetailRequest, current_user: User = Depends(get_current_user)
528):
529 """测试获取订单详情"""
530 try:
531 adapter = TradeDataSourceAdapter(current_user.id)
532 order_detail = adapter.get_order_detail(request.order_id)
534 if order_detail:
535 symbol = order_detail.get("symbol", "N/A")
536 status = order_detail.get("status", "N/A")
537 side = order_detail.get("side", "N/A")
538 quantity = order_detail.get("quantity", 0)
539 executed_quantity = order_detail.get("executed_quantity", 0)
540 price = order_detail.get("price", 0)
541 message = f"成功获取订单详情\n• 订单ID: {request.order_id}\n• 标的: {symbol}\n• 状态: {status}\n• 方向: {side}\n• 数量: {quantity} (已成交: {executed_quantity})\n• 价格: ${price}"
542 else:
543 message = f"未找到订单 {request.order_id} 的详情"
545 return format_trade_response(order_detail is not None, message, order_detail)
547 except Exception as e:
548 return format_trade_response(False, f"获取订单详情失败", error=str(e))
551@router.post("/account-balance", response_model=TradeTestResponse)
552async def test_account_balance(
553 request: GetAccountBalanceRequest, current_user: User = Depends(get_current_user)
554):
555 """测试获取账户余额"""
556 try:
557 adapter = TradeDataSourceAdapter(current_user.id)
558 balances = adapter.get_account_balance(currency=request.currency)
560 if balances:
561 message = f"成功获取 {len(balances)} 个账户的余额信息"
562 for balance in balances[:2]: # 显示前2个
563 currency = balance.get("currency", "N/A")
564 total_cash = balance.get("total_cash", 0)
565 net_assets = balance.get("net_assets", 0)
566 buy_power = balance.get("buy_power", 0)
567 message += f"\n• {currency}: 现金${total_cash} 净资产${net_assets} 购买力${buy_power}"
568 if len(balances) > 2:
569 message += f"\n... 还有 {len(balances) - 2} 个账户"
570 else:
571 message = "未获取到账户余额信息"
573 return format_trade_response(len(balances) > 0, message, balances)
575 except Exception as e:
576 return format_trade_response(False, f"获取账户余额失败", error=str(e))
579@router.post("/estimate-max-purchase", response_model=TradeTestResponse)
580async def test_estimate_max_purchase_quantity(
581 request: EstimateMaxPurchaseQuantityRequest,
582 current_user: User = Depends(get_current_user),
583):
584 """测试估算最大购买数量"""
585 try:
586 adapter = TradeDataSourceAdapter(current_user.id)
587 result = adapter.estimate_max_purchase_quantity(
588 symbol=request.symbol,
589 order_type=request.order_type,
590 side=request.side,
591 price=request.price,
592 currency=request.currency,
593 order_id=request.order_id,
594 fractional_shares=request.fractional_shares,
595 )
597 if result:
598 cash_max_qty = result.get("cash_max_qty", 0)
599 margin_max_qty = result.get("margin_max_qty", 0)
600 message = f"成功估算最大购买数量\n• 标的: {request.symbol}\n• 方向: {request.side}\n• 现金购买力: {cash_max_qty} 股\n• 融资购买力: {margin_max_qty} 股"
601 else:
602 message = f"估算 {request.symbol} 最大购买数量失败"
604 return format_trade_response(result is not None, message, result)
606 except Exception as e:
607 return format_trade_response(False, f"估算最大购买数量失败", error=str(e))