Coverage for api/v1/endpoints/trade_test.py: 70.30%

330 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-13 18:58 +0000

1""" 

2交易测试API端点 

3提供长桥OpenAPI交易功能的测试接口 

4""" 

5 

6from datetime import date, datetime 

7from decimal import Decimal 

8from typing import Any, Dict, List, Optional, Union 

9 

10from fastapi import APIRouter, Depends, HTTPException 

11from pydantic import BaseModel, Field, field_validator 

12 

13from core.data_source.adapters.trade_adapter import TradeDataSourceAdapter 

14from core.middleware.auth_middleware import get_current_user 

15from core.models.user import User 

16 

17router = APIRouter(prefix="/api-test/trade", tags=["交易测试"]) 

18 

19 

20# 请求模型定义 

21class SubmitOrderRequest(BaseModel): 

22 """委托下单请求""" 

23 

24 symbol: str = Field(..., description="股票代码", example="YINN.US") 

25 order_type: str = Field(..., description="订单类型", example="LO") 

26 side: str = Field(..., description="买卖方向", example="Buy") 

27 submitted_quantity: Union[float, int, str] = Field( 

28 ..., description="委托数量", example=100 

29 ) 

30 time_in_force: str = Field(..., description="订单有效期", example="Day") 

31 submitted_price: Optional[Union[float, int, str]] = Field( 

32 None, description="委托价格", example=50.0 

33 ) 

34 trigger_price: Optional[Union[float, int, str]] = Field( 

35 None, description="触发价格", example=45.0 

36 ) 

37 limit_offset: Optional[Union[float, int, str]] = Field( 

38 None, description="限价偏移", example=0.1 

39 ) 

40 trailing_amount: Optional[Union[float, int, str]] = Field( 

41 None, description="追踪金额", example=0.5 

42 ) 

43 trailing_percent: Optional[Union[float, int, str]] = Field( 

44 None, description="追踪百分比", example=0.01 

45 ) 

46 expire_date: Optional[Union[date, str]] = Field( 

47 None, description="过期日期", example=date(2025, 1, 1) 

48 ) 

49 outside_rth: Optional[str] = Field( 

50 None, description="盘前盘后交易", example="RTH_ONLY" 

51 ) 

52 remark: Optional[str] = Field(None, description="备注", example="测试订单") 

53 

54 @field_validator( 

55 "submitted_quantity", 

56 "submitted_price", 

57 "trigger_price", 

58 "limit_offset", 

59 "trailing_amount", 

60 "trailing_percent", 

61 ) 

62 @classmethod 

63 def validate_decimal_fields(cls, v): 

64 if v is not None: 

65 return Decimal(str(v)) 

66 return v 

67 

68 @field_validator("expire_date", mode="before") 

69 @classmethod 

70 def validate_expire_date(cls, v): 

71 if v is not None: 

72 if isinstance(v, str): 

73 # 处理ISO格式的日期时间字符串 

74 if "T" in v: 

75 # 提取日期部分 

76 date_part = v.split("T")[0] 

77 return date.fromisoformat(date_part) 

78 else: 

79 return date.fromisoformat(v) 

80 elif hasattr(v, "date"): 

81 # 处理datetime对象 

82 return v.date() 

83 return v 

84 

85 

86class ReplaceOrderRequest(BaseModel): 

87 """改单请求""" 

88 

89 order_id: str = Field(..., description="订单ID", example="709043056541253632") 

90 quantity: Union[float, int, str] = Field(..., description="修改数量", example=150) 

91 price: Optional[Union[float, int, str]] = Field( 

92 None, description="修改价格", example=55.0 

93 ) 

94 trigger_price: Optional[Union[float, int, str]] = Field( 

95 None, description="修改触发价格", example=50.0 

96 ) 

97 remark: Optional[str] = Field(None, description="备注", example="修改订单") 

98 

99 @field_validator("quantity", "price", "trigger_price") 

100 @classmethod 

101 def validate_decimal_fields(cls, v): 

102 if v is not None: 

103 return Decimal(str(v)) 

104 return v 

105 

106 

