♻️ Refactor backend, settings, DB sessions, types, configs, plugins (#158)

* ♻️ Refactor backend, update DB session handling

*  Add mypy config and plugins

*  Use Python-jose instead of PyJWT

as it has some extra functionalities and features

*  Add/update scripts for test, lint, format

* 🔧 Update lint and format configs

* 🎨 Update import format, comments, and types

* 🎨 Add types to config

*  Add types for all the code, and small fixes

* 🎨 Use global imports to simplify exploring with Jupyter

* ♻️ Import schemas and models, instead of each class

* 🚚 Rename db_session to db for simplicity

* 📌 Update dependencies installation for testing
This commit is contained in:
Sebastián Ramírez
2020-04-20 19:03:13 +02:00
committed by GitHub
parent 4b80bdfdce
commit eed33d276d
59 changed files with 547 additions and 445 deletions

View File

@@ -1,5 +1,5 @@
from .crud_user import user # noqa: F401
from .crud_item import item # noqa: F401
from .crud_user import user # noqa: F401
# For a new basic set of CRUD operations you could just do

View File

@@ -1,4 +1,4 @@
from typing import List, Optional, Generic, TypeVar, Type
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union
from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel
@@ -23,35 +23,44 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
"""
self.model = model
def get(self, db_session: Session, id: int) -> Optional[ModelType]:
return db_session.query(self.model).filter(self.model.id == id).first()
def get(self, db: Session, id: Any) -> Optional[ModelType]:
return db.query(self.model).filter(self.model.id == id).first()
def get_multi(self, db_session: Session, *, skip=0, limit=100) -> List[ModelType]:
return db_session.query(self.model).offset(skip).limit(limit).all()
def get_multi(
self, db: Session, *, skip: int = 0, limit: int = 100
) -> List[ModelType]:
return db.query(self.model).offset(skip).limit(limit).all()
def create(self, db_session: Session, *, obj_in: CreateSchemaType) -> ModelType:
def create(self, db: Session, *, obj_in: CreateSchemaType) -> ModelType:
obj_in_data = jsonable_encoder(obj_in)
db_obj = self.model(**obj_in_data)
db_session.add(db_obj)
db_session.commit()
db_session.refresh(db_obj)
db_obj = self.model(**obj_in_data) # type: ignore
db.add(db_obj)
db.commit()
db.refresh(db_obj)
return db_obj
def update(
self, db_session: Session, *, db_obj: ModelType, obj_in: UpdateSchemaType
self,
db: Session,
*,
db_obj: ModelType,
obj_in: Union[UpdateSchemaType, Dict[str, Any]]
) -> ModelType:
obj_data = jsonable_encoder(db_obj)
update_data = obj_in.dict(exclude_unset=True)
if isinstance(obj_in, dict):
update_data = obj_in
else:
update_data = obj_in.dict(exclude_unset=True)
for field in obj_data:
if field in update_data:
setattr(db_obj, field, update_data[field])
db_session.add(db_obj)
db_session.commit()
db_session.refresh(db_obj)
db.add(db_obj)
db.commit()
db.refresh(db_obj)
return db_obj
def remove(self, db_session: Session, *, id: int) -> ModelType:
obj = db_session.query(self.model).get(id)
db_session.delete(obj)
db_session.commit()
def remove(self, db: Session, *, id: int) -> ModelType:
obj = db.query(self.model).get(id)
db.delete(obj)
db.commit()
return obj

View File

@@ -3,27 +3,27 @@ from typing import List
from fastapi.encoders import jsonable_encoder
from sqlalchemy.orm import Session
from app.crud.base import CRUDBase
from app.models.item import Item
from app.schemas.item import ItemCreate, ItemUpdate
from app.crud.base import CRUDBase
class CRUDItem(CRUDBase[Item, ItemCreate, ItemUpdate]):
def create_with_owner(
self, db_session: Session, *, obj_in: ItemCreate, owner_id: int
self, db: Session, *, obj_in: ItemCreate, owner_id: int
) -> Item:
obj_in_data = jsonable_encoder(obj_in)
db_obj = self.model(**obj_in_data, owner_id=owner_id)
db_session.add(db_obj)
db_session.commit()
db_session.refresh(db_obj)
db.add(db_obj)
db.commit()
db.refresh(db_obj)
return db_obj
def get_multi_by_owner(
self, db_session: Session, *, owner_id: int, skip=0, limit=100
self, db: Session, *, owner_id: int, skip: int = 0, limit: int = 100
) -> List[Item]:
return (
db_session.query(self.model)
db.query(self.model)
.filter(Item.owner_id == owner_id)
.offset(skip)
.limit(limit)

View File

@@ -1,42 +1,44 @@
from typing import Optional
from typing import Any, Dict, Optional, Union
from sqlalchemy.orm import Session
from app.models.user import User
from app.schemas.user import UserCreate, UserUpdate, UserInDB
from app.core.security import verify_password, get_password_hash
from app.core.security import get_password_hash, verify_password
from app.crud.base import CRUDBase
from app.models.user import User
from app.schemas.user import UserCreate, UserUpdate
class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
def get_by_email(self, db_session: Session, *, email: str) -> Optional[User]:
return db_session.query(User).filter(User.email == email).first()
def get_by_email(self, db: Session, *, email: str) -> Optional[User]:
return db.query(User).filter(User.email == email).first()
def create(self, db_session: Session, *, obj_in: UserCreate) -> User:
def create(self, db: Session, *, obj_in: UserCreate) -> User:
db_obj = User(
email=obj_in.email,
hashed_password=get_password_hash(obj_in.password),
full_name=obj_in.full_name,
is_superuser=obj_in.is_superuser,
)
db_session.add(db_obj)
db_session.commit()
db_session.refresh(db_obj)
db.add(db_obj)
db.commit()
db.refresh(db_obj)
return db_obj
def update(self, db_session: Session, *, db_obj: User, obj_in: UserUpdate) -> User:
if obj_in.password:
def update(
self, db: Session, *, db_obj: User, obj_in: Union[UserUpdate, Dict[str, Any]]
) -> User:
if isinstance(obj_in, dict):
update_data = obj_in
else:
update_data = obj_in.dict(exclude_unset=True)
hashed_password = get_password_hash(obj_in.password)
if update_data["password"]:
hashed_password = get_password_hash(update_data["password"])
del update_data["password"]
update_data["hashed_password"] = hashed_password
use_obj_in = UserInDB.parse_obj(update_data)
return super().update(db_session, db_obj=db_obj, obj_in=use_obj_in)
return super().update(db, db_obj=db_obj, obj_in=update_data)
def authenticate(
self, db_session: Session, *, email: str, password: str
) -> Optional[User]:
user = self.get_by_email(db_session, email=email)
def authenticate(self, db: Session, *, email: str, password: str) -> Optional[User]:
user = self.get_by_email(db, email=email)
if not user:
return None
if not verify_password(password, user.hashed_password):