Coverage for core/models/strategy_config.py: 44.17%
120 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-13 18:58 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-13 18:58 +0000
1"""
2策略配置模型
3"""
5from decimal import Decimal
6from enum import Enum
7from typing import Any, Dict, List, Optional, Union
9from pydantic import BaseModel, Field, validator
12class ParameterType(str, Enum):
13 """参数类型枚举"""
15 INTEGER = "integer"
16 FLOAT = "float"
17 STRING = "string"
18 BOOLEAN = "boolean"
19 LIST = "list"
20 DECIMAL = "decimal"
23class StrategyParameter(BaseModel):
24 """策略参数定义"""
26 name: str = Field(..., description="参数名称")
27 type: ParameterType = Field(..., description="参数类型")
28 default_value: Any = Field(..., description="默认值")
29 description: str = Field(..., description="参数描述")
30 min_value: Optional[Union[int, float, Decimal]] = Field(None, description="最小值")
31 max_value: Optional[Union[int, float, Decimal]] = Field(None, description="最大值")
32 allowed_values: Optional[List[Any]] = Field(None, description="允许的值列表")
33 required: bool = Field(True, description="是否必需")
34 category: str = Field("general", description="参数分类")
37class StrategyConfigTemplate(BaseModel):
38 """策略配置模板"""
40 strategy_name: str = Field(..., description="策略名称")
41 display_name: str = Field(..., description="显示名称")
42 description: str = Field(..., description="策略描述")
43 version: str = Field("1.0.0", description="策略版本")
44 author: str = Field("", description="策略作者")
46 # 参数定义
47 parameters: List[StrategyParameter] = Field(
48 default_factory=list, description="策略参数"
49 )
51 # 风险配置建议
52 recommended_risk_config: Dict[str, Any] = Field(
53 default_factory=dict, description="推荐的风险配置"
54 )
56 # 策略特性
57 supports_realtime: bool = Field(True, description="是否支持实时交易")
58 supports_backtest: bool = Field(True, description="是否支持回测")
59 min_capital: Decimal = Field(Decimal("10000"), description="最小资金要求")
60 recommended_capital: Decimal = Field(Decimal("100000"), description="推荐资金")
62 # 市场支持
63 supported_markets: List[str] = Field(
64 default_factory=lambda: ["US", "HK"], description="支持的市场"
65 )
66 supported_symbols: List[str] = Field(
67 default_factory=list, description="推荐的股票代码"
68 )
70 def get_parameter_by_name(self, name: str) -> Optional[StrategyParameter]:
71 """根据名称获取参数"""
72 for param in self.parameters:
73 if param.name == name:
74 return param
75 return None
77 def get_parameters_by_category(self, category: str) -> List[StrategyParameter]:
78 """根据分类获取参数"""
79 return [param for param in self.parameters if param.category == category]
81 def validate_config(self, config: Dict[str, Any]) -> Dict[str, Any]:
82 """验证策略配置"""
83 errors = []
84 warnings = []
85 validated_config = {}
87 # 检查必需参数
88 for param in self.parameters:
89 if param.required and param.name not in config:
90 if param.default_value is not None:
91 validated_config[param.name] = param.default_value
92 warnings.append(
93 f"参数 {param.name} 使用默认值: {param.default_value}"
94 )
95 else:
96 errors.append(f"缺少必需参数: {param.name}")
98 # 验证参数值
99 for param_name, param_value in config.items():
100 param_def = self.get_parameter_by_name(param_name)
101 if not param_def:
102 warnings.append(f"未知参数: {param_name}")
103 validated_config[param_name] = param_value
104 continue
106 # 类型验证
107 if param_def.type == ParameterType.INTEGER:
108 try:
109 validated_value = int(param_value)
110 validated_config[param_name] = validated_value
111 except (ValueError, TypeError):
112 errors.append(f"参数 {param_name} 必须是整数")
113 continue
114 elif param_def.type == ParameterType.FLOAT:
115 try:
116 validated_value = float(param_value)
117 validated_config[param_name] = validated_value
118 except (ValueError, TypeError):
119 errors.append(f"参数 {param_name} 必须是数字")
120 continue
121 elif param_def.type == ParameterType.DECIMAL:
122 try:
123 validated_value = Decimal(str(param_value))
124 validated_config[param_name] = validated_value
125 except (ValueError, TypeError):
126 errors.append(f"参数 {param_name} 必须是数字")
127 continue
128 elif param_def.type == ParameterType.BOOLEAN:
129 if isinstance(param_value, bool):
130 validated_config[param_name] = param_value
131 elif isinstance(param_value, str):
132 validated_config[param_name] = param_value.lower() in [
133 "true",
134 "1",
135 "yes",
136 "on",
137 ]
138 else:
139 validated_config[param_name] = bool(param_value)
140 elif param_def.type == ParameterType.LIST:
141 if isinstance(param_value, list):
142 validated_config[param_name] = param_value
143 elif isinstance(param_value, str):
144 validated_config[param_name] = [
145 item.strip() for item in param_value.split(",")
146 ]
147 else:
148 errors.append(f"参数 {param_name} 必须是列表")
149 continue
150 else:
151 validated_config[param_name] = param_value
153 # 范围验证
154 if (
155 param_def.min_value is not None
156 and validated_config[param_name] < param_def.min_value
157 ):
158 errors.append(f"参数 {param_name} 不能小于 {param_def.min_value}")
160 if (
161 param_def.max_value is not None
162 and validated_config[param_name] > param_def.max_value
163 ):
164 errors.append(f"参数 {param_name} 不能大于 {param_def.max_value}")
166 # 允许值验证
167 if (
168 param_def.allowed_values
169 and validated_config[param_name] not in param_def.allowed_values
170 ):
171 errors.append(
172 f"参数 {param_name} 必须是以下值之一: {param_def.allowed_values}"
173 )
175 return {
176 "valid": len(errors) == 0,
177 "errors": errors,
178 "warnings": warnings,
179 "config": validated_config,
180 }
183class StrategyConfigPreset:
184 """策略配置预设"""
186 @staticmethod
187 def get_strategy_templates() -> Dict[str, StrategyConfigTemplate]:
188 """获取策略配置模板"""
189 return {
190 "ma_crossover": StrategyConfigTemplate(
191 strategy_name="ma_crossover",
192 display_name="移动平均线交叉策略",
193 description="基于短期和长期移动平均线交叉的交易策略",
194 version="1.0.0",
195 author="系统",
196 parameters=[
197 StrategyParameter(
198 name="short_period",
199 type=ParameterType.INTEGER,
200 default_value=5,
201 description="短期均线周期",
202 min_value=2,
203 max_value=50,
204 category="technical",
205 ),
206 StrategyParameter(
207 name="long_period",
208 type=ParameterType.INTEGER,
209 default_value=20,
210 description="长期均线周期",
211 min_value=10,
212 max_value=200,
213 category="technical",
214 ),
215 StrategyParameter(
216 name="position_size",
217 type=ParameterType.FLOAT,
218 default_value=0.04,
219 description="仓位大小比例",
220 min_value=0.01,
221 max_value=1.0,
222 category="risk",
223 ),
224 StrategyParameter(
225 name="symbols",
226 type=ParameterType.LIST,
227 default_value=["AAPL.US"],
228 description="监控股票列表",
229 category="market",
230 ),
231 ],
232 recommended_risk_config={
233 "max_position_ratio": 0.1,
234 "stop_loss_ratio": 0.05,
235 "max_drawdown": 0.15,
236 },
237 supports_realtime=True,
238 supports_backtest=True,
239 min_capital=Decimal("10000"),
240 recommended_capital=Decimal("100000"),
241 supported_markets=["US", "HK"],
242 supported_symbols=["AAPL.US", "MSFT.US", "GOOGL.US"],
243 ),
244 "rsi": StrategyConfigTemplate(
245 strategy_name="rsi",
246 display_name="RSI策略",
247 description="基于相对强弱指标的超买超卖策略",
248 version="1.0.0",
249 author="系统",
250 parameters=[
251 StrategyParameter(
252 name="period",
253 type=ParameterType.INTEGER,
254 default_value=14,
255 description="RSI计算周期",
256 min_value=5,
257 max_value=50,
258 category="technical",
259 ),
260 StrategyParameter(
261 name="overbought",
262 type=ParameterType.INTEGER,
263 default_value=70,
264 description="超买阈值",
265 min_value=60,
266 max_value=90,
267 category="technical",
268 ),
269 StrategyParameter(
270 name="oversold",
271 type=ParameterType.INTEGER,
272 default_value=30,
273 description="超卖阈值",
274 min_value=10,
275 max_value=40,
276 category="technical",
277 ),
278 StrategyParameter(
279 name="position_size",
280 type=ParameterType.FLOAT,
281 default_value=0.04,
282 description="仓位大小比例",
283 min_value=0.01,
284 max_value=1.0,
285 category="risk",
286 ),
287 StrategyParameter(
288 name="symbols",
289 type=ParameterType.LIST,
290 default_value=["AAPL.US"],
291 description="监控股票列表",
292 category="market",
293 ),
294 ],
295 recommended_risk_config={
296 "max_position_ratio": 0.15,
297 "stop_loss_ratio": 0.08,
298 "max_drawdown": 0.20,
299 },
300 supports_realtime=True,
301 supports_backtest=True,
302 min_capital=Decimal("50000"),
303 recommended_capital=Decimal("200000"),
304 supported_markets=["US", "HK"],
305 supported_symbols=["AAPL.US", "MSFT.US", "TSLA.US"],
306 ),
307 "macd": StrategyConfigTemplate(
308 strategy_name="macd",
309 display_name="MACD策略",
310 description="基于移动平均收敛散度的趋势跟踪策略",
311 version="1.0.0",
312 author="系统",
313 parameters=[
314 StrategyParameter(
315 name="fast_period",
316 type=ParameterType.INTEGER,
317 default_value=5,
318 description="快速EMA周期",
319 min_value=5,
320 max_value=50,
321 category="technical",
322 ),
323 StrategyParameter(
324 name="slow_period",
325 type=ParameterType.INTEGER,
326 default_value=13,
327 description="慢速EMA周期",
328 min_value=10,
329 max_value=100,
330 category="technical",
331 ),
332 StrategyParameter(
333 name="signal_period",
334 type=ParameterType.INTEGER,
335 default_value=4,
336 description="信号线周期",
337 min_value=3,
338 max_value=30,
339 category="technical",
340 ),
341 StrategyParameter(
342 name="position_size",
343 type=ParameterType.FLOAT,
344 default_value=0.04,
345 description="仓位大小比例",
346 min_value=0.01,
347 max_value=1.0,
348 category="risk",
349 ),
350 StrategyParameter(
351 name="symbols",
352 type=ParameterType.LIST,
353 default_value=["AAPL.US"],
354 description="监控股票列表",
355 category="market",
356 ),
357 ],
358 recommended_risk_config={
359 "max_position_ratio": 0.12,
360 "stop_loss_ratio": 0.06,
361 "max_drawdown": 0.18,
362 },
363 supports_realtime=True,
364 supports_backtest=True,
365 min_capital=Decimal("75000"),
366 recommended_capital=Decimal("300000"),
367 supported_markets=["US", "HK"],
368 supported_symbols=["AAPL.US", "MSFT.US", "AMZN.US"],
369 ),
370 "bollinger": StrategyConfigTemplate(
371 strategy_name="bollinger",
372 display_name="布林带策略",
373 description="基于布林带的均值回归策略",
374 version="1.0.0",
375 author="系统",
376 parameters=[
377 StrategyParameter(
378 name="period",
379 type=ParameterType.INTEGER,
380 default_value=20,
381 description="移动平均周期",
382 min_value=10,
383 max_value=50,
384 category="technical",
385 ),
386 StrategyParameter(
387 name="std_dev",
388 type=ParameterType.FLOAT,
389 default_value=2.0,
390 description="标准差倍数",
391 min_value=1.0,
392 max_value=3.0,
393 category="technical",
394 ),
395 StrategyParameter(
396 name="position_size",
397 type=ParameterType.FLOAT,
398 default_value=0.04,
399 description="仓位大小比例",
400 min_value=0.01,
401 max_value=1.0,
402 category="risk",
403 ),
404 StrategyParameter(
405 name="symbols",
406 type=ParameterType.LIST,
407 default_value=["AAPL.US"],
408 description="监控股票列表",
409 category="market",
410 ),
411 ],
412 recommended_risk_config={
413 "max_position_ratio": 0.08,
414 "stop_loss_ratio": 0.04,
415 "max_drawdown": 0.12,
416 },
417 supports_realtime=True,
418 supports_backtest=True,
419 min_capital=Decimal("50000"),
420 recommended_capital=Decimal("150000"),
421 supported_markets=["US", "HK"],
422 supported_symbols=["AAPL.US", "MSFT.US", "NVDA.US"],
423 ),
424 }
426 @staticmethod
427 def get_strategy_template(strategy_name: str) -> Optional[StrategyConfigTemplate]:
428 """获取策略配置模板"""
429 templates = StrategyConfigPreset.get_strategy_templates()
430 return templates.get(strategy_name)
432 @staticmethod
433 def get_recommended_risk_config(strategy_name: str) -> Dict[str, Any]:
434 """获取策略推荐的风险配置"""
435 template = StrategyConfigPreset.get_strategy_template(strategy_name)
436 if template:
437 return template.recommended_risk_config
438 return {}
440 @staticmethod
441 def get_strategy_parameters(strategy_name: str) -> List[Dict[str, Any]]:
442 """获取策略参数定义"""
443 template = StrategyConfigPreset.get_strategy_template(strategy_name)
444 if template:
445 return [
446 {
447 "name": param.name,
448 "type": param.type.value,
449 "default_value": param.default_value,
450 "description": param.description,
451 "min_value": param.min_value,
452 "max_value": param.max_value,
453 "allowed_values": param.allowed_values,
454 "required": param.required,
455 "category": param.category,
456 }
457 for param in template.parameters
458 ]
459 return []