107class CancelOrderRequest(BaseModel): 

108 """撤单请求""" 

109 

110 order_id: str = Field(..., description="订单ID", example="709043056541253632") 

111 

112 

113class GetOrdersRequest(BaseModel): 

114 """获取订单请求""" 

115 

116 symbol: Optional[str] = Field(None, description="股票代码", example="YINN.US") 

117 status: Optional[Union[str, List[str]]] = Field( 

118 None, description="订单状态", example=["New", "Filled"] 

119 ) 

120 side: Optional[str] = Field(None, description="买卖方向", example="Buy") 

121 market: Optional[str] = Field(None, description="市场", example="US") 

122 order_id: Optional[str] = Field( 

123 None, description="订单ID", example="709043056541253632" 

124 ) 

125 

126 @field_validator("status") 

127 @classmethod 

128 def validate_status(cls, v): 

129 if isinstance(v, str): 

130 return [s.strip() for s in v.split(",") if s.strip()] 

131 return v 

132 

133 

134class GetHistoryOrdersRequest(BaseModel): 

135 """获取历史订单请求""" 

136 

137 symbol: Optional[str] = Field(None, description="股票代码", example="YINN.US") 

138 status: Optional[Union[str, List[str]]] = Field( 

139 None, description="订单状态", example=["Filled", "Canceled"] 

140 ) 

141 side: Optional[str] = Field(None, description="买卖方向", example="Buy") 

142 market: Optional[str] = Field(None, description="市场", example="US") 

143 start_at: Optional[Union[date, datetime]] = Field(None, description="开始时间") 

144 end_at: Optional[Union[date, datetime]] = Field(None, description="结束时间") 

145 

146 @field_validator("status") 

147 @classmethod 

148 def validate_status(cls, v): 

149 if isinstance(v, str): 

150 return [s.strip() for s in v.split(",") if s.strip()] 

151 return v 

152 

153 @field_validator("start_at", "end_at") 

154 @classmethod 

155 def validate_datetime(cls, v): 

156 if v is not None: 

157 if isinstance(v, str): 

158 # 处理ISO格式的日期时间字符串 

159 try: 

160 # 移除Z后缀并添加时区信息 

161 if v.endswith("Z"): 

162 v = v[:-1] + "+00:00" 

163 return datetime.fromisoformat(v) 

164 except ValueError: 

165 # 如果解析失败,尝试只解析日期部分 

166 try: 

167 date_part = v.split("T")[0] 

168 return datetime.combine( 

169 date.fromisoformat(date_part), datetime.min.time() 

170 ) 

171 except ValueError: 

172 raise ValueError(f"Invalid date format: {v}") 

173 elif isinstance(v, datetime): 

174 return v 

175 elif isinstance(v, date): 

176 return datetime.combine(v, datetime.min.time()) 

177 return v 

178 

179 

180class GetExecutionsRequest(BaseModel): 

181 """获取成交明细请求""" 

182 

183 symbol: Optional[str] = Field(None, description="股票代码", example="YINN.US") 

184 order_id: Optional[str] = Field( 

185 None, description="订单ID", example="709043056541253632" 

186 ) 

187 

188 

189class GetHistoryExecutionsRequest(BaseModel): 

190 """获取历史成交明细请求""" 

191 

192 symbol: Optional[str] = Field(None, description="股票代码", example="YINN.US") 

193 start_at: Optional[Union[date, datetime]] = Field(None, description="开始时间") 

194 end_at: Optional[Union[date, datetime]] = Field(None, description="结束时间") 

195 

196 @field_validator("start_at", "end_at") 

197 @classmethod 

198 def validate_datetime(cls, v): 

199 if v is not None: 

200 if isinstance(v, str): 

201 # 处理ISO格式的日期时间字符串 

202 try: 

203 # 移除Z后缀并添加时区信息 

204 if v.endswith("Z"): 

205 v = v[:-1] + "+00:00" 

206 return datetime.fromisoformat(v) 

207 except ValueError: 

208 # 如果解析失败,尝试只解析日期部分 

209 try: 

210 date_part = v.split("T")[0] 

