♻️ Refactor Users API and dependencies (#561)

Co-authored-by: Sebastián Ramírez <tiangolo@gmail.com>
This commit is contained in:
Alejandra
2023-11-29 12:13:15 -05:00
committed by GitHub
parent 2189b9f43b
commit 6f29eb2438
6 changed files with 115 additions and 119 deletions

View File

@@ -3,14 +3,11 @@ from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from sqlmodel import Session, select from sqlmodel import Session, select
from app.api import deps from app.api.deps import CurrentUser, SessionDep
from app.models import Item, ItemCreate, ItemOut, ItemUpdate, User from app.models import Item, ItemCreate, ItemOut, ItemUpdate
router = APIRouter() router = APIRouter()
SessionDep = Annotated[Session, Depends(deps.get_db)]
CurrentUser = Annotated[User, Depends(deps.get_current_active_user)]
@router.get("/") @router.get("/")
def read_items( def read_items(

View File

@@ -1,99 +1,88 @@
from typing import Any, List from typing import Annotated, Any, List
from fastapi import APIRouter, Body, Depends, HTTPException from fastapi import APIRouter, Body, Depends, HTTPException
from fastapi.encoders import jsonable_encoder from fastapi.encoders import jsonable_encoder
from pydantic.networks import EmailStr from pydantic.networks import EmailStr
from sqlalchemy.orm import Session from sqlmodel import select
from app import crud, models, schemas from app import crud
from app.api import deps from app.api.deps import (
CurrentUser,
SessionDep,
get_current_active_superuser,
)
from app.core.config import settings from app.core.config import settings
from app.models import User, UserCreate, UserCreateOpen, UserOut, UserUpdate
from app.utils import send_new_account_email from app.utils import send_new_account_email
router = APIRouter() router = APIRouter()
@router.get("/", response_model=List[schemas.User]) @router.get("/", dependencies=[Depends(get_current_active_superuser)])
def read_users( def read_users(session: SessionDep, skip: int = 0, limit: int = 100) -> List[UserOut]:
db: Session = Depends(deps.get_db),
skip: int = 0,
limit: int = 100,
current_user: models.User = Depends(deps.get_current_active_superuser),
) -> Any:
""" """
Retrieve users. Retrieve users.
""" """
users = crud.user.get_multi(db, skip=skip, limit=limit) statement = select(User).offset(skip).limit(limit)
return users users = session.exec(statement).all()
return users # type: ignore
@router.post("/", response_model=schemas.User) @router.post("/", dependencies=[Depends(get_current_active_superuser)])
def create_user( def create_user(*, session: SessionDep, user_in: UserCreate) -> UserOut:
*,
db: Session = Depends(deps.get_db),
user_in: schemas.UserCreate,
current_user: models.User = Depends(deps.get_current_active_superuser),
) -> Any:
""" """
Create new user. Create new user.
""" """
user = crud.user.get_by_email(db, email=user_in.email) user = crud.get_user_by_email(session=session, email=user_in.email)
if user: if user:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail="The user with this username already exists in the system.", detail="The user with this username already exists in the system.",
) )
user = crud.user.create(db, obj_in=user_in)
user = crud.create_user(session=session, user_create=user_in)
if settings.EMAILS_ENABLED and user_in.email: if settings.EMAILS_ENABLED and user_in.email:
send_new_account_email( send_new_account_email(
email_to=user_in.email, username=user_in.email, password=user_in.password email_to=user_in.email, username=user_in.email, password=user_in.password
) )
return user return user # type: ignore
@router.put("/me", response_model=schemas.User) # TODO: Refactor when SQLModel has update
def update_user_me( # @router.put("/me")
*, # def update_user_me(
db: Session = Depends(deps.get_db), # *,
password: str = Body(None), # session: SessionDep,
full_name: str = Body(None), # password: Annotated[str, Body(None)],
email: EmailStr = Body(None), # full_name: Annotated[str, Body(None)],
current_user: models.User = Depends(deps.get_current_active_user), # email: Annotated[EmailStr, Body(None)],
) -> Any: # current_user: CurrentUser,
""" # ) -> UserOut:
Update own user. # """
""" # Update own user.
current_user_data = jsonable_encoder(current_user) # """
user_in = schemas.UserUpdate(**current_user_data) # current_user_data = jsonable_encoder(current_user)
if password is not None: # user_in = UserUpdate(**current_user_data)
user_in.password = password # if password is not None:
if full_name is not None: # user_in.password = password
user_in.full_name = full_name # if full_name is not None:
if email is not None: # user_in.full_name = full_name
user_in.email = email # if email is not None:
user = crud.user.update(db, db_obj=current_user, obj_in=user_in) # user_in.email = email
return user # user = crud.user.update(session, session_obj=current_user, obj_in=user_in)
# return user
@router.get("/me", response_model=schemas.User) @router.get("/me")
def read_user_me( def read_user_me(session: SessionDep, current_user: CurrentUser) -> UserOut:
db: Session = Depends(deps.get_db),
current_user: models.User = Depends(deps.get_current_active_user),
) -> Any:
""" """
Get current user. Get current user.
""" """
return current_user return current_user # type: ignore
@router.post("/open", response_model=schemas.User) @router.post("/open")
def create_user_open( def create_user_open(session: SessionDep, user_in: UserCreateOpen) -> UserOut:
*,
db: Session = Depends(deps.get_db),
password: str = Body(...),
email: EmailStr = Body(...),
full_name: str = Body(None),
) -> Any:
""" """
Create new user without the need to be logged in. Create new user without the need to be logged in.
""" """
@@ -102,52 +91,52 @@ def create_user_open(
status_code=403, status_code=403,
detail="Open user registration is forbidden on this server", detail="Open user registration is forbidden on this server",
) )
user = crud.user.get_by_email(db, email=email) user = crud.get_user_by_email(session=session, email=user_in.email)
if user: if user:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail="The user with this username already exists in the system", detail="The user with this username already exists in the system",
) )
user_in = schemas.UserCreate(password=password, email=email, full_name=full_name) user_create = UserCreate.from_orm(user_in)
user = crud.user.create(db, obj_in=user_in) user = crud.create_user(session=session, user_create=user_create)
return user return user # type: ignore
@router.get("/{user_id}", response_model=schemas.User) @router.get("/{user_id}")
def read_user_by_id( def read_user_by_id(
user_id: int, user_id: int, session: SessionDep, current_user: CurrentUser
current_user: models.User = Depends(deps.get_current_active_user), ) -> UserOut:
db: Session = Depends(deps.get_db),
) -> Any:
""" """
Get a specific user by id. Get a specific user by id.
""" """
user = crud.user.get(db, id=user_id) user = session.get(User, user_id)
if user == current_user: if user == current_user:
return user return user # type: ignore
if not crud.user.is_superuser(current_user): if not current_user.is_superuser:
raise HTTPException( raise HTTPException(
status_code=400, detail="The user doesn't have enough privileges" # TODO: Review status code
status_code=400,
detail="The user doesn't have enough privileges",
) )
return user return user # type: ignore
@router.put("/{user_id}", response_model=schemas.User) # TODO: Refactor when SQLModel has update
def update_user( # @router.put("/{user_id}", dependencies=[Depends(get_current_active_superuser)])
*, # def update_user(
db: Session = Depends(deps.get_db), # *,
user_id: int, # session: SessionDep,
user_in: schemas.UserUpdate, # user_id: int,
current_user: models.User = Depends(deps.get_current_active_superuser), # user_in: UserUpdate,
) -> Any: # ) -> UserOut:
""" # """
Update a user. # Update a user.
""" # """
user = crud.user.get(db, id=user_id) # user = session.get(User, user_id)
if not user: # if not user:
raise HTTPException( # raise HTTPException(
status_code=404, # status_code=404,
detail="The user with this username does not exist in the system", # detail="The user with this username does not exist in the system",
) # )
user = crud.user.update(db, db_obj=user, obj_in=user_in) # user = crud.user.update(session, db_obj=user, obj_in=user_in)
return user # return user # type: ignore

