Coverage for core/repositories/user_repository.py: 31.97%

122 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-13 18:58 +0000

1""" 

2用户数据仓库层 

3""" 

4 

5import json 

6import uuid 

7from datetime import datetime, timezone 

8from typing import List, Optional 

9 

10from core.models.user import (Account, AccountCreate, User, UserCreate, 

11 UserResponse, UserUpdate) 

12from infrastructure.database.redis_client import get_redis 

13 

14 

15class UserRepository: 

16 """用户数据仓库""" 

17 

18 def __init__(self): 

19 self.redis = get_redis() 

20 self.users_key = "users" 

21 self.accounts_key = "accounts" 

22 self.user_by_username_key = "user_by_username" 

23 

24 def _get_user_key(self, user_id: str) -> str: 

25 """获取用户键""" 

26 return f"user:{user_id}" 

27 

28 def _get_account_key(self, account_id: str) -> str: 

29 """获取账户键""" 

30 return f"account:{account_id}" 

31 

32 def _get_user_by_username_key(self, username: str) -> str: 

33 """获取用户名索引键""" 

34 return f"user_by_username:{username}" 

35 

36 def create_user(self, user_data: UserCreate, auth_service) -> User: 

37 """创建用户""" 

38 user_id = str(uuid.uuid4()) 

39 now = datetime.now(timezone.utc) 

40 

41 user = User( 

42 id=user_id, 

43 username=user_data.username, 

44 email=user_data.email, 

45 phone=user_data.phone, 

46 user_type=user_data.user_type, 

47 status=user_data.status, 

48 created_at=now, 

49 updated_at=now, 

50 ) 

51 

52 # 存储用户数据 

53 user_key = self._get_user_key(user_id) 

54 user_dict = user.model_dump() 

55 # 转换datetime为字符串 

56 user_dict["created_at"] = user_dict["created_at"].isoformat() 

57 user_dict["updated_at"] = user_dict["updated_at"].isoformat() 

58 # 添加加密密码 

59 hashed_password = auth_service.hash_password(user_data.password) 

60 user_dict["password"] = hashed_password 

61 self.redis.setex(user_key, 86400 * 30, user_dict) # 30天过期 

62 

63 # 添加到用户集合 

64 self.redis.sadd(self.users_key, user_id) 

65 

66 # 创建用户名索引 

67 username_key = self._get_user_by_username_key(user_data.username) 

68 self.redis.setex(username_key, 86400 * 30, user_id) 

69 

70 return user 

71 

72 def get_user_by_id(self, user_id: str) -> Optional[User]: 

73 """根据ID获取用户""" 

74 user_key = self._get_user_key(user_id) 

75 data = self.redis.get(user_key) 

76 if data and isinstance(data, dict): 

77 return User(**data) 

78 return None 

79 

80 def get_user_by_username(self, username: str) -> Optional[User]: 

81 """根据用户名获取用户""" 

82 username_key = self._get_user_by_username_key(username) 

83 user_id = self.redis.get(username_key) 

84 if user_id: 

85 return self.get_user_by_id(user_id) 

86 return None 

87 

88 def get_user_password(self, user_id: str) -> Optional[str]: 

89 """获取用户密码(加密存储)""" 

90 user_key = self._get_user_key(user_id) 

91 data = self.redis.get(user_key) 

92 if data and isinstance(data, dict): 

93 return data.get("password") 

94 return None 

95 

96 def get_all_users(self) -> List[User]: 

97 """获取所有用户""" 

98 users = [] 

99 user_ids = self.redis.smembers(self.users_key) 

100 for user_id in user_ids: 

101 user = self.get_user_by_id(user_id) 

102 if user: 

103 users.append(user) 

104 return users 

105 

106 def update_user( 

107 self, user_id: str, user_data: UserUpdate, auth_service=None 

108 ) -> Optional[User]: 

109 """更新用户""" 

110 user = self.get_user_by_id(user_id) 

111 if not user: 

112 return None 

113 

114 # 更新字段 

115 update_data = user_data.model_dump(exclude_unset=True) 

