🇨🇳 简体中文
🇺🇸 English
🇯🇵 日本語
Skip to the content.

Kronos API 文档

概述

Kronos是一个专门为金融市场K线数据设计的基础模型(Foundation Model)。它是第一个开源的金融K线序列预训练模型,基于来自全球45个交易所的数据训练而成。Kronos采用创新的两阶段框架:首先使用专门的分词器将连续的多维K线数据(OHLCV)量化为分层离散令牌,然后使用大型自回归Transformer模型进行预训练。

核心组件

1. KronosTokenizer - K线数据分词器

KronosTokenizer用于将原始K线数据转换为模型可以理解的离散令牌。

类定义

class KronosTokenizer(nn.Module, PyTorchModelHubMixin)

初始化参数

参数 类型 描述
d_in int 输入维度(默认为6:OHLCV + Amount)
d_model int 模型维度
n_heads int 注意力头数
ff_dim int 前馈网络维度
n_enc_layers int 编码器层数
n_dec_layers int 解码器层数
ffn_dropout_p float 前馈网络的Dropout概率
attn_dropout_p float 注意力机制的Dropout概率
resid_dropout_p float 残差连接的Dropout概率
s1_bits int 第一阶段量化的比特数
s2_bits int 第二阶段量化的比特数
beta float BSQuantizer的Beta参数
gamma0 float BSQuantizer的Gamma0参数
gamma float BSQuantizer的Gamma参数
zeta float BSQuantizer的Zeta参数
group_size int BSQuantizer的组大小参数

主要方法

from_pretrained(model_name: str) -> KronosTokenizer

从Hugging Face Hub加载预训练的分词器。

tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base")
encode(x: torch.Tensor, half: bool = False) -> torch.Tensor

将输入数据编码为量化索引。

decode(z_indices: torch.Tensor, half: bool = False) -> torch.Tensor

将量化索引解码回原始空间。

2. Kronos - 主预测模型

Kronos是用于K线序列预测的主模型。

类定义

class Kronos(nn.Module, PyTorchModelHubMixin)

初始化参数

参数 类型 描述
s1_bits int 第一阶段令牌的比特数
s2_bits int 第二阶段令牌的比特数
n_layers int Transformer层数
d_model int 模型维度
n_heads int 注意力头数
learn_te bool 是否学习时间嵌入
ff_dim int 前馈网络维度
ffn_dropout_p float 前馈网络的Dropout概率
attn_dropout_p float 注意力机制的Dropout概率
resid_dropout_p float 残差连接的Dropout概率
token_dropout_p float 令牌级别的Dropout概率

主要方法

from_pretrained(model_name: str) -> Kronos

从Hugging Face Hub加载预训练模型。

model = Kronos.from_pretrained("NeoQuasar/Kronos-small")
forward(s1_ids, s2_ids, stamp=None, padding_mask=None, use_teacher_forcing=False, s1_targets=None)

模型的前向传播。

3. KronosPredictor - 高级预测接口

KronosPredictor提供了一个简化的接口,用于处理原始数据到预测结果的完整流程。

类定义

class KronosPredictor

初始化参数

参数 类型 默认值 描述
model Kronos - Kronos模型实例
tokenizer KronosTokenizer - KronosTokenizer实例
device str “cuda:0” 计算设备
max_context int 512 最大上下文长度
clip float 5 数据裁剪阈值

主要方法

predict(df, x_timestamp, y_timestamp, pred_len, T=1.0, top_k=0, top_p=0.9, sample_count=1, verbose=True)

对单个时间序列进行预测。

参数:

返回值:

示例:

pred_df = predictor.predict(
    df=historical_df,
    x_timestamp=historical_timestamps,
    y_timestamp=future_timestamps,
    pred_len=24,
    T=1.0,
    top_p=0.9,
    sample_count=3
)
predict_batch(df_list, x_timestamp_list, y_timestamp_list, pred_len, T=1.0, top_k=0, top_p=0.9, sample_count=1, verbose=True)

对多个时间序列进行并行批量预测。

参数:

返回值:

注意事项:

示例:

# 准备多个数据集
df_list = [df1, df2, df3]
x_timestamp_list = [x_ts1, x_ts2, x_ts3]
y_timestamp_list = [y_ts1, y_ts2, y_ts3]

# 批量预测
predictions = predictor.predict_batch(
    df_list=df_list,
    x_timestamp_list=x_timestamp_list,
    y_timestamp_list=y_timestamp_list,
    pred_len=24,
    T=1.0,
    top_p=0.9,
    sample_count=1
)

模型规格

可用模型

模型 分词器 上下文长度 参数量 Hugging Face地址
Kronos-mini Kronos-Tokenizer-2k 2048 4.1M NeoQuasar/Kronos-mini
Kronos-small Kronos-Tokenizer-base 512 24.7M NeoQuasar/Kronos-small
Kronos-base Kronos-Tokenizer-base 512 102.3M NeoQuasar/Kronos-base
Kronos-large Kronos-Tokenizer-base 512 499.2M (未开源)

数据格式要求

输入数据格式

KronosPredictor期望的输入DataFrame必须包含以下列:

必需列:

可选列:

时间戳格式

时间戳应为pd.DatetimeIndexpd.Series类型,包含datetime对象。模型会自动提取以下时间特征:

采样参数说明

温度(Temperature)- T

控制预测的随机性:

Top-k 采样

只从概率最高的k个令牌中采样:

Top-p 采样(核采样)

从累积概率达到p的最小令牌集合中采样:

Sample Count

生成多条预测路径并取平均,提高预测的稳定性:

性能优化建议

  1. 批量处理: 使用predict_batch方法同时处理多个序列,充分利用GPU并行计算

  2. 上下文长度:
    • Kronos-small和Kronos-base的最大上下文为512
    • 建议输入长度不超过此限制以获得最佳性能
  3. GPU使用:
    • 确保模型和数据在同一设备上
    • 使用CUDA加速计算
  4. 采样策略:
    • 对于确定性预测:使用低温度(T=0.5-0.8)
    • 对于探索性预测:使用高温度(T=1.2-1.5)
    • 结合top-p采样(0.9-0.95)通常效果较好

错误处理

KronosPredictor会进行以下验证:

注意事项

  1. 模型专门针对金融K线数据设计,不适用于其他类型的时间序列
  2. 预测结果会自动进行反归一化,返回原始尺度的值
  3. 模型使用分层量化技术,可能存在一定的重构误差
  4. 长期预测(>100步)的不确定性会显著增加