♻️ Refactor Users API and dependencies (#561)

Co-authored-by: Sebastián Ramírez <tiangolo@gmail.com>
This commit is contained in:
Alejandra
2023-11-29 12:13:15 -05:00
committed by GitHub
parent 2189b9f43b
commit 6f29eb2438
6 changed files with 115 additions and 119 deletions

View File

@@ -3,14 +3,11 @@ from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException
from sqlmodel import Session, select
from app.api import deps
from app.models import Item, ItemCreate, ItemOut, ItemUpdate, User
from app.api.deps import CurrentUser, SessionDep
from app.models import Item, ItemCreate, ItemOut, ItemUpdate
router = APIRouter()
SessionDep = Annotated[Session, Depends(deps.get_db)]
CurrentUser = Annotated[User, Depends(deps.get_current_active_user)]
@router.get("/")
def read_items(

View File

@@ -1,99 +1,88 @@
from typing import Any, List
from typing import Annotated, Any, List
from fastapi import APIRouter, Body, Depends, HTTPException
from fastapi.encoders import jsonable_encoder
from pydantic.networks import EmailStr
from sqlalchemy.orm import Session
from sqlmodel import select
from app import crud, models, schemas
from app.api import deps
from app import crud
from app.api.deps import (
CurrentUser,
SessionDep,
get_current_active_superuser,
)
from app.core.config import settings
from app.models import User, UserCreate, UserCreateOpen, UserOut, UserUpdate
from app.utils import send_new_account_email
router = APIRouter()
@router.get("/", response_model=List[schemas.User])
def read_users(
db: Session = Depends(deps.get_db),
skip: int = 0,
limit: int = 100,
current_user: models.User = Depends(deps.get_current_active_superuser),
) -> Any:
@router.get("/", dependencies=[Depends(get_current_active_superuser)])
def read_users(session: SessionDep, skip: int = 0, limit: int = 100) -> List[UserOut]:
"""
Retrieve users.
"""
users = crud.user.get_multi(db, skip=skip, limit=limit)
return users
statement = select(User).offset(skip).limit(limit)
users = session.exec(statement).all()
return users # type: ignore
@router.post("/", response_model=schemas.User)
def create_user(
*,
db: Session = Depends(deps.get_db),
user_in: schemas.UserCreate,
current_user: models.User = Depends(deps.get_current_active_superuser),
) -> Any:
@router.post("/", dependencies=[Depends(get_current_active_superuser)])
def create_user(*, session: SessionDep, user_in: UserCreate) -> UserOut:
"""
Create new user.
"""
user = crud.user.get_by_email(db, email=user_in.email)
user = crud.get_user_by_email(session=session, email=user_in.email)
if user:
raise HTTPException(
status_code=400,
detail="The user with this username already exists in the system.",
)
user = crud.user.create(db, obj_in=user_in)
user = crud.create_user(session=session, user_create=user_in)
if settings.EMAILS_ENABLED and user_in.email:
send_new_account_email(
email_to=user_in.email, username=user_in.email, password=user_in.password
)
return user
return user # type: ignore
@router.put("/me", response_model=schemas.User)
def update_user_me(
*,
db: Session = Depends(deps.get_db),
password: str = Body(None),
full_name: str = Body(None),
email: EmailStr = Body(None),
current_user: models.User = Depends(deps.get_current_active_user),
) -> Any:
"""
Update own user.
"""
current_user_data = jsonable_encoder(current_user)
user_in = schemas.UserUpdate(**current_user_data)
if password is not None:
user_in.password = password
if full_name is not None:
user_in.full_name = full_name
if email is not None:
user_in.email = email
user = crud.user.update(db, db_obj=current_user, obj_in=user_in)
return user
# TODO: Refactor when SQLModel has update
# @router.put("/me")
# def update_user_me(
# *,
# session: SessionDep,
# password: Annotated[str, Body(None)],
# full_name: Annotated[str, Body(None)],
# email: Annotated[EmailStr, Body(None)],
# current_user: CurrentUser,
# ) -> UserOut:
# """
# Update own user.
# """
# current_user_data = jsonable_encoder(current_user)
# user_in = UserUpdate(**current_user_data)
# if password is not None:
# user_in.password = password
# if full_name is not None:
# user_in.full_name = full_name
# if email is not None:
# user_in.email = email
# user = crud.user.update(session, session_obj=current_user, obj_in=user_in)
# return user
@router.get("/me", response_model=schemas.User)
def read_user_me(
db: Session = Depends(deps.get_db),
current_user: models.User = Depends(deps.get_current_active_user),
) -> Any:
@router.get("/me")
def read_user_me(session: SessionDep, current_user: CurrentUser) -> UserOut:
"""
Get current user.
"""
return current_user
return current_user # type: ignore
@router.post("/open", response_model=schemas.User)
def create_user_open(
*,
db: Session = Depends(deps.get_db),
password: str = Body(...),
email: EmailStr = Body(...),
full_name: str = Body(None),
) -> Any:
@router.post("/open")
def create_user_open(session: SessionDep, user_in: UserCreateOpen) -> UserOut:
"""
Create new user without the need to be logged in.
"""
@@ -102,52 +91,52 @@ def create_user_open(
status_code=403,
detail="Open user registration is forbidden on this server",
)
user = crud.user.get_by_email(db, email=email)
user = crud.get_user_by_email(session=session, email=user_in.email)
if user:
raise HTTPException(
status_code=400,
detail="The user with this username already exists in the system",
)
user_in = schemas.UserCreate(password=password, email=email, full_name=full_name)
user = crud.user.create(db, obj_in=user_in)
return user
user_create = UserCreate.from_orm(user_in)
user = crud.create_user(session=session, user_create=user_create)
return user # type: ignore
@router.get("/{user_id}", response_model=schemas.User)
@router.get("/{user_id}")
def read_user_by_id(
user_id: int,
current_user: models.User = Depends(deps.get_current_active_user),
db: Session = Depends(deps.get_db),
) -> Any:
user_id: int, session: SessionDep, current_user: CurrentUser
) -> UserOut:
"""
Get a specific user by id.
"""
user = crud.user.get(db, id=user_id)
user = session.get(User, user_id)
if user == current_user:
return user
if not crud.user.is_superuser(current_user):
return user # type: ignore
if not current_user.is_superuser:
raise HTTPException(
status_code=400, detail="The user doesn't have enough privileges"
# TODO: Review status code
status_code=400,
detail="The user doesn't have enough privileges",
)
return user
return user # type: ignore
@router.put("/{user_id}", response_model=schemas.User)
def update_user(
*,
db: Session = Depends(deps.get_db),
user_id: int,
user_in: schemas.UserUpdate,
current_user: models.User = Depends(deps.get_current_active_superuser),
) -> Any:
"""
Update a user.
"""
user = crud.user.get(db, id=user_id)
if not user:
raise HTTPException(
status_code=404,
detail="The user with this username does not exist in the system",
)
user = crud.user.update(db, db_obj=user, obj_in=user_in)
return user
# TODO: Refactor when SQLModel has update
# @router.put("/{user_id}", dependencies=[Depends(get_current_active_superuser)])
# def update_user(
# *,
# session: SessionDep,
# user_id: int,
# user_in: UserUpdate,
# ) -> UserOut:
# """
# Update a user.
# """
# user = session.get(User, user_id)
# if not user:
# raise HTTPException(
# status_code=404,
# detail="The user with this username does not exist in the system",
# )
# user = crud.user.update(session, db_obj=user, obj_in=user_in)
# return user # type: ignore

View File

@@ -1,15 +1,15 @@
from typing import Generator
from typing import Annotated, Generator
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from jose import jwt
from pydantic import ValidationError
from sqlalchemy.orm import Session
from sqlmodel import Session
from app import crud, models, schemas
from app.core import security
from app.core.config import settings
from app.db.engine import engine
from app.models import TokenPayload, User
reusable_oauth2 = OAuth2PasswordBearer(
tokenUrl=f"{settings.API_V1_STR}/login/access-token"
@@ -21,37 +21,40 @@ def get_db() -> Generator:
yield session
def get_current_user(
db: Session = Depends(get_db), token: str = Depends(reusable_oauth2)
) -> models.User:
SessionDep = Annotated[Session, Depends(get_db)]
TokenDep = Annotated[str, Depends(reusable_oauth2)]
def get_current_user(session: SessionDep, token: TokenDep) -> User:
try:
payload = jwt.decode(
token, settings.SECRET_KEY, algorithms=[security.ALGORITHM]
)
token_data = schemas.TokenPayload(**payload)
token_data = TokenPayload(**payload)
except (jwt.JWTError, ValidationError):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Could not validate credentials",
)
user = crud.user.get(db, id=token_data.sub)
user = session.get(User, token_data.sub)
if not user:
raise HTTPException(status_code=404, detail="User not found")
return user
def get_current_active_user(
current_user: models.User = Depends(get_current_user),
) -> models.User:
if not crud.user.is_active(current_user):
current_user: Annotated[User, Depends(get_current_user)]
) -> User:
if not current_user.is_active:
raise HTTPException(status_code=400, detail="Inactive user")
return current_user
def get_current_active_superuser(
current_user: models.User = Depends(get_current_user),
) -> models.User:
if not crud.user.is_superuser(current_user):
CurrentUser = Annotated[User, Depends(get_current_active_user)]
def get_current_active_superuser(current_user: CurrentUser) -> User:
if not current_user.is_superuser:
raise HTTPException(
status_code=400, detail="The user doesn't have enough privileges"
)

