🇨🇳 简体中文
🇺🇸 English
🇯🇵 日本語
Skip to the content.

Kronos 使用指南

1. 快速开始

1.1 安装依赖

首先确保您的Python版本≥3.10,然后安装必要的依赖:

cd Kronos
pip install -r requirements.txt

主要依赖包括:

1.2 第一个预测示例

以下是使用Kronos进行K线预测的最简单示例:

import pandas as pd
from model import Kronos, KronosTokenizer, KronosPredictor

# 步骤1:加载模型和分词器
tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base")
model = Kronos.from_pretrained("NeoQuasar/Kronos-small")

# 步骤2:初始化预测器
predictor = KronosPredictor(
    model=model,
    tokenizer=tokenizer,
    device="cuda:0",  # 使用GPU,如果没有GPU可以改为"cpu"
    max_context=512
)

# 步骤3:准备数据
# 假设您有一个CSV文件包含K线数据
df = pd.read_csv("your_data.csv")
df['timestamps'] = pd.to_datetime(df['timestamps'])

# 定义历史窗口和预测长度
lookback = 200  # 使用200个历史K线
pred_len = 24   # 预测未来24个K线

# 准备输入
x_df = df.iloc[:lookback][['open', 'high', 'low', 'close', 'volume']]
x_timestamp = df.iloc[:lookback]['timestamps']
y_timestamp = df.iloc[lookback:lookback+pred_len]['timestamps']

# 步骤4:进行预测
pred_df = predictor.predict(
    df=x_df,
    x_timestamp=x_timestamp,
    y_timestamp=y_timestamp,
    pred_len=pred_len,
    T=1.0,
    top_p=0.9,
    sample_count=1
)

print("预测结果:")
print(pred_df)

2. 详细使用教程

2.1 数据准备

2.1.1 数据格式要求

Kronos期望的输入数据格式如下:

# 必需的列
required_columns = ['open', 'high', 'low', 'close']

# 可选的列(如果没有会自动填充)
optional_columns = ['volume', 'amount']

# 示例数据结构
data = {
    'timestamps': pd.date_range('2024-01-01', periods=500, freq='1h'),
    'open': [100.0, 101.0, ...],
    'high': [102.0, 103.0, ...],
    'low': [99.0, 100.0, ...],
    'close': [101.0, 102.0, ...],
    'volume': [1000, 1100, ...],
    'amount': [101000, 112200, ...]
}
df = pd.DataFrame(data)

2.1.2 处理缺失数据

如果您的数据缺少volume或amount列:

# 方法1:让模型自动处理(会填充0)
pred_df = predictor.predict(df=df_without_volume, ...)

# 方法2:手动填充
df['volume'] = 0.0  # 或使用其他填充策略
df['amount'] = df['volume'] * df[['open', 'high', 'low', 'close']].mean(axis=1)

2.2 单序列预测

2.2.1 基础预测

def predict_single_series(df, lookback, pred_len):
    """
    对单个时间序列进行预测

    Args:
        df: 包含K线数据的DataFrame
        lookback: 历史窗口长度
        pred_len: 预测长度
    """
    # 准备数据
    x_df = df.iloc[:lookback][['open', 'high', 'low', 'close', 'volume']]
    x_timestamp = df.iloc[:lookback]['timestamps']
    y_timestamp = df.iloc[lookback:lookback+pred_len]['timestamps']

    # 进行预测
    pred_df = predictor.predict(
        df=x_df,
        x_timestamp=x_timestamp,
        y_timestamp=y_timestamp,
        pred_len=pred_len,
        T=1.0,        # 温度参数
        top_p=0.9,    # 核采样阈值
        sample_count=5 # 生成5条路径取平均
    )

    return pred_df

2.2.2 调整预测参数

不同的参数组合适用于不同的场景:

# 场景1:短期精确预测(低温度,少采样)
conservative_pred = predictor.predict(
    df=x_df,
    x_timestamp=x_timestamp,
    y_timestamp=y_timestamp,
    pred_len=12,  # 短期预测
    T=0.5,        # 低温度,更确定性
    top_p=0.95,   # 高阈值
    sample_count=1
)

