Coverage for api/v1/endpoints/stock_import.py: 77.86%

140 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-13 18:58 +0000

1""" 

2股票数据导入API端点 

3""" 

4 

5import asyncio 

6import logging 

7from datetime import date 

8from typing import Optional 

9 

10from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException 

11from pydantic import BaseModel, Field 

12 

13from core.data_source.adapters.data_adapter import DataAdapter 

14from core.middleware.auth_middleware import get_current_user 

15from core.models.user import User 

16from core.repositories.stock_repository import StockRepository 

17from infrastructure.database.redis_client import get_redis 

18 

19router = APIRouter() 

20logger = logging.getLogger(__name__) 

21 

22 

23class StockImportRequest(BaseModel): 

24 """股票数据导入请求""" 

25 

26 symbol: str = Field(..., description="股票代码", example="YINN.US") 

27 start_date: str = Field( 

28 ..., description="开始日期(ISO格式)", example="2025-09-16T00:00:00-04:00" 

29 ) 

30 end_date: str = Field( 

31 ..., description="结束日期(ISO格式)", example="2025-09-16T23:59:00-04:00" 

32 ) 

33 timezone: str = Field(default="America/New_York", description="时区") 

34 

35 

36class StockImportResponse(BaseModel): 

37 """股票数据导入响应""" 

38 

39 task_id: str = Field(..., description="任务ID") 

40 message: str = Field(..., description="响应消息") 

41 

42 

43class StockImportStatus(BaseModel): 

44 """股票数据导入状态""" 

45 

46 task_id: str = Field(..., description="任务ID") 

47 status: str = Field(..., description="任务状态") 

48 progress: int = Field(..., description="进度百分比") 

49 message: str = Field(..., description="状态消息") 

50 data_count: int = Field(default=0, description="已获取数据条数") 

51 overwritten_count: int = Field(default=0, description="覆盖的数据条数") 

52 new_count: int = Field(default=0, description="新增的数据条数") 

53 filtered_count: int = Field(default=0, description="过滤掉的数据条数") 

54 avg_daily_count: int = Field(default=0, description="平均每天数据条数") 

55 

56 

57# 存储任务状态的内存字典(生产环境应使用Redis) 

58task_status = {} 

59 

60 

61def get_data_adapter(user_id: str) -> DataAdapter: 

62 """获取数据拉取业务适配器实例""" 

63 return DataAdapter(user_id) 

64 

65 

66async def fetch_stock_data_task( 

67 task_id: str, 

68 user_id: str, 

69 symbol: str, 

70 start_date: date, 

71 end_date: date, 

72 timezone: str = "America/New_York", 

73): 

74 """后台任务:获取股票数据""" 

75 try: 

76 # 更新任务状态 

77 task_status[task_id] = { 

78 "status": "running", 

79 "progress": 0, 

80 "message": f"开始获取 {symbol} 数据,时间范围: {start_date}{end_date}", 

81 "data_count": 0, 

82 "overwritten_count": 0, 

83 "new_count": 0, 

84 "filtered_count": 0, 

85 "avg_daily_count": 0, 

86 } 

87 

88 # 等待WebSocket连接建立(最多等待5秒) 

89 import asyncio 

90 

91 # 发送初始日志 

92 from core.services.websocket_service import (send_websocket_log, 

93 websocket_service) 

94 

95 send_websocket_log( 

96 task_id, 

97 f"开始获取 {symbol} 数据,时间范围: {start_date}{end_date}", 

98 "log", 

99 ) 

100 send_websocket_log(task_id, f"任务ID: {task_id}", "log") 

101 send_websocket_log(task_id, f"时区: {timezone}", "log") 

102 

103 # 等待WebSocket连接建立 

104 wait_count = 0 

105 while ( 

106 not websocket_service.is_connected(task_id) and wait_count < 50 

107 ): # 最多等待5秒 

108 await asyncio.sleep(0.1) 

109 wait_count += 1 

110 

111 if websocket_service.is_connected(task_id): 

112 send_websocket_log(task_id, "WebSocket连接已建立,开始数据获取", "log") 

113 else: 

