Coverage for core/models/strategy_config.py: 44.17%

120 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-13 18:58 +0000

1""" 

2策略配置模型 

3""" 

4 

5from decimal import Decimal 

6from enum import Enum 

7from typing import Any, Dict, List, Optional, Union 

8 

9from pydantic import BaseModel, Field, validator 

10 

11 

12class ParameterType(str, Enum): 

13 """参数类型枚举""" 

14 

15 INTEGER = "integer" 

16 FLOAT = "float" 

17 STRING = "string" 

18 BOOLEAN = "boolean" 

19 LIST = "list" 

20 DECIMAL = "decimal" 

21 

22 

23class StrategyParameter(BaseModel): 

24 """策略参数定义""" 

25 

26 name: str = Field(..., description="参数名称") 

27 type: ParameterType = Field(..., description="参数类型") 

28 default_value: Any = Field(..., description="默认值") 

29 description: str = Field(..., description="参数描述") 

30 min_value: Optional[Union[int, float, Decimal]] = Field(None, description="最小值") 

31 max_value: Optional[Union[int, float, Decimal]] = Field(None, description="最大值") 

32 allowed_values: Optional[List[Any]] = Field(None, description="允许的值列表") 

33 required: bool = Field(True, description="是否必需") 

34 category: str = Field("general", description="参数分类") 

35 

36 

37class StrategyConfigTemplate(BaseModel): 

38 """策略配置模板""" 

39 

40 strategy_name: str = Field(..., description="策略名称") 

41 display_name: str = Field(..., description="显示名称") 

42 description: str = Field(..., description="策略描述") 

43 version: str = Field("1.0.0", description="策略版本") 

44 author: str = Field("", description="策略作者") 

45 

46 # 参数定义 

47 parameters: List[StrategyParameter] = Field( 

48 default_factory=list, description="策略参数" 

49 ) 

50 

51 # 风险配置建议 

52 recommended_risk_config: Dict[str, Any] = Field( 

53 default_factory=dict, description="推荐的风险配置" 

54 ) 

55 

56 # 策略特性 

57 supports_realtime: bool = Field(True, description="是否支持实时交易") 

58 supports_backtest: bool = Field(True, description="是否支持回测") 

59 min_capital: Decimal = Field(Decimal("10000"), description="最小资金要求") 

60 recommended_capital: Decimal = Field(Decimal("100000"), description="推荐资金") 

61 

62 # 市场支持 

63 supported_markets: List[str] = Field( 

64 default_factory=lambda: ["US", "HK"], description="支持的市场" 

65 ) 

66 supported_symbols: List[str] = Field( 

67 default_factory=list, description="推荐的股票代码" 

68 ) 

69 

70 def get_parameter_by_name(self, name: str) -> Optional[StrategyParameter]: 

71 """根据名称获取参数""" 

72 for param in self.parameters: 

73 if param.name == name: 

74 return param 

75 return None 

76 

77 def get_parameters_by_category(self, category: str) -> List[StrategyParameter]: 

78 """根据分类获取参数""" 

79 return [param for param in self.parameters if param.category == category] 

80 

81 def validate_config(self, config: Dict[str, Any]) -> Dict[str, Any]: 

82 """验证策略配置""" 

83 errors = [] 

84 warnings = [] 

85 validated_config = {} 

86 

87 # 检查必需参数 

88 for param in self.parameters: 

89 if param.required and param.name not in config: 

90 if param.default_value is not None: 

91 validated_config[param.name] = param.default_value 

92 warnings.append( 

93 f"参数 {param.name} 使用默认值: {param.default_value}" 

94 ) 

95 else: 

96 errors.append(f"缺少必需参数: {param.name}") 

97 

98 # 验证参数值 

99 for param_name, param_value in config.items(): 

100 param_def = self.get_parameter_by_name(param_name) 

101 if not param_def: 

102 warnings.append(f"未知参数: {param_name}") 