# 场景2:长期趋势探索(高温度,多采样)
exploratory_pred = predictor.predict(
    df=x_df,
    x_timestamp=x_timestamp,
    y_timestamp=y_timestamp,
    pred_len=48,   # 长期预测
    T=1.2,         # 高温度,更多样性
    top_p=0.9,     # 标准阈值
    sample_count=10 # 多路径平均
)

# 场景3:集成预测(多组参数)
def ensemble_predict(x_df, x_timestamp, y_timestamp, pred_len):
    """使用不同参数组合进行集成预测"""
    predictions = []

    # 不同的温度和采样策略
    configs = [
        {'T': 0.8, 'top_p': 0.95, 'sample_count': 3},
        {'T': 1.0, 'top_p': 0.9, 'sample_count': 3},
        {'T': 1.2, 'top_p': 0.85, 'sample_count': 3}
    ]

    for config in configs:
        pred = predictor.predict(
            df=x_df,
            x_timestamp=x_timestamp,
            y_timestamp=y_timestamp,
            pred_len=pred_len,
            **config
        )
        predictions.append(pred)

    # 平均所有预测
    ensemble_pred = pd.concat(predictions).groupby(level=0).mean()
    return ensemble_pred

2.3 批量预测

对多个资产或时间段进行并行预测:

def batch_predict_multiple_assets(asset_data_dict, lookback, pred_len):
    """
    批量预测多个资产

    Args:
        asset_data_dict: {asset_name: DataFrame}的字典
        lookback: 历史窗口长度
        pred_len: 预测长度
    """
    df_list = []
    x_timestamp_list = []
    y_timestamp_list = []
    asset_names = []

    for asset_name, df in asset_data_dict.items():
        # 准备每个资产的数据
        x_df = df.iloc[:lookback][['open', 'high', 'low', 'close', 'volume']]
        x_timestamp = df.iloc[:lookback]['timestamps']
        y_timestamp = df.iloc[lookback:lookback+pred_len]['timestamps']

        df_list.append(x_df)
        x_timestamp_list.append(x_timestamp)
        y_timestamp_list.append(y_timestamp)
        asset_names.append(asset_name)

    # 批量预测
    predictions = predictor.predict_batch(
        df_list=df_list,
        x_timestamp_list=x_timestamp_list,
        y_timestamp_list=y_timestamp_list,
        pred_len=pred_len,
        T=1.0,
        top_p=0.9,
        sample_count=3,
        verbose=True
    )

    # 返回结果字典
    return dict(zip(asset_names, predictions))

# 使用示例
asset_data = {
    'BTC/USDT': btc_df,
    'ETH/USDT': eth_df,
    'BNB/USDT': bnb_df
}

predictions = batch_predict_multiple_assets(asset_data, lookback=200, pred_len=24)

2.4 滑动窗口预测

实现滑动窗口进行连续预测:

def sliding_window_prediction(df, window_size=200, pred_len=24, step=1):
    """
    使用滑动窗口进行连续预测

    Args:
        df: 完整的历史数据
        window_size: 历史窗口大小
        pred_len: 每次预测的长度
        step: 滑动步长
    """
    predictions = []

    # 计算可以进行预测的次数
    n_predictions = (len(df) - window_size - pred_len) // step + 1

    for i in range(n_predictions):
        start_idx = i * step
        end_idx = start_idx + window_size

        # 准备当前窗口的数据
        x_df = df.iloc[start_idx:end_idx][['open', 'high', 'low', 'close', 'volume']]
        x_timestamp = df.iloc[start_idx:end_idx]['timestamps']
        y_timestamp = df.iloc[end_idx:end_idx+pred_len]['timestamps']

        # 预测
        pred = predictor.predict(
            df=x_df,
            x_timestamp=x_timestamp,
            y_timestamp=y_timestamp,
            pred_len=pred_len,
            T=1.0,
            top_p=0.9,
            sample_count=1,
            verbose=False
        )

        predictions.append({
            'window_start': start_idx,
            'window_end': end_idx,
            'prediction': pred
        })

    return predictions

