136 lines
4.2 KiB
Python
136 lines
4.2 KiB
Python
"""
|
|
API路由 - 用户认证
|
|
"""
|
|
from typing import Annotated, List
|
|
from fastapi import APIRouter, Depends, HTTPException, Query
|
|
from fastapi.security import OAuth2PasswordRequestForm
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import select
|
|
|
|
from app.core.database import get_db
|
|
from app.core.security import verify_password, create_access_token, get_current_active_user
|
|
from app.schemas.schemas import UserLogin, UserCreate, UserResponse, Token
|
|
from app.models.models import User
|
|
|
|
router = APIRouter(prefix="/auth", tags=["用户认证"])
|
|
|
|
|
|
@router.post("/login", response_model=Token, summary="用户登录")
|
|
async def login(
|
|
form_data: OAuth2PasswordRequestForm = Depends(),
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
"""用户登录"""
|
|
from sqlalchemy import select
|
|
|
|
result = await db.execute(
|
|
select(User).where(User.username == form_data.username)
|
|
)
|
|
user = result.scalar_one_or_none()
|
|
|
|
if not user or not verify_password(form_data.password, user.password_hash):
|
|
raise HTTPException(status_code=401, detail="用户名或密码错误")
|
|
|
|
if not user.is_active:
|
|
raise HTTPException(status_code=403, detail="账户已禁用")
|
|
|
|
access_token = create_access_token(user.id)
|
|
return {"access_token": access_token, "token_type": "bearer"}
|
|
|
|
|
|
@router.post("/login-json", summary="用户登录(JSON格式)")
|
|
async def login_json(
|
|
login_data: UserLogin,
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
"""用户登录(JSON格式)"""
|
|
result = await db.execute(
|
|
select(User).where(User.username == login_data.username)
|
|
)
|
|
user = result.scalar_one_or_none()
|
|
|
|
if not user or not verify_password(login_data.password, user.password_hash):
|
|
raise HTTPException(status_code=401, detail="用户名或密码错误")
|
|
|
|
if not user.is_active:
|
|
raise HTTPException(status_code=403, detail="账户已禁用")
|
|
|
|
access_token = create_access_token(user.id)
|
|
return {"code": 200, "message": "登录成功", "data": {"access_token": access_token, "token_type": "bearer"}}
|
|
|
|
|
|
@router.post("/register", summary="用户注册")
|
|
async def register(
|
|
user_data: UserCreate,
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
"""用户注册"""
|
|
from sqlalchemy import select
|
|
from app.core.security import get_password_hash
|
|
|
|
# 检查用户名是否已存在
|
|
result = await db.execute(
|
|
select(User).where(User.username == user_data.username)
|
|
)
|
|
if result.scalar_one_or_none():
|
|
raise HTTPException(status_code=400, detail="用户名已存在")
|
|
|
|
user = User(
|
|
username=user_data.username,
|
|
password_hash=get_password_hash(user_data.password),
|
|
staff_id=user_data.staff_id,
|
|
role=user_data.role
|
|
)
|
|
db.add(user)
|
|
await db.flush()
|
|
|
|
return {"code": 200, "message": "注册成功", "data": {"id": user.id}}
|
|
|
|
|
|
@router.get("/me", response_model=UserResponse, summary="获取当前用户")
|
|
async def get_current_user_info(
|
|
current_user: Annotated[User, Depends(get_current_active_user)]
|
|
):
|
|
"""获取当前用户信息"""
|
|
return current_user
|
|
|
|
|
|
@router.get("/users", summary="获取用户列表")
|
|
async def get_users(
|
|
page: int = Query(1, ge=1),
|
|
page_size: int = Query(20, ge=1, le=100),
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user: User = Depends(get_current_active_user)
|
|
):
|
|
"""获取用户列表(需要登录)"""
|
|
# 查询用户总数
|
|
count_result = await db.execute(select(User))
|
|
all_users = count_result.scalars().all()
|
|
total = len(all_users)
|
|
|
|
# 分页查询
|
|
offset = (page - 1) * page_size
|
|
result = await db.execute(
|
|
select(User).order_by(User.id.desc()).offset(offset).limit(page_size)
|
|
)
|
|
users = result.scalars().all()
|
|
|
|
return {
|
|
"code": 200,
|
|
"message": "success",
|
|
"data": [
|
|
{
|
|
"id": u.id,
|
|
"username": u.username,
|
|
"role": u.role,
|
|
"is_active": u.is_active,
|
|
"last_login": u.last_login,
|
|
"created_at": u.created_at
|
|
}
|
|
for u in users
|
|
],
|
|
"total": total,
|
|
"page": page,
|
|
"page_size": page_size
|
|
}
|