View File

@@ -1,15 +1,15 @@
from typing import Generator from typing import Annotated, Generator
from fastapi import Depends, HTTPException, status from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
from jose import jwt from jose import jwt
from pydantic import ValidationError from pydantic import ValidationError
from sqlalchemy.orm import Session from sqlmodel import Session
from app import crud, models, schemas
from app.core import security from app.core import security
from app.core.config import settings from app.core.config import settings
from app.db.engine import engine from app.db.engine import engine
from app.models import TokenPayload, User
reusable_oauth2 = OAuth2PasswordBearer( reusable_oauth2 = OAuth2PasswordBearer(
tokenUrl=f"{settings.API_V1_STR}/login/access-token" tokenUrl=f"{settings.API_V1_STR}/login/access-token"
@@ -21,37 +21,40 @@ def get_db() -> Generator:
yield session yield session
def get_current_user( SessionDep = Annotated[Session, Depends(get_db)]
db: Session = Depends(get_db), token: str = Depends(reusable_oauth2) TokenDep = Annotated[str, Depends(reusable_oauth2)]
) -> models.User:
def get_current_user(session: SessionDep, token: TokenDep) -> User:
try: try:
payload = jwt.decode( payload = jwt.decode(
token, settings.SECRET_KEY, algorithms=[security.ALGORITHM] token, settings.SECRET_KEY, algorithms=[security.ALGORITHM]
) )
token_data = schemas.TokenPayload(**payload) token_data = TokenPayload(**payload)
except (jwt.JWTError, ValidationError): except (jwt.JWTError, ValidationError):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail="Could not validate credentials", detail="Could not validate credentials",
) )
user = crud.user.get(db, id=token_data.sub) user = session.get(User, token_data.sub)
if not user: if not user:
raise HTTPException(status_code=404, detail="User not found") raise HTTPException(status_code=404, detail="User not found")
return user return user
def get_current_active_user( def get_current_active_user(
current_user: models.User = Depends(get_current_user), current_user: Annotated[User, Depends(get_current_user)]
) -> models.User: ) -> User:
if not crud.user.is_active(current_user): if not current_user.is_active:
raise HTTPException(status_code=400, detail="Inactive user") raise HTTPException(status_code=400, detail="Inactive user")
return current_user return current_user
def get_current_active_superuser( CurrentUser = Annotated[User, Depends(get_current_active_user)]
current_user: models.User = Depends(get_current_user),
) -> models.User:
if not crud.user.is_superuser(current_user): def get_current_active_superuser(current_user: CurrentUser) -> User:
if not current_user.is_superuser:
raise HTTPException( raise HTTPException(
status_code=400, detail="The user doesn't have enough privileges" status_code=400, detail="The user doesn't have enough privileges"
) )

