🏷️ Add mypy to the GitHub Action for tests and fixed types in the whole project (#655)

Co-authored-by: Sebastián Ramírez <tiangolo@gmail.com>
This commit is contained in:
Esteban Maya
2024-03-10 14:47:21 -05:00
committed by GitHub
parent 6607eaded4
commit a230f4fb2c
19 changed files with 106 additions and 79 deletions

View File

@@ -17,7 +17,7 @@ reusable_oauth2 = OAuth2PasswordBearer(
) )
def get_db() -> Generator: def get_db() -> Generator[Session, None, None]:
with Session(engine) as session: with Session(engine) as session:
yield session yield session

View File

@@ -1,7 +1,7 @@
from typing import Any from typing import Any
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from sqlmodel import delete, func, select from sqlmodel import col, delete, func, select
from app import crud from app import crud
from app.api.deps import ( from app.api.deps import (
@@ -189,16 +189,17 @@ def delete_user(
user = session.get(User, user_id) user = session.get(User, user_id)
if not user: if not user:
raise HTTPException(status_code=404, detail="User not found") raise HTTPException(status_code=404, detail="User not found")
elif user != current_user and not current_user.is_superuser:
if (user == current_user and not current_user.is_superuser) or ( raise HTTPException(
user != current_user and current_user.is_superuser status_code=403, detail="The user doesn't have enough privileges"
): )
statement = delete(Item).where(Item.owner_id == user_id)
session.exec(statement)
session.delete(user)
session.commit()
return Message(message="User deleted successfully")
elif user == current_user and current_user.is_superuser: elif user == current_user and current_user.is_superuser:
raise HTTPException( raise HTTPException(
status_code=403, detail="Super users are not allowed to delete themselves" status_code=403, detail="Super users are not allowed to delete themselves"
) )
statement = delete(Item).where(col(Item.owner_id) == user_id)
session.exec(statement) # type: ignore
session.delete(user)
session.commit()
return Message(message="User deleted successfully")

View File

@@ -68,7 +68,7 @@ class Settings(BaseSettings):
@field_validator("EMAILS_FROM_NAME") @field_validator("EMAILS_FROM_NAME")
def get_project_name(cls, v: str | None, info: ValidationInfo) -> str: def get_project_name(cls, v: str | None, info: ValidationInfo) -> str:
if not v: if not v:
return info.data["PROJECT_NAME"] return str(info.data["PROJECT_NAME"])
return v return v
EMAIL_RESET_TOKEN_EXPIRE_HOURS: int = 48 EMAIL_RESET_TOKEN_EXPIRE_HOURS: int = 48
@@ -89,7 +89,7 @@ class Settings(BaseSettings):
FIRST_SUPERUSER: str FIRST_SUPERUSER: str
FIRST_SUPERUSER_PASSWORD: str FIRST_SUPERUSER_PASSWORD: str
USERS_OPEN_REGISTRATION: bool = False USERS_OPEN_REGISTRATION: bool = False
model_config = SettingsConfigDict(case_sensitive=True) model_config = SettingsConfigDict(env_file=".env")
settings = Settings() settings = Settings() # type: ignore

View File

@@ -12,13 +12,8 @@ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
ALGORITHM = "HS256" ALGORITHM = "HS256"
def create_access_token(subject: str | Any, expires_delta: timedelta = None) -> str: def create_access_token(subject: str | Any, expires_delta: timedelta) -> str:
if expires_delta: expire = datetime.utcnow() + expires_delta
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
)
to_encode = {"exp": expire, "sub": str(subject)} to_encode = {"exp": expire, "sub": str(subject)}
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM) encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt return encoded_jwt

View File

@@ -6,7 +6,7 @@ from app.api.main import api_router
from app.core.config import settings from app.core.config import settings
def custom_generate_unique_id(route: APIRoute): def custom_generate_unique_id(route: APIRoute) -> str:
return f"{route.tags[0]}-{route.name}" return f"{route.tags[0]}-{route.name}"

View File

@@ -25,7 +25,7 @@ class UserCreateOpen(SQLModel):
# Properties to receive via API on update, all are optional # Properties to receive via API on update, all are optional
# TODO replace email str with EmailStr when sqlmodel supports it # TODO replace email str with EmailStr when sqlmodel supports it
class UserUpdate(UserBase): class UserUpdate(UserBase):
email: str | None = None email: str | None = None # type: ignore
password: str | None = None password: str | None = None
@@ -70,7 +70,7 @@ class ItemCreate(ItemBase):
# Properties to receive on item update # Properties to receive on item update
class ItemUpdate(ItemBase): class ItemUpdate(ItemBase):
title: str | None = None title: str | None = None # type: ignore
# Database model, database table inferred from class name # Database model, database table inferred from class name

View File

@@ -6,7 +6,7 @@ from app.tests.utils.item import create_random_item
def test_create_item( def test_create_item(
client: TestClient, superuser_token_headers: dict, db: Session client: TestClient, superuser_token_headers: dict[str, str], db: Session
) -> None: ) -> None:
data = {"title": "Foo", "description": "Fighters"} data = {"title": "Foo", "description": "Fighters"}
response = client.post( response = client.post(
@@ -23,7 +23,7 @@ def test_create_item(
def test_read_item( def test_read_item(
client: TestClient, superuser_token_headers: dict, db: Session client: TestClient, superuser_token_headers: dict[str, str], db: Session
) -> None: ) -> None:
item = create_random_item(db) item = create_random_item(db)
response = client.get( response = client.get(
@@ -39,7 +39,7 @@ def test_read_item(
def test_read_item_not_found( def test_read_item_not_found(
client: TestClient, superuser_token_headers: dict, db: Session client: TestClient, superuser_token_headers: dict[str, str], db: Session
) -> None: ) -> None:
response = client.get( response = client.get(
f"{settings.API_V1_STR}/items/999", f"{settings.API_V1_STR}/items/999",
@@ -51,7 +51,7 @@ def test_read_item_not_found(
def test_read_item_not_enough_permissions( def test_read_item_not_enough_permissions(
client: TestClient, normal_user_token_headers: dict, db: Session client: TestClient, normal_user_token_headers: dict[str, str], db: Session
) -> None: ) -> None:
item = create_random_item(db) item = create_random_item(db)
response = client.get( response = client.get(
@@ -64,7 +64,7 @@ def test_read_item_not_enough_permissions(
def test_read_items( def test_read_items(
client: TestClient, superuser_token_headers: dict, db: Session client: TestClient, superuser_token_headers: dict[str, str], db: Session
) -> None: ) -> None:
create_random_item(db) create_random_item(db)
create_random_item(db) create_random_item(db)
@@ -78,7 +78,7 @@ def test_read_items(
def test_update_item( def test_update_item(
client: TestClient, superuser_token_headers: dict, db: Session client: TestClient, superuser_token_headers: dict[str, str], db: Session
) -> None: ) -> None:
item = create_random_item(db) item = create_random_item(db)
data = {"title": "Updated title", "description": "Updated description"} data = {"title": "Updated title", "description": "Updated description"}
@@ -96,7 +96,7 @@ def test_update_item(
def test_update_item_not_found( def test_update_item_not_found(
client: TestClient, superuser_token_headers: dict, db: Session client: TestClient, superuser_token_headers: dict[str, str], db: Session
) -> None: ) -> None:
data = {"title": "Updated title", "description": "Updated description"} data = {"title": "Updated title", "description": "Updated description"}
response = client.put( response = client.put(
@@ -110,7 +110,7 @@ def test_update_item_not_found(
def test_update_item_not_enough_permissions( def test_update_item_not_enough_permissions(
client: TestClient, normal_user_token_headers: dict, db: Session client: TestClient, normal_user_token_headers: dict[str, str], db: Session
) -> None: ) -> None:
item = create_random_item(db) item = create_random_item(db)
data = {"title": "Updated title", "description": "Updated description"} data = {"title": "Updated title", "description": "Updated description"}
@@ -125,7 +125,7 @@ def test_update_item_not_enough_permissions(
def test_delete_item( def test_delete_item(
client: TestClient, superuser_token_headers: dict, db: Session client: TestClient, superuser_token_headers: dict[str, str], db: Session
) -> None: ) -> None:
item = create_random_item(db) item = create_random_item(db)
response = client.delete( response = client.delete(
@@ -138,7 +138,7 @@ def test_delete_item(
def test_delete_item_not_found( def test_delete_item_not_found(
client: TestClient, superuser_token_headers: dict, db: Session client: TestClient, superuser_token_headers: dict[str, str], db: Session
) -> None: ) -> None:
response = client.delete( response = client.delete(
f"{settings.API_V1_STR}/items/999", f"{settings.API_V1_STR}/items/999",
@@ -150,7 +150,7 @@ def test_delete_item_not_found(
def test_delete_item_not_enough_permissions( def test_delete_item_not_enough_permissions(
client: TestClient, normal_user_token_headers: dict, db: Session client: TestClient, normal_user_token_headers: dict[str, str], db: Session
) -> None: ) -> None:
item = create_random_item(db) item = create_random_item(db)
response = client.delete( response = client.delete(

View File

@@ -1,7 +1,8 @@
from app.utils import generate_password_reset_token
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from pytest_mock import MockerFixture
from app.core.config import settings from app.core.config import settings
from app.utils import generate_password_reset_token
def test_get_access_token(client: TestClient) -> None: def test_get_access_token(client: TestClient) -> None:
@@ -38,7 +39,7 @@ def test_use_access_token(
def test_recovery_password( def test_recovery_password(
client: TestClient, normal_user_token_headers: dict[str, str], mocker client: TestClient, normal_user_token_headers: dict[str, str], mocker: MockerFixture
) -> None: ) -> None:
mocker.patch("app.utils.send_email", return_value=None) mocker.patch("app.utils.send_email", return_value=None)
mocker.patch("app.core.config.settings.EMAILS_ENABLED", True) mocker.patch("app.core.config.settings.EMAILS_ENABLED", True)

View File

@@ -1,4 +1,5 @@
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from pytest_mock import MockerFixture
from sqlmodel import Session from sqlmodel import Session
from app import crud from app import crud
@@ -30,7 +31,10 @@ def test_get_users_normal_user_me(
def test_create_user_new_email( def test_create_user_new_email(
client: TestClient, superuser_token_headers: dict, db: Session, mocker client: TestClient,
superuser_token_headers: dict[str, str],
db: Session,
mocker: MockerFixture,
) -> None: ) -> None:
mocker.patch("app.utils.send_email") mocker.patch("app.utils.send_email")
mocker.patch("app.core.config.settings.EMAILS_ENABLED", True) mocker.patch("app.core.config.settings.EMAILS_ENABLED", True)
@@ -50,7 +54,7 @@ def test_create_user_new_email(
def test_get_existing_user( def test_get_existing_user(
client: TestClient, superuser_token_headers: dict, db: Session client: TestClient, superuser_token_headers: dict[str, str], db: Session
) -> None: ) -> None:
username = random_email() username = random_email()
password = random_lower_string() password = random_lower_string()
@@ -107,7 +111,7 @@ def test_get_existing_user_permissions_error(
def test_create_user_existing_username( def test_create_user_existing_username(
client: TestClient, superuser_token_headers: dict, db: Session client: TestClient, superuser_token_headers: dict[str, str], db: Session
) -> None: ) -> None:
username = random_email() username = random_email()
# username = email # username = email
@@ -140,7 +144,7 @@ def test_create_user_by_normal_user(
def test_retrieve_users( def test_retrieve_users(
client: TestClient, superuser_token_headers: dict, db: Session client: TestClient, superuser_token_headers: dict[str, str], db: Session
) -> None: ) -> None:
username = random_email() username = random_email()
password = random_lower_string() password = random_lower_string()
@@ -179,7 +183,7 @@ def test_update_user_me(
def test_update_password_me( def test_update_password_me(
client: TestClient, superuser_token_headers: dict, db: Session client: TestClient, superuser_token_headers: dict[str, str], db: Session
) -> None: ) -> None:
new_password = random_lower_string() new_password = random_lower_string()
data = { data = {
@@ -209,7 +213,7 @@ def test_update_password_me(
def test_update_password_me_incorrect_password( def test_update_password_me_incorrect_password(
client: TestClient, superuser_token_headers: dict, db: Session client: TestClient, superuser_token_headers: dict[str, str], db: Session
) -> None: ) -> None:
new_password = random_lower_string() new_password = random_lower_string()
data = {"current_password": new_password, "new_password": new_password} data = {"current_password": new_password, "new_password": new_password}
@@ -224,7 +228,7 @@ def test_update_password_me_incorrect_password(
def test_update_password_me_same_password_error( def test_update_password_me_same_password_error(
client: TestClient, superuser_token_headers: dict, db: Session client: TestClient, superuser_token_headers: dict[str, str], db: Session
) -> None: ) -> None:
data = { data = {
"current_password": settings.FIRST_SUPERUSER_PASSWORD, "current_password": settings.FIRST_SUPERUSER_PASSWORD,
@@ -242,7 +246,7 @@ def test_update_password_me_same_password_error(
) )
def test_create_user_open(client: TestClient, mocker) -> None: def test_create_user_open(client: TestClient, mocker: MockerFixture) -> None:
mocker.patch("app.core.config.settings.USERS_OPEN_REGISTRATION", True) mocker.patch("app.core.config.settings.USERS_OPEN_REGISTRATION", True)
username = random_email() username = random_email()
password = random_lower_string() password = random_lower_string()
@@ -258,7 +262,9 @@ def test_create_user_open(client: TestClient, mocker) -> None:
assert created_user["full_name"] == full_name assert created_user["full_name"] == full_name
def test_create_user_open_forbidden_error(client: TestClient, mocker) -> None: def test_create_user_open_forbidden_error(
client: TestClient, mocker: MockerFixture
) -> None:
mocker.patch("app.core.config.settings.USERS_OPEN_REGISTRATION", False) mocker.patch("app.core.config.settings.USERS_OPEN_REGISTRATION", False)
username = random_email() username = random_email()
password = random_lower_string() password = random_lower_string()
@@ -272,7 +278,9 @@ def test_create_user_open_forbidden_error(client: TestClient, mocker) -> None:
assert r.json()["detail"] == "Open user registration is forbidden on this server" assert r.json()["detail"] == "Open user registration is forbidden on this server"
def test_create_user_open_already_exists_error(client: TestClient, mocker) -> None: def test_create_user_open_already_exists_error(
client: TestClient, mocker: MockerFixture
) -> None:
mocker.patch("app.core.config.settings.USERS_OPEN_REGISTRATION", True) mocker.patch("app.core.config.settings.USERS_OPEN_REGISTRATION", True)
password = random_lower_string() password = random_lower_string()
full_name = random_lower_string() full_name = random_lower_string()
@@ -382,6 +390,7 @@ def test_delete_user_current_super_user_error(
client: TestClient, superuser_token_headers: dict[str, str], db: Session client: TestClient, superuser_token_headers: dict[str, str], db: Session
) -> None: ) -> None:
super_user = crud.get_user_by_email(session=db, email=settings.FIRST_SUPERUSER) super_user = crud.get_user_by_email(session=db, email=settings.FIRST_SUPERUSER)
assert super_user
user_id = super_user.id user_id = super_user.id
r = client.delete( r = client.delete(

View File

@@ -13,7 +13,7 @@ from app.tests.utils.utils import get_superuser_token_headers
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
def db() -> Generator: def db() -> Generator[Session, None, None]:
with Session(engine) as session: with Session(engine) as session:
init_db(session) init_db(session)
yield session yield session
@@ -25,7 +25,7 @@ def db() -> Generator:
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def client() -> Generator: def client() -> Generator[TestClient, None, None]:
with TestClient(app) as c: with TestClient(app) as c:
yield c yield c

View File

@@ -1,21 +1,22 @@
from unittest.mock import MagicMock from unittest.mock import MagicMock
from pytest_mock import MockerFixture
from sqlmodel import select from sqlmodel import select
from app.backend_pre_start import init, logger from app.backend_pre_start import init, logger
def test_init_successful_connection(mocker): def test_init_successful_connection(mocker: MockerFixture) -> None:
engine_mock = MagicMock() engine_mock = MagicMock()
session_mock = MagicMock() session_mock = MagicMock()
exec_mock = MagicMock(return_value=True) exec_mock = MagicMock(return_value=True)
session_mock.configure_mock(**{'exec.return_value': exec_mock}) session_mock.configure_mock(**{"exec.return_value": exec_mock})
mocker.patch('sqlmodel.Session', return_value=session_mock) mocker.patch("sqlmodel.Session", return_value=session_mock)
mocker.patch.object(logger, 'info') mocker.patch.object(logger, "info")
mocker.patch.object(logger, 'error') mocker.patch.object(logger, "error")
mocker.patch.object(logger, 'warn') mocker.patch.object(logger, "warn")
try: try:
init(engine_mock) init(engine_mock)
@@ -23,6 +24,10 @@ def test_init_successful_connection(mocker):
except Exception: except Exception:
connection_successful = False connection_successful = False
assert connection_successful, "The database connection should be successful and not raise an exception." assert (
connection_successful
), "The database connection should be successful and not raise an exception."
assert session_mock.exec.called_once_with(select(1)), "The session should execute a select statement once." assert session_mock.exec.called_once_with(
select(1)
), "The session should execute a select statement once."

View File

@@ -1,21 +1,22 @@
from unittest.mock import MagicMock from unittest.mock import MagicMock
from pytest_mock import MockerFixture
from sqlmodel import select from sqlmodel import select
from app.celeryworker_pre_start import init, logger from app.celeryworker_pre_start import init, logger
def test_init_successful_connection(mocker): def test_init_successful_connection(mocker: MockerFixture) -> None:
engine_mock = MagicMock() engine_mock = MagicMock()
session_mock = MagicMock() session_mock = MagicMock()
exec_mock = MagicMock(return_value=True) exec_mock = MagicMock(return_value=True)
session_mock.configure_mock(**{'exec.return_value': exec_mock}) session_mock.configure_mock(**{"exec.return_value": exec_mock})
mocker.patch('sqlmodel.Session', return_value=session_mock) mocker.patch("sqlmodel.Session", return_value=session_mock)
mocker.patch.object(logger, 'info') mocker.patch.object(logger, "info")
mocker.patch.object(logger, 'error') mocker.patch.object(logger, "error")
mocker.patch.object(logger, 'warn') mocker.patch.object(logger, "warn")
try: try:
init(engine_mock) init(engine_mock)
@@ -23,6 +24,10 @@ def test_init_successful_connection(mocker):
except Exception: except Exception:
connection_successful = False connection_successful = False
assert connection_successful, "The database connection should be successful and not raise an exception." assert (
connection_successful
), "The database connection should be successful and not raise an exception."
assert session_mock.exec.called_once_with(select(1)), "The session should execute a select statement once." assert session_mock.exec.called_once_with(
select(1)
), "The session should execute a select statement once."

View File

@@ -1,21 +1,22 @@
from unittest.mock import MagicMock from unittest.mock import MagicMock
from pytest_mock import MockerFixture
from sqlmodel import select from sqlmodel import select
from app.tests_pre_start import init, logger from app.tests_pre_start import init, logger
def test_init_successful_connection(mocker): def test_init_successful_connection(mocker: MockerFixture) -> None:
engine_mock = MagicMock() engine_mock = MagicMock()
session_mock = MagicMock() session_mock = MagicMock()
exec_mock = MagicMock(return_value=True) exec_mock = MagicMock(return_value=True)
session_mock.configure_mock(**{'exec.return_value': exec_mock}) session_mock.configure_mock(**{"exec.return_value": exec_mock})
mocker.patch('sqlmodel.Session', return_value=session_mock) mocker.patch("sqlmodel.Session", return_value=session_mock)
mocker.patch.object(logger, 'info') mocker.patch.object(logger, "info")
mocker.patch.object(logger, 'error') mocker.patch.object(logger, "error")
mocker.patch.object(logger, 'warn') mocker.patch.object(logger, "warn")
try: try:
init(engine_mock) init(engine_mock)
@@ -23,6 +24,10 @@ def test_init_successful_connection(mocker):
except Exception: except Exception:
connection_successful = False connection_successful = False
assert connection_successful, "The database connection should be successful and not raise an exception." assert (
connection_successful
), "The database connection should be successful and not raise an exception."
assert session_mock.exec.called_once_with(select(1)), "The session should execute a select statement once." assert session_mock.exec.called_once_with(
select(1)
), "The session should execute a select statement once."

View File

@@ -42,6 +42,8 @@ def authentication_token_from_email(
user = crud.create_user(session=db, user_create=user_in_create) user = crud.create_user(session=db, user_create=user_in_create)
else: else:
user_in_update = UserUpdate(password=password) user_in_update = UserUpdate(password=password)
if not user.id:
raise Exception("User id not set")
user = crud.update_user(session=db, user_id=user.id, user_in=user_in_update) user = crud.update_user(session=db, user_id=user.id, user_in=user_in_update)
return user_authentication_headers(client=client, email=email, password=password) return user_authentication_headers(client=client, email=email, password=password)

View File

@@ -4,7 +4,7 @@ from datetime import datetime, timedelta
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
import emails import emails # type: ignore
from jinja2 import Template from jinja2 import Template
from jose import JWTError, jwt from jose import JWTError, jwt
@@ -109,6 +109,6 @@ def generate_password_reset_token(email: str) -> str:
def verify_password_reset_token(token: str) -> str | None: def verify_password_reset_token(token: str) -> str | None:
try: try:
decoded_token = jwt.decode(token, settings.SECRET_KEY, algorithms=["HS256"]) decoded_token = jwt.decode(token, settings.SECRET_KEY, algorithms=["HS256"])
return decoded_token["sub"] return str(decoded_token["sub"])
except JWTError: except JWTError:
return None return None

View File

@@ -3,7 +3,7 @@ import sentry_sdk
from app.core.celery_app import celery_app from app.core.celery_app import celery_app
from app.core.config import settings from app.core.config import settings
sentry_sdk.init(dsn=settings.SENTRY_DSN) sentry_sdk.init(dsn=str(settings.SENTRY_DSN))
@celery_app.task(acks_late=True) @celery_app.task(acks_late=True)

View File

@@ -35,6 +35,9 @@ mypy = "^1.8.0"
ruff = "^0.2.2" ruff = "^0.2.2"
pre-commit = "^3.6.2" pre-commit = "^3.6.2"
pytest-mock = "^3.12.0" pytest-mock = "^3.12.0"
types-python-jose = "^3.3.4.20240106"
types-passlib = "^1.7.7.20240106"
celery-types = "^0.22.0"
[tool.isort] [tool.isort]
multi_line_output = 3 multi_line_output = 3
@@ -47,6 +50,7 @@ build-backend = "poetry.masonry.api"
[tool.mypy] [tool.mypy]
strict = true strict = true
exclude = ["venv", "alembic"]
[tool.ruff] [tool.ruff]
target-version = "py310" target-version = "py310"

View File

@@ -3,6 +3,5 @@
set -x set -x
mypy app mypy app
black app --check ruff app
isort --recursive --check-only app ruff format app --check
flake8

View File

@@ -3,4 +3,5 @@ set -e
python /app/app/tests_pre_start.py python /app/app/tests_pre_start.py
bash ./scripts/lint.sh
bash ./scripts/test.sh "$@" bash ./scripts/test.sh "$@"