3. 高级应用示例

3.1 结合技术指标

将Kronos预测与传统技术指标结合:

import talib
import numpy as np

def predict_with_indicators(df, lookback, pred_len):
    """
    结合技术指标进行增强预测
    """
    # 计算技术指标
    df['SMA_20'] = talib.SMA(df['close'].values, timeperiod=20)
    df['RSI'] = talib.RSI(df['close'].values, timeperiod=14)
    df['MACD'], df['MACD_signal'], _ = talib.MACD(df['close'].values)

    # Kronos预测
    x_df = df.iloc[:lookback][['open', 'high', 'low', 'close', 'volume']]
    x_timestamp = df.iloc[:lookback]['timestamps']
    y_timestamp = df.iloc[lookback:lookback+pred_len]['timestamps']

    pred_df = predictor.predict(
        df=x_df,
        x_timestamp=x_timestamp,
        y_timestamp=y_timestamp,
        pred_len=pred_len,
        T=1.0,
        top_p=0.9
    )

    # 基于技术指标调整预测
    last_rsi = df.iloc[lookback-1]['RSI']

    if last_rsi > 70:  # 超买
        # 降低预测价格
        pred_df[['open', 'high', 'low', 'close']] *= 0.98
    elif last_rsi < 30:  # 超卖
        # 提高预测价格
        pred_df[['open', 'high', 'low', 'close']] *= 1.02

    return pred_df

3.2 构建交易信号

基于Kronos预测生成交易信号:

class KronosTradingStrategy:
    """基于Kronos预测的交易策略"""

    def __init__(self, predictor, lookback=200, pred_len=24):
        self.predictor = predictor
        self.lookback = lookback
        self.pred_len = pred_len

    def generate_signal(self, df):
        """
        生成交易信号

        Returns:
            1: 买入信号
            -1: 卖出信号
            0: 无信号
        """
        # 准备数据
        x_df = df.iloc[-self.lookback:][['open', 'high', 'low', 'close', 'volume']]
        x_timestamp = df.iloc[-self.lookback:]['timestamps']

        # 生成未来时间戳
        last_timestamp = x_timestamp.iloc[-1]
        freq = pd.infer_freq(x_timestamp)
        y_timestamp = pd.date_range(
            start=last_timestamp + pd.Timedelta(1, unit='h'),
            periods=self.pred_len,
            freq=freq
        )

        # 预测
        pred_df = self.predictor.predict(
            df=x_df,
            x_timestamp=x_timestamp,
            y_timestamp=y_timestamp,
            pred_len=self.pred_len,
            T=0.8,  # 较低温度,更保守
            top_p=0.95,
            sample_count=5  # 多路径平均
        )

        # 分析预测结果
        current_price = x_df.iloc[-1]['close']
        pred_mean_price = pred_df['close'].mean()
        pred_max_price = pred_df['high'].max()
        pred_min_price = pred_df['low'].min()

        # 计算预期收益和风险
        expected_return = (pred_mean_price - current_price) / current_price
        upside_potential = (pred_max_price - current_price) / current_price
        downside_risk = (current_price - pred_min_price) / current_price

        # 生成信号
        if expected_return > 0.02 and upside_potential > downside_risk * 1.5:
            return 1  # 买入
        elif expected_return < -0.02 or downside_risk > 0.05:
            return -1  # 卖出
        else:
            return 0  # 持有

    def backtest(self, df, initial_capital=10000):
        """
        简单回测
        """
        capital = initial_capital
        position = 0
        trades = []

        # 滑动窗口回测
        for i in range(self.lookback, len(df) - self.pred_len, 24):
            current_df = df.iloc[:i+1]
            signal = self.generate_signal(current_df)
            current_price = current_df.iloc[-1]['close']

            if signal == 1 and position == 0:
                # 买入
                position = capital / current_price
                capital = 0
                trades.append({
                    'timestamp': current_df.iloc[-1]['timestamps'],
                    'action': 'BUY',
                    'price': current_price,
                    'quantity': position
                })

            elif signal == -1 and position > 0:
                # 卖出
                capital = position * current_price
                position = 0
                trades.append({
                    'timestamp': current_df.iloc[-1]['timestamps'],
                    'action': 'SELL',
                    'price': current_price,
                    'capital': capital
                })

        # 计算最终价值
        final_df = df.iloc[-1]
        final_value = capital + position * final_df['close']
        total_return = (final_value - initial_capital) / initial_capital

        return {
            'trades': trades,
            'final_value': final_value,
            'total_return': total_return,
            'num_trades': len(trades)
        }

