Coverage for api/v1/endpoints/stock_import.py: 77.86%
140 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-13 18:58 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-13 18:58 +0000
1"""
2股票数据导入API端点
3"""
5import asyncio
6import logging
7from datetime import date
8from typing import Optional
10from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
11from pydantic import BaseModel, Field
13from core.data_source.adapters.data_adapter import DataAdapter
14from core.middleware.auth_middleware import get_current_user
15from core.models.user import User
16from core.repositories.stock_repository import StockRepository
17from infrastructure.database.redis_client import get_redis
19router = APIRouter()
20logger = logging.getLogger(__name__)
23class StockImportRequest(BaseModel):
24 """股票数据导入请求"""
26 symbol: str = Field(..., description="股票代码", example="YINN.US")
27 start_date: str = Field(
28 ..., description="开始日期(ISO格式)", example="2025-09-16T00:00:00-04:00"
29 )
30 end_date: str = Field(
31 ..., description="结束日期(ISO格式)", example="2025-09-16T23:59:00-04:00"
32 )
33 timezone: str = Field(default="America/New_York", description="时区")
36class StockImportResponse(BaseModel):
37 """股票数据导入响应"""
39 task_id: str = Field(..., description="任务ID")
40 message: str = Field(..., description="响应消息")
43class StockImportStatus(BaseModel):
44 """股票数据导入状态"""
46 task_id: str = Field(..., description="任务ID")
47 status: str = Field(..., description="任务状态")
48 progress: int = Field(..., description="进度百分比")
49 message: str = Field(..., description="状态消息")
50 data_count: int = Field(default=0, description="已获取数据条数")
51 overwritten_count: int = Field(default=0, description="覆盖的数据条数")
52 new_count: int = Field(default=0, description="新增的数据条数")
53 filtered_count: int = Field(default=0, description="过滤掉的数据条数")
54 avg_daily_count: int = Field(default=0, description="平均每天数据条数")
57# 存储任务状态的内存字典(生产环境应使用Redis)
58task_status = {}
61def get_data_adapter(user_id: str) -> DataAdapter:
62 """获取数据拉取业务适配器实例"""
63 return DataAdapter(user_id)
66async def fetch_stock_data_task(
67 task_id: str,
68 user_id: str,
69 symbol: str,
70 start_date: date,
71 end_date: date,
72 timezone: str = "America/New_York",
73):
74 """后台任务:获取股票数据"""
75 try:
76 # 更新任务状态
77 task_status[task_id] = {
78 "status": "running",
79 "progress": 0,
80 "message": f"开始获取 {symbol} 数据,时间范围: {start_date} 到 {end_date}",
81 "data_count": 0,
82 "overwritten_count": 0,
83 "new_count": 0,
84 "filtered_count": 0,
85 "avg_daily_count": 0,
86 }
88 # 等待WebSocket连接建立(最多等待5秒)
89 import asyncio
91 # 发送初始日志
92 from core.services.websocket_service import (send_websocket_log,
93 websocket_service)
95 send_websocket_log(
96 task_id,
97 f"开始获取 {symbol} 数据,时间范围: {start_date} 到 {end_date}",
98 "log",
99 )
100 send_websocket_log(task_id, f"任务ID: {task_id}", "log")
101 send_websocket_log(task_id, f"时区: {timezone}", "log")
103 # 等待WebSocket连接建立
104 wait_count = 0
105 while (
106 not websocket_service.is_connected(task_id) and wait_count < 50
107 ): # 最多等待5秒
108 await asyncio.sleep(0.1)
109 wait_count += 1
111 if websocket_service.is_connected(task_id):
112 send_websocket_log(task_id, "WebSocket连接已建立,开始数据获取", "log")
113 else:
114 send_websocket_log(task_id, "WebSocket连接未建立,但继续执行任务", "log")
116 # 创建业务适配器实例
117 adapter = get_data_adapter(user_id)
119 # 计算总天数用于进度计算
120 from datetime import datetime, timedelta
122 start_dt = datetime.fromisoformat(start_date.replace("Z", "+00:00"))
123 end_dt = datetime.fromisoformat(end_date.replace("Z", "+00:00"))
124 total_days = (end_dt.date() - start_dt.date()).days + 1
126 # 进度回调函数
127 def progress_callback(message: str):
128 # 输出到日志
129 print(f"[进度回调] {message}")
130 logger.info(f"[进度回调] {message}")
132 # 使用统一的WebSocket服务发送日志
133 from core.services.websocket_service import send_websocket_log
135 send_websocket_log(task_id, message, "log")
137 # 更新任务状态(只更新统计信息,不存储日志)
138 if task_id in task_status:
139 # 更新最新消息
140 task_status[task_id]["message"] = message
141 progress_updated = False
143 # 基于天级别的进度计算
144 import re
146 # 解析日期处理信息:处理日期 YYYY-MM-DD (X/总天数)
147 date_match = re.search(
148 r"处理日期 \d{4}-\d{2}-\d{2} \((\d+)/(\d+)\)", message
149 )
150 if date_match:
151 current_batch = int(date_match.group(1))
152 total_batches = int(date_match.group(2))
153 # 计算基于批次的进度(0-85%)
154 batch_progress = int((current_batch / total_batches) * 85)
155 task_status[task_id]["progress"] = batch_progress
156 progress_updated = True
158 # 更新数据统计信息
159 if "成功保存" in message and "条数据到数据库" in message:
160 # 解析消息中的统计信息
161 saved_match = re.search(r"成功保存 (\d+) 条数据到数据库", message)
162 if saved_match:
163 task_status[task_id]["data_count"] = int(saved_match.group(1))
164 progress_updated = True
166 # 数据保存完成时,进度设为90%
167 task_status[task_id]["progress"] = 90
168 progress_updated = True
170 # 解析数据统计信息
171 if "数据统计:" in message:
172 overwritten_match = re.search(r"覆盖 (\d+) 条", message)
173 if overwritten_match:
174 task_status[task_id]["overwritten_count"] = int(
175 overwritten_match.group(1)
176 )
177 progress_updated = True
179 new_match = re.search(r"新增 (\d+) 条", message)
180 if new_match:
181 task_status[task_id]["new_count"] = int(new_match.group(1))
183 filtered_match = re.search(r"过滤 (\d+) 条", message)
184 if filtered_match:
185 task_status[task_id]["filtered_count"] = int(
186 filtered_match.group(1)
187 )
189 # 解析平均每天数据量
190 if "平均每天" in message and "条数据" in message:
191 avg_match = re.search(r"平均每天 (\d+) 条数据", message)
192 if avg_match:
193 task_status[task_id]["avg_daily_count"] = int(
194 avg_match.group(1)
195 )
197 # 如果进度有更新,发送status消息
198 if progress_updated:
199 from core.services.websocket_service import \
200 send_websocket_status
202 send_websocket_status(
203 task_id,
204 task_status[task_id]["status"],
205 task_status[task_id]["progress"],
206 task_status[task_id].get("data_count", 0),
207 )
209 # 获取股票数据
210 result = await adapter.fetch_stock_data(
211 symbol=symbol,
212 start_date=start_date,
213 end_date=end_date,
214 timezone=timezone,
215 progress_callback=progress_callback,
216 )
218 # 更新最终状态
219 if result["success"]:
220 overwritten_count = result.get("overwritten_count", 0)
221 new_count = result.get("new_count", 0)
222 filtered_count = result.get("filtered_count", 0)
223 avg_daily_count = result.get("avg_daily_count", 0)
225 # 构建详细的消息
226 message_parts = [
227 f"✅ 数据导入完成!共写入数据库 {result['data_count']} 条数据"
228 ]
229 if overwritten_count > 0:
230 message_parts.append(f"覆盖已存在数据: {overwritten_count} 条")
231 if new_count > 0:
232 message_parts.append(f"新增数据: {new_count} 条")
233 if filtered_count > 0:
234 message_parts.append(f"因时间区间过滤: {filtered_count} 条")
235 if avg_daily_count > 0:
236 message_parts.append(f"平均每天: {avg_daily_count} 条")
238 task_status[task_id] = {
239 "status": "completed",
240 "progress": 100,
241 "message": " | ".join(message_parts),
242 "data_count": result["data_count"],
243 "overwritten_count": overwritten_count,
244 "new_count": new_count,
245 "filtered_count": filtered_count,
246 "avg_daily_count": avg_daily_count,
247 }
248 else:
249 task_status[task_id] = {
250 "status": "failed",
251 "progress": 0,
252 "message": f"数据获取失败: {result.get('error', '未知错误')}",
253 "data_count": 0,
254 "overwritten_count": 0,
255 "new_count": 0,
256 "filtered_count": 0,
257 "avg_daily_count": 0,
258 }
260 except Exception as e:
261 logger.error(f"股票数据获取任务失败: {e}")
262 task_status[task_id] = {
263 "status": "failed",
264 "progress": 0,
265 "message": f"任务执行失败: {str(e)}",
266 "data_count": 0,
267 "overwritten_count": 0,
268 "new_count": 0,
269 "filtered_count": 0,
270 "avg_daily_count": 0,
271 }
274@router.post("/import", response_model=StockImportResponse)
275async def import_stock_data(
276 request: StockImportRequest,
277 background_tasks: BackgroundTasks,
278 current_user: User = Depends(get_current_user),
279):
280 """导入股票数据"""
281 try:
282 # 生成任务ID
283 import uuid
285 task_id = str(uuid.uuid4())
287 # 启动后台任务
288 background_tasks.add_task(
289 fetch_stock_data_task,
290 task_id=task_id,
291 user_id=current_user.id,
292 symbol=request.symbol,
293 start_date=request.start_date,
294 end_date=request.end_date,
295 timezone=request.timezone,
296 )
298 return StockImportResponse(task_id=task_id, message="股票数据获取任务已启动")
300 except Exception as e:
301 logger.error(f"启动股票数据获取任务失败: {e}")
302 raise HTTPException(status_code=500, detail=f"启动任务失败: {str(e)}")
305@router.get("/import/{task_id}/status", response_model=StockImportStatus)
306async def get_import_status(task_id: str):
307 """获取导入任务状态"""
308 if task_id not in task_status:
309 raise HTTPException(status_code=404, detail="任务不存在")
311 status_info = task_status[task_id]
312 return StockImportStatus(
313 task_id=task_id,
314 status=status_info["status"],
315 progress=status_info["progress"],
316 message=status_info["message"],
317 data_count=status_info["data_count"],
318 )
321@router.delete("/import/{task_id}")
322async def cancel_import_task(task_id: str):
323 """取消导入任务"""
324 if task_id not in task_status:
325 raise HTTPException(status_code=404, detail="任务不存在")
327 # 简单的取消逻辑(实际实现可能需要更复杂的任务管理)
328 if task_status[task_id]["status"] == "running":
329 task_status[task_id]["status"] = "cancelled"
330 task_status[task_id]["message"] = "任务已取消"
331 return {"message": "任务已取消"}
332 else:
333 raise HTTPException(status_code=400, detail="任务无法取消")