114 send_websocket_log(task_id, "WebSocket连接未建立,但继续执行任务", "log") 

115 

116 # 创建业务适配器实例 

117 adapter = get_data_adapter(user_id) 

118 

119 # 计算总天数用于进度计算 

120 from datetime import datetime, timedelta 

121 

122 start_dt = datetime.fromisoformat(start_date.replace("Z", "+00:00")) 

123 end_dt = datetime.fromisoformat(end_date.replace("Z", "+00:00")) 

124 total_days = (end_dt.date() - start_dt.date()).days + 1 

125 

126 # 进度回调函数 

127 def progress_callback(message: str): 

128 # 输出到日志 

129 print(f"[进度回调] {message}") 

130 logger.info(f"[进度回调] {message}") 

131 

132 # 使用统一的WebSocket服务发送日志 

133 from core.services.websocket_service import send_websocket_log 

134 

135 send_websocket_log(task_id, message, "log") 

136 

137 # 更新任务状态(只更新统计信息,不存储日志) 

138 if task_id in task_status: 

139 # 更新最新消息 

140 task_status[task_id]["message"] = message 

141 progress_updated = False 

142 

143 # 基于天级别的进度计算 

144 import re 

145 

146 # 解析日期处理信息:处理日期 YYYY-MM-DD (X/总天数) 

147 date_match = re.search( 

148 r"处理日期 \d{4}-\d{2}-\d{2} \((\d+)/(\d+)\)", message 

149 ) 

150 if date_match: 

151 current_batch = int(date_match.group(1)) 

152 total_batches = int(date_match.group(2)) 

153 # 计算基于批次的进度(0-85%) 

154 batch_progress = int((current_batch / total_batches) * 85) 

155 task_status[task_id]["progress"] = batch_progress 

156 progress_updated = True 

157 

158 # 更新数据统计信息 

159 if "成功保存" in message and "条数据到数据库" in message: 

160 # 解析消息中的统计信息 

161 saved_match = re.search(r"成功保存 (\d+) 条数据到数据库", message) 

162 if saved_match: 

163 task_status[task_id]["data_count"] = int(saved_match.group(1)) 

164 progress_updated = True 

165 

166 # 数据保存完成时,进度设为90% 

167 task_status[task_id]["progress"] = 90 

168 progress_updated = True 

169 

170 # 解析数据统计信息 

171 if "数据统计:" in message: 

172 overwritten_match = re.search(r"覆盖 (\d+) 条", message) 

173 if overwritten_match: 

174 task_status[task_id]["overwritten_count"] = int( 

175 overwritten_match.group(1) 

176 ) 

177 progress_updated = True 

178 

179 new_match = re.search(r"新增 (\d+) 条", message) 

180 if new_match: 

181 task_status[task_id]["new_count"] = int(new_match.group(1)) 

182 

183 filtered_match = re.search(r"过滤 (\d+) 条", message) 

184 if filtered_match: 

185 task_status[task_id]["filtered_count"] = int( 

186 filtered_match.group(1) 

187 ) 

188 

189 # 解析平均每天数据量 

190 if "平均每天" in message and "条数据" in message: 

191 avg_match = re.search(r"平均每天 (\d+) 条数据", message) 

192 if avg_match: 

193 task_status[task_id]["avg_daily_count"] = int( 

194 avg_match.group(1) 

195 ) 

196 

197 # 如果进度有更新,发送status消息 

198 if progress_updated: 

199 from core.services.websocket_service import \ 

200 send_websocket_status 

201 

202 send_websocket_status( 

203 task_id, 

204 task_status[task_id]["status"], 

205 task_status[task_id]["progress"], 

206 task_status[task_id].get("data_count", 0), 

207 ) 

208 

209 # 获取股票数据 

210 result = await adapter.fetch_stock_data( 

211 symbol=symbol, 

212 start_date=start_date, 

213 end_date=end_date, 

214 timezone=timezone, 

215 progress_callback=progress_callback, 

216 ) 

217 

218 # 更新最终状态 

219 if result["success"]: 

220 overwritten_count = result.get("overwritten_count", 0) 

221 new_count = result.get("new_count", 0) 

