Adopt SQLModel, create models, start using it (#559)

* 🔥 Remove old SQLAlchemy models

*  Add new SQLModel models

* 🔧 Update Alembic configs to work with SQLModel

*  Re-generate initial Alembic migration

* 🔧 Update PostgreSQL driver connection string URL

*  Create new SQLModel engine

* 🔥 Remove old unneeded SQLAlchemy-specific files

* ♻️ Update init_db

* ♻️ Use new SQLModel session

* ♻️ Update conftest with new DB Session

* ♻️ Update pre-start scripts to use SQLModel session

* ♻️ Import new SQLModel models

*  Create new simplified create_user crud util

* ♻️ Update import in CRUDBase class (soon to be removed)

* 🙈 Update .gitignore with Python files
This commit is contained in:
Sebastián Ramírez
2023-11-25 00:08:22 +01:00
committed by GitHub
parent 2d92cd70a4
commit a66a9256dd
26 changed files with 193 additions and 163 deletions

View File

@@ -1,2 +1,3 @@
__pycache__ __pycache__
app.egg-info app.egg-info
*.pyc

View File

@@ -20,9 +20,9 @@ fileConfig(config.config_file_name)
# target_metadata = mymodel.Base.metadata # target_metadata = mymodel.Base.metadata
# target_metadata = None # target_metadata = None
from app.db.base import Base # noqa from app.models import SQLModel # noqa
target_metadata = Base.metadata target_metadata = SQLModel.metadata
# other values from the config, defined by the needs of env.py, # other values from the config, defined by the needs of env.py,
# can be acquired: # can be acquired:
@@ -35,7 +35,7 @@ def get_url():
password = os.getenv("POSTGRES_PASSWORD", "") password = os.getenv("POSTGRES_PASSWORD", "")
server = os.getenv("POSTGRES_SERVER", "db") server = os.getenv("POSTGRES_SERVER", "db")
db = os.getenv("POSTGRES_DB", "app") db = os.getenv("POSTGRES_DB", "app")
return f"postgresql://{user}:{password}@{server}/{db}" return f"postgresql+psycopg://{user}:{password}@{server}/{db}"
def run_migrations_offline(): def run_migrations_offline():

View File

@@ -7,6 +7,7 @@ Create Date: ${create_date}
""" """
from alembic import op from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
import sqlmodel.sql.sqltypes
${imports if imports else ""} ${imports if imports else ""}
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.

View File

@@ -1,59 +0,0 @@
"""First revision
Revision ID: d4867f3a4c0a
Revises:
Create Date: 2019-04-17 13:53:32.978401
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "d4867f3a4c0a"
down_revision = None
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"user",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("full_name", sa.String(), nullable=True),
sa.Column("email", sa.String(), nullable=True),
sa.Column("hashed_password", sa.String(), nullable=True),
sa.Column("is_active", sa.Boolean(), nullable=True),
sa.Column("is_superuser", sa.Boolean(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(op.f("ix_user_email"), "user", ["email"], unique=True)
op.create_index(op.f("ix_user_full_name"), "user", ["full_name"], unique=False)
op.create_index(op.f("ix_user_id"), "user", ["id"], unique=False)
op.create_table(
"item",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("title", sa.String(), nullable=True),
sa.Column("description", sa.String(), nullable=True),
sa.Column("owner_id", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(["owner_id"], ["user.id"],),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(op.f("ix_item_description"), "item", ["description"], unique=False)
op.create_index(op.f("ix_item_id"), "item", ["id"], unique=False)
op.create_index(op.f("ix_item_title"), "item", ["title"], unique=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f("ix_item_title"), table_name="item")
op.drop_index(op.f("ix_item_id"), table_name="item")
op.drop_index(op.f("ix_item_description"), table_name="item")
op.drop_table("item")
op.drop_index(op.f("ix_user_id"), table_name="user")
op.drop_index(op.f("ix_user_full_name"), table_name="user")
op.drop_index(op.f("ix_user_email"), table_name="user")
op.drop_table("user")
# ### end Alembic commands ###

View File

@@ -0,0 +1,48 @@
"""Initialize models
Revision ID: e2412789c190
Revises:
Create Date: 2023-11-24 22:55:43.195942
"""
from alembic import op
import sqlalchemy as sa
import sqlmodel.sql.sqltypes
# revision identifiers, used by Alembic.
revision = 'e2412789c190'
down_revision = None
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('user',
sa.Column('email', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('is_active', sa.Boolean(), nullable=False),
sa.Column('is_superuser', sa.Boolean(), nullable=False),
sa.Column('full_name', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('hashed_password', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_user_email'), 'user', ['email'], unique=True)
op.create_table('item',
sa.Column('description', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('title', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('owner_id', sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(['owner_id'], ['user.id'], ),
sa.PrimaryKeyConstraint('id')
)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('item')
op.drop_index(op.f('ix_user_email'), table_name='user')
op.drop_table('user')
# ### end Alembic commands ###

View File

@@ -9,7 +9,7 @@ from sqlalchemy.orm import Session
from app import crud, models, schemas from app import crud, models, schemas
from app.core import security from app.core import security
from app.core.config import settings from app.core.config import settings
from app.db.session import SessionLocal from app.db.engine import engine
reusable_oauth2 = OAuth2PasswordBearer( reusable_oauth2 = OAuth2PasswordBearer(
tokenUrl=f"{settings.API_V1_STR}/login/access-token" tokenUrl=f"{settings.API_V1_STR}/login/access-token"
@@ -17,11 +17,8 @@ reusable_oauth2 = OAuth2PasswordBearer(
def get_db() -> Generator: def get_db() -> Generator:
try: with Session(engine) as session:
db = SessionLocal() yield session
yield db
finally:
db.close()
def get_current_user( def get_current_user(

View File

@@ -1,8 +1,9 @@
import logging import logging
from sqlmodel import Session, select
from tenacity import after_log, before_log, retry, stop_after_attempt, wait_fixed from tenacity import after_log, before_log, retry, stop_after_attempt, wait_fixed
from app.db.session import SessionLocal from app.db.engine import engine
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -19,9 +20,9 @@ wait_seconds = 1
) )
def init() -> None: def init() -> None:
try: try:
db = SessionLocal() with Session(engine) as session:
# Try to create session to check if DB is awake # Try to create session to check if DB is awake
db.execute("SELECT 1") session.exec(select(1))
except Exception as e: except Exception as e:
logger.error(e) logger.error(e)
raise e raise e

View File

@@ -1,8 +1,9 @@
import logging import logging
from sqlmodel import Session, select
from tenacity import after_log, before_log, retry, stop_after_attempt, wait_fixed from tenacity import after_log, before_log, retry, stop_after_attempt, wait_fixed
from app.db.session import SessionLocal from app.db.engine import engine
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -20,8 +21,8 @@ wait_seconds = 1
def init() -> None: def init() -> None:
try: try:
# Try to create session to check if DB is awake # Try to create session to check if DB is awake
db = SessionLocal() with Session(engine) as session:
db.execute("SELECT 1") session.exec(select(1))
except Exception as e: except Exception as e:
logger.error(e) logger.error(e)
raise e raise e

View File

@@ -44,7 +44,7 @@ class Settings(BaseSettings):
if isinstance(v, str): if isinstance(v, str):
return v return v
return PostgresDsn.build( return PostgresDsn.build(
scheme="postgresql", scheme="postgresql+psycopg",
user=values.get("POSTGRES_USER"), user=values.get("POSTGRES_USER"),
password=values.get("POSTGRES_PASSWORD"), password=values.get("POSTGRES_PASSWORD"),
host=values.get("POSTGRES_SERVER"), host=values.get("POSTGRES_SERVER"),

View File

@@ -8,3 +8,16 @@ from .crud_user import user
# from app.schemas.item import ItemCreate, ItemUpdate # from app.schemas.item import ItemCreate, ItemUpdate
# item = CRUDBase[Item, ItemCreate, ItemUpdate](Item) # item = CRUDBase[Item, ItemCreate, ItemUpdate](Item)
from sqlmodel import Session
from app.core.security import get_password_hash
from app.models import UserCreate, User
def create_user(session: Session, *, user_create: UserCreate) -> User:
db_obj = User.from_orm(
user_create, update={"hashed_password": get_password_hash(user_create.password)}
)
session.add(db_obj)
session.commit()
session.refresh(db_obj)
return db_obj

View File

@@ -4,9 +4,7 @@ from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.db.base_class import Base ModelType = TypeVar("ModelType", bound=Any)
ModelType = TypeVar("ModelType", bound=Base)
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel) CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel) UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)

View File

@@ -4,7 +4,7 @@ from fastapi.encoders import jsonable_encoder
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.crud.base import CRUDBase from app.crud.base import CRUDBase
from app.models.item import Item from app.models import Item
from app.schemas.item import ItemCreate, ItemUpdate from app.schemas.item import ItemCreate, ItemUpdate

View File

@@ -4,7 +4,7 @@ from sqlalchemy.orm import Session
from app.core.security import get_password_hash, verify_password from app.core.security import get_password_hash, verify_password
from app.crud.base import CRUDBase from app.crud.base import CRUDBase
from app.models.user import User from app.models import User
from app.schemas.user import UserCreate, UserUpdate from app.schemas.user import UserCreate, UserUpdate

View File

@@ -1,5 +0,0 @@
# Import all the models, so that Base has them before being
# imported by Alembic
from app.db.base_class import Base # noqa
from app.models.item import Item # noqa
from app.models.user import User # noqa

View File

@@ -1,13 +0,0 @@
from typing import Any
from sqlalchemy.ext.declarative import as_declarative, declared_attr
@as_declarative()
class Base:
id: Any
__name__: str
# Generate __tablename__ automatically
@declared_attr
def __tablename__(cls) -> str:
return cls.__name__.lower()

View File

@@ -0,0 +1,5 @@
from sqlmodel import create_engine
from app.core.config import settings
engine = create_engine(settings.SQLALCHEMY_DATABASE_URI)

View File

@@ -1,25 +1,26 @@
from sqlalchemy.orm import Session from sqlmodel import Session, select
from app import crud, schemas from app import crud
from app.core.config import settings from app.core.config import settings
from app.db import base # noqa: F401 from app.models import User, UserCreate # noqa: F401
# make sure all SQL Alchemy models are imported (app.db.base) before initializing DB # make sure all SQLModel models are imported (app.models) before initializing DB
# otherwise, SQL Alchemy might fail to initialize relationships properly # otherwise, SQLModel might fail to initialize relationships properly
# for more details: https://github.com/tiangolo/full-stack-fastapi-postgresql/issues/28 # for more details: https://github.com/tiangolo/full-stack-fastapi-postgresql/issues/28
def init_db(db: Session) -> None: def init_db(session: Session) -> None:
# Tables should be created with Alembic migrations # Tables should be created with Alembic migrations
# But if you don't want to use migrations, create # But if you don't want to use migrations, create
# the tables un-commenting the next line # the tables un-commenting the next line
# Base.metadata.create_all(bind=engine) # Base.metadata.create_all(bind=engine)
user = session.exec(
user = crud.user.get_by_email(db, email=settings.FIRST_SUPERUSER) select(User).where(User.email == settings.FIRST_SUPERUSER)
).first()
if not user: if not user:
user_in = schemas.UserCreate( user_in = UserCreate(
email=settings.FIRST_SUPERUSER, email=settings.FIRST_SUPERUSER,
password=settings.FIRST_SUPERUSER_PASSWORD, password=settings.FIRST_SUPERUSER_PASSWORD,
is_superuser=True, is_superuser=True,
) )
user = crud.user.create(db, obj_in=user_in) # noqa: F841 user = crud.create_user(session, user_create=user_in)

View File

@@ -1,7 +0,0 @@
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from app.core.config import settings
engine = create_engine(settings.SQLALCHEMY_DATABASE_URI, pool_pre_ping=True)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

View File

@@ -1,15 +1,17 @@
import logging import logging
from sqlmodel import Session
from app.db.engine import engine
from app.db.init_db import init_db from app.db.init_db import init_db
from app.db.session import SessionLocal
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def init() -> None: def init() -> None:
db = SessionLocal() with Session(engine) as session:
init_db(db) init_db(session)
def main() -> None: def main() -> None:

View File

@@ -0,0 +1,82 @@
from typing import Union
from pydantic import BaseModel, EmailStr
from sqlmodel import Field, Relationship, SQLModel
# Shared properties
class UserBase(SQLModel):
email: EmailStr = Field(unique=True, index=True)
is_active: bool = True
is_superuser: bool = False
full_name: Union[str, None] = None
# Properties to receive via API on creation
class UserCreate(UserBase):
password: str
# Properties to receive via API on update, all are optional
class UserUpdate(UserBase):
email: Union[EmailStr, None] = None
password: Union[str, None] = None
# Database model, database table inferred from class name
class User(UserBase, table=True):
id: Union[int, None] = Field(default=None, primary_key=True)
hashed_password: str
items: list["Item"] = Relationship(back_populates="owner")
# Properties to return via API, id is always required
class UserOut(UserBase):
id: int
# Shared properties
class ItemBase(SQLModel):
title: str
description: Union[str, None] = None
# Properties to receive on item creation
class ItemCreate(ItemBase):
title: str
# Properties to receive on item update
class ItemUpdate(ItemBase):
title: Union[str, None] = None
# Database model, database table inferred from class name
class Item(ItemBase, table=True):
id: Union[int, None] = Field(default=None, primary_key=True)
title: str
owner_id: Union[int, None] = Field(
default=None, foreign_key="user.id", nullable=False
)
owner: Union[User, None] = Relationship(back_populates="items")
# Properties to return via API, id is always required
class ItemOut(ItemBase):
id: int
# Generic message
class Msg(BaseModel):
msg: str
# JSON payload containing access token
class Token(BaseModel):
access_token: str
token_type: str
# Contents of JWT token
class TokenPayload(BaseModel):
sub: Union[int, None] = None

View File

@@ -1,2 +0,0 @@
from .item import Item
from .user import User

View File

@@ -1,17 +0,0 @@
from typing import TYPE_CHECKING
from sqlalchemy import Column, ForeignKey, Integer, String
from sqlalchemy.orm import relationship
from app.db.base_class import Base
if TYPE_CHECKING:
from .user import User # noqa: F401
class Item(Base):
id = Column(Integer, primary_key=True, index=True)
title = Column(String, index=True)
description = Column(String, index=True)
owner_id = Column(Integer, ForeignKey("user.id"))
owner = relationship("User", back_populates="items")

View File

@@ -1,19 +0,0 @@
from typing import TYPE_CHECKING
from sqlalchemy import Boolean, Column, Integer, String
from sqlalchemy.orm import relationship
from app.db.base_class import Base
if TYPE_CHECKING:
from .item import Item # noqa: F401
class User(Base):
id = Column(Integer, primary_key=True, index=True)
full_name = Column(String, index=True)
email = Column(String, unique=True, index=True, nullable=False)
hashed_password = Column(String, nullable=False)
is_active = Column(Boolean(), default=True)
is_superuser = Column(Boolean(), default=False)
items = relationship("Item", back_populates="owner")

View File

@@ -5,7 +5,7 @@ from fastapi.testclient import TestClient
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.core.config import settings from app.core.config import settings
from app.db.session import SessionLocal from app.db.engine import engine
from app.main import app from app.main import app
from app.tests.utils.user import authentication_token_from_email from app.tests.utils.user import authentication_token_from_email
from app.tests.utils.utils import get_superuser_token_headers from app.tests.utils.utils import get_superuser_token_headers
@@ -13,7 +13,8 @@ from app.tests.utils.utils import get_superuser_token_headers
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def db() -> Generator: def db() -> Generator:
yield SessionLocal() with Session(engine) as session:
yield session
@pytest.fixture(scope="module") @pytest.fixture(scope="module")

View File

@@ -5,7 +5,7 @@ from sqlalchemy.orm import Session
from app import crud from app import crud
from app.core.config import settings from app.core.config import settings
from app.models.user import User from app.models import User
from app.schemas.user import UserCreate, UserUpdate from app.schemas.user import UserCreate, UserUpdate
from app.tests.utils.utils import random_email, random_lower_string from app.tests.utils.utils import random_email, random_lower_string

View File

@@ -1,8 +1,9 @@
import logging import logging
from sqlmodel import Session, select
from tenacity import after_log, before_log, retry, stop_after_attempt, wait_fixed from tenacity import after_log, before_log, retry, stop_after_attempt, wait_fixed
from app.db.session import SessionLocal from app.db.engine import engine
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -20,8 +21,8 @@ wait_seconds = 1
def init() -> None: def init() -> None:
try: try:
# Try to create session to check if DB is awake # Try to create session to check if DB is awake
db = SessionLocal() with Session(engine) as session:
db.execute("SELECT 1") session.exec(select(1))
except Exception as e: except Exception as e:
logger.error(e) logger.error(e)
raise e raise e