♻️ Refactor old CRUD utils and tests (#622)

This commit is contained in:
Alejandra
2024-02-29 15:42:55 -05:00
committed by GitHub
parent ad0abb08ef
commit 2d0f77421f
13 changed files with 104 additions and 307 deletions

View File

@@ -166,22 +166,12 @@ def update_user(
Update a user. Update a user.
""" """
db_user = session.get(User, user_id) db_user = crud.update_user(session=session, user_id=user_id, user_in=user_in)
if not db_user: if db_user is None:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
detail="The user with this username does not exist in the system", detail="The user with this username does not exist in the system",
) )
user_data = user_in.model_dump(exclude_unset=True)
extra_data = {}
if "password" in user_data:
password = user_data["password"]
hashed_password = get_password_hash(password)
extra_data["hashed_password"] = hashed_password
db_user.sqlmodel_update(user_data, update=extra_data)
session.add(db_user)
session.commit()
session.refresh(db_user)
return db_user return db_user

55
src/backend/app/crud.py Normal file
View File

@@ -0,0 +1,55 @@
from typing import Any
from sqlmodel import Session, select
from app.core.security import get_password_hash, verify_password
from app.models import Item, ItemCreate, User, UserCreate, UserUpdate
def create_user(*, session: Session, user_create: UserCreate) -> User:
db_obj = User.model_validate(
user_create, update={"hashed_password": get_password_hash(user_create.password)}
)
session.add(db_obj)
session.commit()
session.refresh(db_obj)
return db_obj
def update_user(*, session: Session, user_id: int, user_in: UserUpdate) -> Any:
db_user = session.get(User, user_id)
if not db_user:
return None
user_data = user_in.model_dump(exclude_unset=True)
extra_data = {}
if "password" in user_data:
password = user_data["password"]
hashed_password = get_password_hash(password)
extra_data["hashed_password"] = hashed_password
db_user.sqlmodel_update(user_data, update=extra_data)
session.add(db_user)
session.commit()
session.refresh(db_user)
return db_user
def get_user_by_email(*, session: Session, email: str) -> User | None:
statement = select(User).where(User.email == email)
session_user = session.exec(statement).first()
return session_user
def authenticate(*, session: Session, email: str, password: str) -> User | None:
db_user = get_user_by_email(session=session, email=email)
if not db_user:
return None
if not verify_password(password, db_user.hashed_password):
return None
return db_user
def create_item(*, session: Session, item_in: ItemCreate, owner_id: int) -> Item:
db_item = Item.model_validate(item_in, update={"owner_id": owner_id})
session.add(db_item)
session.commit()
session.refresh(db_item)
return db_item

View File

@@ -1,37 +0,0 @@
# For a new basic set of CRUD operations you could just do
# from .base import CRUDBase
# from app.models.item import Item
# from app.schemas.item import ItemCreate, ItemUpdate
# item = CRUDBase[Item, ItemCreate, ItemUpdate](Item)
from sqlmodel import Session, select
from app.core.security import get_password_hash, verify_password
from app.models import User, UserCreate
from .crud_item import item as item
from .crud_user import user as user
def create_user(*, session: Session, user_create: UserCreate) -> User:
db_obj = User.from_orm(
user_create, update={"hashed_password": get_password_hash(user_create.password)}
)
session.add(db_obj)
session.commit()
session.refresh(db_obj)
return db_obj
def get_user_by_email(*, session: Session, email: str) -> User | None:
statement = select(User).where(User.email == email)
session_user = session.exec(statement).first()
return session_user
def authenticate(*, session: Session, email: str, password: str) -> User | None:
db_user = get_user_by_email(session=session, email=email)
if not db_user:
return None
if not verify_password(password, db_user.hashed_password):
return None
return db_user

View File