103 validated_config[param_name] = param_value 

104 continue 

105 

106 # 类型验证 

107 if param_def.type == ParameterType.INTEGER: 

108 try: 

109 validated_value = int(param_value) 

110 validated_config[param_name] = validated_value 

111 except (ValueError, TypeError): 

112 errors.append(f"参数 {param_name} 必须是整数") 

113 continue 

114 elif param_def.type == ParameterType.FLOAT: 

115 try: 

116 validated_value = float(param_value) 

117 validated_config[param_name] = validated_value 

118 except (ValueError, TypeError): 

119 errors.append(f"参数 {param_name} 必须是数字") 

120 continue 

121 elif param_def.type == ParameterType.DECIMAL: 

122 try: 

123 validated_value = Decimal(str(param_value)) 

124 validated_config[param_name] = validated_value 

125 except (ValueError, TypeError): 

126 errors.append(f"参数 {param_name} 必须是数字") 

127 continue 

128 elif param_def.type == ParameterType.BOOLEAN: 

129 if isinstance(param_value, bool): 

130 validated_config[param_name] = param_value 

131 elif isinstance(param_value, str): 

132 validated_config[param_name] = param_value.lower() in [ 

133 "true", 

134 "1", 

135 "yes", 

136 "on", 

137 ] 

138 else: 

139 validated_config[param_name] = bool(param_value) 

140 elif param_def.type == ParameterType.LIST: 

141 if isinstance(param_value, list): 

142 validated_config[param_name] = param_value 

143 elif isinstance(param_value, str): 

144 validated_config[param_name] = [ 

145 item.strip() for item in param_value.split(",") 

146 ] 

147 else: 

148 errors.append(f"参数 {param_name} 必须是列表") 

149 continue 

150 else: 

151 validated_config[param_name] = param_value 

152 

153 # 范围验证 

154 if ( 

155 param_def.min_value is not None 

156 and validated_config[param_name] < param_def.min_value 

157 ): 

158 errors.append(f"参数 {param_name} 不能小于 {param_def.min_value}") 

159 

160 if ( 

161 param_def.max_value is not None 

162 and validated_config[param_name] > param_def.max_value 

163 ): 

164 errors.append(f"参数 {param_name} 不能大于 {param_def.max_value}") 

165 

166 # 允许值验证 

167 if ( 

168 param_def.allowed_values 

169 and validated_config[param_name] not in param_def.allowed_values 

170 ): 

171 errors.append( 

172 f"参数 {param_name} 必须是以下值之一: {param_def.allowed_values}" 

173 ) 

174 

175 return { 

176 "valid": len(errors) == 0, 

177 "errors": errors, 

178 "warnings": warnings, 

179 "config": validated_config, 

180 } 

181 

182 

183class StrategyConfigPreset: 

184 """策略配置预设""" 

185 

186 @staticmethod 

187 def get_strategy_templates() -> Dict[str, StrategyConfigTemplate]: 

188 """获取策略配置模板""" 

