🇨🇳 简体中文
🇺🇸 English
🇯🇵 日本語
Skip to the content.

WebSocket服务架构文档

📋 概述

WebSocket服务是量化交易系统的重要基础设施组件,负责提供实时数据推送、日志监控、进度反馈等功能。本文档详细描述了WebSocket服务的架构设计、实现原理和使用方式。

🏗️ 架构设计

1. 服务架构图

graph TB
    subgraph "前端层"
        FE[React前端]
        WS_CLIENT[WebSocket客户端]
    end
    
    subgraph "WebSocket服务层"
        WS_FACTORY[WebSocketServiceFactory<br/>WebSocket服务工厂]
        TASK_WS[TaskWebSocketService<br/>任务专用WebSocket服务]
        ASYNC_QUEUE[异步队列管理]
        CONN_MGR[连接管理器]
    end
    
    subgraph "业务层"
        TASK1[数据导入任务]
        TASK2[交易监控任务]
        TASK3[策略引擎日志]
        TASK4[其他长时间任务]
    end
    
    subgraph "存储层"
        MEMORY[内存日志存储]
        REDIS[(Redis缓存)]
    end
    
    FE --> WS_CLIENT
    WS_CLIENT <--> WS_FACTORY
    WS_FACTORY --> TASK_WS
    TASK_WS --> ASYNC_QUEUE
    TASK_WS --> CONN_MGR
    TASK_WS --> MEMORY
    
    TASK1 --> WS_FACTORY
    TASK2 --> WS_FACTORY
    TASK3 --> WS_FACTORY
    TASK4 --> WS_FACTORY
    
    TASK_WS --> REDIS

2. 核心组件

组件 职责 实现位置
WebSocketServiceFactory WebSocket服务工厂管理 core/services/websocket_service_factory.py
TaskWebSocketService 任务专用WebSocket服务 core/services/websocket_service_factory.py
连接管理器 WebSocket连接生命周期管理 内置在服务中
异步安全队列 避免事件循环冲突 使用asyncio.Queue
内存日志存储 任务日志临时存储 内置在服务中

🔌 核心接口定义

1. WebSocket服务工厂

class WebSocketServiceFactory:
    """WebSocket服务工厂 - 管理多个任务专用的WebSocket服务"""
    
    def __init__(self):
        self._services: Dict[str, TaskWebSocketService] = {}
        self._lock = asyncio.Lock()
    
    def get_or_create_service(self, task_id: str) -> TaskWebSocketService:
        """获取或创建任务专用的WebSocket服务"""
        pass
    
    def remove_service(self, task_id: str):
        """移除任务专用的WebSocket服务"""
        pass

class TaskWebSocketService:
    """任务专用WebSocket服务 - 每个任务独立的WebSocket连接和队列"""
    
    def __init__(self, task_id: str):
        self.task_id = task_id
        self._websocket: Optional[WebSocket] = None
        self._log_queue: asyncio.Queue = asyncio.Queue()
        self._running = False
        self._task = None
    
    async def connect(self, websocket: WebSocket):
        """建立WebSocket连接"""
        pass
    
    async def disconnect(self):
        """断开WebSocket连接"""
        pass
    
    async def send_log(self, message: str, log_type: str = "log"):
        """发送日志消息"""
        pass
    
    async def _process_log_queue(self):
        """处理日志队列"""
        pass
    
    async def send_status(self, task_id: str, status: str, progress: int, data_count: int):
        """发送状态更新"""
        pass
    
    async def send_error(self, task_id: str, error_message: str):
        """发送错误消息"""
        pass
    
    async def send_success(self, task_id: str, success_message: str):
        """发送成功消息"""
        pass

2. 便捷函数接口

# 全局便捷函数
def send_websocket_log(task_id: str, message: str, log_type: str = "log"):
    """发送WebSocket日志消息 - 线程安全"""
    pass

def send_websocket_status(task_id: str, status: str, progress: int, data_count: int):
    """发送WebSocket状态更新"""
    pass

async def send_websocket_error(task_id: str, error_message: str):
    """发送WebSocket错误消息"""
    pass

async def send_websocket_success(task_id: str, success_message: str):
    """发送WebSocket成功消息"""
    pass

🔄 工作流程

1. 连接建立流程

sequenceDiagram
    participant FE as 前端
    participant WS as WebSocket服务
    participant TASK as 业务任务
    
    FE->>WS: 建立WebSocket连接
    WS->>WS: 注册连接
    WS->>FE: 发送连接确认
    
    TASK->>WS: 发送初始日志
    WS->>FE: 推送初始日志
    
    TASK->>WS: 发送进度更新
    WS->>FE: 推送进度更新
    
    TASK->>WS: 发送完成消息
    WS->>FE: 推送完成消息
    FE->>WS: 关闭连接
    WS->>WS: 清理连接

2. 消息推送流程

graph TD
    A[业务任务] --> B[调用便捷函数]
    B --> C[WebSocket服务]
    C --> D[检查连接状态]
    D --> E{连接存在?}
    E -->|是| F[发送消息]
    E -->|否| G[存储到内存]
    F --> H[前端接收]
    G --> I[等待连接建立]
    I --> F