@@ -1,59 +0,0 @@
from typing import Any, Generic, TypeVar
from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel
from sqlalchemy.orm import Session
ModelType = TypeVar("ModelType", bound=Any)
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
def __init__(self, model: type[ModelType]):
"""
CRUD object with default methods to Create, Read, Update, Delete (CRUD).
**Parameters**
* `model`: A SQLAlchemy model class
* `schema`: A Pydantic model (schema) class
"""
self.model = model
def get(self, db: Session, id: Any) -> ModelType | None:
return db.query(self.model).filter(self.model.id == id).first()
def create(self, db: Session, *, obj_in: CreateSchemaType) -> ModelType:
obj_in_data = jsonable_encoder(obj_in)
db_obj = self.model(**obj_in_data) # type: ignore
db.add(db_obj)
db.commit()
db.refresh(db_obj)
return db_obj
def update(
self,
db: Session,
*,
db_obj: ModelType,
obj_in: UpdateSchemaType | dict[str, Any],
) -> ModelType:
obj_data = jsonable_encoder(db_obj)
if isinstance(obj_in, dict):
update_data = obj_in
else:
update_data = obj_in.dict(exclude_unset=True)
for field in obj_data:
if field in update_data:
setattr(db_obj, field, update_data[field])
db.add(db_obj)
db.commit()
db.refresh(db_obj)
return db_obj
def remove(self, db: Session, *, id: int) -> ModelType:
obj = db.query(self.model).get(id)
db.delete(obj)
db.commit()
return obj

View File

@@ -1,32 +0,0 @@
from fastapi.encoders import jsonable_encoder
from sqlalchemy.orm import Session
from app.crud.base import CRUDBase
from app.models import Item
from app.schemas.item import ItemCreate, ItemUpdate
class CRUDItem(CRUDBase[Item, ItemCreate, ItemUpdate]):
def create_with_owner(
self, db: Session, *, obj_in: ItemCreate, owner_id: int
) -> Item:
obj_in_data = jsonable_encoder(obj_in)
db_obj = self.model(**obj_in_data, owner_id=owner_id)
db.add(db_obj)
db.commit()
db.refresh(db_obj)
return db_obj
def get_multi_by_owner(
self, db: Session, *, owner_id: int, skip: int = 0, limit: int = 100
) -> list[Item]:
return (
db.query(self.model)
.filter(Item.owner_id == owner_id)
.offset(skip)
.limit(limit)
.all()
)
item = CRUDItem(Item)

View File

@@ -1,55 +0,0 @@
from typing import Any
from sqlalchemy.orm import Session
from app.core.security import get_password_hash, verify_password
from app.crud.base import CRUDBase
from app.models import User
from app.schemas.user import UserCreate, UserUpdate
class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
def get_by_email(self, db: Session, *, email: str) -> User | None:
return db.query(User).filter(User.email == email).first()
def create(self, db: Session, *, obj_in: UserCreate) -> User:
db_obj = User(
email=obj_in.email,
hashed_password=get_password_hash(obj_in.password),
full_name=obj_in.full_name,
is_superuser=obj_in.is_superuser,
)
db.add(db_obj)
db.commit()
db.refresh(db_obj)
return db_obj
def update(
self, db: Session, *, db_obj: User, obj_in: UserUpdate | dict[str, Any]
) -> User:
if isinstance(obj_in, dict):
update_data = obj_in
else:
update_data = obj_in.dict(exclude_unset=True)
if update_data["password"]:
hashed_password = get_password_hash(update_data["password"])
del update_data["password"]
update_data["hashed_password"] = hashed_password
return super().update(db, db_obj=db_obj, obj_in=update_data)
def authenticate(self, db: Session, *, email: str, password: str) -> User | None:
user = self.get_by_email(db, email=email)
if not user:
return None
if not verify_password(password, user.hashed_password):
return None
return user
def is_active(self, user: User) -> bool:
return user.is_active
def is_superuser(self, user: User) -> bool:
return user.is_superuser
user = CRUDUser(User)

View File

@@ -1,5 +1,5 @@
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from sqlalchemy.orm import Session from sqlmodel import Session
from app.core.config import settings from app.core.config import settings
from app.tests.utils.item import create_random_item from app.tests.utils.item import create_random_item

View File

