✨ Adopt SQLModel, create models, start using it (#559)
* 🔥 Remove old SQLAlchemy models * ✨ Add new SQLModel models * 🔧 Update Alembic configs to work with SQLModel * ✨ Re-generate initial Alembic migration * 🔧 Update PostgreSQL driver connection string URL * ✨ Create new SQLModel engine * 🔥 Remove old unneeded SQLAlchemy-specific files * ♻️ Update init_db * ♻️ Use new SQLModel session * ♻️ Update conftest with new DB Session * ♻️ Update pre-start scripts to use SQLModel session * ♻️ Import new SQLModel models * ✨ Create new simplified create_user crud util * ♻️ Update import in CRUDBase class (soon to be removed) * 🙈 Update .gitignore with Python files
This commit is contained in:

committed by
GitHub

parent
2d92cd70a4
commit
a66a9256dd
@@ -20,9 +20,9 @@ fileConfig(config.config_file_name)
|
||||
# target_metadata = mymodel.Base.metadata
|
||||
# target_metadata = None
|
||||
|
||||
from app.db.base import Base # noqa
|
||||
from app.models import SQLModel # noqa
|
||||
|
||||
target_metadata = Base.metadata
|
||||
target_metadata = SQLModel.metadata
|
||||
|
||||
# other values from the config, defined by the needs of env.py,
|
||||
# can be acquired:
|
||||
@@ -35,7 +35,7 @@ def get_url():
|
||||
password = os.getenv("POSTGRES_PASSWORD", "")
|
||||
server = os.getenv("POSTGRES_SERVER", "db")
|
||||
db = os.getenv("POSTGRES_DB", "app")
|
||||
return f"postgresql://{user}:{password}@{server}/{db}"
|
||||
return f"postgresql+psycopg://{user}:{password}@{server}/{db}"
|
||||
|
||||
|
||||
def run_migrations_offline():
|
||||
|
@@ -7,6 +7,7 @@ Create Date: ${create_date}
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel.sql.sqltypes
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
|
@@ -1,59 +0,0 @@
|
||||
"""First revision
|
||||
|
||||
Revision ID: d4867f3a4c0a
|
||||
Revises:
|
||||
Create Date: 2019-04-17 13:53:32.978401
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "d4867f3a4c0a"
|
||||
down_revision = None
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"user",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("full_name", sa.String(), nullable=True),
|
||||
sa.Column("email", sa.String(), nullable=True),
|
||||
sa.Column("hashed_password", sa.String(), nullable=True),
|
||||
sa.Column("is_active", sa.Boolean(), nullable=True),
|
||||
sa.Column("is_superuser", sa.Boolean(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(op.f("ix_user_email"), "user", ["email"], unique=True)
|
||||
op.create_index(op.f("ix_user_full_name"), "user", ["full_name"], unique=False)
|
||||
op.create_index(op.f("ix_user_id"), "user", ["id"], unique=False)
|
||||
op.create_table(
|
||||
"item",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("title", sa.String(), nullable=True),
|
||||
sa.Column("description", sa.String(), nullable=True),
|
||||
sa.Column("owner_id", sa.Integer(), nullable=True),
|
||||
sa.ForeignKeyConstraint(["owner_id"], ["user.id"],),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(op.f("ix_item_description"), "item", ["description"], unique=False)
|
||||
op.create_index(op.f("ix_item_id"), "item", ["id"], unique=False)
|
||||
op.create_index(op.f("ix_item_title"), "item", ["title"], unique=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(op.f("ix_item_title"), table_name="item")
|
||||
op.drop_index(op.f("ix_item_id"), table_name="item")
|
||||
op.drop_index(op.f("ix_item_description"), table_name="item")
|
||||
op.drop_table("item")
|
||||
op.drop_index(op.f("ix_user_id"), table_name="user")
|
||||
op.drop_index(op.f("ix_user_full_name"), table_name="user")
|
||||
op.drop_index(op.f("ix_user_email"), table_name="user")
|
||||
op.drop_table("user")
|
||||
# ### end Alembic commands ###
|
@@ -0,0 +1,48 @@
|
||||
"""Initialize models
|
||||
|
||||
Revision ID: e2412789c190
|
||||
Revises:
|
||||
Create Date: 2023-11-24 22:55:43.195942
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel.sql.sqltypes
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'e2412789c190'
|
||||
down_revision = None
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('user',
|
||||
sa.Column('email', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False),
|
||||
sa.Column('is_superuser', sa.Boolean(), nullable=False),
|
||||
sa.Column('full_name', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('hashed_password', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_user_email'), 'user', ['email'], unique=True)
|
||||
op.create_table('item',
|
||||
sa.Column('description', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('title', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column('owner_id', sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['owner_id'], ['user.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table('item')
|
||||
op.drop_index(op.f('ix_user_email'), table_name='user')
|
||||
op.drop_table('user')
|
||||
# ### end Alembic commands ###
|
@@ -9,7 +9,7 @@ from sqlalchemy.orm import Session
|
||||
from app import crud, models, schemas
|
||||
from app.core import security
|
||||
from app.core.config import settings
|
||||
from app.db.session import SessionLocal
|
||||
from app.db.engine import engine
|
||||
|
||||
reusable_oauth2 = OAuth2PasswordBearer(
|
||||
tokenUrl=f"{settings.API_V1_STR}/login/access-token"
|
||||
@@ -17,11 +17,8 @@ reusable_oauth2 = OAuth2PasswordBearer(
|
||||
|
||||
|
||||
def get_db() -> Generator:
|
||||
try:
|
||||
db = SessionLocal()
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
with Session(engine) as session:
|
||||
yield session
|
||||
|
||||
|
||||
def get_current_user(
|
||||
|
@@ -1,8 +1,9 @@
|
||||
import logging
|
||||
|
||||
from sqlmodel import Session, select
|
||||
from tenacity import after_log, before_log, retry, stop_after_attempt, wait_fixed
|
||||
|
||||
from app.db.session import SessionLocal
|
||||
from app.db.engine import engine
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -19,9 +20,9 @@ wait_seconds = 1
|
||||
)
|
||||
def init() -> None:
|
||||
try:
|
||||
db = SessionLocal()
|
||||
# Try to create session to check if DB is awake
|
||||
db.execute("SELECT 1")
|
||||
with Session(engine) as session:
|
||||
# Try to create session to check if DB is awake
|
||||
session.exec(select(1))
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
raise e
|
||||
|
@@ -1,8 +1,9 @@
|
||||
import logging
|
||||
|
||||
from sqlmodel import Session, select
|
||||
from tenacity import after_log, before_log, retry, stop_after_attempt, wait_fixed
|
||||
|
||||
from app.db.session import SessionLocal
|
||||
from app.db.engine import engine
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -20,8 +21,8 @@ wait_seconds = 1
|
||||
def init() -> None:
|
||||
try:
|
||||
# Try to create session to check if DB is awake
|
||||
db = SessionLocal()
|
||||
db.execute("SELECT 1")
|
||||
with Session(engine) as session:
|
||||
session.exec(select(1))
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
raise e
|
||||
|
@@ -44,7 +44,7 @@ class Settings(BaseSettings):
|
||||
if isinstance(v, str):
|
||||
return v
|
||||
return PostgresDsn.build(
|
||||
scheme="postgresql",
|
||||
scheme="postgresql+psycopg",
|
||||
user=values.get("POSTGRES_USER"),
|
||||
password=values.get("POSTGRES_PASSWORD"),
|
||||
host=values.get("POSTGRES_SERVER"),
|
||||
|
@@ -8,3 +8,16 @@ from .crud_user import user
|
||||
# from app.schemas.item import ItemCreate, ItemUpdate
|
||||
|
||||
# item = CRUDBase[Item, ItemCreate, ItemUpdate](Item)
|
||||
from sqlmodel import Session
|
||||
from app.core.security import get_password_hash
|
||||
from app.models import UserCreate, User
|
||||
|
||||
|
||||
def create_user(session: Session, *, user_create: UserCreate) -> User:
|
||||
db_obj = User.from_orm(
|
||||
user_create, update={"hashed_password": get_password_hash(user_create.password)}
|
||||
)
|
||||
session.add(db_obj)
|
||||
session.commit()
|
||||
session.refresh(db_obj)
|
||||
return db_obj
|
||||
|
@@ -4,9 +4,7 @@ from fastapi.encoders import jsonable_encoder
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.db.base_class import Base
|
||||
|
||||
ModelType = TypeVar("ModelType", bound=Base)
|
||||
ModelType = TypeVar("ModelType", bound=Any)
|
||||
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
|
||||
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
|
||||
|
||||
|
@@ -4,7 +4,7 @@ from fastapi.encoders import jsonable_encoder
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.crud.base import CRUDBase
|
||||
from app.models.item import Item
|
||||
from app.models import Item
|
||||
from app.schemas.item import ItemCreate, ItemUpdate
|
||||
|
||||
|
||||
|
@@ -4,7 +4,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.security import get_password_hash, verify_password
|
||||
from app.crud.base import CRUDBase
|
||||
from app.models.user import User
|
||||
from app.models import User
|
||||
from app.schemas.user import UserCreate, UserUpdate
|
||||
|
||||
|
||||
|
@@ -1,5 +0,0 @@
|
||||
# Import all the models, so that Base has them before being
|
||||
# imported by Alembic
|
||||
from app.db.base_class import Base # noqa
|
||||
from app.models.item import Item # noqa
|
||||
from app.models.user import User # noqa
|
@@ -1,13 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.ext.declarative import as_declarative, declared_attr
|
||||
|
||||
|
||||
@as_declarative()
|
||||
class Base:
|
||||
id: Any
|
||||
__name__: str
|
||||
# Generate __tablename__ automatically
|
||||
@declared_attr
|
||||
def __tablename__(cls) -> str:
|
||||
return cls.__name__.lower()
|
5
src/backend/app/app/db/engine.py
Normal file
5
src/backend/app/app/db/engine.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from sqlmodel import create_engine
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
engine = create_engine(settings.SQLALCHEMY_DATABASE_URI)
|
@@ -1,25 +1,26 @@
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlmodel import Session, select
|
||||
|
||||
from app import crud, schemas
|
||||
from app import crud
|
||||
from app.core.config import settings
|
||||
from app.db import base # noqa: F401
|
||||
from app.models import User, UserCreate # noqa: F401
|
||||
|
||||
# make sure all SQL Alchemy models are imported (app.db.base) before initializing DB
|
||||
# otherwise, SQL Alchemy might fail to initialize relationships properly
|
||||
# make sure all SQLModel models are imported (app.models) before initializing DB
|
||||
# otherwise, SQLModel might fail to initialize relationships properly
|
||||
# for more details: https://github.com/tiangolo/full-stack-fastapi-postgresql/issues/28
|
||||
|
||||
|
||||
def init_db(db: Session) -> None:
|
||||
def init_db(session: Session) -> None:
|
||||
# Tables should be created with Alembic migrations
|
||||
# But if you don't want to use migrations, create
|
||||
# the tables un-commenting the next line
|
||||
# Base.metadata.create_all(bind=engine)
|
||||
|
||||
user = crud.user.get_by_email(db, email=settings.FIRST_SUPERUSER)
|
||||
user = session.exec(
|
||||
select(User).where(User.email == settings.FIRST_SUPERUSER)
|
||||
).first()
|
||||
if not user:
|
||||
user_in = schemas.UserCreate(
|
||||
user_in = UserCreate(
|
||||
email=settings.FIRST_SUPERUSER,
|
||||
password=settings.FIRST_SUPERUSER_PASSWORD,
|
||||
is_superuser=True,
|
||||
)
|
||||
user = crud.user.create(db, obj_in=user_in) # noqa: F841
|
||||
user = crud.create_user(session, user_create=user_in)
|
||||
|
@@ -1,7 +0,0 @@
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
engine = create_engine(settings.SQLALCHEMY_DATABASE_URI, pool_pre_ping=True)
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
@@ -1,15 +1,17 @@
|
||||
import logging
|
||||
|
||||
from sqlmodel import Session
|
||||
|
||||
from app.db.engine import engine
|
||||
from app.db.init_db import init_db
|
||||
from app.db.session import SessionLocal
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def init() -> None:
|
||||
db = SessionLocal()
|
||||
init_db(db)
|
||||
with Session(engine) as session:
|
||||
init_db(session)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
|
82
src/backend/app/app/models.py
Normal file
82
src/backend/app/app/models.py
Normal file
@@ -0,0 +1,82 @@
|
||||
from typing import Union
|
||||
|
||||
from pydantic import BaseModel, EmailStr
|
||||
from sqlmodel import Field, Relationship, SQLModel
|
||||
|
||||
|
||||
# Shared properties
|
||||
class UserBase(SQLModel):
|
||||
email: EmailStr = Field(unique=True, index=True)
|
||||
is_active: bool = True
|
||||
is_superuser: bool = False
|
||||
full_name: Union[str, None] = None
|
||||
|
||||
|
||||
# Properties to receive via API on creation
|
||||
class UserCreate(UserBase):
|
||||
password: str
|
||||
|
||||
|
||||
# Properties to receive via API on update, all are optional
|
||||
class UserUpdate(UserBase):
|
||||
email: Union[EmailStr, None] = None
|
||||
password: Union[str, None] = None
|
||||
|
||||
|
||||
# Database model, database table inferred from class name
|
||||
class User(UserBase, table=True):
|
||||
id: Union[int, None] = Field(default=None, primary_key=True)
|
||||
hashed_password: str
|
||||
items: list["Item"] = Relationship(back_populates="owner")
|
||||
|
||||
|
||||
# Properties to return via API, id is always required
|
||||
class UserOut(UserBase):
|
||||
id: int
|
||||
|
||||
|
||||
# Shared properties
|
||||
class ItemBase(SQLModel):
|
||||
title: str
|
||||
description: Union[str, None] = None
|
||||
|
||||
|
||||
# Properties to receive on item creation
|
||||
class ItemCreate(ItemBase):
|
||||
title: str
|
||||
|
||||
|
||||
# Properties to receive on item update
|
||||
class ItemUpdate(ItemBase):
|
||||
title: Union[str, None] = None
|
||||
|
||||
|
||||
# Database model, database table inferred from class name
|
||||
class Item(ItemBase, table=True):
|
||||
id: Union[int, None] = Field(default=None, primary_key=True)
|
||||
title: str
|
||||
owner_id: Union[int, None] = Field(
|
||||
default=None, foreign_key="user.id", nullable=False
|
||||
)
|
||||
owner: Union[User, None] = Relationship(back_populates="items")
|
||||
|
||||
|
||||
# Properties to return via API, id is always required
|
||||
class ItemOut(ItemBase):
|
||||
id: int
|
||||
|
||||
|
||||
# Generic message
|
||||
class Msg(BaseModel):
|
||||
msg: str
|
||||
|
||||
|
||||
# JSON payload containing access token
|
||||
class Token(BaseModel):
|
||||
access_token: str
|
||||
token_type: str
|
||||
|
||||
|
||||
# Contents of JWT token
|
||||
class TokenPayload(BaseModel):
|
||||
sub: Union[int, None] = None
|
@@ -1,2 +0,0 @@
|
||||
from .item import Item
|
||||
from .user import User
|
@@ -1,17 +0,0 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import Column, ForeignKey, Integer, String
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.db.base_class import Base
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .user import User # noqa: F401
|
||||
|
||||
|
||||
class Item(Base):
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
title = Column(String, index=True)
|
||||
description = Column(String, index=True)
|
||||
owner_id = Column(Integer, ForeignKey("user.id"))
|
||||
owner = relationship("User", back_populates="items")
|
@@ -1,19 +0,0 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import Boolean, Column, Integer, String
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.db.base_class import Base
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .item import Item # noqa: F401
|
||||
|
||||
|
||||
class User(Base):
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
full_name = Column(String, index=True)
|
||||
email = Column(String, unique=True, index=True, nullable=False)
|
||||
hashed_password = Column(String, nullable=False)
|
||||
is_active = Column(Boolean(), default=True)
|
||||
is_superuser = Column(Boolean(), default=False)
|
||||
items = relationship("Item", back_populates="owner")
|
@@ -5,7 +5,7 @@ from fastapi.testclient import TestClient
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.config import settings
|
||||
from app.db.session import SessionLocal
|
||||
from app.db.engine import engine
|
||||
from app.main import app
|
||||
from app.tests.utils.user import authentication_token_from_email
|
||||
from app.tests.utils.utils import get_superuser_token_headers
|
||||
@@ -13,7 +13,8 @@ from app.tests.utils.utils import get_superuser_token_headers
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def db() -> Generator:
|
||||
yield SessionLocal()
|
||||
with Session(engine) as session:
|
||||
yield session
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
|
@@ -5,7 +5,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from app import crud
|
||||
from app.core.config import settings
|
||||
from app.models.user import User
|
||||
from app.models import User
|
||||
from app.schemas.user import UserCreate, UserUpdate
|
||||
from app.tests.utils.utils import random_email, random_lower_string
|
||||
|
||||
|
@@ -1,8 +1,9 @@
|
||||
import logging
|
||||
|
||||
from sqlmodel import Session, select
|
||||
from tenacity import after_log, before_log, retry, stop_after_attempt, wait_fixed
|
||||
|
||||
from app.db.session import SessionLocal
|
||||
from app.db.engine import engine
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -20,8 +21,8 @@ wait_seconds = 1
|
||||
def init() -> None:
|
||||
try:
|
||||
# Try to create session to check if DB is awake
|
||||
db = SessionLocal()
|
||||
db.execute("SELECT 1")
|
||||
with Session(engine) as session:
|
||||
session.exec(select(1))
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
raise e
|
||||
|
Reference in New Issue
Block a user