📊 使用场景

1. 数据导入流程

前端连接:

const ws = new WebSocket(`ws://localhost:8000/ws/stock-import/${taskId}`);

ws.onmessage = (event) => {
    const data = JSON.parse(event.data);
    switch(data.type) {
        case 'log':
            // 显示日志消息
            break;
        case 'status':
            // 更新进度条
            break;
        case 'error':
            // 显示错误信息
            break;
        case 'success':
            // 显示成功信息
            break;
    }
};

后端推送:

# 数据导入任务中
async def import_stock_data(symbols, task_id):
    await send_websocket_log(task_id, "开始导入股票数据", "log")
    
    for i, symbol in enumerate(symbols):
        # 导入数据
        await import_symbol_data(symbol)
        
        # 更新进度
        progress = int((i + 1) / len(symbols) * 100)
        await send_websocket_status(task_id, "导入中", progress, i + 1)
    
    await send_websocket_success(task_id, "数据导入完成")

2. 交易监控流程

实时监控:

# 交易监控任务中
async def monitor_trading_session(session_id, task_id):
    await send_websocket_log(task_id, "开始监控交易会话", "log")
    
    while trading_active:
        # 获取交易状态
        status = get_trading_status(session_id)
        
        # 推送状态更新
        await send_websocket_status(
            task_id, 
            status["status"], 
            status["progress"], 
            status["trade_count"]
        )
        
        await asyncio.sleep(1)

3. 其他长时间任务

批量操作:

# 批量操作任务中
async def batch_operation(items, task_id):
    await send_websocket_log(task_id, f"开始批量操作,共{len(items)}项", "log")
    
    for i, item in enumerate(items):
        try:
            # 执行操作
            await process_item(item)
            
            # 更新进度
            progress = int((i + 1) / len(items) * 100)
            await send_websocket_status(task_id, "处理中", progress, i + 1)
            
        except Exception as e:
            await send_websocket_error(task_id, f"处理项目失败: {e}")
    
    await send_websocket_success(task_id, "批量操作完成")

🛠️ 实现细节

1. 线程安全设计

class UnifiedWebSocketService:
    def __init__(self):
        self._connections: Dict[str, WebSocket] = {}
        self._task_logs: Dict[str, List[Dict]] = {}
        self._queue: asyncio.Queue = asyncio.Queue()
        self._running = False
        self._lock = asyncio.Lock()  # 异步锁
    
    async def send_log(self, task_id: str, message: str, log_type: str = "log"):
        """线程安全的日志发送"""
        async with self._lock:
            # 检查连接状态
            if task_id in self._connections:
                websocket = self._connections[task_id]
                try:
                    await websocket.send_json({
                        "type": log_type,
                        "message": message,
                        "timestamp": datetime.now().isoformat()
                    })
                except Exception as e:
                    # 连接断开,移除连接
                    await self.remove_connection(task_id)
            else:
                # 连接不存在,存储到内存
                await self._store_log(task_id, message, log_type)

2. 连接管理

async def add_connection(self, task_id: str, websocket: WebSocket):
    """添加WebSocket连接"""
    async with self._lock:
        self._connections[task_id] = websocket
        
        # 发送缓存的日志
        if task_id in self._task_logs:
            for log in self._task_logs[task_id]:
                try:
                    await websocket.send_json(log)
                except Exception as e:
                    print(f"发送缓存日志失败: {e}")
            
            # 清空缓存
            del self._task_logs[task_id]

async def remove_connection(self, task_id: str):
    """移除WebSocket连接"""
    async with self._lock:
        if task_id in self._connections:
            del self._connections[task_id]

3. 消息格式

# 日志消息格式
{
    "type": "log",
    "message": "开始导入股票数据",
    "timestamp": "2025-01-01T12:00:00.000Z"
}

# 状态消息格式
{
    "type": "status",
    "status": "导入中",
    "progress": 50,
    "data_count": 100,
    "timestamp": "2025-01-01T12:00:00.000Z"
}

# 错误消息格式
{
    "type": "error",
    "message": "导入失败: 网络错误",
    "timestamp": "2025-01-01T12:00:00.000Z"
}

# 成功消息格式
{
    "type": "success",
    "message": "数据导入完成",
    "timestamp": "2025-01-01T12:00:00.000Z"
}

🔧 最佳实践

1. 连接管理

# ✅ 正确:使用任务ID管理连接
task_id = f"import_{uuid.uuid4()}"
ws_url = f"ws://localhost:8000/ws/stock-import/{task_id}"

# ❌ 错误:使用固定连接ID
ws_url = "ws://localhost:8000/ws/stock-import/fixed_id"

2. 错误处理

# ✅ 正确:完善的错误处理
try:
    await send_websocket_log(task_id, "开始处理", "log")
    # 业务逻辑
    await send_websocket_success(task_id, "处理完成")