189 return { 

190 "ma_crossover": StrategyConfigTemplate( 

191 strategy_name="ma_crossover", 

192 display_name="移动平均线交叉策略", 

193 description="基于短期和长期移动平均线交叉的交易策略", 

194 version="1.0.0", 

195 author="系统", 

196 parameters=[ 

197 StrategyParameter( 

198 name="short_period", 

199 type=ParameterType.INTEGER, 

200 default_value=5, 

201 description="短期均线周期", 

202 min_value=2, 

203 max_value=50, 

204 category="technical", 

205 ), 

206 StrategyParameter( 

207 name="long_period", 

208 type=ParameterType.INTEGER, 

209 default_value=20, 

210 description="长期均线周期", 

211 min_value=10, 

212 max_value=200, 

213 category="technical", 

214 ), 

215 StrategyParameter( 

216 name="position_size", 

217 type=ParameterType.FLOAT, 

218 default_value=0.04, 

219 description="仓位大小比例", 

220 min_value=0.01, 

221 max_value=1.0, 

222 category="risk", 

223 ), 

224 StrategyParameter( 

225 name="symbols", 

226 type=ParameterType.LIST, 

227 default_value=["AAPL.US"], 

228 description="监控股票列表", 

229 category="market", 

230 ), 

231 ], 

232 recommended_risk_config={ 

233 "max_position_ratio": 0.1, 

234 "stop_loss_ratio": 0.05, 

235 "max_drawdown": 0.15, 

236 }, 

237 supports_realtime=True, 

238 supports_backtest=True, 

239 min_capital=Decimal("10000"), 

240 recommended_capital=Decimal("100000"), 

241 supported_markets=["US", "HK"], 

242 supported_symbols=["AAPL.US", "MSFT.US", "GOOGL.US"], 

243 ), 

244 "rsi": StrategyConfigTemplate( 

245 strategy_name="rsi", 

246 display_name="RSI策略", 

247 description="基于相对强弱指标的超买超卖策略", 

248 version="1.0.0", 

249 author="系统", 

250 parameters=[ 

251 StrategyParameter( 

252 name="period", 

253 type=ParameterType.INTEGER, 

254 default_value=14, 

255 description="RSI计算周期", 

256 min_value=5, 

257 max_value=50, 

258 category="technical", 

259 ), 

260 StrategyParameter( 

261 name="overbought", 

262 type=ParameterType.INTEGER, 

263 default_value=70, 

264 description="超买阈值", 

265 min_value=60, 

266 max_value=90, 

267 category="technical", 

268 ), 

269 StrategyParameter( 

270 name="oversold", 

271 type=ParameterType.INTEGER, 

272 default_value=30, 

273 description="超卖阈值", 

274 min_value=10, 

275 max_value=40, 

276 category="technical", 

277 ), 

278 StrategyParameter( 

279 name="position_size", 

280 type=ParameterType.FLOAT, 

281 default_value=0.04, 

282 description="仓位大小比例", 

283 min_value=0.01, 

284 max_value=1.0, 

285 category="risk", 

286 ), 

287 StrategyParameter( 

288 name="symbols", 

289 type=ParameterType.LIST, 

290 default_value=["AAPL.US"], 

291 description="监控股票列表", 

292 category="market", 

293 ), 

294 ], 

295 recommended_risk_config={ 

296 "max_position_ratio": 0.15, 

297 "stop_loss_ratio": 0.08, 

298 "max_drawdown": 0.20, 

299 }, 

300 supports_realtime=True, 

301 supports_backtest=True, 

302 min_capital=Decimal("50000"), 

303 recommended_capital=Decimal("200000"), 

304 supported_markets=["US", "HK"], 

305 supported_symbols=["AAPL.US", "MSFT.US", "TSLA.US"], 

306 ), 

307 "macd": StrategyConfigTemplate( 

308 strategy_name="macd", 

309 display_name="MACD策略", 

310 description="基于移动平均收敛散度的趋势跟踪策略", 

311 version="1.0.0", 

312 author="系统", 

313 parameters=[ 

314 StrategyParameter( 

315 name="fast_period", 

316 type=ParameterType.INTEGER, 

317 default_value=5, 

318 description="快速EMA周期", 

319 min_value=5, 

320 max_value=50, 

321 category="technical", 

322 ), 

323 StrategyParameter( 

324 name="slow_period", 

325 type=ParameterType.INTEGER, 

326 default_value=13, 

327 description="慢速EMA周期", 

328 min_value=10, 

329 max_value=100, 

330 category="technical", 

331 ), 

332 StrategyParameter( 

333 name="signal_period", 

334 type=ParameterType.INTEGER, 

335 default_value=4, 

336 description="信号线周期", 

337 min_value=3, 

338 max_value=30, 

339 category="technical", 

340 ), 

341 StrategyParameter( 

342 name="position_size", 

343 type=ParameterType.FLOAT, 

344 default_value=0.04, 

345 description="仓位大小比例", 

346 min_value=0.01, 

347 max_value=1.0, 

348 category="risk", 

349 ), 

350 StrategyParameter( 

351 name="symbols", 

352 type=ParameterType.LIST, 

353 default_value=["AAPL.US"], 

354 description="监控股票列表", 

355 category="market", 

356 ), 

357 ], 

358 recommended_risk_config={ 

359 "max_position_ratio": 0.12, 

360 "stop_loss_ratio": 0.06, 

361 "max_drawdown": 0.18, 

362 }, 

363 supports_realtime=True, 

364 supports_backtest=True, 

365 min_capital=Decimal("75000"), 

366 recommended_capital=Decimal("300000"), 

367 supported_markets=["US", "HK"], 

368 supported_symbols=["AAPL.US", "MSFT.US", "AMZN.US"], 

369 ), 

370 "bollinger": StrategyConfigTemplate( 

371 strategy_name="bollinger", 

372 display_name="布林带策略", 

373 description="基于布林带的均值回归策略", 

374 version="1.0.0", 

375 author="系统", 

376 parameters=[ 

377 StrategyParameter( 

378 name="period", 

379 type=ParameterType.INTEGER, 

380 default_value=20, 

381 description="移动平均周期", 

382 min_value=10, 

383 max_value=50, 

384 category="technical", 

385 ), 

386 StrategyParameter( 

387 name="std_dev", 

388 type=ParameterType.FLOAT, 

389 default_value=2.0, 

390 description="标准差倍数", 

391 min_value=1.0, 

392 max_value=3.0, 

393 category="technical", 

394 ), 

395 StrategyParameter( 

396 name="position_size", 

397 type=ParameterType.FLOAT, 

398 default_value=0.04, 

399 description="仓位大小比例", 

400 min_value=0.01, 

401 max_value=1.0, 

402 category="risk", 

403 ), 

404 StrategyParameter( 

405 name="symbols", 

406 type=ParameterType.LIST, 

407 default_value=["AAPL.US"], 

408 description="监控股票列表", 

409 category="market", 

410 ), 

411 ], 

412 recommended_risk_config={ 

413 "max_position_ratio": 0.08, 

414 "stop_loss_ratio": 0.04, 

415 "max_drawdown": 0.12, 

416 }, 

417 supports_realtime=True, 

418 supports_backtest=True, 

419 min_capital=Decimal("50000"), 

420 recommended_capital=Decimal("150000"), 

421 supported_markets=["US", "HK"], 

422 supported_symbols=["AAPL.US", "MSFT.US", "NVDA.US"], 

423 ), 

424 } 

425 

426 @staticmethod 

427 def get_strategy_template(strategy_name: str) -> Optional[StrategyConfigTemplate]: 

428 """获取策略配置模板""" 

429 templates = StrategyConfigPreset.get_strategy_templates() 

430 return templates.get(strategy_name) 

431 

432 @staticmethod 

433 def get_recommended_risk_config(strategy_name: str) -> Dict[str, Any]: 

434 """获取策略推荐的风险配置""" 

435 template = StrategyConfigPreset.get_strategy_template(strategy_name) 

436 if template: 

437 return template.recommended_risk_config 

438 return {} 

439 

440 @staticmethod 

441 def get_strategy_parameters(strategy_name: str) -> List[Dict[str, Any]]: 

442 """获取策略参数定义""" 

443 template = StrategyConfigPreset.get_strategy_template(strategy_name) 

444 if template: 

445 return [ 

446 { 

447 "name": param.name, 

448 "type": param.type.value, 

449 "default_value": param.default_value, 

450 "description": param.description, 

451 "min_value": param.min_value, 

452 "max_value": param.max_value, 

453 "allowed_values": param.allowed_values, 

454 "required": param.required, 

455 "category": param.category, 

456 } 

457 for param in template.parameters 

458 ] 

459 return []