Kronos 使用指南
1. 快速开始
1.1 安装依赖
首先确保您的Python版本≥3.10,然后安装必要的依赖:
cd Kronos
pip install -r requirements.txt
主要依赖包括:
torch: PyTorch深度学习框架pandas: 数据处理numpy: 数值计算huggingface_hub: 模型下载和管理einops: 张量操作matplotlib: 可视化(可选)
1.2 第一个预测示例
以下是使用Kronos进行K线预测的最简单示例:
import pandas as pd
from model import Kronos, KronosTokenizer, KronosPredictor
# 步骤1:加载模型和分词器
tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base")
model = Kronos.from_pretrained("NeoQuasar/Kronos-small")
# 步骤2:初始化预测器
predictor = KronosPredictor(
model=model,
tokenizer=tokenizer,
device="cuda:0", # 使用GPU,如果没有GPU可以改为"cpu"
max_context=512
)
# 步骤3:准备数据
# 假设您有一个CSV文件包含K线数据
df = pd.read_csv("your_data.csv")
df['timestamps'] = pd.to_datetime(df['timestamps'])
# 定义历史窗口和预测长度
lookback = 200 # 使用200个历史K线
pred_len = 24 # 预测未来24个K线
# 准备输入
x_df = df.iloc[:lookback][['open', 'high', 'low', 'close', 'volume']]
x_timestamp = df.iloc[:lookback]['timestamps']
y_timestamp = df.iloc[lookback:lookback+pred_len]['timestamps']
# 步骤4:进行预测
pred_df = predictor.predict(
df=x_df,
x_timestamp=x_timestamp,
y_timestamp=y_timestamp,
pred_len=pred_len,
T=1.0,
top_p=0.9,
sample_count=1
)
print("预测结果:")
print(pred_df)
2. 详细使用教程
2.1 数据准备
2.1.1 数据格式要求
Kronos期望的输入数据格式如下:
# 必需的列
required_columns = ['open', 'high', 'low', 'close']
# 可选的列(如果没有会自动填充)
optional_columns = ['volume', 'amount']
# 示例数据结构
data = {
'timestamps': pd.date_range('2024-01-01', periods=500, freq='1h'),
'open': [100.0, 101.0, ...],
'high': [102.0, 103.0, ...],
'low': [99.0, 100.0, ...],
'close': [101.0, 102.0, ...],
'volume': [1000, 1100, ...],
'amount': [101000, 112200, ...]
}
df = pd.DataFrame(data)
2.1.2 处理缺失数据
如果您的数据缺少volume或amount列:
# 方法1:让模型自动处理(会填充0)
pred_df = predictor.predict(df=df_without_volume, ...)
# 方法2:手动填充
df['volume'] = 0.0 # 或使用其他填充策略
df['amount'] = df['volume'] * df[['open', 'high', 'low', 'close']].mean(axis=1)
2.2 单序列预测
2.2.1 基础预测
def predict_single_series(df, lookback, pred_len):
"""
对单个时间序列进行预测
Args:
df: 包含K线数据的DataFrame
lookback: 历史窗口长度
pred_len: 预测长度
"""
# 准备数据
x_df = df.iloc[:lookback][['open', 'high', 'low', 'close', 'volume']]
x_timestamp = df.iloc[:lookback]['timestamps']
y_timestamp = df.iloc[lookback:lookback+pred_len]['timestamps']
# 进行预测
pred_df = predictor.predict(
df=x_df,
x_timestamp=x_timestamp,
y_timestamp=y_timestamp,
pred_len=pred_len,
T=1.0, # 温度参数
top_p=0.9, # 核采样阈值
sample_count=5 # 生成5条路径取平均
)
return pred_df
2.2.2 调整预测参数
不同的参数组合适用于不同的场景:
# 场景1:短期精确预测(低温度,少采样)
conservative_pred = predictor.predict(
df=x_df,
x_timestamp=x_timestamp,
y_timestamp=y_timestamp,
pred_len=12, # 短期预测
T=0.5, # 低温度,更确定性
top_p=0.95, # 高阈值
sample_count=1
)
# 场景2:长期趋势探索(高温度,多采样)
exploratory_pred = predictor.predict(
df=x_df,
x_timestamp=x_timestamp,
y_timestamp=y_timestamp,
pred_len=48, # 长期预测
T=1.2, # 高温度,更多样性
top_p=0.9, # 标准阈值
sample_count=10 # 多路径平均
)
# 场景3:集成预测(多组参数)
def ensemble_predict(x_df, x_timestamp, y_timestamp, pred_len):
"""使用不同参数组合进行集成预测"""
predictions = []
# 不同的温度和采样策略
configs = [
{'T': 0.8, 'top_p': 0.95, 'sample_count': 3},
{'T': 1.0, 'top_p': 0.9, 'sample_count': 3},
{'T': 1.2, 'top_p': 0.85, 'sample_count': 3}
]
for config in configs:
pred = predictor.predict(
df=x_df,
x_timestamp=x_timestamp,
y_timestamp=y_timestamp,
pred_len=pred_len,
**config
)
predictions.append(pred)
# 平均所有预测
ensemble_pred = pd.concat(predictions).groupby(level=0).mean()
return ensemble_pred
2.3 批量预测
对多个资产或时间段进行并行预测:
def batch_predict_multiple_assets(asset_data_dict, lookback, pred_len):
"""
批量预测多个资产
Args:
asset_data_dict: {asset_name: DataFrame}的字典
lookback: 历史窗口长度
pred_len: 预测长度
"""
df_list = []
x_timestamp_list = []
y_timestamp_list = []
asset_names = []
for asset_name, df in asset_data_dict.items():
# 准备每个资产的数据
x_df = df.iloc[:lookback][['open', 'high', 'low', 'close', 'volume']]
x_timestamp = df.iloc[:lookback]['timestamps']
y_timestamp = df.iloc[lookback:lookback+pred_len]['timestamps']
df_list.append(x_df)
x_timestamp_list.append(x_timestamp)
y_timestamp_list.append(y_timestamp)
asset_names.append(asset_name)
# 批量预测
predictions = predictor.predict_batch(
df_list=df_list,
x_timestamp_list=x_timestamp_list,
y_timestamp_list=y_timestamp_list,
pred_len=pred_len,
T=1.0,
top_p=0.9,
sample_count=3,
verbose=True
)
# 返回结果字典
return dict(zip(asset_names, predictions))
# 使用示例
asset_data = {
'BTC/USDT': btc_df,
'ETH/USDT': eth_df,
'BNB/USDT': bnb_df
}
predictions = batch_predict_multiple_assets(asset_data, lookback=200, pred_len=24)
2.4 滑动窗口预测
实现滑动窗口进行连续预测:
def sliding_window_prediction(df, window_size=200, pred_len=24, step=1):
"""
使用滑动窗口进行连续预测
Args:
df: 完整的历史数据
window_size: 历史窗口大小
pred_len: 每次预测的长度
step: 滑动步长
"""
predictions = []
# 计算可以进行预测的次数
n_predictions = (len(df) - window_size - pred_len) // step + 1
for i in range(n_predictions):
start_idx = i * step
end_idx = start_idx + window_size
# 准备当前窗口的数据
x_df = df.iloc[start_idx:end_idx][['open', 'high', 'low', 'close', 'volume']]
x_timestamp = df.iloc[start_idx:end_idx]['timestamps']
y_timestamp = df.iloc[end_idx:end_idx+pred_len]['timestamps']
# 预测
pred = predictor.predict(
df=x_df,
x_timestamp=x_timestamp,
y_timestamp=y_timestamp,
pred_len=pred_len,
T=1.0,
top_p=0.9,
sample_count=1,
verbose=False
)
predictions.append({
'window_start': start_idx,
'window_end': end_idx,
'prediction': pred
})
return predictions
3. 高级应用示例
3.1 结合技术指标
将Kronos预测与传统技术指标结合:
import talib
import numpy as np
def predict_with_indicators(df, lookback, pred_len):
"""
结合技术指标进行增强预测
"""
# 计算技术指标
df['SMA_20'] = talib.SMA(df['close'].values, timeperiod=20)
df['RSI'] = talib.RSI(df['close'].values, timeperiod=14)
df['MACD'], df['MACD_signal'], _ = talib.MACD(df['close'].values)
# Kronos预测
x_df = df.iloc[:lookback][['open', 'high', 'low', 'close', 'volume']]
x_timestamp = df.iloc[:lookback]['timestamps']
y_timestamp = df.iloc[lookback:lookback+pred_len]['timestamps']
pred_df = predictor.predict(
df=x_df,
x_timestamp=x_timestamp,
y_timestamp=y_timestamp,
pred_len=pred_len,
T=1.0,
top_p=0.9
)
# 基于技术指标调整预测
last_rsi = df.iloc[lookback-1]['RSI']
if last_rsi > 70: # 超买
# 降低预测价格
pred_df[['open', 'high', 'low', 'close']] *= 0.98
elif last_rsi < 30: # 超卖
# 提高预测价格
pred_df[['open', 'high', 'low', 'close']] *= 1.02
return pred_df
3.2 构建交易信号
基于Kronos预测生成交易信号:
class KronosTradingStrategy:
"""基于Kronos预测的交易策略"""
def __init__(self, predictor, lookback=200, pred_len=24):
self.predictor = predictor
self.lookback = lookback
self.pred_len = pred_len
def generate_signal(self, df):
"""
生成交易信号
Returns:
1: 买入信号
-1: 卖出信号
0: 无信号
"""
# 准备数据
x_df = df.iloc[-self.lookback:][['open', 'high', 'low', 'close', 'volume']]
x_timestamp = df.iloc[-self.lookback:]['timestamps']
# 生成未来时间戳
last_timestamp = x_timestamp.iloc[-1]
freq = pd.infer_freq(x_timestamp)
y_timestamp = pd.date_range(
start=last_timestamp + pd.Timedelta(1, unit='h'),
periods=self.pred_len,
freq=freq
)
# 预测
pred_df = self.predictor.predict(
df=x_df,
x_timestamp=x_timestamp,
y_timestamp=y_timestamp,
pred_len=self.pred_len,
T=0.8, # 较低温度,更保守
top_p=0.95,
sample_count=5 # 多路径平均
)
# 分析预测结果
current_price = x_df.iloc[-1]['close']
pred_mean_price = pred_df['close'].mean()
pred_max_price = pred_df['high'].max()
pred_min_price = pred_df['low'].min()
# 计算预期收益和风险
expected_return = (pred_mean_price - current_price) / current_price
upside_potential = (pred_max_price - current_price) / current_price
downside_risk = (current_price - pred_min_price) / current_price
# 生成信号
if expected_return > 0.02 and upside_potential > downside_risk * 1.5:
return 1 # 买入
elif expected_return < -0.02 or downside_risk > 0.05:
return -1 # 卖出
else:
return 0 # 持有
def backtest(self, df, initial_capital=10000):
"""
简单回测
"""
capital = initial_capital
position = 0
trades = []
# 滑动窗口回测
for i in range(self.lookback, len(df) - self.pred_len, 24):
current_df = df.iloc[:i+1]
signal = self.generate_signal(current_df)
current_price = current_df.iloc[-1]['close']
if signal == 1 and position == 0:
# 买入
position = capital / current_price
capital = 0
trades.append({
'timestamp': current_df.iloc[-1]['timestamps'],
'action': 'BUY',
'price': current_price,
'quantity': position
})
elif signal == -1 and position > 0:
# 卖出
capital = position * current_price
position = 0
trades.append({
'timestamp': current_df.iloc[-1]['timestamps'],
'action': 'SELL',
'price': current_price,
'capital': capital
})
# 计算最终价值
final_df = df.iloc[-1]
final_value = capital + position * final_df['close']
total_return = (final_value - initial_capital) / initial_capital
return {
'trades': trades,
'final_value': final_value,
'total_return': total_return,
'num_trades': len(trades)
}
3.3 实时预测系统
构建实时预测系统:
import time
from datetime import datetime, timedelta
import threading
import queue
class RealTimeKronosPredictor:
"""实时Kronos预测系统"""
def __init__(self, predictor, data_source, lookback=200, pred_len=24):
self.predictor = predictor
self.data_source = data_source # 数据源接口
self.lookback = lookback
self.pred_len = pred_len
self.prediction_queue = queue.Queue()
self.is_running = False
def fetch_latest_data(self):
"""从数据源获取最新数据"""
# 这里应该实现真实的数据获取逻辑
# 示例:从交易所API获取数据
return self.data_source.get_latest_klines(limit=self.lookback)
def prediction_loop(self):
"""预测循环"""
while self.is_running:
try:
# 获取最新数据
df = self.fetch_latest_data()
if len(df) >= self.lookback:
# 准备预测
x_df = df.iloc[-self.lookback:][['open', 'high', 'low', 'close', 'volume']]
x_timestamp = df.iloc[-self.lookback:]['timestamps']
# 生成未来时间戳
last_timestamp = x_timestamp.iloc[-1]
y_timestamp = pd.date_range(
start=last_timestamp + pd.Timedelta(1, unit='h'),
periods=self.pred_len,
freq='1h'
)
# 进行预测
pred_df = self.predictor.predict(
df=x_df,
x_timestamp=x_timestamp,
y_timestamp=y_timestamp,
pred_len=self.pred_len,
T=1.0,
top_p=0.9,
sample_count=3
)
# 将预测结果放入队列
self.prediction_queue.put({
'timestamp': datetime.now(),
'prediction': pred_df,
'current_price': df.iloc[-1]['close']
})
# 等待下一个周期
time.sleep(3600) # 每小时预测一次
except Exception as e:
print(f"预测错误: {e}")
time.sleep(60) # 错误后等待1分钟
def start(self):
"""启动实时预测"""
self.is_running = True
self.prediction_thread = threading.Thread(target=self.prediction_loop)
self.prediction_thread.start()
print("实时预测系统已启动")
def stop(self):
"""停止实时预测"""
self.is_running = False
self.prediction_thread.join()
print("实时预测系统已停止")
def get_latest_prediction(self, timeout=1):
"""获取最新预测结果"""
try:
return self.prediction_queue.get(timeout=timeout)
except queue.Empty:
return None
4. 可视化示例
4.1 基础可视化
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
def visualize_prediction(historical_df, pred_df, title="Kronos Prediction"):
"""
可视化预测结果
"""
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8), sharex=True)
# 合并历史和预测数据
hist_close = historical_df['close']
pred_close = pred_df['close']
# 绘制价格
ax1.plot(historical_df['timestamps'], hist_close,
label='Historical', color='blue', linewidth=1.5)
ax1.plot(pred_df.index, pred_close,
label='Prediction', color='red', linewidth=1.5, linestyle='--')
# 添加预测区间(如果有多次采样)
ax1.fill_between(pred_df.index,
pred_df['low'],
pred_df['high'],
alpha=0.3, color='red', label='Prediction Range')
ax1.set_ylabel('Price', fontsize=12)
ax1.legend(loc='best')
ax1.grid(True, alpha=0.3)
ax1.set_title(title)
# 绘制成交量
if 'volume' in historical_df.columns and 'volume' in pred_df.columns:
ax2.bar(historical_df['timestamps'], historical_df['volume'],
color='blue', alpha=0.5, label='Historical Volume')
ax2.bar(pred_df.index, pred_df['volume'],
color='red', alpha=0.5, label='Predicted Volume')
ax2.set_ylabel('Volume', fontsize=12)
ax2.legend(loc='best')
ax2.grid(True, alpha=0.3)
# 格式化x轴
ax2.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d %H:%M'))
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()
4.2 高级可视化
import plotly.graph_objects as go
from plotly.subplots import make_subplots
def interactive_candlestick_prediction(historical_df, pred_df):
"""
创建交互式K线图预测可视化
"""
# 创建子图
fig = make_subplots(
rows=2, cols=1,
shared_xaxes=True,
vertical_spacing=0.03,
subplot_titles=('Price Prediction', 'Volume'),
row_heights=[0.7, 0.3]
)
# 历史K线
fig.add_trace(
go.Candlestick(
x=historical_df['timestamps'],
open=historical_df['open'],
high=historical_df['high'],
low=historical_df['low'],
close=historical_df['close'],
name='Historical',
increasing_line_color='green',
decreasing_line_color='red'
),
row=1, col=1
)
# 预测K线
fig.add_trace(
go.Candlestick(
x=pred_df.index,
open=pred_df['open'],
high=pred_df['high'],
low=pred_df['low'],
close=pred_df['close'],
name='Prediction',
increasing_line_color='lightgreen',
decreasing_line_color='lightcoral',
opacity=0.7
),
row=1, col=1
)
# 成交量
fig.add_trace(
go.Bar(
x=historical_df['timestamps'],
y=historical_df['volume'],
name='Historical Volume',
marker_color='blue',
opacity=0.5
),
row=2, col=1
)
fig.add_trace(
go.Bar(
x=pred_df.index,
y=pred_df['volume'],
name='Predicted Volume',
marker_color='red',
opacity=0.5
),
row=2, col=1
)
# 更新布局
fig.update_layout(
title='Kronos K-line Prediction',
xaxis_rangeslider_visible=False,
height=800,
showlegend=True
)
fig.show()
5. 性能优化技巧
5.1 GPU加速
# 检查GPU可用性
import torch
if torch.cuda.is_available():
device = f"cuda:{torch.cuda.current_device()}"
print(f"使用GPU: {torch.cuda.get_device_name()}")
else:
device = "cpu"
print("GPU不可用,使用CPU")
# 优化GPU内存使用
torch.cuda.empty_cache() # 清理缓存
# 使用混合精度加速
with torch.cuda.amp.autocast():
predictions = predictor.predict(...)
5.2 批量处理优化
def optimize_batch_size(data_list, max_memory_gb=8):
"""
根据GPU内存自动优化批量大小
"""
import torch
if not torch.cuda.is_available():
return len(data_list) # CPU模式,处理所有
# 估算每个样本的内存占用
sample_size = data_list[0].memory_usage(deep=True).sum() / 1e9 # GB
# 计算最优批量大小
optimal_batch_size = int(max_memory_gb * 0.7 / sample_size) # 留30%余量
return min(optimal_batch_size, len(data_list))
5.3 缓存策略
from functools import lru_cache
import hashlib
class CachedKronosPredictor:
"""带缓存的Kronos预测器"""
def __init__(self, predictor):
self.predictor = predictor
self.cache = {}
def _get_cache_key(self, df, pred_len, T, top_p, sample_count):
"""生成缓存键"""
# 使用数据的哈希值作为键
data_hash = hashlib.md5(
df.to_string().encode() +
f"{pred_len}{T}{top_p}{sample_count}".encode()
).hexdigest()
return data_hash
def predict_with_cache(self, df, x_timestamp, y_timestamp,
pred_len, T=1.0, top_p=0.9, sample_count=1):
"""带缓存的预测"""
cache_key = self._get_cache_key(df, pred_len, T, top_p, sample_count)
if cache_key in self.cache:
print("使用缓存结果")
return self.cache[cache_key]
# 进行预测
result = self.predictor.predict(
df=df,
x_timestamp=x_timestamp,
y_timestamp=y_timestamp,
pred_len=pred_len,
T=T,
top_p=top_p,
sample_count=sample_count
)
# 缓存结果
self.cache[cache_key] = result
return result
6. 常见问题和解决方案
6.1 内存不足
# 问题:处理大批量数据时内存不足
# 解决方案:分批处理
def process_large_dataset(df_list, batch_size=10):
"""分批处理大数据集"""
all_predictions = []
for i in range(0, len(df_list), batch_size):
batch = df_list[i:i+batch_size]
batch_predictions = predictor.predict_batch(batch, ...)
all_predictions.extend(batch_predictions)
# 清理内存
torch.cuda.empty_cache()
return all_predictions
6.2 预测不稳定
# 问题:多次预测结果差异较大
# 解决方案:增加采样次数,降低温度
def stable_prediction(df, x_timestamp, y_timestamp, pred_len):
"""稳定的预测策略"""
return predictor.predict(
df=df,
x_timestamp=x_timestamp,
y_timestamp=y_timestamp,
pred_len=pred_len,
T=0.7, # 降低温度
top_p=0.95, # 提高阈值
sample_count=10 # 增加采样次数
)
6.3 长期预测失真
# 问题:长期预测偏离太大
# 解决方案:递归预测,每次预测较短周期
def recursive_long_term_prediction(df, total_pred_len, step_len=24):
"""
递归进行长期预测
Args:
df: 初始数据
total_pred_len: 总预测长度
step_len: 每步预测长度
"""
all_predictions = []
current_df = df.copy()
steps = total_pred_len // step_len
for step in range(steps):
# 预测一步
x_df = current_df.iloc[-200:][['open', 'high', 'low', 'close', 'volume']]
x_timestamp = current_df.iloc[-200:]['timestamps']
# 生成时间戳
last_time = x_timestamp.iloc[-1]
y_timestamp = pd.date_range(
start=last_time + pd.Timedelta(1, unit='h'),
periods=step_len,
freq='1h'
)
# 预测
pred = predictor.predict(
df=x_df,
x_timestamp=x_timestamp,
y_timestamp=y_timestamp,
pred_len=step_len,
T=0.9,
top_p=0.9,
sample_count=5
)
# 将预测结果添加到数据中
pred['timestamps'] = y_timestamp
current_df = pd.concat([current_df, pred], ignore_index=True)
all_predictions.append(pred)
return pd.concat(all_predictions)
7. 总结
Kronos为金融K线预测提供了强大而灵活的工具。通过合理配置参数和使用策略,可以在不同场景下获得良好的预测效果。关键要点:
- 数据准备:确保数据格式正确,处理缺失值
- 参数调优:根据具体需求调整温度、采样参数
- 批量处理:利用批量预测提高效率
- 稳定性:通过多路径平均提高预测稳定性
- 性能优化:合理使用GPU和缓存策略
记住,Kronos是一个概率模型,预测结果应该作为决策支持而非唯一依据。在实际应用中,建议结合其他技术分析工具和风险管理策略。