211 return datetime.combine( 

212 date.fromisoformat(date_part), datetime.min.time() 

213 ) 

214 except ValueError: 

215 raise ValueError(f"Invalid date format: {v}") 

216 elif isinstance(v, datetime): 

217 return v 

218 elif isinstance(v, date): 

219 return datetime.combine(v, datetime.min.time()) 

220 return v 

221 

222 

223class GetOrderDetailRequest(BaseModel): 

224 """获取订单详情请求""" 

225 

226 order_id: str = Field(..., description="订单ID", example="709043056541253632") 

227 

228 

229class GetAccountBalanceRequest(BaseModel): 

230 """获取账户余额请求""" 

231 

232 currency: Optional[str] = Field(None, description="货币", example="USD") 

233 

234 

235class EstimateMaxPurchaseQuantityRequest(BaseModel): 

236 """估算最大购买数量请求""" 

237 

238 symbol: str = Field(..., description="股票代码", example="YINN.US") 

239 order_type: str = Field(..., description="订单类型", example="LO") 

240 side: str = Field(..., description="买卖方向", example="Buy") 

241 price: Optional[Union[float, int, str]] = Field( 

242 None, description="估算价格", example=50.0 

243 ) 

244 currency: Optional[str] = Field(None, description="货币", example="USD") 

245 order_id: Optional[str] = Field( 

246 None, description="订单ID(修改订单时使用)", example="709043056541253632" 

247 ) 

248 fractional_shares: bool = Field(False, description="是否支持碎股", example=False) 

249 

250 @field_validator("price") 

251 @classmethod 

252 def validate_price(cls, v): 

253 if v is not None: 

254 return Decimal(str(v)) 

255 return v 

256 

257 

258# 响应模型 

259class TradeTestResponse(BaseModel): 

260 """交易测试响应""" 

261 

262 success: bool = Field(..., description="是否成功") 

263 message: str = Field(..., description="人类可读的结果描述") 

264 data: Optional[Any] = Field(None, description="原始数据") 

265 error: Optional[str] = Field(None, description="错误信息") 

266 

267 

268def format_trade_response( 

269 success: bool, message: str, data: Any = None, error: str = None 

270) -> TradeTestResponse: 

271 """格式化交易测试响应""" 

272 return TradeTestResponse(success=success, message=message, data=data, error=error) 

273 

274 

275# 交易测试API端点 

276 

277 

278@router.get("/supported-order-types/{market}", response_model=TradeTestResponse) 

279async def get_supported_order_types( 

280 market: str, current_user: User = Depends(get_current_user) 

281): 

282 """获取指定市场支持的订单类型""" 

283 try: 

284 adapter = TradeDataSourceAdapter(current_user.id) 

285 order_types = adapter.get_supported_order_types(market) 

286 

287 message = f"市场 {market} 支持的订单类型" 

288 return format_trade_response(True, message, order_types) 

289 except Exception as e: 

290 return format_trade_response(False, f"获取订单类型失败: {e}", error=str(e)) 

291 

292 

293@router.get("/order-type-fields/{order_type}", response_model=TradeTestResponse) 

294async def get_order_type_required_fields( 

295 order_type: str, current_user: User = Depends(get_current_user) 

296): 

297 """获取订单类型必填字段""" 

298 try: 

299 adapter = TradeDataSourceAdapter(current_user.id) 

300 required_fields = adapter.get_order_type_required_fields(order_type) 

301 

302 message = f"订单类型 {order_type} 的必填字段" 

303 return format_trade_response(True, message, required_fields) 

304 except Exception as e: 

305 return format_trade_response(False, f"获取必填字段失败: {e}", error=str(e)) 

306 

307 

308@router.post("/submit-order", response_model=TradeTestResponse) 

309async def test_submit_order( 

310 request: SubmitOrderRequest, current_user: User = Depends(get_current_user) 

311): 

312 """测试委托下单""" 

313 try: 

314 adapter = TradeDataSourceAdapter(current_user.id) 