3.3 实时预测系统

构建实时预测系统:

import time
from datetime import datetime, timedelta
import threading
import queue

class RealTimeKronosPredictor:
    """实时Kronos预测系统"""

    def __init__(self, predictor, data_source, lookback=200, pred_len=24):
        self.predictor = predictor
        self.data_source = data_source  # 数据源接口
        self.lookback = lookback
        self.pred_len = pred_len
        self.prediction_queue = queue.Queue()
        self.is_running = False

    def fetch_latest_data(self):
        """从数据源获取最新数据"""
        # 这里应该实现真实的数据获取逻辑
        # 示例:从交易所API获取数据
        return self.data_source.get_latest_klines(limit=self.lookback)

    def prediction_loop(self):
        """预测循环"""
        while self.is_running:
            try:
                # 获取最新数据
                df = self.fetch_latest_data()

                if len(df) >= self.lookback:
                    # 准备预测
                    x_df = df.iloc[-self.lookback:][['open', 'high', 'low', 'close', 'volume']]
                    x_timestamp = df.iloc[-self.lookback:]['timestamps']

                    # 生成未来时间戳
                    last_timestamp = x_timestamp.iloc[-1]
                    y_timestamp = pd.date_range(
                        start=last_timestamp + pd.Timedelta(1, unit='h'),
                        periods=self.pred_len,
                        freq='1h'
                    )

                    # 进行预测
                    pred_df = self.predictor.predict(
                        df=x_df,
                        x_timestamp=x_timestamp,
                        y_timestamp=y_timestamp,
                        pred_len=self.pred_len,
                        T=1.0,
                        top_p=0.9,
                        sample_count=3
                    )

                    # 将预测结果放入队列
                    self.prediction_queue.put({
                        'timestamp': datetime.now(),
                        'prediction': pred_df,
                        'current_price': df.iloc[-1]['close']
                    })

                # 等待下一个周期
                time.sleep(3600)  # 每小时预测一次

            except Exception as e:
                print(f"预测错误: {e}")
                time.sleep(60)  # 错误后等待1分钟

    def start(self):
        """启动实时预测"""
        self.is_running = True
        self.prediction_thread = threading.Thread(target=self.prediction_loop)
        self.prediction_thread.start()
        print("实时预测系统已启动")

    def stop(self):
        """停止实时预测"""
        self.is_running = False
        self.prediction_thread.join()
        print("实时预测系统已停止")

    def get_latest_prediction(self, timeout=1):
        """获取最新预测结果"""
        try:
            return self.prediction_queue.get(timeout=timeout)
        except queue.Empty:
            return None

4. 可视化示例

4.1 基础可视化

import matplotlib.pyplot as plt
import matplotlib.dates as mdates