@@ -1,9 +1,9 @@
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from sqlalchemy.orm import Session from sqlmodel import Session
from app import crud from app import crud
from app.core.config import settings from app.core.config import settings
from app.schemas.user import UserCreate from app.models import UserCreate
from app.tests.utils.utils import random_email, random_lower_string from app.tests.utils.utils import random_email, random_lower_string
@@ -42,7 +42,7 @@ def test_create_user_new_email(
) )
assert 200 <= r.status_code < 300 assert 200 <= r.status_code < 300
created_user = r.json() created_user = r.json()
user = crud.user.get_by_email(db, email=username) user = crud.get_user_by_email(session=db, email=username)
assert user assert user
assert user.email == created_user["email"] assert user.email == created_user["email"]
@@ -53,7 +53,7 @@ def test_get_existing_user(
username = random_email() username = random_email()
password = random_lower_string() password = random_lower_string()
user_in = UserCreate(email=username, password=password) user_in = UserCreate(email=username, password=password)
user = crud.user.create(db, obj_in=user_in) user = crud.create_user(session=db, user_create=user_in)
user_id = user.id user_id = user.id
r = client.get( r = client.get(
f"{settings.API_V1_STR}/users/{user_id}", f"{settings.API_V1_STR}/users/{user_id}",
@@ -61,7 +61,7 @@ def test_get_existing_user(
) )
assert 200 <= r.status_code < 300 assert 200 <= r.status_code < 300
api_user = r.json() api_user = r.json()
existing_user = crud.user.get_by_email(db, email=username) existing_user = crud.get_user_by_email(session=db, email=username)
assert existing_user assert existing_user
assert existing_user.email == api_user["email"] assert existing_user.email == api_user["email"]
@@ -73,7 +73,7 @@ def test_create_user_existing_username(
# username = email # username = email
password = random_lower_string() password = random_lower_string()
user_in = UserCreate(email=username, password=password) user_in = UserCreate(email=username, password=password)
crud.user.create(db, obj_in=user_in) crud.create_user(session=db, user_create=user_in)
data = {"email": username, "password": password} data = {"email": username, "password": password}
r = client.post( r = client.post(
f"{settings.API_V1_STR}/users/", f"{settings.API_V1_STR}/users/",
@@ -105,12 +105,12 @@ def test_retrieve_users(
username = random_email() username = random_email()
password = random_lower_string() password = random_lower_string()
user_in = UserCreate(email=username, password=password) user_in = UserCreate(email=username, password=password)
crud.user.create(db, obj_in=user_in) crud.create_user(session=db, user_create=user_in)
username2 = random_email() username2 = random_email()
password2 = random_lower_string() password2 = random_lower_string()
user_in2 = UserCreate(email=username2, password=password2) user_in2 = UserCreate(email=username2, password=password2)
crud.user.create(db, obj_in=user_in2) crud.create_user(session=db, user_create=user_in2)
r = client.get(f"{settings.API_V1_STR}/users/", headers=superuser_token_headers) r = client.get(f"{settings.API_V1_STR}/users/", headers=superuser_token_headers)
all_users = r.json() all_users = r.json()

View File

@@ -2,7 +2,7 @@ from collections.abc import Generator
import pytest import pytest
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from sqlalchemy.orm import Session from sqlmodel import Session
from app.core.config import settings from app.core.config import settings
from app.db.engine import engine from app.db.engine import engine

View File

@@ -1,61 +0,0 @@
from sqlalchemy.orm import Session
from app import crud
from app.schemas.item import ItemCreate, ItemUpdate
from app.tests.utils.user import create_random_user
from app.tests.utils.utils import random_lower_string
def test_create_item(db: Session) -> None:
title = random_lower_string()
description = random_lower_string()
item_in = ItemCreate(title=title, description=description)
user = create_random_user(db)
item = crud.item.create_with_owner(db=db, obj_in=item_in, owner_id=user.id)
assert item.title == title
assert item.description == description
assert item.owner_id == user.id
def test_get_item(db: Session) -> None:
title = random_lower_string()
description = random_lower_string()
item_in = ItemCreate(title=title, description=description)
user = create_random_user(db)
item = crud.item.create_with_owner(db=db, obj_in=item_in, owner_id=user.id)
stored_item = crud.item.get(db=db, id=item.id)
assert stored_item
assert item.id == stored_item.id
assert item.title == stored_item.title
assert item.description == stored_item.description
assert item.owner_id == stored_item.owner_id
def test_update_item(db: Session) -> None:
title = random_lower_string()
description = random_lower_string()
item_in = ItemCreate(title=title, description=description)
user = create_random_user(db)
item = crud.item.create_with_owner(db=db, obj_in=item_in, owner_id=user.id)
description2 = random_lower_string()
item_update = ItemUpdate(description=description2)
item2 = crud.item.update(db=db, db_obj=item, obj_in=item_update)
assert item.id == item2.id
assert item.title == item2.title
assert item2.description == description2
assert item.owner_id == item2.owner_id
def test_delete_item(db: Session) -> None:
title = random_lower_string()
description = random_lower_string()
item_in = ItemCreate(title=title, description=description)
user = create_random_user(db)
item = crud.item.create_with_owner(db=db, obj_in=item_in, owner_id=user.id)
item2 = crud.item.remove(db=db, id=item.id)
item3 = crud.item.get(db=db, id=item.id)
assert item3 is None
assert item2.id == item.id
assert item2.title == title
assert item2.description == description
assert item2.owner_id == user.id

View File

@@ -1,9 +1,9 @@
from fastapi.encoders import jsonable_encoder from fastapi.encoders import jsonable_encoder
from sqlalchemy.orm import Session from sqlmodel import Session
from app import crud from app import crud
from app.core.security import verify_password from app.core.security import verify_password
from app.schemas.user import UserCreate, UserUpdate from app.models import User, UserCreate, UserUpdate
from app.tests.utils.utils import random_email, random_lower_string from app.tests.utils.utils import random_email, random_lower_string
@@ -11,7 +11,7 @@ def test_create_user(db: Session) -> None:
email = random_email() email = random_email()
password = random_lower_string() password = random_lower_string()
user_in = UserCreate(email=email, password=password) user_in = UserCreate(email=email, password=password)
user = crud.user.create(db, obj_in=user_in) user = crud.create_user(session=db, user_create=user_in)
assert user.email == email assert user.email == email
assert hasattr(user, "hashed_password") assert hasattr(user, "hashed_password")
@@ -20,8 +20,8 @@ def test_authenticate_user(db: Session) -> None:
email = random_email() email = random_email()
password = random_lower_string() password = random_lower_string()
user_in = UserCreate(email=email, password=password) user_in = UserCreate(email=email, password=password)
user = crud.user.create(db, obj_in=user_in) user = crud.create_user(session=db, user_create=user_in)
authenticated_user = crud.user.authenticate(db, email=email, password=password) authenticated_user = crud.authenticate(session=db, email=email, password=password)
assert authenticated_user assert authenticated_user
assert user.email == authenticated_user.email assert user.email == authenticated_user.email
@@ -29,7 +29,7 @@ def test_authenticate_user(db: Session) -> None:
def test_not_authenticate_user(db: Session) -> None: def test_not_authenticate_user(db: Session) -> None:
email = random_email() email = random_email()
password = random_lower_string() password = random_lower_string()
user = crud.user.authenticate(db, email=email, password=password) user = crud.authenticate(session=db, email=email, password=password)
assert user is None assert user is None
@@ -37,44 +37,40 @@ def test_check_if_user_is_active(db: Session) -> None:
email = random_email() email = random_email()
password = random_lower_string() password = random_lower_string()
user_in = UserCreate(email=email, password=password) user_in = UserCreate(email=email, password=password)
user = crud.user.create(db, obj_in=user_in) user = crud.create_user(session=db, user_create=user_in)
is_active = crud.user.is_active(user) assert user.is_active is True
assert is_active is True
def test_check_if_user_is_active_inactive(db: Session) -> None: def test_check_if_user_is_active_inactive(db: Session) -> None:
email = random_email() email = random_email()
password = random_lower_string() password = random_lower_string()
user_in = UserCreate(email=email, password=password, disabled=True) user_in = UserCreate(email=email, password=password, disabled=True)
user = crud.user.create(db, obj_in=user_in) user = crud.create_user(session=db, user_create=user_in)
is_active = crud.user.is_active(user) assert user.is_active
assert is_active
def test_check_if_user_is_superuser(db: Session) -> None: def test_check_if_user_is_superuser(db: Session) -> None:
email = random_email() email = random_email()
password = random_lower_string() password = random_lower_string()
user_in = UserCreate(email=email, password=password, is_superuser=True) user_in = UserCreate(email=email, password=password, is_superuser=True)
user = crud.user.create(db, obj_in=user_in) user = crud.create_user(session=db, user_create=user_in)
is_superuser = crud.user.is_superuser(user) assert user.is_superuser is True
assert is_superuser is True
def test_check_if_user_is_superuser_normal_user(db: Session) -> None: def test_check_if_user_is_superuser_normal_user(db: Session) -> None:
username = random_email() username = random_email()
password = random_lower_string() password = random_lower_string()
user_in = UserCreate(email=username, password=password) user_in = UserCreate(email=username, password=password)
user = crud.user.create(db, obj_in=user_in) user = crud.create_user(session=db, user_create=user_in)
is_superuser = crud.user.is_superuser(user) assert user.is_superuser is False
assert is_superuser is False
def test_get_user(db: Session) -> None: def test_get_user(db: Session) -> None:
password = random_lower_string() password = random_lower_string()
username = random_email() username = random_email()
user_in = UserCreate(email=username, password=password, is_superuser=True) user_in = UserCreate(email=username, password=password, is_superuser=True)
user = crud.user.create(db, obj_in=user_in) user = crud.create_user(session=db, user_create=user_in)
user_2 = crud.user.get(db, id=user.id) user_2 = db.get(User, user.id)
assert user_2 assert user_2
assert user.email == user_2.email assert user.email == user_2.email
assert jsonable_encoder(user) == jsonable_encoder(user_2) assert jsonable_encoder(user) == jsonable_encoder(user_2)
@@ -84,11 +80,12 @@ def test_update_user(db: Session) -> None:
password = random_lower_string() password = random_lower_string()
email = random_email() email = random_email()
user_in = UserCreate(email=email, password=password, is_superuser=True) user_in = UserCreate(email=email, password=password, is_superuser=True)
user = crud.user.create(db, obj_in=user_in) user = crud.create_user(session=db, user_create=user_in)
new_password = random_lower_string() new_password = random_lower_string()
user_in_update = UserUpdate(password=new_password, is_superuser=True) user_in_update = UserUpdate(password=new_password, is_superuser=True)
crud.user.update(db, db_obj=user, obj_in=user_in_update) if user.id is not None:
user_2 = crud.user.get(db, id=user.id) crud.update_user(session=db, user_id=user.id, user_in=user_in_update)
user_2 = db.get(User, user.id)
assert user_2 assert user_2
assert user.email == user_2.email assert user.email == user_2.email
assert verify_password(new_password, user_2.hashed_password) assert verify_password(new_password, user_2.hashed_password)

View File

@@ -1,16 +1,16 @@
from sqlalchemy.orm import Session from sqlmodel import Session
from app import crud, models from app import crud
from app.schemas.item import ItemCreate from app.models import Item, ItemCreate
from app.tests.utils.user import create_random_user from app.tests.utils.user import create_random_user
from app.tests.utils.utils import random_lower_string from app.tests.utils.utils import random_lower_string
def create_random_item(db: Session, *, owner_id: int | None = None) -> models.Item: def create_random_item(db: Session) -> Item:
if owner_id is None: user = create_random_user(db)
user = create_random_user(db) owner_id = user.id
owner_id = user.id assert owner_id is not None
title = random_lower_string() title = random_lower_string()
description = random_lower_string() description = random_lower_string()
item_in = ItemCreate(title=title, description=description, id=id) item_in = ItemCreate(title=title, description=description)
return crud.item.create_with_owner(db=db, obj_in=item_in, owner_id=owner_id) return crud.create_item(session=db, item_in=item_in, owner_id=owner_id)

View File

@@ -1,10 +1,9 @@
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from sqlalchemy.orm import Session from sqlmodel import Session
from app import crud from app import crud
from app.core.config import settings from app.core.config import settings
from app.models import User from app.models import User, UserCreate, UserUpdate
from app.schemas.user import UserCreate, UserUpdate
from app.tests.utils.utils import random_email, random_lower_string from app.tests.utils.utils import random_email, random_lower_string
@@ -23,8 +22,8 @@ def user_authentication_headers(
def create_random_user(db: Session) -> User: def create_random_user(db: Session) -> User:
email = random_email() email = random_email()
password = random_lower_string() password = random_lower_string()
user_in = UserCreate(username=email, email=email, password=password) user_in = UserCreate(email=email, password=password)
user = crud.user.create(db=db, obj_in=user_in) user = crud.create_user(session=db, user_create=user_in)
return user return user
@@ -37,12 +36,12 @@ def authentication_token_from_email(
If the user doesn't exist it is created first. If the user doesn't exist it is created first.
""" """
password = random_lower_string() password = random_lower_string()
user = crud.user.get_by_email(db, email=email) user = crud.get_user_by_email(session=db, email=email)
if not user: if not user:
user_in_create = UserCreate(username=email, email=email, password=password) user_in_create = UserCreate(email=email, password=password)
user = crud.user.create(db, obj_in=user_in_create) user = crud.create_user(session=db, user_create=user_in_create)
else: else:
user_in_update = UserUpdate(password=password) user_in_update = UserUpdate(password=password)
user = crud.user.update(db, db_obj=user, obj_in=user_in_update) user = crud.update_user(session=db, user_id=user.id, user_in=user_in_update)
return user_authentication_headers(client=client, email=email, password=password) return user_authentication_headers(client=client, email=email, password=password)