315 result = adapter.submit_order( 

316 symbol=request.symbol, 

317 order_type=request.order_type, 

318 side=request.side, 

319 submitted_quantity=request.submitted_quantity, 

320 time_in_force=request.time_in_force, 

321 submitted_price=request.submitted_price, 

322 trigger_price=request.trigger_price, 

323 limit_offset=request.limit_offset, 

324 trailing_amount=request.trailing_amount, 

325 trailing_percent=request.trailing_percent, 

326 expire_date=request.expire_date, 

327 outside_rth=request.outside_rth, 

328 remark=request.remark, 

329 ) 

330 

331 if result: 

332 order_id = result.get("order_id", "N/A") 

333 message = f"委托下单成功\n• 订单ID: {order_id}\n• 标的: {request.symbol}\n• 方向: {request.side}\n• 数量: {request.submitted_quantity}\n• 价格: {request.submitted_price or '市价'}" 

334 else: 

335 message = "委托下单失败" 

336 

337 return format_trade_response(result is not None, message, result) 

338 

339 except Exception as e: 

340 return format_trade_response(False, f"委托下单失败", error=str(e)) 

341 

342 

343@router.post("/replace-order", response_model=TradeTestResponse) 

344async def test_replace_order( 

345 request: ReplaceOrderRequest, current_user: User = Depends(get_current_user) 

346): 

347 """测试改单""" 

348 try: 

349 adapter = TradeDataSourceAdapter(current_user.id) 

350 success = adapter.replace_order( 

351 order_id=request.order_id, 

352 quantity=request.quantity, 

353 price=request.price, 

354 trigger_price=request.trigger_price, 

355 remark=request.remark, 

356 ) 

357 

358 message = f"改单{'成功' if success else '失败'}\n• 订单ID: {request.order_id}\n• 新数量: {request.quantity}\n• 新价格: {request.price or '不变'}" 

359 

360 return format_trade_response( 

361 success, message, {"order_id": request.order_id, "success": success} 

362 ) 

363 

364 except Exception as e: 

365 return format_trade_response(False, f"改单失败", error=str(e)) 

366 

367 

368@router.post("/cancel-order", response_model=TradeTestResponse) 

369async def test_cancel_order( 

370 request: CancelOrderRequest, current_user: User = Depends(get_current_user) 

371): 

372 """测试撤单""" 

373 try: 

374 adapter = TradeDataSourceAdapter(current_user.id) 

375 success = adapter.cancel_order(request.order_id) 

376 

377 message = f"撤单{'成功' if success else '失败'}\n• 订单ID: {request.order_id}" 

378 

379 return format_trade_response( 

380 success, message, {"order_id": request.order_id, "success": success} 

381 ) 

382 

383 except Exception as e: 

384 return format_trade_response(False, f"撤单失败", error=str(e)) 

385 

386 

387@router.post("/today-orders", response_model=TradeTestResponse) 

388async def test_today_orders( 

389 request: GetOrdersRequest, current_user: User = Depends(get_current_user) 

390): 

391 """测试获取当日订单""" 

392 try: 

393 adapter = TradeDataSourceAdapter(current_user.id) 

394 orders = adapter.get_today_orders( 

395 symbol=request.symbol, 

396 status=request.status, 

397 side=request.side, 

398 market=request.market, 

399 order_id=request.order_id, 

400 ) 

401 

402 if orders: 

403 message = f"成功获取 {len(orders)} 条当日订单" 

404 for order in orders[:3]: # 显示前3个 

405 order_id = order.get("order_id", "N/A") 

406 symbol = order.get("symbol", "N/A") 

407 status = order.get("status", "N/A") 

408 side = order.get("side", "N/A") 

409 quantity = order.get("quantity", 0) 

410 message += f"\n• {order_id}: {symbol} {side} {quantity}股 ({status})" 

411 if len(orders) > 3: 

412 message += f"\n... 还有 {len(orders) - 3} 条订单" 

413 else: 

414 message = "未获取到当日订单" 

415 

416 return format_trade_response(len(orders) > 0, message, orders) 

417 

418 except Exception as e: 

419 return format_trade_response(False, f"获取当日订单失败", error=str(e)) 

420 

421 

