110 lines
3.3 KiB
Python
110 lines
3.3 KiB
Python
"""
|
|
安全认证模块
|
|
"""
|
|
from datetime import datetime, timedelta
|
|
from typing import Any, Optional, Annotated
|
|
from jose import jwt, JWTError
|
|
import bcrypt
|
|
from fastapi import Depends, HTTPException, status
|
|
from fastapi.security import OAuth2PasswordBearer
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import select
|
|
|
|
from app.core.config import settings
|
|
from app.core.database import get_db
|
|
from app.models.models import User
|
|
|
|
|
|
# 密码加密直接使用 bcrypt
|
|
|
|
# OAuth2 密码模式
|
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{settings.API_PREFIX}/auth/login")
|
|
|
|
|
|
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
|
"""验证密码"""
|
|
return bcrypt.checkpw(plain_password.encode('utf-8'), hashed_password.encode('utf-8'))
|
|
|
|
|
|
def get_password_hash(password: str) -> str:
|
|
"""生成密码哈希"""
|
|
return bcrypt.hashpw(password.encode('utf-8'), bcrypt.gensalt()).decode('utf-8')
|
|
|
|
|
|
def create_access_token(subject: str | Any, expires_delta: Optional[timedelta] = None) -> str:
|
|
"""创建访问令牌"""
|
|
if expires_delta:
|
|
expire = datetime.utcnow() + expires_delta
|
|
else:
|
|
expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
|
|
|
to_encode = {"exp": expire, "sub": str(subject)}
|
|
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
|
return encoded_jwt
|
|
|
|
|
|
def decode_token(token: str) -> Optional[dict]:
|
|
"""解码令牌"""
|
|
try:
|
|
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
|
return payload
|
|
except JWTError:
|
|
return None
|
|
|
|
|
|
async def get_current_user(
|
|
token: Annotated[str, Depends(oauth2_scheme)],
|
|
db: AsyncSession = Depends(get_db)
|
|
) -> User:
|
|
"""从JWT获取当前用户"""
|
|
credentials_exception = HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="无法验证凭据",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
|
|
payload = decode_token(token)
|
|
if payload is None:
|
|
raise credentials_exception
|
|
|
|
user_id: str = payload.get("sub")
|
|
if user_id is None:
|
|
raise credentials_exception
|
|
|
|
result = await db.execute(
|
|
select(User).where(User.id == int(user_id))
|
|
)
|
|
user = result.scalar_one_or_none()
|
|
|
|
if user is None:
|
|
raise credentials_exception
|
|
|
|
return user
|
|
|
|
|
|
async def get_current_active_user(
|
|
current_user: Annotated[User, Depends(get_current_user)]
|
|
) -> User:
|
|
"""验证当前用户是否激活"""
|
|
if not current_user.is_active:
|
|
raise HTTPException(status_code=400, detail="用户已被禁用")
|
|
return current_user
|
|
|
|
|
|
async def get_current_admin_user(
|
|
current_user: Annotated[User, Depends(get_current_active_user)]
|
|
) -> User:
|
|
"""验证当前用户是否为管理员"""
|
|
if current_user.role != "admin":
|
|
raise HTTPException(status_code=403, detail="需要管理员权限")
|
|
return current_user
|
|
|
|
|
|
async def get_current_manager_user(
|
|
current_user: Annotated[User, Depends(get_current_active_user)]
|
|
) -> User:
|
|
"""验证当前用户是否为管理员或经理"""
|
|
if current_user.role not in ("admin", "manager"):
|
|
raise HTTPException(status_code=403, detail="需要管理员或经理权限")
|
|
return current_user
|