Coverage for core/services/auth_service.py: 63.51%

74 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-13 18:58 +0000

1""" 

2认证服务层 

3""" 

4 

5import hashlib 

6import secrets 

7from datetime import datetime, timedelta, timezone 

8from typing import Optional 

9 

10import jwt 

11 

12from core.models.user import LoginRequest, LoginResponse, User, UserCreate 

13from core.repositories.user_repository import UserRepository 

14from infrastructure.config.settings import settings 

15 

16 

17class AuthService: 

18 """认证服务""" 

19 

20 def __init__(self): 

21 self.user_repository = UserRepository() 

22 self.secret_key = settings.jwt_secret_key 

23 self.algorithm = "HS256" 

24 self.token_expire_days = 7 

25 

26 def hash_password(self, password: str) -> str: 

27 """密码加密""" 

28 salt = secrets.token_hex(16) 

29 password_hash = hashlib.pbkdf2_hmac( 

30 "sha256", password.encode(), salt.encode(), 100000 

31 ) 

32 return f"{salt}:{password_hash.hex()}" 

33 

34 def verify_password(self, password: str, hashed_password: str) -> bool: 

35 """验证密码""" 

36 try: 

37 salt, password_hash = hashed_password.split(":") 

38 password_hash_check = hashlib.pbkdf2_hmac( 

39 "sha256", password.encode(), salt.encode(), 100000 

40 ) 

41 return password_hash_check.hex() == password_hash 

42 except: 

43 return False 

44 

45 def create_access_token(self, user: User) -> str: 

46 """创建访问令牌""" 

47 payload = { 

48 "user_id": user.id, 

49 "username": user.username, 

50 "user_type": user.user_type, 

51 } 

52 

53 # 非生产环境下token永不过期,生产环境下正常过期 

54 if settings.environment == "production": 

55 payload["exp"] = datetime.now(timezone.utc) + timedelta( 

56 days=self.token_expire_days 

57 ) 

58 # 非生产环境不设置exp字段,token永不过期 

59 

60 return jwt.encode(payload, self.secret_key, algorithm=self.algorithm) 

61 

62 def verify_token(self, token: str) -> Optional[dict]: 

63 """验证令牌""" 

64 try: 

65 # 非生产环境下跳过过期检查,生产环境下正常检查过期 

66 if settings.environment == "production": 

67 payload = jwt.decode( 

68 token, self.secret_key, algorithms=[self.algorithm] 

69 ) 

70 else: 

71 # 非生产环境:跳过过期检查,只验证签名和格式 

72 payload = jwt.decode( 

73 token, 

74 self.secret_key, 

75 algorithms=[self.algorithm], 

76 options={"verify_exp": False}, 

77 ) 

78 return payload 

79 except jwt.ExpiredSignatureError: 

80 return None 

81 except jwt.InvalidTokenError: 

82 return None 

83 

84 def login(self, login_data: LoginRequest) -> Optional[LoginResponse]: 

85 """用户登录""" 

86 # 验证码验证(仅在生产环境校验) 

87 if settings.environment == "production": 

88 if not login_data.captcha or len(login_data.captcha) != 4: 

89 return None 

90 # 非生产环境跳过验证码校验 

91 

92 # 获取用户 

93 user = self.user_repository.get_user_by_username(login_data.username) 

94 if not user: 

95 return None 

96 

97 # 验证密码 - 从数据库获取加密密码进行验证 

98 stored_password = self.user_repository.get_user_password(user.id) 

99 if not stored_password or not self.verify_password( 

100 login_data.password, stored_password 

101 ): 

102 return None 

103 

104 # 创建访问令牌 

105 access_token = self.create_access_token(user) 

106 

107 return LoginResponse(access_token=access_token, token_type="bearer", user=user) 

108 

109 def get_current_user(self, token: str) -> Optional[User]: 

110 """获取当前用户""" 

111 payload = self.verify_token(token) 

112 if not payload: 

113 return None 

114 

115 user_id = payload.get("user_id") 

116 if not user_id: 

117 return None 

118 

119 return self.user_repository.get_user_by_id(user_id) 

120 

121 def create_default_admin(self) -> bool: 

122 """创建默认管理员用户""" 

123 # 检查是否已存在管理员用户 

124 existing_admin = self.user_repository.get_user_by_username("test_for_dev") 

125 if existing_admin: 

126 return True 

127 

128 # 创建默认管理员 

129 admin_data = UserCreate( 

130 username="test_for_dev", 

131 email="test_for_dev@example.com", 

132 password="test_only_for_dev", 

133 user_type="admin", 

134 status="active", 

135 ) 

136 

137 try: 

138 # 创建用户 

139 user = self.user_repository.create_user(admin_data, self) 

140 

141 # 创建默认账户 

142 from core.models.user import AccountCreate 

143 

144 account_data = AccountCreate( 

145 user_id=user.id, 

146 account_type="demo", 

147 balance=100000.0, 

148 available_balance=100000.0, 

149 currency="CNY", 

150 ) 

151 self.user_repository.create_account(account_data) 

152 

153 return True 

154 except Exception as e: 

155 print(f"创建默认管理员失败: {e}") 

156 return False