♻️ Refactor Users API and dependencies (#561)
Co-authored-by: Sebastián Ramírez <tiangolo@gmail.com>
This commit is contained in:
@@ -3,14 +3,11 @@ from typing import Annotated
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlmodel import Session, select
|
||||
|
||||
from app.api import deps
|
||||
from app.models import Item, ItemCreate, ItemOut, ItemUpdate, User
|
||||
from app.api.deps import CurrentUser, SessionDep
|
||||
from app.models import Item, ItemCreate, ItemOut, ItemUpdate
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
SessionDep = Annotated[Session, Depends(deps.get_db)]
|
||||
CurrentUser = Annotated[User, Depends(deps.get_current_active_user)]
|
||||
|
||||
|
||||
@router.get("/")
|
||||
def read_items(
|
||||
|
@@ -1,99 +1,88 @@
|
||||
from typing import Any, List
|
||||
from typing import Annotated, Any, List
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from pydantic.networks import EmailStr
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlmodel import select
|
||||
|
||||
from app import crud, models, schemas
|
||||
from app.api import deps
|
||||
from app import crud
|
||||
from app.api.deps import (
|
||||
CurrentUser,
|
||||
SessionDep,
|
||||
get_current_active_superuser,
|
||||
)
|
||||
from app.core.config import settings
|
||||
from app.models import User, UserCreate, UserCreateOpen, UserOut, UserUpdate
|
||||
from app.utils import send_new_account_email
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/", response_model=List[schemas.User])
|
||||
def read_users(
|
||||
db: Session = Depends(deps.get_db),
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
current_user: models.User = Depends(deps.get_current_active_superuser),
|
||||
) -> Any:
|
||||
@router.get("/", dependencies=[Depends(get_current_active_superuser)])
|
||||
def read_users(session: SessionDep, skip: int = 0, limit: int = 100) -> List[UserOut]:
|
||||
"""
|
||||
Retrieve users.
|
||||
"""
|
||||
users = crud.user.get_multi(db, skip=skip, limit=limit)
|
||||
return users
|
||||
statement = select(User).offset(skip).limit(limit)
|
||||
users = session.exec(statement).all()
|
||||
return users # type: ignore
|
||||
|
||||
|
||||
@router.post("/", response_model=schemas.User)
|
||||
def create_user(
|
||||
*,
|
||||
db: Session = Depends(deps.get_db),
|
||||
user_in: schemas.UserCreate,
|
||||
current_user: models.User = Depends(deps.get_current_active_superuser),
|
||||
) -> Any:
|
||||
@router.post("/", dependencies=[Depends(get_current_active_superuser)])
|
||||
def create_user(*, session: SessionDep, user_in: UserCreate) -> UserOut:
|
||||
"""
|
||||
Create new user.
|
||||
"""
|
||||
user = crud.user.get_by_email(db, email=user_in.email)
|
||||
user = crud.get_user_by_email(session=session, email=user_in.email)
|
||||
if user:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="The user with this username already exists in the system.",
|
||||
)
|
||||
user = crud.user.create(db, obj_in=user_in)
|
||||
|
||||
user = crud.create_user(session=session, user_create=user_in)
|
||||
if settings.EMAILS_ENABLED and user_in.email:
|
||||
send_new_account_email(
|
||||
email_to=user_in.email, username=user_in.email, password=user_in.password
|
||||
)
|
||||
return user
|
||||
return user # type: ignore
|
||||
|
||||
|
||||
@router.put("/me", response_model=schemas.User)
|
||||
def update_user_me(
|
||||
*,
|
||||
db: Session = Depends(deps.get_db),
|
||||
password: str = Body(None),
|
||||
full_name: str = Body(None),
|
||||
email: EmailStr = Body(None),
|
||||
current_user: models.User = Depends(deps.get_current_active_user),
|
||||
) -> Any:
|
||||
"""
|
||||
Update own user.
|
||||
"""
|
||||
current_user_data = jsonable_encoder(current_user)
|
||||
user_in = schemas.UserUpdate(**current_user_data)
|
||||
if password is not None:
|
||||
user_in.password = password
|
||||
if full_name is not None:
|
||||
user_in.full_name = full_name
|
||||
if email is not None:
|
||||
user_in.email = email
|
||||
user = crud.user.update(db, db_obj=current_user, obj_in=user_in)
|
||||
return user
|
||||
# TODO: Refactor when SQLModel has update
|
||||
# @router.put("/me")
|
||||
# def update_user_me(
|
||||
# *,
|
||||
# session: SessionDep,
|
||||
# password: Annotated[str, Body(None)],
|
||||
# full_name: Annotated[str, Body(None)],
|
||||
# email: Annotated[EmailStr, Body(None)],
|
||||
# current_user: CurrentUser,
|
||||
# ) -> UserOut:
|
||||
# """
|
||||
# Update own user.
|
||||
# """
|
||||
# current_user_data = jsonable_encoder(current_user)
|
||||
# user_in = UserUpdate(**current_user_data)
|
||||
# if password is not None:
|
||||
# user_in.password = password
|
||||
# if full_name is not None:
|
||||
# user_in.full_name = full_name
|
||||
# if email is not None:
|
||||
# user_in.email = email
|
||||
# user = crud.user.update(session, session_obj=current_user, obj_in=user_in)
|
||||
# return user
|
||||
|
||||
|
||||
@router.get("/me", response_model=schemas.User)
|
||||
def read_user_me(
|
||||
db: Session = Depends(deps.get_db),
|
||||
current_user: models.User = Depends(deps.get_current_active_user),
|
||||
) -> Any:
|
||||
@router.get("/me")
|
||||
def read_user_me(session: SessionDep, current_user: CurrentUser) -> UserOut:
|
||||
"""
|
||||
Get current user.
|
||||
"""
|
||||
return current_user
|
||||
return current_user # type: ignore
|
||||
|
||||
|
||||
@router.post("/open", response_model=schemas.User)
|
||||
def create_user_open(
|
||||
*,
|
||||
db: Session = Depends(deps.get_db),
|
||||
password: str = Body(...),
|
||||
email: EmailStr = Body(...),
|
||||
full_name: str = Body(None),
|
||||
) -> Any:
|
||||
@router.post("/open")
|
||||
def create_user_open(session: SessionDep, user_in: UserCreateOpen) -> UserOut:
|
||||
"""
|
||||
Create new user without the need to be logged in.
|
||||
"""
|
||||
@@ -102,52 +91,52 @@ def create_user_open(
|
||||
status_code=403,
|
||||
detail="Open user registration is forbidden on this server",
|
||||
)
|
||||
user = crud.user.get_by_email(db, email=email)
|
||||
user = crud.get_user_by_email(session=session, email=user_in.email)
|
||||
if user:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="The user with this username already exists in the system",
|
||||
)
|
||||
user_in = schemas.UserCreate(password=password, email=email, full_name=full_name)
|
||||
user = crud.user.create(db, obj_in=user_in)
|
||||
return user
|
||||
user_create = UserCreate.from_orm(user_in)
|
||||
user = crud.create_user(session=session, user_create=user_create)
|
||||
return user # type: ignore
|
||||
|
||||
|
||||
@router.get("/{user_id}", response_model=schemas.User)
|
||||
@router.get("/{user_id}")
|
||||
def read_user_by_id(
|
||||
user_id: int,
|
||||
current_user: models.User = Depends(deps.get_current_active_user),
|
||||
db: Session = Depends(deps.get_db),
|
||||
) -> Any:
|
||||
user_id: int, session: SessionDep, current_user: CurrentUser
|
||||
) -> UserOut:
|
||||
"""
|
||||
Get a specific user by id.
|
||||
"""
|
||||
user = crud.user.get(db, id=user_id)
|
||||
user = session.get(User, user_id)
|
||||
if user == current_user:
|
||||
return user
|
||||
if not crud.user.is_superuser(current_user):
|
||||
return user # type: ignore
|
||||
if not current_user.is_superuser:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="The user doesn't have enough privileges"
|
||||
# TODO: Review status code
|
||||
status_code=400,
|
||||
detail="The user doesn't have enough privileges",
|
||||
)
|
||||
return user
|
||||
return user # type: ignore
|
||||
|
||||
|
||||
@router.put("/{user_id}", response_model=schemas.User)
|
||||
def update_user(
|
||||
*,
|
||||
db: Session = Depends(deps.get_db),
|
||||
user_id: int,
|
||||
user_in: schemas.UserUpdate,
|
||||
current_user: models.User = Depends(deps.get_current_active_superuser),
|
||||
) -> Any:
|
||||
"""
|
||||
Update a user.
|
||||
"""
|
||||
user = crud.user.get(db, id=user_id)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="The user with this username does not exist in the system",
|
||||
)
|
||||
user = crud.user.update(db, db_obj=user, obj_in=user_in)
|
||||
return user
|
||||
# TODO: Refactor when SQLModel has update
|
||||
# @router.put("/{user_id}", dependencies=[Depends(get_current_active_superuser)])
|
||||
# def update_user(
|
||||
# *,
|
||||
# session: SessionDep,
|
||||
# user_id: int,
|
||||
# user_in: UserUpdate,
|
||||
# ) -> UserOut:
|
||||
# """
|
||||
# Update a user.
|
||||
# """
|
||||
# user = session.get(User, user_id)
|
||||
# if not user:
|
||||
# raise HTTPException(
|
||||
# status_code=404,
|
||||
# detail="The user with this username does not exist in the system",
|
||||
# )
|
||||
# user = crud.user.update(session, db_obj=user, obj_in=user_in)
|
||||
# return user # type: ignore
|
||||
|
@@ -1,15 +1,15 @@
|
||||
from typing import Generator
|
||||
from typing import Annotated, Generator
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from jose import jwt
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlmodel import Session
|
||||
|
||||
from app import crud, models, schemas
|
||||
from app.core import security
|
||||
from app.core.config import settings
|
||||
from app.db.engine import engine
|
||||
from app.models import TokenPayload, User
|
||||
|
||||
reusable_oauth2 = OAuth2PasswordBearer(
|
||||
tokenUrl=f"{settings.API_V1_STR}/login/access-token"
|
||||
@@ -21,37 +21,40 @@ def get_db() -> Generator:
|
||||
yield session
|
||||
|
||||
|
||||
def get_current_user(
|
||||
db: Session = Depends(get_db), token: str = Depends(reusable_oauth2)
|
||||
) -> models.User:
|
||||
SessionDep = Annotated[Session, Depends(get_db)]
|
||||
TokenDep = Annotated[str, Depends(reusable_oauth2)]
|
||||
|
||||
|
||||
def get_current_user(session: SessionDep, token: TokenDep) -> User:
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token, settings.SECRET_KEY, algorithms=[security.ALGORITHM]
|
||||
)
|
||||
token_data = schemas.TokenPayload(**payload)
|
||||
token_data = TokenPayload(**payload)
|
||||
except (jwt.JWTError, ValidationError):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Could not validate credentials",
|
||||
)
|
||||
user = crud.user.get(db, id=token_data.sub)
|
||||
user = session.get(User, token_data.sub)
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
return user
|
||||
|
||||
|
||||
def get_current_active_user(
|
||||
current_user: models.User = Depends(get_current_user),
|
||||
) -> models.User:
|
||||
if not crud.user.is_active(current_user):
|
||||
current_user: Annotated[User, Depends(get_current_user)]
|
||||
) -> User:
|
||||
if not current_user.is_active:
|
||||
raise HTTPException(status_code=400, detail="Inactive user")
|
||||
return current_user
|
||||
|
||||
|
||||
def get_current_active_superuser(
|
||||
current_user: models.User = Depends(get_current_user),
|
||||
) -> models.User:
|
||||
if not crud.user.is_superuser(current_user):
|
||||
CurrentUser = Annotated[User, Depends(get_current_active_user)]
|
||||
|
||||
|
||||
def get_current_active_superuser(current_user: CurrentUser) -> User:
|
||||
if not current_user.is_superuser:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="The user doesn't have enough privileges"
|
||||
)
|
||||
|
@@ -8,12 +8,12 @@ from .crud_user import user
|
||||
# from app.schemas.item import ItemCreate, ItemUpdate
|
||||
|
||||
# item = CRUDBase[Item, ItemCreate, ItemUpdate](Item)
|
||||
from sqlmodel import Session
|
||||
from sqlmodel import Session, select
|
||||
from app.core.security import get_password_hash
|
||||
from app.models import UserCreate, User
|
||||
|
||||
|
||||
def create_user(session: Session, *, user_create: UserCreate) -> User:
|
||||
def create_user(*, session: Session, user_create: UserCreate) -> User:
|
||||
db_obj = User.from_orm(
|
||||
user_create, update={"hashed_password": get_password_hash(user_create.password)}
|
||||
)
|
||||
@@ -21,3 +21,9 @@ def create_user(session: Session, *, user_create: UserCreate) -> User:
|
||||
session.commit()
|
||||
session.refresh(db_obj)
|
||||
return db_obj
|
||||
|
||||
|
||||
def get_user_by_email(*, session: Session, email: str) -> User | None:
|
||||
statement = select(User).where(User.email == email)
|
||||
session_user = session.exec(statement).first()
|
||||
return session_user
|
||||
|
@@ -24,11 +24,6 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
def get(self, db: Session, id: Any) -> Optional[ModelType]:
|
||||
return db.query(self.model).filter(self.model.id == id).first()
|
||||
|
||||
def get_multi(
|
||||
self, db: Session, *, skip: int = 0, limit: int = 100
|
||||
) -> List[ModelType]:
|
||||
return db.query(self.model).offset(skip).limit(limit).all()
|
||||
|
||||
def create(self, db: Session, *, obj_in: CreateSchemaType) -> ModelType:
|
||||
obj_in_data = jsonable_encoder(obj_in)
|
||||
db_obj = self.model(**obj_in_data) # type: ignore
|
||||
|
@@ -17,6 +17,12 @@ class UserCreate(UserBase):
|
||||
password: str
|
||||
|
||||
|
||||
class UserCreateOpen(SQLModel):
|
||||
email: EmailStr
|
||||
password: str
|
||||
full_name: Union[str, None] = None
|
||||
|
||||
|
||||
# Properties to receive via API on update, all are optional
|
||||
class UserUpdate(UserBase):
|
||||
email: Union[EmailStr, None] = None
|
||||
|
Reference in New Issue
Block a user