except Exception as e:
    await send_websocket_error(task_id, f"处理失败: {e}")

3. 进度更新

# ✅ 正确:合理的进度更新频率
for i, item in enumerate(items):
    if i % 10 == 0:  # 每10个项目更新一次
        progress = int((i + 1) / len(items) * 100)
        await send_websocket_status(task_id, "处理中", progress, i + 1)

4. 消息类型

# ✅ 正确:使用合适的消息类型
await send_websocket_log(task_id, "开始处理", "log")      # 普通日志
await send_websocket_status(task_id, "处理中", 50, 100)   # 状态更新
await send_websocket_error(task_id, "处理失败")           # 错误信息
await send_websocket_success(task_id, "处理完成")         # 成功信息

📈 性能优化

1. 连接池管理

class ConnectionPool:
    """WebSocket连接池"""
    
    def __init__(self, max_connections: int = 100):
        self._connections: Dict[str, WebSocket] = {}
        self._max_connections = max_connections
        self._lock = asyncio.Lock()
    
    async def get_connection(self, task_id: str) -> Optional[WebSocket]:
        """获取连接"""
        async with self._lock:
            return self._connections.get(task_id)
    
    async def add_connection(self, task_id: str, websocket: WebSocket):
        """添加连接"""
        async with self._lock:
            if len(self._connections) >= self._max_connections:
                # 清理最旧的连接
                oldest_task_id = next(iter(self._connections))
                await self.remove_connection(oldest_task_id)
            
            self._connections[task_id] = websocket

2. 消息批处理

class MessageBatcher:
    """消息批处理器"""
    
    def __init__(self, batch_size: int = 10, batch_timeout: float = 1.0):
        self._batch_size = batch_size
        self._batch_timeout = batch_timeout
        self._pending_messages: List[Dict] = []
        self._last_batch_time = time.time()
    
    async def add_message(self, message: Dict):
        """添加消息到批次"""
        self._pending_messages.append(message)
        
        if (len(self._pending_messages) >= self._batch_size or 
            time.time() - self._last_batch_time >= self._batch_timeout):
            await self._flush_batch()
    
    async def _flush_batch(self):
        """刷新批次"""
        if self._pending_messages:
            # 发送批次消息
            await self._send_batch(self._pending_messages)
            self._pending_messages.clear()
            self._last_batch_time = time.time()

🚀 扩展建议

1. 消息持久化

class PersistentWebSocketService(UnifiedWebSocketService):
    """支持消息持久化的WebSocket服务"""
    
    def __init__(self, redis_client):
        super().__init__()
        self._redis_client = redis_client
    
    async def send_log(self, task_id: str, message: str, log_type: str = "log"):
        """发送日志并持久化"""
        # 发送到WebSocket
        await super().send_log(task_id, message, log_type)
        
        # 持久化到Redis
        log_data = {
            "type": log_type,
            "message": message,
            "timestamp": datetime.now().isoformat()
        }
        await self._redis_client.lpush(f"task_logs:{task_id}", json.dumps(log_data))
        await self._redis_client.ltrim(f"task_logs:{task_id}", 0, 999)  # 保留最近1000条

2. 消息过滤

class FilteredWebSocketService(UnifiedWebSocketService):
    """支持消息过滤的WebSocket服务"""
    
    def __init__(self, filters: List[Callable] = None):
        super().__init__()
        self._filters = filters or []
    
    async def send_log(self, task_id: str, message: str, log_type: str = "log"):
        """发送过滤后的日志"""
        # 应用过滤器
        for filter_func in self._filters:
            if not filter_func(task_id, message, log_type):
                return  # 消息被过滤
        
        await super().send_log(task_id, message, log_type)

3. 消息路由

class RoutedWebSocketService(UnifiedWebSocketService):
    """支持消息路由的WebSocket服务"""
    
    def __init__(self):
        super().__init__()
        self._routes: Dict[str, List[str]] = {}
    
    def add_route(self, pattern: str, task_ids: List[str]):
        """添加路由规则"""
        self._routes[pattern] = task_ids
    
    async def send_log(self, task_id: str, message: str, log_type: str = "log"):
        """发送路由后的日志"""
        # 查找匹配的路由
        for pattern, target_task_ids in self._routes.items():
            if re.match(pattern, task_id):
                # 发送到所有目标任务
                for target_task_id in target_task_ids:
                    await super().send_log(target_task_id, message, log_type)
                return
        
        # 没有匹配的路由,发送到原始任务
        await super().send_log(task_id, message, log_type)

📚 总结

核心特性

  1. 统一管理: 全局单例服务,统一管理所有WebSocket连接
  2. 线程安全: 使用异步锁和队列机制,避免事件循环冲突
  3. 连接管理: 自动管理连接生命周期,支持连接断开重连
  4. 消息缓存: 连接断开时缓存消息,连接恢复后自动发送
  5. 便捷接口: 提供简单的函数接口,便于业务代码使用

使用场景

扩展能力

通过WebSocket服务,系统实现了高效的实时通信机制,为用户提供了良好的交互体验和实时反馈。