222 filtered_count = result.get("filtered_count", 0) 

223 avg_daily_count = result.get("avg_daily_count", 0) 

224 

225 # 构建详细的消息 

226 message_parts = [ 

227 f"✅ 数据导入完成!共写入数据库 {result['data_count']} 条数据" 

228 ] 

229 if overwritten_count > 0: 

230 message_parts.append(f"覆盖已存在数据: {overwritten_count}") 

231 if new_count > 0: 

232 message_parts.append(f"新增数据: {new_count}") 

233 if filtered_count > 0: 

234 message_parts.append(f"因时间区间过滤: {filtered_count}") 

235 if avg_daily_count > 0: 

236 message_parts.append(f"平均每天: {avg_daily_count}") 

237 

238 task_status[task_id] = { 

239 "status": "completed", 

240 "progress": 100, 

241 "message": " | ".join(message_parts), 

242 "data_count": result["data_count"], 

243 "overwritten_count": overwritten_count, 

244 "new_count": new_count, 

245 "filtered_count": filtered_count, 

246 "avg_daily_count": avg_daily_count, 

247 } 

248 else: 

249 task_status[task_id] = { 

250 "status": "failed", 

251 "progress": 0, 

252 "message": f"数据获取失败: {result.get('error', '未知错误')}", 

253 "data_count": 0, 

254 "overwritten_count": 0, 

255 "new_count": 0, 

256 "filtered_count": 0, 

257 "avg_daily_count": 0, 

258 } 

259 

260 except Exception as e: 

261 logger.error(f"股票数据获取任务失败: {e}") 

262 task_status[task_id] = { 

263 "status": "failed", 

264 "progress": 0, 

265 "message": f"任务执行失败: {str(e)}", 

266 "data_count": 0, 

267 "overwritten_count": 0, 

268 "new_count": 0, 

269 "filtered_count": 0, 

270 "avg_daily_count": 0, 

271 } 

272 

273 

274@router.post("/import", response_model=StockImportResponse) 

275async def import_stock_data( 

276 request: StockImportRequest, 

277 background_tasks: BackgroundTasks, 

278 current_user: User = Depends(get_current_user), 

279): 

280 """导入股票数据""" 

281 try: 

282 # 生成任务ID 

283 import uuid 

284 

285 task_id = str(uuid.uuid4()) 

286 

287 # 启动后台任务 

288 background_tasks.add_task( 

289 fetch_stock_data_task, 

290 task_id=task_id, 

291 user_id=current_user.id, 

292 symbol=request.symbol, 

293 start_date=request.start_date, 

294 end_date=request.end_date, 

295 timezone=request.timezone, 

296 ) 

297 

298 return StockImportResponse(task_id=task_id, message="股票数据获取任务已启动") 

299 

300 except Exception as e: 

301 logger.error(f"启动股票数据获取任务失败: {e}") 

302 raise HTTPException(status_code=500, detail=f"启动任务失败: {str(e)}") 

303 

304 

305@router.get("/import/{task_id}/status", response_model=StockImportStatus) 

306async def get_import_status(task_id: str): 

307 """获取导入任务状态""" 

308 if task_id not in task_status: 

309 raise HTTPException(status_code=404, detail="任务不存在") 

310 

311 status_info = task_status[task_id] 

312 return StockImportStatus( 

313 task_id=task_id, 

314 status=status_info["status"], 

315 progress=status_info["progress"], 

316 message=status_info["message"], 

317 data_count=status_info["data_count"], 

318 ) 

319 

320 

321@router.delete("/import/{task_id}") 

322async def cancel_import_task(task_id: str): 

323 """取消导入任务""" 

324 if task_id not in task_status: 

325 raise HTTPException(status_code=404, detail="任务不存在") 

326 

327 # 简单的取消逻辑(实际实现可能需要更复杂的任务管理) 

328 if task_status[task_id]["status"] == "running": 

329 task_status[task_id]["status"] = "cancelled" 

330 task_status[task_id]["message"] = "任务已取消" 

331 return {"message": "任务已取消"} 

332 else: 

333 raise HTTPException(status_code=400, detail="任务无法取消")