Kronos API 文档
概述
Kronos是一个专门为金融市场K线数据设计的基础模型(Foundation Model)。它是第一个开源的金融K线序列预训练模型,基于来自全球45个交易所的数据训练而成。Kronos采用创新的两阶段框架:首先使用专门的分词器将连续的多维K线数据(OHLCV)量化为分层离散令牌,然后使用大型自回归Transformer模型进行预训练。
核心组件
1. KronosTokenizer - K线数据分词器
KronosTokenizer用于将原始K线数据转换为模型可以理解的离散令牌。
类定义
class KronosTokenizer(nn.Module, PyTorchModelHubMixin)
初始化参数
| 参数 | 类型 | 描述 |
|---|---|---|
d_in |
int | 输入维度(默认为6:OHLCV + Amount) |
d_model |
int | 模型维度 |
n_heads |
int | 注意力头数 |
ff_dim |
int | 前馈网络维度 |
n_enc_layers |
int | 编码器层数 |
n_dec_layers |
int | 解码器层数 |
ffn_dropout_p |
float | 前馈网络的Dropout概率 |
attn_dropout_p |
float | 注意力机制的Dropout概率 |
resid_dropout_p |
float | 残差连接的Dropout概率 |
s1_bits |
int | 第一阶段量化的比特数 |
s2_bits |
int | 第二阶段量化的比特数 |
beta |
float | BSQuantizer的Beta参数 |
gamma0 |
float | BSQuantizer的Gamma0参数 |
gamma |
float | BSQuantizer的Gamma参数 |
zeta |
float | BSQuantizer的Zeta参数 |
group_size |
int | BSQuantizer的组大小参数 |
主要方法
from_pretrained(model_name: str) -> KronosTokenizer
从Hugging Face Hub加载预训练的分词器。
tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base")
encode(x: torch.Tensor, half: bool = False) -> torch.Tensor
将输入数据编码为量化索引。
- 输入:
x- 形状为(batch_size, seq_len, d_in)的输入张量 - 输出: 量化后的索引张量
decode(z_indices: torch.Tensor, half: bool = False) -> torch.Tensor
将量化索引解码回原始空间。
- 输入:
z_indices- 量化索引张量 - 输出: 重构的数据张量
2. Kronos - 主预测模型
Kronos是用于K线序列预测的主模型。
类定义
class Kronos(nn.Module, PyTorchModelHubMixin)
初始化参数
| 参数 | 类型 | 描述 |
|---|---|---|
s1_bits |
int | 第一阶段令牌的比特数 |
s2_bits |
int | 第二阶段令牌的比特数 |
n_layers |
int | Transformer层数 |
d_model |
int | 模型维度 |
n_heads |
int | 注意力头数 |
learn_te |
bool | 是否学习时间嵌入 |
ff_dim |
int | 前馈网络维度 |
ffn_dropout_p |
float | 前馈网络的Dropout概率 |
attn_dropout_p |
float | 注意力机制的Dropout概率 |
resid_dropout_p |
float | 残差连接的Dropout概率 |
token_dropout_p |
float | 令牌级别的Dropout概率 |
主要方法
from_pretrained(model_name: str) -> Kronos
从Hugging Face Hub加载预训练模型。
model = Kronos.from_pretrained("NeoQuasar/Kronos-small")
forward(s1_ids, s2_ids, stamp=None, padding_mask=None, use_teacher_forcing=False, s1_targets=None)
模型的前向传播。
- 输入:
s1_ids: s1令牌ID张量,形状[batch_size, seq_len]s2_ids: s2令牌ID张量,形状[batch_size, seq_len]stamp: 时间戳张量(可选)padding_mask: 填充掩码(可选)use_teacher_forcing: 是否使用教师强制s1_targets: 教师强制的目标s1令牌ID
- 输出:
s1_logits: s1令牌预测的logitss2_logits: 基于s1条件的s2令牌预测的logits
3. KronosPredictor - 高级预测接口
KronosPredictor提供了一个简化的接口,用于处理原始数据到预测结果的完整流程。
类定义
class KronosPredictor
初始化参数
| 参数 | 类型 | 默认值 | 描述 |
|---|---|---|---|
model |
Kronos | - | Kronos模型实例 |
tokenizer |
KronosTokenizer | - | KronosTokenizer实例 |
device |
str | “cuda:0” | 计算设备 |
max_context |
int | 512 | 最大上下文长度 |
clip |
float | 5 | 数据裁剪阈值 |
主要方法
predict(df, x_timestamp, y_timestamp, pred_len, T=1.0, top_k=0, top_p=0.9, sample_count=1, verbose=True)
对单个时间序列进行预测。
参数:
df(pd.DataFrame): 包含历史K线数据的DataFrame,必须包含['open', 'high', 'low', 'close']列x_timestamp(pd.Series): 历史数据对应的时间戳y_timestamp(pd.Series): 要预测的未来时间戳pred_len(int): 预测长度T(float): 采样温度,控制预测的随机性(默认1.0)top_k(int): Top-k过滤阈值(默认0,不启用)top_p(float): Top-p(核采样)阈值(默认0.9)sample_count(int): 生成并平均的预测路径数(默认1)verbose(bool): 是否显示进度(默认True)
返回值:
pd.DataFrame: 包含预测值的DataFrame,包含open,high,low,close,volume,amount列
示例:
pred_df = predictor.predict(
df=historical_df,
x_timestamp=historical_timestamps,
y_timestamp=future_timestamps,
pred_len=24,
T=1.0,
top_p=0.9,
sample_count=3
)
predict_batch(df_list, x_timestamp_list, y_timestamp_list, pred_len, T=1.0, top_k=0, top_p=0.9, sample_count=1, verbose=True)
对多个时间序列进行并行批量预测。
参数:
df_list(List[pd.DataFrame]): 输入DataFrame列表x_timestamp_list(List[pd.Series]): 历史时间戳列表y_timestamp_list(List[pd.Series]): 未来时间戳列表pred_len(int): 预测长度(所有序列必须相同)- 其他参数与
predict方法相同
返回值:
List[pd.DataFrame]: 预测结果列表,顺序与输入相同
注意事项:
- 所有序列必须具有相同的历史长度(lookback窗口)
- 所有序列必须具有相同的预测长度
- 批量预测利用GPU并行计算,效率更高
示例:
# 准备多个数据集
df_list = [df1, df2, df3]
x_timestamp_list = [x_ts1, x_ts2, x_ts3]
y_timestamp_list = [y_ts1, y_ts2, y_ts3]
# 批量预测
predictions = predictor.predict_batch(
df_list=df_list,
x_timestamp_list=x_timestamp_list,
y_timestamp_list=y_timestamp_list,
pred_len=24,
T=1.0,
top_p=0.9,
sample_count=1
)
模型规格
可用模型
| 模型 | 分词器 | 上下文长度 | 参数量 | Hugging Face地址 |
|---|---|---|---|---|
| Kronos-mini | Kronos-Tokenizer-2k | 2048 | 4.1M | NeoQuasar/Kronos-mini |
| Kronos-small | Kronos-Tokenizer-base | 512 | 24.7M | NeoQuasar/Kronos-small |
| Kronos-base | Kronos-Tokenizer-base | 512 | 102.3M | NeoQuasar/Kronos-base |
| Kronos-large | Kronos-Tokenizer-base | 512 | 499.2M | (未开源) |
数据格式要求
输入数据格式
KronosPredictor期望的输入DataFrame必须包含以下列:
必需列:
open: 开盘价high: 最高价low: 最低价close: 收盘价
可选列:
volume: 成交量(如果缺失会自动填充0)amount: 成交额(如果缺失会基于价格和成交量估算)
时间戳格式
时间戳应为pd.DatetimeIndex或pd.Series类型,包含datetime对象。模型会自动提取以下时间特征:
minute: 分钟(0-59)hour: 小时(0-23)weekday: 星期几(0-6)day: 日期(1-31)month: 月份(1-12)
采样参数说明
温度(Temperature)- T
控制预测的随机性:
T < 1.0: 更确定性的预测,倾向于高概率事件T = 1.0: 标准采样T > 1.0: 更随机的预测,增加多样性
Top-k 采样
只从概率最高的k个令牌中采样:
top_k = 0: 不使用top-k过滤top_k > 0: 只考虑概率最高的k个选项
Top-p 采样(核采样)
从累积概率达到p的最小令牌集合中采样:
top_p = 0.9: 从累积概率达到90%的令牌中采样top_p = 1.0: 考虑所有令牌
Sample Count
生成多条预测路径并取平均,提高预测的稳定性:
sample_count = 1: 单路径预测sample_count > 1: 多路径平均,更稳定但计算成本更高
性能优化建议
-
批量处理: 使用
predict_batch方法同时处理多个序列,充分利用GPU并行计算 - 上下文长度:
- Kronos-small和Kronos-base的最大上下文为512
- 建议输入长度不超过此限制以获得最佳性能
- GPU使用:
- 确保模型和数据在同一设备上
- 使用CUDA加速计算
- 采样策略:
- 对于确定性预测:使用低温度(T=0.5-0.8)
- 对于探索性预测:使用高温度(T=1.2-1.5)
- 结合top-p采样(0.9-0.95)通常效果较好
错误处理
KronosPredictor会进行以下验证:
- 检查必需的价格列是否存在
- 检查NaN值
- 自动处理缺失的volume/amount列
- 验证时间戳长度与数据长度的一致性
- 批量预测时验证所有序列的长度一致性
注意事项
- 模型专门针对金融K线数据设计,不适用于其他类型的时间序列
- 预测结果会自动进行反归一化,返回原始尺度的值
- 模型使用分层量化技术,可能存在一定的重构误差
- 长期预测(>100步)的不确定性会显著增加