422@router.post("/history-orders", response_model=TradeTestResponse) 

423async def test_history_orders( 

424 request: GetHistoryOrdersRequest, current_user: User = Depends(get_current_user) 

425): 

426 """测试获取历史订单""" 

427 try: 

428 adapter = TradeDataSourceAdapter(current_user.id) 

429 orders = adapter.get_history_orders( 

430 symbol=request.symbol, 

431 status=request.status, 

432 side=request.side, 

433 market=request.market, 

434 start_at=request.start_at, 

435 end_at=request.end_at, 

436 ) 

437 

438 if orders: 

439 message = f"成功获取 {len(orders)} 条历史订单" 

440 for order in orders[:3]: # 显示前3个 

441 order_id = order.get("order_id", "N/A") 

442 symbol = order.get("symbol", "N/A") 

443 status = order.get("status", "N/A") 

444 side = order.get("side", "N/A") 

445 quantity = order.get("quantity", 0) 

446 submitted_at = order.get("submitted_at", "N/A") 

447 message += f"\n• {order_id}: {symbol} {side} {quantity}股 ({status}) - {submitted_at}" 

448 if len(orders) > 3: 

449 message += f"\n... 还有 {len(orders) - 3} 条订单" 

450 else: 

451 message = "未获取到历史订单" 

452 

453 return format_trade_response(len(orders) > 0, message, orders) 

454 

455 except Exception as e: 

456 return format_trade_response(False, f"获取历史订单失败", error=str(e)) 

457 

458 

459@router.post("/today-executions", response_model=TradeTestResponse) 

460async def test_today_executions( 

461 request: GetExecutionsRequest, current_user: User = Depends(get_current_user) 

462): 

463 """测试获取当日成交明细""" 

464 try: 

465 adapter = TradeDataSourceAdapter(current_user.id) 

466 executions = adapter.get_today_executions( 

467 symbol=request.symbol, order_id=request.order_id 

468 ) 

469 

470 if executions: 

471 message = f"成功获取 {len(executions)} 条当日成交明细" 

472 for execution in executions[:3]: # 显示前3个 

473 trade_id = execution.get("trade_id", "N/A") 

474 symbol = execution.get("symbol", "N/A") 

475 quantity = execution.get("quantity", 0) 

476 price = execution.get("price", 0) 

477 trade_done_at = execution.get("trade_done_at", "N/A") 

478 message += ( 

479 f"\n• {trade_id}: {symbol} {quantity}股 @${price} ({trade_done_at})" 

480 ) 

481 if len(executions) > 3: 

482 message += f"\n... 还有 {len(executions) - 3} 条成交" 

483 else: 

484 message = "未获取到当日成交明细" 

485 

486 return format_trade_response(len(executions) > 0, message, executions) 

487 

488 except Exception as e: 

489 return format_trade_response(False, f"获取当日成交明细失败", error=str(e)) 

490 

491 

492@router.post("/history-executions", response_model=TradeTestResponse) 

493async def test_history_executions( 

494 request: GetHistoryExecutionsRequest, current_user: User = Depends(get_current_user) 

495): 

496 """测试获取历史成交明细""" 

497 try: 

498 adapter = TradeDataSourceAdapter(current_user.id) 

499 executions = adapter.get_history_executions( 

500 symbol=request.symbol, start_at=request.start_at, end_at=request.end_at 

501 ) 

502 

503 if executions: 

504 message = f"成功获取 {len(executions)} 条历史成交明细" 

505 for execution in executions[:3]: # 显示前3个 

506 trade_id = execution.get("trade_id", "N/A") 

507 symbol = execution.get("symbol", "N/A") 

508 quantity = execution.get("quantity", 0) 

509 price = execution.get("price", 0) 

510 trade_done_at = execution.get("trade_done_at", "N/A") 

511 message += ( 

512 f"\n• {trade_id}: {symbol} {quantity}股 @${price} ({trade_done_at})" 

513 ) 

514 if len(executions) > 3: 

515 message += f"\n... 还有 {len(executions) - 3} 条成交" 

516 else: 

517 message = "未获取到历史成交明细" 

