Add base class to simplify CRUD (#23)

This commit is contained in:
Manu
2020-01-19 22:40:50 +01:00
committed by Sebastián Ramírez
parent 1c975c7f2d
commit ab46165387
33 changed files with 322 additions and 283 deletions

View File

@@ -6,8 +6,8 @@ from sqlalchemy.orm import Session
from app import crud
from app.api.utils.db import get_db
from app.api.utils.security import get_current_active_user
from app.db_models.user import User as DBUser
from app.models.item import Item, ItemCreate, ItemUpdate
from app.models.user import User as DBUser
from app.schemas.item import Item, ItemCreate, ItemUpdate
router = APIRouter()
@@ -41,7 +41,9 @@ def create_item(
"""
Create new item.
"""
item = crud.item.create(db_session=db, item_in=item_in, owner_id=current_user.id)
item = crud.item.create_with_owner(
db_session=db, obj_in=item_in, owner_id=current_user.id
)
return item
@@ -61,7 +63,7 @@ def update_item(
raise HTTPException(status_code=404, detail="Item not found")
if not crud.user.is_superuser(current_user) and (item.owner_id != current_user.id):
raise HTTPException(status_code=400, detail="Not enough permissions")
item = crud.item.update(db_session=db, item=item, item_in=item_in)
item = crud.item.update(db_session=db, db_obj=item, obj_in=item_in)
return item

View File

@@ -10,10 +10,10 @@ from app.api.utils.security import get_current_user
from app.core import config
from app.core.jwt import create_access_token
from app.core.security import get_password_hash
from app.db_models.user import User as DBUser
from app.models.msg import Msg
from app.models.token import Token
from app.models.user import User
from app.models.user import User as DBUser
from app.schemas.msg import Msg
from app.schemas.token import Token
from app.schemas.user import User
from app.utils import (
generate_password_reset_token,
send_reset_password_email,

View File

@@ -2,15 +2,15 @@ from typing import List
from fastapi import APIRouter, Body, Depends, HTTPException
from fastapi.encoders import jsonable_encoder
from pydantic.types import EmailStr
from pydantic.networks import EmailStr
from sqlalchemy.orm import Session
from app import crud
from app.api.utils.db import get_db
from app.api.utils.security import get_current_active_superuser, get_current_active_user
from app.core import config
from app.db_models.user import User as DBUser
from app.models.user import User, UserCreate, UserInDB, UserUpdate
from app.models.user import User as DBUser
from app.schemas.user import User, UserCreate, UserUpdate
from app.utils import send_new_account_email
router = APIRouter()
@@ -46,7 +46,7 @@ def create_user(
status_code=400,
detail="The user with this username already exists in the system.",
)
user = crud.user.create(db, user_in=user_in)
user = crud.user.create(db, obj_in=user_in)
if config.EMAILS_ENABLED and user_in.email:
send_new_account_email(
email_to=user_in.email, username=user_in.email, password=user_in.password
@@ -74,7 +74,7 @@ def update_user_me(
user_in.full_name = full_name
if email is not None:
user_in.email = email
user = crud.user.update(db, user=current_user, user_in=user_in)
user = crud.user.update(db, db_obj=current_user, obj_in=user_in)
return user
@@ -103,7 +103,7 @@ def create_user_open(
if not config.USERS_OPEN_REGISTRATION:
raise HTTPException(
status_code=403,
detail="Open user resgistration is forbidden on this server",
detail="Open user registration is forbidden on this server",
)
user = crud.user.get_by_email(db, email=email)
if user:
@@ -112,7 +112,7 @@ def create_user_open(
detail="The user with this username already exists in the system",
)
user_in = UserCreate(password=password, email=email, full_name=full_name)
user = crud.user.create(db, user_in=user_in)
user = crud.user.create(db, obj_in=user_in)
return user
@@ -125,7 +125,7 @@ def read_user_by_id(
"""
Get a specific user by id.
"""
user = crud.user.get(db, user_id=user_id)
user = crud.user.get(db, id=user_id)
if user == current_user:
return user
if not crud.user.is_superuser(current_user):
@@ -141,16 +141,16 @@ def update_user(
db: Session = Depends(get_db),
user_id: int,
user_in: UserUpdate,
current_user: UserInDB = Depends(get_current_active_superuser),
current_user: DBUser = Depends(get_current_active_superuser),
):
"""
Update a user.
"""
user = crud.user.get(db, user_id=user_id)
user = crud.user.get(db, id=user_id)
if not user:
raise HTTPException(
status_code=404,
detail="The user with this username does not exist in the system",
)
user = crud.user.update(db, user=user, user_in=user_in)
user = crud.user.update(db, db_obj=user, obj_in=user_in)
return user

View File

@@ -1,10 +1,11 @@
from fastapi import APIRouter, Depends
from pydantic.types import EmailStr
from pydantic.networks import EmailStr
from app.api.utils.security import get_current_active_superuser
from app.core.celery_app import celery_app
from app.models.msg import Msg
from app.models.user import UserInDB
from app.schemas.msg import Msg
from app.schemas.user import User
from app.models.user import User as DBUser
from app.utils import send_test_email
router = APIRouter()
@@ -12,7 +13,7 @@ router = APIRouter()
@router.post("/test-celery/", response_model=Msg, status_code=201)
def test_celery(
msg: Msg, current_user: UserInDB = Depends(get_current_active_superuser)
msg: Msg, current_user: DBUser = Depends(get_current_active_superuser)
):
"""
Test Celery worker.
@@ -23,7 +24,7 @@ def test_celery(
@router.post("/test-email/", response_model=Msg, status_code=201)
def test_email(
email_to: EmailStr, current_user: UserInDB = Depends(get_current_active_superuser)
email_to: EmailStr, current_user: DBUser = Depends(get_current_active_superuser)
):
"""
Test emails.

View File

@@ -9,8 +9,8 @@ from app import crud
from app.api.utils.db import get_db
from app.core import config
from app.core.jwt import ALGORITHM
from app.db_models.user import User
from app.models.token import TokenPayload
from app.models.user import User
from app.schemas.token import TokenPayload
reusable_oauth2 = OAuth2PasswordBearer(tokenUrl="/api/v1/login/access-token")
@@ -25,7 +25,7 @@ def get_current_user(
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Could not validate credentials"
)
user = crud.user.get(db, user_id=token_data.user_id)
user = crud.user.get(db, id=token_data.user_id)
if not user:
raise HTTPException(status_code=404, detail="User not found")
return user

View File

@@ -1 +1,10 @@
from . import item, user
from .crud_user import user
from .crud_item import item
# For a new basic set of CRUD operations you could just do
# from .base import CRUDBase
# from app.models.item import Item
# from app.schemas.item import ItemCreate, ItemUpdate
# item = CRUDBase[Item, ItemCreate, ItemUpdate](Item)

View File

@@ -0,0 +1,57 @@
from typing import List, Optional, Generic, TypeVar, Type
from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel
from sqlalchemy.orm import Session
from app.db.base_class import Base
ModelType = TypeVar("ModelType", bound=Base)
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
def __init__(self, model: Type[ModelType]):
"""
CRUD object with default methods to Create, Read, Update, Delete (CRUD).
**Parameters**
* `model`: A SQLAlchemy model class
* `schema`: A Pydantic model (schema) class
"""
self.model = model
def get(self, db_session: Session, id: int) -> Optional[ModelType]:
return db_session.query(self.model).filter(self.model.id == id).first()
def get_multi(self, db_session: Session, *, skip=0, limit=100) -> List[ModelType]:
return db_session.query(self.model).offset(skip).limit(limit).all()
def create(self, db_session: Session, *, obj_in: CreateSchemaType) -> ModelType:
obj_in_data = jsonable_encoder(obj_in)
db_obj = self.model(**obj_in_data)
db_session.add(db_obj)
db_session.commit()
db_session.refresh(db_obj)
return db_obj
def update(
self, db_session: Session, *, db_obj: ModelType, obj_in: UpdateSchemaType
) -> ModelType:
obj_data = jsonable_encoder(db_obj)
update_data = obj_in.dict(skip_defaults=True)
for field in obj_data:
if field in update_data:
setattr(db_obj, field, update_data[field])
db_session.add(db_obj)
db_session.commit()
db_session.refresh(db_obj)
return db_obj
def remove(self, db_session: Session, *, id: int) -> ModelType:
obj = db_session.query(self.model).get(id)
db_session.delete(obj)
db_session.commit()
return obj

View File

@@ -0,0 +1,34 @@
from typing import List
from fastapi.encoders import jsonable_encoder
from sqlalchemy.orm import Session
from app.models.item import Item
from app.schemas.item import ItemCreate, ItemUpdate
from app.crud.base import CRUDBase
class CRUDItem(CRUDBase[Item, ItemCreate, ItemUpdate]):
def create_with_owner(
self, db_session: Session, *, obj_in: ItemCreate, owner_id: int
) -> Item:
obj_in_data = jsonable_encoder(obj_in)
db_obj = self.model(**obj_in_data, owner_id=owner_id)
db_session.add(db_obj)
db_session.commit()
db_session.refresh(db_obj)
return db_obj
def get_multi_by_owner(
self, db_session: Session, *, owner_id: int, skip=0, limit=100
) -> List[Item]:
return (
db_session.query(self.model)
.filter(Item.owner_id == owner_id)
.offset(skip)
.limit(limit)
.all()
)
item = CRUDItem(Item)

View File

@@ -0,0 +1,44 @@
from typing import Optional
from sqlalchemy.orm import Session
from app.models.user import User
from app.schemas.user import UserCreate, UserUpdate
from app.core.security import verify_password, get_password_hash
from app.crud.base import CRUDBase
class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
def get_by_email(self, db_session: Session, *, email: str) -> Optional[User]:
return db_session.query(User).filter(User.email == email).first()
def create(self, db_session: Session, *, obj_in: UserCreate) -> User:
db_obj = User(
email=obj_in.email,
hashed_password=get_password_hash(obj_in.password),
full_name=obj_in.full_name,
is_superuser=obj_in.is_superuser,
)
db_session.add(db_obj)
db_session.commit()
db_session.refresh(db_obj)
return db_obj
def authenticate(
self, db_session: Session, *, email: str, password: str
) -> Optional[User]:
user = self.get_by_email(db_session, email=email)
if not user:
return None
if not verify_password(password, user.hashed_password):
return None
return user
def is_active(self, user: User) -> bool:
return user.is_active
def is_superuser(self, user: User) -> bool:
return user.is_superuser
user = CRUDUser(User)

View File

@@ -1,55 +0,0 @@
from typing import List, Optional
from fastapi.encoders import jsonable_encoder
from sqlalchemy.orm import Session
from app.db_models.item import Item
from app.models.item import ItemCreate, ItemUpdate
def get(db_session: Session, *, id: int) -> Optional[Item]:
return db_session.query(Item).filter(Item.id == id).first()
def get_multi(db_session: Session, *, skip=0, limit=100) -> List[Optional[Item]]:
return db_session.query(Item).offset(skip).limit(limit).all()
def get_multi_by_owner(
db_session: Session, *, owner_id: int, skip=0, limit=100
) -> List[Optional[Item]]:
return (
db_session.query(Item)
.filter(Item.owner_id == owner_id)
.offset(skip)
.limit(limit)
.all()
)
def create(db_session: Session, *, item_in: ItemCreate, owner_id: int) -> Item:
item_in_data = jsonable_encoder(item_in)
item = Item(**item_in_data, owner_id=owner_id)
db_session.add(item)
db_session.commit()
db_session.refresh(item)
return item
def update(db_session: Session, *, item: Item, item_in: ItemUpdate) -> Item:
item_data = jsonable_encoder(item)
update_data = item_in.dict(skip_defaults=True)
for field in item_data:
if field in update_data:
setattr(item, field, update_data[field])
db_session.add(item)
db_session.commit()
db_session.refresh(item)
return item
def remove(db_session: Session, *, id: int):
item = db_session.query(Item).filter(Item.id == id).first()
db_session.delete(item)
db_session.commit()
return item

View File

@@ -1,65 +0,0 @@
from typing import List, Optional
from fastapi.encoders import jsonable_encoder
from sqlalchemy.orm import Session
from app.core.security import get_password_hash, verify_password
from app.db_models.user import User
from app.models.user import UserCreate, UserUpdate
def get(db_session: Session, *, user_id: int) -> Optional[User]:
return db_session.query(User).filter(User.id == user_id).first()
def get_by_email(db_session: Session, *, email: str) -> Optional[User]:
return db_session.query(User).filter(User.email == email).first()
def authenticate(db_session: Session, *, email: str, password: str) -> Optional[User]:
user = get_by_email(db_session, email=email)
if not user:
return None
if not verify_password(password, user.hashed_password):
return None
return user
def is_active(user) -> bool:
return user.is_active
def is_superuser(user) -> bool:
return user.is_superuser
def get_multi(db_session: Session, *, skip=0, limit=100) -> List[Optional[User]]:
return db_session.query(User).offset(skip).limit(limit).all()
def create(db_session: Session, *, user_in: UserCreate) -> User:
user = User(
email=user_in.email,
hashed_password=get_password_hash(user_in.password),
full_name=user_in.full_name,
is_superuser=user_in.is_superuser,
)
db_session.add(user)
db_session.commit()
db_session.refresh(user)
return user
def update(db_session: Session, *, user: User, user_in: UserUpdate) -> User:
user_data = jsonable_encoder(user)
update_data = user_in.dict(skip_defaults=True)
for field in user_data:
if field in update_data:
setattr(user, field, update_data[field])
if user_in.password:
passwordhash = get_password_hash(user_in.password)
user.hashed_password = passwordhash
db_session.add(user)
db_session.commit()
db_session.refresh(user)
return user

View File

@@ -1,5 +1,5 @@
# Import all the models, so that Base has them before being
# imported by Alembic
from app.db.base_class import Base # noqa
from app.db_models.user import User # noqa
from app.db_models.item import Item # noqa
from app.models.user import User # noqa
from app.models.item import Item # noqa

View File

@@ -1,6 +1,6 @@
from app import crud
from app.core import config
from app.models.user import UserCreate
from app.schemas.user import UserCreate
# make sure all SQL Alchemy models are imported before initializing DB
# otherwise, SQL Alchemy might fail to initialize properly relationships
@@ -21,4 +21,4 @@ def init_db(db_session):
password=config.FIRST_SUPERUSER_PASSWORD,
is_superuser=True,
)
user = crud.user.create(db_session, user_in=user_in)
user = crud.user.create(db_session, obj_in=user_in)

View File

@@ -1,12 +0,0 @@
from sqlalchemy import Column, ForeignKey, Integer, String
from sqlalchemy.orm import relationship
from app.db.base_class import Base
class Item(Base):
id = Column(Integer, primary_key=True, index=True)
title = Column(String, index=True)
description = Column(String, index=True)
owner_id = Column(Integer, ForeignKey("user.id"))
owner = relationship("User", back_populates="items")

View File

@@ -1,14 +0,0 @@
from sqlalchemy import Boolean, Column, Integer, String
from sqlalchemy.orm import relationship
from app.db.base_class import Base
class User(Base):
id = Column(Integer, primary_key=True, index=True)
full_name = Column(String, index=True)
email = Column(String, unique=True, index=True)
hashed_password = Column(String)
is_active = Column(Boolean(), default=True)
is_superuser = Column(Boolean(), default=False)
items = relationship("Item", back_populates="owner")

View File

@@ -1,34 +1,12 @@
from pydantic import BaseModel
from sqlalchemy import Column, ForeignKey, Integer, String
from sqlalchemy.orm import relationship
from app.db.base_class import Base
# Shared properties
class ItemBase(BaseModel):
title: str = None
description: str = None
# Properties to receive on item creation
class ItemCreate(ItemBase):
title: str
# Properties to receive on item update
class ItemUpdate(ItemBase):
pass
# Properties shared by models stored in DB
class ItemInDBBase(ItemBase):
id: int
title: str
owner_id: int
# Properties to return to client
class Item(ItemInDBBase):
pass
# Properties properties stored in DB
class ItemInDB(ItemInDBBase):
pass
class Item(Base):
id = Column(Integer, primary_key=True, index=True)
title = Column(String, index=True)
description = Column(String, index=True)
owner_id = Column(Integer, ForeignKey("user.id"))
owner = relationship("User", back_populates="items")

View File

@@ -1,36 +1,14 @@
from typing import Optional
from sqlalchemy import Boolean, Column, Integer, String
from sqlalchemy.orm import relationship
from pydantic import BaseModel
from app.db.base_class import Base
# Shared properties
class UserBase(BaseModel):
email: Optional[str] = None
is_active: Optional[bool] = True
is_superuser: Optional[bool] = False
full_name: Optional[str] = None
class UserBaseInDB(UserBase):
id: int = None
# Properties to receive via API on creation
class UserCreate(UserBaseInDB):
email: str
password: str
# Properties to receive via API on update
class UserUpdate(UserBaseInDB):
password: Optional[str] = None
# Additional properties to return via API
class User(UserBaseInDB):
pass
# Additional properties stored in DB
class UserInDB(UserBaseInDB):
hashed_password: str
class User(Base):
id = Column(Integer, primary_key=True, index=True)
full_name = Column(String, index=True)
email = Column(String, unique=True, index=True)
hashed_password = Column(String)
is_active = Column(Boolean(), default=True)
is_superuser = Column(Boolean(), default=False)
items = relationship("Item", back_populates="owner")

View File

@@ -0,0 +1,38 @@
from pydantic import BaseModel
from .user import User
# Shared properties
class ItemBase(BaseModel):
title: str = None
description: str = None
# Properties to receive on item creation
class ItemCreate(ItemBase):
title: str
# Properties to receive on item update
class ItemUpdate(ItemBase):
pass
# Properties shared by models stored in DB
class ItemInDBBase(ItemBase):
id: int
title: str
owner_id: int
class Config:
orm_mode = True
# Properties to return to client
class Item(ItemInDBBase):
pass
# Properties properties stored in DB
class ItemInDB(ItemInDBBase):
pass

View File

@@ -0,0 +1,39 @@
from typing import Optional
from pydantic import BaseModel
# Shared properties
class UserBase(BaseModel):
email: Optional[str] = None
is_active: Optional[bool] = True
is_superuser: Optional[bool] = False
full_name: Optional[str] = None
class UserBaseInDB(UserBase):
id: int = None
class Config:
orm_mode = True
# Properties to receive via API on creation
class UserCreate(UserBaseInDB):
email: str
password: str
# Properties to receive via API on update
class UserUpdate(UserBaseInDB):
password: Optional[str] = None
# Additional properties to return via API
class User(UserBaseInDB):
pass
# Additional properties stored in DB
class UserInDB(UserBaseInDB):
hashed_password: str

View File

@@ -3,6 +3,7 @@ import requests
from app.core import config
from app.tests.utils.item import create_random_item
from app.tests.utils.utils import get_server_api
from app.tests.utils.user import create_random_user
def test_create_item(superuser_token_headers):
@@ -13,6 +14,7 @@ def test_create_item(superuser_token_headers):
headers=superuser_token_headers,
json=data,
)
assert response.status_code == 200
content = response.json()
assert content["title"] == data["title"]
assert content["description"] == data["description"]
@@ -27,6 +29,7 @@ def test_read_item(superuser_token_headers):
f"{server_api}{config.API_V1_STR}/items/{item.id}",
headers=superuser_token_headers,
)
assert response.status_code == 200
content = response.json()
assert content["title"] == item.title
assert content["description"] == item.description

View File

@@ -3,8 +3,7 @@ import requests
from app import crud
from app.core import config
from app.db.session import db_session
from app.models.user import UserCreate
from app.tests.utils.user import user_authentication_headers
from app.schemas.user import UserCreate
from app.tests.utils.utils import get_server_api, random_lower_string
@@ -53,7 +52,7 @@ def test_get_existing_user(superuser_token_headers):
username = random_lower_string()
password = random_lower_string()
user_in = UserCreate(email=username, password=password)
user = crud.user.create(db_session, user_in=user_in)
user = crud.user.create(db_session, obj_in=user_in)
user_id = user.id
r = requests.get(
f"{server_api}{config.API_V1_STR}/users/{user_id}",
@@ -71,7 +70,7 @@ def test_create_user_existing_username(superuser_token_headers):
# username = email
password = random_lower_string()
user_in = UserCreate(email=username, password=password)
user = crud.user.create(db_session, user_in=user_in)
crud.user.create(db_session, obj_in=user_in)
data = {"email": username, "password": password}
r = requests.post(
f"{server_api}{config.API_V1_STR}/users/",
@@ -89,7 +88,9 @@ def test_create_user_by_normal_user(normal_user_token_headers):
password = random_lower_string()
data = {"email": username, "password": password}
r = requests.post(
f"{server_api}{config.API_V1_STR}/users/", headers=normal_user_token_headers, json=data
f"{server_api}{config.API_V1_STR}/users/",
headers=normal_user_token_headers,
json=data,
)
assert r.status_code == 400
@@ -99,12 +100,12 @@ def test_retrieve_users(superuser_token_headers):
username = random_lower_string()
password = random_lower_string()
user_in = UserCreate(email=username, password=password)
user = crud.user.create(db_session, user_in=user_in)
user = crud.user.create(db_session, obj_in=user_in)
username2 = random_lower_string()
password2 = random_lower_string()
user_in2 = UserCreate(email=username2, password=password2)
user2 = crud.user.create(db_session, user_in=user_in2)
crud.user.create(db_session, obj_in=user_in2)
r = requests.get(
f"{server_api}{config.API_V1_STR}/users/", headers=superuser_token_headers

View File

@@ -1,5 +1,5 @@
from app import crud
from app.models.item import ItemCreate, ItemUpdate
from app.schemas.item import ItemCreate, ItemUpdate
from app.tests.utils.user import create_random_user
from app.tests.utils.utils import random_lower_string
from app.db.session import db_session
@@ -10,7 +10,9 @@ def test_create_item():
description = random_lower_string()
item_in = ItemCreate(title=title, description=description)
user = create_random_user()
item = crud.item.create(db_session=db_session, item_in=item_in, owner_id=user.id)
item = crud.item.create_with_owner(
db_session=db_session, obj_in=item_in, owner_id=user.id
)
assert item.title == title
assert item.description == description
assert item.owner_id == user.id
@@ -21,7 +23,9 @@ def test_get_item():
description = random_lower_string()
item_in = ItemCreate(title=title, description=description)
user = create_random_user()
item = crud.item.create(db_session=db_session, item_in=item_in, owner_id=user.id)
item = crud.item.create_with_owner(
db_session=db_session, obj_in=item_in, owner_id=user.id
)
stored_item = crud.item.get(db_session=db_session, id=item.id)
assert item.id == stored_item.id
assert item.title == stored_item.title
@@ -34,12 +38,12 @@ def test_update_item():
description = random_lower_string()
item_in = ItemCreate(title=title, description=description)
user = create_random_user()
item = crud.item.create(db_session=db_session, item_in=item_in, owner_id=user.id)
item = crud.item.create_with_owner(
db_session=db_session, obj_in=item_in, owner_id=user.id
)
description2 = random_lower_string()
item_update = ItemUpdate(description=description2)
item2 = crud.item.update(
db_session=db_session, item=item, item_in=item_update
)
item2 = crud.item.update(db_session=db_session, db_obj=item, obj_in=item_update)
assert item.id == item2.id
assert item.title == item2.title
assert item2.description == description2
@@ -51,7 +55,7 @@ def test_delete_item():
description = random_lower_string()
item_in = ItemCreate(title=title, description=description)
user = create_random_user()
item = crud.item.create(db_session=db_session, item_in=item_in, owner_id=user.id)
item = crud.item.create_with_owner(db_session=db_session, obj_in=item_in, owner_id=user.id)
item2 = crud.item.remove(db_session=db_session, id=item.id)
item3 = crud.item.get(db_session=db_session, id=item.id)
assert item3 is None

View File

@@ -2,7 +2,7 @@ from fastapi.encoders import jsonable_encoder
from app import crud
from app.db.session import db_session
from app.models.user import UserCreate
from app.schemas.user import UserCreate
from app.tests.utils.utils import random_lower_string
@@ -10,7 +10,7 @@ def test_create_user():
email = random_lower_string()
password = random_lower_string()
user_in = UserCreate(email=email, password=password)
user = crud.user.create(db_session, user_in=user_in)
user = crud.user.create(db_session, obj_in=user_in)
assert user.email == email
assert hasattr(user, "hashed_password")
@@ -19,7 +19,7 @@ def test_authenticate_user():
email = random_lower_string()
password = random_lower_string()
user_in = UserCreate(email=email, password=password)
user = crud.user.create(db_session, user_in=user_in)
user = crud.user.create(db_session, obj_in=user_in)
authenticated_user = crud.user.authenticate(
db_session, email=email, password=password
)
@@ -38,7 +38,7 @@ def test_check_if_user_is_active():
email = random_lower_string()
password = random_lower_string()
user_in = UserCreate(email=email, password=password)
user = crud.user.create(db_session, user_in=user_in)
user = crud.user.create(db_session, obj_in=user_in)
is_active = crud.user.is_active(user)
assert is_active is True
@@ -47,11 +47,8 @@ def test_check_if_user_is_active_inactive():
email = random_lower_string()
password = random_lower_string()
user_in = UserCreate(email=email, password=password, disabled=True)
print(user_in)
user = crud.user.create(db_session, user_in=user_in)
print(user)
user = crud.user.create(db_session, obj_in=user_in)
is_active = crud.user.is_active(user)
print(is_active)
assert is_active
@@ -59,7 +56,7 @@ def test_check_if_user_is_superuser():
email = random_lower_string()
password = random_lower_string()
user_in = UserCreate(email=email, password=password, is_superuser=True)
user = crud.user.create(db_session, user_in=user_in)
user = crud.user.create(db_session, obj_in=user_in)
is_superuser = crud.user.is_superuser(user)
assert is_superuser is True
@@ -68,7 +65,7 @@ def test_check_if_user_is_superuser_normal_user():
username = random_lower_string()
password = random_lower_string()
user_in = UserCreate(email=username, password=password)
user = crud.user.create(db_session, user_in=user_in)
user = crud.user.create(db_session, obj_in=user_in)
is_superuser = crud.user.is_superuser(user)
assert is_superuser is False
@@ -77,7 +74,7 @@ def test_get_user():
password = random_lower_string()
username = random_lower_string()
user_in = UserCreate(email=username, password=password, is_superuser=True)
user = crud.user.create(db_session, user_in=user_in)
user_2 = crud.user.get(db_session, user_id=user.id)
user = crud.user.create(db_session, obj_in=user_in)
user_2 = crud.user.get(db_session, id=user.id)
assert user.email == user_2.email
assert jsonable_encoder(user) == jsonable_encoder(user_2)

View File

@@ -1,6 +1,6 @@
from app import crud
from app.db.session import db_session
from app.models.item import ItemCreate
from app.schemas.item import ItemCreate
from app.tests.utils.user import create_random_user
from app.tests.utils.utils import random_lower_string
@@ -12,6 +12,6 @@ def create_random_item(owner_id: int = None):
title = random_lower_string()
description = random_lower_string()
item_in = ItemCreate(title=title, description=description, id=id)
return crud.item.create(
db_session=db_session, item_in=item_in, owner_id=owner_id
return crud.item.create_with_owner(
db_session=db_session, obj_in=item_in, owner_id=owner_id
)

View File

@@ -3,7 +3,7 @@ import requests
from app import crud
from app.core import config
from app.db.session import db_session
from app.models.user import UserCreate, UserUpdate
from app.schemas.user import UserCreate, UserUpdate
from app.tests.utils.utils import get_server_api, random_lower_string
@@ -21,7 +21,7 @@ def create_random_user():
email = random_lower_string()
password = random_lower_string()
user_in = UserCreate(username=email, email=email, password=password)
user = crud.user.create(db_session=db_session, user_in=user_in)
user = crud.user.create(db_session=db_session, obj_in=user_in)
return user
@@ -35,9 +35,9 @@ def authentication_token_from_email(email):
user = crud.user.get_by_email(db_session, email=email)
if not user:
user_in = UserCreate(username=email, email=email, password=password)
user = crud.user.create(db_session=db_session, user_in=user_in)
user = crud.user.create(db_session=db_session, obj_in=user_in)
else:
user_in = UserUpdate(password=password)
user = crud.user.update(db_session, user=user, user_in=user_in)
user = crud.user.update(db_session, obj_in=user, db_obj=user_in)
return user_authentication_headers(get_server_api(), email, password)