116 

117 # 处理密码更新(单独处理,不设置到user对象) 

118 password_to_update = None 

119 if "password" in update_data: 

120 if update_data["password"] is not None and update_data["password"] != "": 

121 # 如果有密码且不为空,则加密密码 

122 if auth_service: 

123 password_to_update = auth_service.hash_password( 

124 update_data["password"] 

125 ) 

126 # 如果没有auth_service,不更新密码 

127 # 如果密码为None或空字符串,不更新密码 

128 del update_data["password"] # 从更新数据中移除密码 

129 

130 for field, value in update_data.items(): 

131 setattr(user, field, value) 

132 

133 user.updated_at = datetime.now(timezone.utc) 

134 

135 # 保存更新后的用户 

136 user_key = self._get_user_key(user_id) 

137 user_dict = user.model_dump() 

138 # 转换datetime为字符串 

139 user_dict["created_at"] = user_dict["created_at"].isoformat() 

140 user_dict["updated_at"] = user_dict["updated_at"].isoformat() 

141 

142 # 如果有密码更新,需要保存到Redis 

143 if password_to_update is not None: 

144 user_dict["password"] = password_to_update 

145 

146 self.redis.setex(user_key, 86400 * 30, user_dict) 

147 

148 return user 

149 

150 def delete_user(self, user_id: str) -> bool: 

151 """删除用户""" 

152 user = self.get_user_by_id(user_id) 

153 if not user: 

154 return False 

155 

156 # 删除用户数据 

157 user_key = self._get_user_key(user_id) 

158 self.redis.delete(user_key) 

159 

160 # 从用户集合中移除 

161 self.redis.srem(self.users_key, user_id) 

162 

163 # 删除用户名索引 

164 username_key = self._get_user_by_username_key(user.username) 

165 self.redis.delete(username_key) 

166 

167 # 删除相关账户 

168 self._delete_user_accounts(user_id) 

169 

170 return True 

171 

172 def _delete_user_accounts(self, user_id: str): 

173 """删除用户的所有账户""" 

174 account_ids = self.redis.smembers(self.accounts_key) 

175 for account_id in account_ids: 

176 account_key = self._get_account_key(account_id) 

177 account_data = self.redis.get(account_key) 

178 if account_data and isinstance(account_data, dict): 

179 if account_data.get("user_id") == user_id: 

180 self.redis.delete(account_key) 

181 self.redis.srem(self.accounts_key, account_id) 

182 

183 def create_account(self, account_data: AccountCreate) -> Account: 

184 """创建账户""" 

185 account_id = str(uuid.uuid4()) 

186 now = datetime.now(timezone.utc) 

187 

188 account = Account( 

189 id=account_id, 

190 user_id=account_data.user_id, 

191 account_type=account_data.account_type, 

192 balance=account_data.balance, 

193 available_balance=account_data.available_balance, 

194 frozen_balance=account_data.frozen_balance, 

195 currency=account_data.currency, 

196 created_at=now, 

197 updated_at=now, 

198 ) 

199 

200 # 存储账户数据 

201 account_key = self._get_account_key(account_id) 

202 account_dict = account.model_dump() 

203 # 转换datetime为字符串 

204 account_dict["created_at"] = account_dict["created_at"].isoformat() 

205 account_dict["updated_at"] = account_dict["updated_at"].isoformat() 

206 self.redis.setex(account_key, 86400 * 30, account_dict) 

207 

208 # 添加到账户集合 

209 self.redis.sadd(self.accounts_key, account_id) 

210 

211 return account 

212 

213 def get_user_accounts(self, user_id: str) -> List[Account]: 

214 """获取用户的所有账户""" 

215 accounts = [] 

216 account_ids = self.redis.smembers(self.accounts_key) 

217 for account_id in account_ids: 

218 account_key = self._get_account_key(account_id) 

219 account_data = self.redis.get(account_key) 

220 if account_data and isinstance(account_data, dict): 

221 if account_data.get("user_id") == user_id: 

222 accounts.append(Account(**account_data)) 

223 return accounts