518 

519 return format_trade_response(len(executions) > 0, message, executions) 

520 

521 except Exception as e: 

522 return format_trade_response(False, f"获取历史成交明细失败", error=str(e)) 

523 

524 

525@router.post("/order-detail", response_model=TradeTestResponse) 

526async def test_order_detail( 

527 request: GetOrderDetailRequest, current_user: User = Depends(get_current_user) 

528): 

529 """测试获取订单详情""" 

530 try: 

531 adapter = TradeDataSourceAdapter(current_user.id) 

532 order_detail = adapter.get_order_detail(request.order_id) 

533 

534 if order_detail: 

535 symbol = order_detail.get("symbol", "N/A") 

536 status = order_detail.get("status", "N/A") 

537 side = order_detail.get("side", "N/A") 

538 quantity = order_detail.get("quantity", 0) 

539 executed_quantity = order_detail.get("executed_quantity", 0) 

540 price = order_detail.get("price", 0) 

541 message = f"成功获取订单详情\n• 订单ID: {request.order_id}\n• 标的: {symbol}\n• 状态: {status}\n• 方向: {side}\n• 数量: {quantity} (已成交: {executed_quantity})\n• 价格: ${price}" 

542 else: 

543 message = f"未找到订单 {request.order_id} 的详情" 

544 

545 return format_trade_response(order_detail is not None, message, order_detail) 

546 

547 except Exception as e: 

548 return format_trade_response(False, f"获取订单详情失败", error=str(e)) 

549 

550 

551@router.post("/account-balance", response_model=TradeTestResponse) 

552async def test_account_balance( 

553 request: GetAccountBalanceRequest, current_user: User = Depends(get_current_user) 

554): 

555 """测试获取账户余额""" 

556 try: 

557 adapter = TradeDataSourceAdapter(current_user.id) 

558 balances = adapter.get_account_balance(currency=request.currency) 

559 

560 if balances: 

561 message = f"成功获取 {len(balances)} 个账户的余额信息" 

562 for balance in balances[:2]: # 显示前2个 

563 currency = balance.get("currency", "N/A") 

564 total_cash = balance.get("total_cash", 0) 

565 net_assets = balance.get("net_assets", 0) 

566 buy_power = balance.get("buy_power", 0) 

567 message += f"\n• {currency}: 现金${total_cash} 净资产${net_assets} 购买力${buy_power}" 

568 if len(balances) > 2: 

569 message += f"\n... 还有 {len(balances) - 2} 个账户" 

570 else: 

571 message = "未获取到账户余额信息" 

572 

573 return format_trade_response(len(balances) > 0, message, balances) 

574 

575 except Exception as e: 

576 return format_trade_response(False, f"获取账户余额失败", error=str(e)) 

577 

578 

579@router.post("/estimate-max-purchase", response_model=TradeTestResponse) 

580async def test_estimate_max_purchase_quantity( 

581 request: EstimateMaxPurchaseQuantityRequest, 

582 current_user: User = Depends(get_current_user), 

583): 

584 """测试估算最大购买数量""" 

585 try: 

586 adapter = TradeDataSourceAdapter(current_user.id) 

587 result = adapter.estimate_max_purchase_quantity( 

588 symbol=request.symbol, 

589 order_type=request.order_type, 

590 side=request.side, 

591 price=request.price, 

592 currency=request.currency, 

593 order_id=request.order_id, 

594 fractional_shares=request.fractional_shares, 

595 ) 

596 

597 if result: 

598 cash_max_qty = result.get("cash_max_qty", 0) 

599 margin_max_qty = result.get("margin_max_qty", 0) 

600 message = f"成功估算最大购买数量\n• 标的: {request.symbol}\n• 方向: {request.side}\n• 现金购买力: {cash_max_qty} 股\n• 融资购买力: {margin_max_qty}" 

601 else: 

602 message = f"估算 {request.symbol} 最大购买数量失败" 

603 

604 return format_trade_response(result is not None, message, result) 

605 

606 except Exception as e: 

607 return format_trade_response(False, f"估算最大购买数量失败", error=str(e))