View File

@@ -8,12 +8,12 @@ from .crud_user import user
# from app.schemas.item import ItemCreate, ItemUpdate # from app.schemas.item import ItemCreate, ItemUpdate
# item = CRUDBase[Item, ItemCreate, ItemUpdate](Item) # item = CRUDBase[Item, ItemCreate, ItemUpdate](Item)
from sqlmodel import Session from sqlmodel import Session, select
from app.core.security import get_password_hash from app.core.security import get_password_hash
from app.models import UserCreate, User from app.models import UserCreate, User
def create_user(session: Session, *, user_create: UserCreate) -> User: def create_user(*, session: Session, user_create: UserCreate) -> User:
db_obj = User.from_orm( db_obj = User.from_orm(
user_create, update={"hashed_password": get_password_hash(user_create.password)} user_create, update={"hashed_password": get_password_hash(user_create.password)}
) )
@@ -21,3 +21,9 @@ def create_user(session: Session, *, user_create: UserCreate) -> User:
session.commit() session.commit()
session.refresh(db_obj) session.refresh(db_obj)
return db_obj return db_obj
def get_user_by_email(*, session: Session, email: str) -> User | None:
statement = select(User).where(User.email == email)
session_user = session.exec(statement).first()
return session_user

View File

@@ -24,11 +24,6 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
def get(self, db: Session, id: Any) -> Optional[ModelType]: def get(self, db: Session, id: Any) -> Optional[ModelType]:
return db.query(self.model).filter(self.model.id == id).first() return db.query(self.model).filter(self.model.id == id).first()
def get_multi(
self, db: Session, *, skip: int = 0, limit: int = 100
) -> List[ModelType]:
return db.query(self.model).offset(skip).limit(limit).all()
def create(self, db: Session, *, obj_in: CreateSchemaType) -> ModelType: def create(self, db: Session, *, obj_in: CreateSchemaType) -> ModelType:
obj_in_data = jsonable_encoder(obj_in) obj_in_data = jsonable_encoder(obj_in)
db_obj = self.model(**obj_in_data) # type: ignore db_obj = self.model(**obj_in_data) # type: ignore

View File

@@ -17,6 +17,12 @@ class UserCreate(UserBase):
password: str password: str
class UserCreateOpen(SQLModel):
email: EmailStr
password: str
full_name: Union[str, None] = None
# Properties to receive via API on update, all are optional # Properties to receive via API on update, all are optional
class UserUpdate(UserBase): class UserUpdate(UserBase):
email: Union[EmailStr, None] = None email: Union[EmailStr, None] = None