def visualize_prediction(historical_df, pred_df, title="Kronos Prediction"):
    """
    可视化预测结果
    """
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8), sharex=True)

    # 合并历史和预测数据
    hist_close = historical_df['close']
    pred_close = pred_df['close']

    # 绘制价格
    ax1.plot(historical_df['timestamps'], hist_close,
             label='Historical', color='blue', linewidth=1.5)
    ax1.plot(pred_df.index, pred_close,
             label='Prediction', color='red', linewidth=1.5, linestyle='--')

    # 添加预测区间(如果有多次采样)
    ax1.fill_between(pred_df.index,
                     pred_df['low'],
                     pred_df['high'],
                     alpha=0.3, color='red', label='Prediction Range')

    ax1.set_ylabel('Price', fontsize=12)
    ax1.legend(loc='best')
    ax1.grid(True, alpha=0.3)
    ax1.set_title(title)

    # 绘制成交量
    if 'volume' in historical_df.columns and 'volume' in pred_df.columns:
        ax2.bar(historical_df['timestamps'], historical_df['volume'],
                color='blue', alpha=0.5, label='Historical Volume')
        ax2.bar(pred_df.index, pred_df['volume'],
                color='red', alpha=0.5, label='Predicted Volume')
        ax2.set_ylabel('Volume', fontsize=12)
        ax2.legend(loc='best')
        ax2.grid(True, alpha=0.3)

    # 格式化x轴
    ax2.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d %H:%M'))
    plt.xticks(rotation=45)

    plt.tight_layout()
    plt.show()

4.2 高级可视化

import plotly.graph_objects as go
from plotly.subplots import make_subplots

def interactive_candlestick_prediction(historical_df, pred_df):
    """
    创建交互式K线图预测可视化
    """
    # 创建子图
    fig = make_subplots(
        rows=2, cols=1,
        shared_xaxes=True,
        vertical_spacing=0.03,
        subplot_titles=('Price Prediction', 'Volume'),
        row_heights=[0.7, 0.3]
    )

    # 历史K线
    fig.add_trace(
        go.Candlestick(
            x=historical_df['timestamps'],
            open=historical_df['open'],
            high=historical_df['high'],
            low=historical_df['low'],
            close=historical_df['close'],
            name='Historical',
            increasing_line_color='green',
            decreasing_line_color='red'
        ),
        row=1, col=1
    )

    # 预测K线
    fig.add_trace(
        go.Candlestick(
            x=pred_df.index,
            open=pred_df['open'],
            high=pred_df['high'],
            low=pred_df['low'],
            close=pred_df['close'],
            name='Prediction',
            increasing_line_color='lightgreen',
            decreasing_line_color='lightcoral',
            opacity=0.7
        ),
        row=1, col=1
    )

    # 成交量
    fig.add_trace(
        go.Bar(
            x=historical_df['timestamps'],
            y=historical_df['volume'],
            name='Historical Volume',
            marker_color='blue',
            opacity=0.5
        ),
        row=2, col=1
    )

    fig.add_trace(
        go.Bar(
            x=pred_df.index,
            y=pred_df['volume'],
            name='Predicted Volume',
            marker_color='red',
            opacity=0.5
        ),
        row=2, col=1
    )

    # 更新布局
    fig.update_layout(
        title='Kronos K-line Prediction',
        xaxis_rangeslider_visible=False,
        height=800,
        showlegend=True
    )

    fig.show()

5. 性能优化技巧

5.1 GPU加速

# 检查GPU可用性
import torch

if torch.cuda.is_available():
    device = f"cuda:{torch.cuda.current_device()}"
    print(f"使用GPU: {torch.cuda.get_device_name()}")
else:
    device = "cpu"
    print("GPU不可用,使用CPU")

# 优化GPU内存使用
torch.cuda.empty_cache()  # 清理缓存

# 使用混合精度加速
with torch.cuda.amp.autocast():
    predictions = predictor.predict(...)

5.2 批量处理优化

def optimize_batch_size(data_list, max_memory_gb=8):
    """
    根据GPU内存自动优化批量大小
    """
    import torch

    if not torch.cuda.is_available():
        return len(data_list)  # CPU模式,处理所有

    # 估算每个样本的内存占用
    sample_size = data_list[0].memory_usage(deep=True).sum() / 1e9  # GB

    # 计算最优批量大小
    optimal_batch_size = int(max_memory_gb * 0.7 / sample_size)  # 留30%余量

    return min(optimal_batch_size, len(data_list))

5.3 缓存策略

from functools import lru_cache
import hashlib

