Coverage for core/repositories/user_repository.py: 31.97%
122 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-13 18:58 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-13 18:58 +0000
1"""
2用户数据仓库层
3"""
5import json
6import uuid
7from datetime import datetime, timezone
8from typing import List, Optional
10from core.models.user import (Account, AccountCreate, User, UserCreate,
11 UserResponse, UserUpdate)
12from infrastructure.database.redis_client import get_redis
15class UserRepository:
16 """用户数据仓库"""
18 def __init__(self):
19 self.redis = get_redis()
20 self.users_key = "users"
21 self.accounts_key = "accounts"
22 self.user_by_username_key = "user_by_username"
24 def _get_user_key(self, user_id: str) -> str:
25 """获取用户键"""
26 return f"user:{user_id}"
28 def _get_account_key(self, account_id: str) -> str:
29 """获取账户键"""
30 return f"account:{account_id}"
32 def _get_user_by_username_key(self, username: str) -> str:
33 """获取用户名索引键"""
34 return f"user_by_username:{username}"
36 def create_user(self, user_data: UserCreate, auth_service) -> User:
37 """创建用户"""
38 user_id = str(uuid.uuid4())
39 now = datetime.now(timezone.utc)
41 user = User(
42 id=user_id,
43 username=user_data.username,
44 email=user_data.email,
45 phone=user_data.phone,
46 user_type=user_data.user_type,
47 status=user_data.status,
48 created_at=now,
49 updated_at=now,
50 )
52 # 存储用户数据
53 user_key = self._get_user_key(user_id)
54 user_dict = user.model_dump()
55 # 转换datetime为字符串
56 user_dict["created_at"] = user_dict["created_at"].isoformat()
57 user_dict["updated_at"] = user_dict["updated_at"].isoformat()
58 # 添加加密密码
59 hashed_password = auth_service.hash_password(user_data.password)
60 user_dict["password"] = hashed_password
61 self.redis.setex(user_key, 86400 * 30, user_dict) # 30天过期
63 # 添加到用户集合
64 self.redis.sadd(self.users_key, user_id)
66 # 创建用户名索引
67 username_key = self._get_user_by_username_key(user_data.username)
68 self.redis.setex(username_key, 86400 * 30, user_id)
70 return user
72 def get_user_by_id(self, user_id: str) -> Optional[User]:
73 """根据ID获取用户"""
74 user_key = self._get_user_key(user_id)
75 data = self.redis.get(user_key)
76 if data and isinstance(data, dict):
77 return User(**data)
78 return None
80 def get_user_by_username(self, username: str) -> Optional[User]:
81 """根据用户名获取用户"""
82 username_key = self._get_user_by_username_key(username)
83 user_id = self.redis.get(username_key)
84 if user_id:
85 return self.get_user_by_id(user_id)
86 return None
88 def get_user_password(self, user_id: str) -> Optional[str]:
89 """获取用户密码(加密存储)"""
90 user_key = self._get_user_key(user_id)
91 data = self.redis.get(user_key)
92 if data and isinstance(data, dict):
93 return data.get("password")
94 return None
96 def get_all_users(self) -> List[User]:
97 """获取所有用户"""
98 users = []
99 user_ids = self.redis.smembers(self.users_key)
100 for user_id in user_ids:
101 user = self.get_user_by_id(user_id)
102 if user:
103 users.append(user)
104 return users
106 def update_user(
107 self, user_id: str, user_data: UserUpdate, auth_service=None
108 ) -> Optional[User]:
109 """更新用户"""
110 user = self.get_user_by_id(user_id)
111 if not user:
112 return None
114 # 更新字段
115 update_data = user_data.model_dump(exclude_unset=True)
117 # 处理密码更新(单独处理,不设置到user对象)
118 password_to_update = None
119 if "password" in update_data:
120 if update_data["password"] is not None and update_data["password"] != "":
121 # 如果有密码且不为空,则加密密码
122 if auth_service:
123 password_to_update = auth_service.hash_password(
124 update_data["password"]
125 )
126 # 如果没有auth_service,不更新密码
127 # 如果密码为None或空字符串,不更新密码
128 del update_data["password"] # 从更新数据中移除密码
130 for field, value in update_data.items():
131 setattr(user, field, value)
133 user.updated_at = datetime.now(timezone.utc)
135 # 保存更新后的用户
136 user_key = self._get_user_key(user_id)
137 user_dict = user.model_dump()
138 # 转换datetime为字符串
139 user_dict["created_at"] = user_dict["created_at"].isoformat()
140 user_dict["updated_at"] = user_dict["updated_at"].isoformat()
142 # 如果有密码更新,需要保存到Redis
143 if password_to_update is not None:
144 user_dict["password"] = password_to_update
146 self.redis.setex(user_key, 86400 * 30, user_dict)
148 return user
150 def delete_user(self, user_id: str) -> bool:
151 """删除用户"""
152 user = self.get_user_by_id(user_id)
153 if not user:
154 return False
156 # 删除用户数据
157 user_key = self._get_user_key(user_id)
158 self.redis.delete(user_key)
160 # 从用户集合中移除
161 self.redis.srem(self.users_key, user_id)
163 # 删除用户名索引
164 username_key = self._get_user_by_username_key(user.username)
165 self.redis.delete(username_key)
167 # 删除相关账户
168 self._delete_user_accounts(user_id)
170 return True
172 def _delete_user_accounts(self, user_id: str):
173 """删除用户的所有账户"""
174 account_ids = self.redis.smembers(self.accounts_key)
175 for account_id in account_ids:
176 account_key = self._get_account_key(account_id)
177 account_data = self.redis.get(account_key)
178 if account_data and isinstance(account_data, dict):
179 if account_data.get("user_id") == user_id:
180 self.redis.delete(account_key)
181 self.redis.srem(self.accounts_key, account_id)
183 def create_account(self, account_data: AccountCreate) -> Account:
184 """创建账户"""
185 account_id = str(uuid.uuid4())
186 now = datetime.now(timezone.utc)
188 account = Account(
189 id=account_id,
190 user_id=account_data.user_id,
191 account_type=account_data.account_type,
192 balance=account_data.balance,
193 available_balance=account_data.available_balance,
194 frozen_balance=account_data.frozen_balance,
195 currency=account_data.currency,
196 created_at=now,
197 updated_at=now,
198 )
200 # 存储账户数据
201 account_key = self._get_account_key(account_id)
202 account_dict = account.model_dump()
203 # 转换datetime为字符串
204 account_dict["created_at"] = account_dict["created_at"].isoformat()
205 account_dict["updated_at"] = account_dict["updated_at"].isoformat()
206 self.redis.setex(account_key, 86400 * 30, account_dict)
208 # 添加到账户集合
209 self.redis.sadd(self.accounts_key, account_id)
211 return account
213 def get_user_accounts(self, user_id: str) -> List[Account]:
214 """获取用户的所有账户"""
215 accounts = []
216 account_ids = self.redis.smembers(self.accounts_key)
217 for account_id in account_ids:
218 account_key = self._get_account_key(account_id)
219 account_data = self.redis.get(account_key)
220 if account_data and isinstance(account_data, dict):
221 if account_data.get("user_id") == user_id:
222 accounts.append(Account(**account_data))
223 return accounts