View File

@@ -8,12 +8,12 @@ from .crud_user import user
# from app.schemas.item import ItemCreate, ItemUpdate
# item = CRUDBase[Item, ItemCreate, ItemUpdate](Item)
from sqlmodel import Session
from sqlmodel import Session, select
from app.core.security import get_password_hash
from app.models import UserCreate, User
def create_user(session: Session, *, user_create: UserCreate) -> User:
def create_user(*, session: Session, user_create: UserCreate) -> User:
db_obj = User.from_orm(
user_create, update={"hashed_password": get_password_hash(user_create.password)}
)
@@ -21,3 +21,9 @@ def create_user(session: Session, *, user_create: UserCreate) -> User:
session.commit()
session.refresh(db_obj)
return db_obj
def get_user_by_email(*, session: Session, email: str) -> User | None:
statement = select(User).where(User.email == email)
session_user = session.exec(statement).first()
return session_user

View File

@@ -24,11 +24,6 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
def get(self, db: Session, id: Any) -> Optional[ModelType]:
return db.query(self.model).filter(self.model.id == id).first()
def get_multi(
self, db: Session, *, skip: int = 0, limit: int = 100
) -> List[ModelType]:
return db.query(self.model).offset(skip).limit(limit).all()
def create(self, db: Session, *, obj_in: CreateSchemaType) -> ModelType:
obj_in_data = jsonable_encoder(obj_in)
db_obj = self.model(**obj_in_data) # type: ignore

View File

@@ -17,6 +17,12 @@ class UserCreate(UserBase):
password: str
class UserCreateOpen(SQLModel):
email: EmailStr
password: str
full_name: Union[str, None] = None
# Properties to receive via API on update, all are optional
class UserUpdate(UserBase):
email: Union[EmailStr, None] = None