class CachedKronosPredictor:
    """带缓存的Kronos预测器"""

    def __init__(self, predictor):
        self.predictor = predictor
        self.cache = {}

    def _get_cache_key(self, df, pred_len, T, top_p, sample_count):
        """生成缓存键"""
        # 使用数据的哈希值作为键
        data_hash = hashlib.md5(
            df.to_string().encode() +
            f"{pred_len}{T}{top_p}{sample_count}".encode()
        ).hexdigest()
        return data_hash

    def predict_with_cache(self, df, x_timestamp, y_timestamp,
                           pred_len, T=1.0, top_p=0.9, sample_count=1):
        """带缓存的预测"""
        cache_key = self._get_cache_key(df, pred_len, T, top_p, sample_count)

        if cache_key in self.cache:
            print("使用缓存结果")
            return self.cache[cache_key]

        # 进行预测
        result = self.predictor.predict(
            df=df,
            x_timestamp=x_timestamp,
            y_timestamp=y_timestamp,
            pred_len=pred_len,
            T=T,
            top_p=top_p,
            sample_count=sample_count
        )

        # 缓存结果
        self.cache[cache_key] = result

        return result

6. 常见问题和解决方案

6.1 内存不足

# 问题:处理大批量数据时内存不足
# 解决方案:分批处理

def process_large_dataset(df_list, batch_size=10):
    """分批处理大数据集"""
    all_predictions = []

    for i in range(0, len(df_list), batch_size):
        batch = df_list[i:i+batch_size]
        batch_predictions = predictor.predict_batch(batch, ...)
        all_predictions.extend(batch_predictions)

        # 清理内存
        torch.cuda.empty_cache()

    return all_predictions

6.2 预测不稳定

# 问题:多次预测结果差异较大
# 解决方案:增加采样次数,降低温度

def stable_prediction(df, x_timestamp, y_timestamp, pred_len):
    """稳定的预测策略"""
    return predictor.predict(
        df=df,
        x_timestamp=x_timestamp,
        y_timestamp=y_timestamp,
        pred_len=pred_len,
        T=0.7,          # 降低温度
        top_p=0.95,     # 提高阈值
        sample_count=10 # 增加采样次数
    )

6.3 长期预测失真

# 问题:长期预测偏离太大
# 解决方案:递归预测,每次预测较短周期

def recursive_long_term_prediction(df, total_pred_len, step_len=24):
    """
    递归进行长期预测

    Args:
        df: 初始数据
        total_pred_len: 总预测长度
        step_len: 每步预测长度
    """
    all_predictions = []
    current_df = df.copy()

    steps = total_pred_len // step_len

    for step in range(steps):
        # 预测一步
        x_df = current_df.iloc[-200:][['open', 'high', 'low', 'close', 'volume']]
        x_timestamp = current_df.iloc[-200:]['timestamps']

        # 生成时间戳
        last_time = x_timestamp.iloc[-1]
        y_timestamp = pd.date_range(
            start=last_time + pd.Timedelta(1, unit='h'),
            periods=step_len,
            freq='1h'
        )

        # 预测
        pred = predictor.predict(
            df=x_df,
            x_timestamp=x_timestamp,
            y_timestamp=y_timestamp,
            pred_len=step_len,
            T=0.9,
            top_p=0.9,
            sample_count=5
        )

        # 将预测结果添加到数据中
        pred['timestamps'] = y_timestamp
        current_df = pd.concat([current_df, pred], ignore_index=True)
        all_predictions.append(pred)

    return pd.concat(all_predictions)

7. 总结

Kronos为金融K线预测提供了强大而灵活的工具。通过合理配置参数和使用策略,可以在不同场景下获得良好的预测效果。关键要点:

  1. 数据准备:确保数据格式正确,处理缺失值
  2. 参数调优:根据具体需求调整温度、采样参数
  3. 批量处理:利用批量预测提高效率
  4. 稳定性:通过多路径平均提高预测稳定性
  5. 性能优化:合理使用GPU和缓存策略

记住,Kronos是一个概率模型,预测结果应该作为决策支持而非唯一依据。在实际应用中,建议结合其他技术分析工具和风险管理策略。