🏷️ Add mypy to the GitHub Action for tests and fixed types in the whole project (#655)

Co-authored-by: Sebastián Ramírez <tiangolo@gmail.com>
This commit is contained in:
Esteban Maya
2024-03-10 14:47:21 -05:00
committed by GitHub
parent 6607eaded4
commit a230f4fb2c
19 changed files with 106 additions and 79 deletions

View File

@@ -17,7 +17,7 @@ reusable_oauth2 = OAuth2PasswordBearer(
)
def get_db() -> Generator:
def get_db() -> Generator[Session, None, None]:
with Session(engine) as session:
yield session

View File

@@ -1,7 +1,7 @@
from typing import Any
from fastapi import APIRouter, Depends, HTTPException
from sqlmodel import delete, func, select
from sqlmodel import col, delete, func, select
from app import crud
from app.api.deps import (
@@ -189,16 +189,17 @@ def delete_user(
user = session.get(User, user_id)
if not user:
raise HTTPException(status_code=404, detail="User not found")
if (user == current_user and not current_user.is_superuser) or (
user != current_user and current_user.is_superuser
):
statement = delete(Item).where(Item.owner_id == user_id)
session.exec(statement)
session.delete(user)
session.commit()
return Message(message="User deleted successfully")
elif user != current_user and not current_user.is_superuser:
raise HTTPException(
status_code=403, detail="The user doesn't have enough privileges"
)
elif user == current_user and current_user.is_superuser:
raise HTTPException(
status_code=403, detail="Super users are not allowed to delete themselves"
)
statement = delete(Item).where(col(Item.owner_id) == user_id)
session.exec(statement) # type: ignore
session.delete(user)
session.commit()
return Message(message="User deleted successfully")

View File

@@ -68,7 +68,7 @@ class Settings(BaseSettings):
@field_validator("EMAILS_FROM_NAME")
def get_project_name(cls, v: str | None, info: ValidationInfo) -> str:
if not v:
return info.data["PROJECT_NAME"]
return str(info.data["PROJECT_NAME"])
return v
EMAIL_RESET_TOKEN_EXPIRE_HOURS: int = 48
@@ -89,7 +89,7 @@ class Settings(BaseSettings):
FIRST_SUPERUSER: str
FIRST_SUPERUSER_PASSWORD: str
USERS_OPEN_REGISTRATION: bool = False
model_config = SettingsConfigDict(case_sensitive=True)
model_config = SettingsConfigDict(env_file=".env")
settings = Settings()
settings = Settings() # type: ignore

View File

@@ -12,13 +12,8 @@ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
ALGORITHM = "HS256"
def create_access_token(subject: str | Any, expires_delta: timedelta = None) -> str:
if expires_delta:
def create_access_token(subject: str | Any, expires_delta: timedelta) -> str:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
)
to_encode = {"exp": expire, "sub": str(subject)}
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt

View File

@@ -6,7 +6,7 @@ from app.api.main import api_router
from app.core.config import settings
def custom_generate_unique_id(route: APIRoute):
def custom_generate_unique_id(route: APIRoute) -> str:
return f"{route.tags[0]}-{route.name}"

View File

@@ -25,7 +25,7 @@ class UserCreateOpen(SQLModel):
# Properties to receive via API on update, all are optional
# TODO replace email str with EmailStr when sqlmodel supports it
class UserUpdate(UserBase):
email: str | None = None
email: str | None = None # type: ignore
password: str | None = None
@@ -70,7 +70,7 @@ class ItemCreate(ItemBase):
# Properties to receive on item update
class ItemUpdate(ItemBase):
title: str | None = None
title: str | None = None # type: ignore
# Database model, database table inferred from class name

View File

@@ -6,7 +6,7 @@ from app.tests.utils.item import create_random_item
def test_create_item(
client: TestClient, superuser_token_headers: dict, db: Session
client: TestClient, superuser_token_headers: dict[str, str], db: Session
) -> None:
data = {"title": "Foo", "description": "Fighters"}
response = client.post(
@@ -23,7 +23,7 @@ def test_create_item(
def test_read_item(
client: TestClient, superuser_token_headers: dict, db: Session
client: TestClient, superuser_token_headers: dict[str, str], db: Session
) -> None:
item = create_random_item(db)
response = client.get(
@@ -39,7 +39,7 @@ def test_read_item(
def test_read_item_not_found(
client: TestClient, superuser_token_headers: dict, db: Session
client: TestClient, superuser_token_headers: dict[str, str], db: Session
) -> None:
response = client.get(
f"{settings.API_V1_STR}/items/999",
@@ -51,7 +51,7 @@ def test_read_item_not_found(
def test_read_item_not_enough_permissions(
client: TestClient, normal_user_token_headers: dict, db: Session
client: TestClient, normal_user_token_headers: dict[str, str], db: Session
) -> None:
item = create_random_item(db)
response = client.get(
@@ -64,7 +64,7 @@ def test_read_item_not_enough_permissions(
def test_read_items(
client: TestClient, superuser_token_headers: dict, db: Session
client: TestClient, superuser_token_headers: dict[str, str], db: Session
) -> None:
create_random_item(db)
create_random_item(db)
@@ -78,7 +78,7 @@ def test_read_items(
def test_update_item(
client: TestClient, superuser_token_headers: dict, db: Session
client: TestClient, superuser_token_headers: dict[str, str], db: Session
) -> None:
item = create_random_item(db)
data = {"title": "Updated title", "description": "Updated description"}
@@ -96,7 +96,7 @@ def test_update_item(
def test_update_item_not_found(
client: TestClient, superuser_token_headers: dict, db: Session
client: TestClient, superuser_token_headers: dict[str, str], db: Session
) -> None:
data = {"title": "Updated title", "description": "Updated description"}
response = client.put(
@@ -110,7 +110,7 @@ def test_update_item_not_found(
def test_update_item_not_enough_permissions(
client: TestClient, normal_user_token_headers: dict, db: Session
client: TestClient, normal_user_token_headers: dict[str, str], db: Session
) -> None:
item = create_random_item(db)
data = {"title": "Updated title", "description": "Updated description"}
@@ -125,7 +125,7 @@ def test_update_item_not_enough_permissions(
def test_delete_item(
client: TestClient, superuser_token_headers: dict, db: Session
client: TestClient, superuser_token_headers: dict[str, str], db: Session
) -> None:
item = create_random_item(db)
response = client.delete(
@@ -138,7 +138,7 @@ def test_delete_item(
def test_delete_item_not_found(
client: TestClient, superuser_token_headers: dict, db: Session
client: TestClient, superuser_token_headers: dict[str, str], db: Session
) -> None:
response = client.delete(
f"{settings.API_V1_STR}/items/999",
@@ -150,7 +150,7 @@ def test_delete_item_not_found(
def test_delete_item_not_enough_permissions(
client: TestClient, normal_user_token_headers: dict, db: Session
client: TestClient, normal_user_token_headers: dict[str, str], db: Session
) -> None:
item = create_random_item(db)
response = client.delete(

View File

@@ -1,7 +1,8 @@
from app.utils import generate_password_reset_token
from fastapi.testclient import TestClient
from pytest_mock import MockerFixture
from app.core.config import settings
from app.utils import generate_password_reset_token
def test_get_access_token(client: TestClient) -> None:
@@ -38,7 +39,7 @@ def test_use_access_token(
def test_recovery_password(
client: TestClient, normal_user_token_headers: dict[str, str], mocker
client: TestClient, normal_user_token_headers: dict[str, str], mocker: MockerFixture
) -> None:
mocker.patch("app.utils.send_email", return_value=None)
mocker.patch("app.core.config.settings.EMAILS_ENABLED", True)

View File

@@ -1,4 +1,5 @@
from fastapi.testclient import TestClient
from pytest_mock import MockerFixture
from sqlmodel import Session
from app import crud
@@ -30,7 +31,10 @@ def test_get_users_normal_user_me(
def test_create_user_new_email(
client: TestClient, superuser_token_headers: dict, db: Session, mocker
client: TestClient,
superuser_token_headers: dict[str, str],
db: Session,
mocker: MockerFixture,
) -> None:
mocker.patch("app.utils.send_email")
mocker.patch("app.core.config.settings.EMAILS_ENABLED", True)
@@ -50,7 +54,7 @@ def test_create_user_new_email(
def test_get_existing_user(
client: TestClient, superuser_token_headers: dict, db: Session
client: TestClient, superuser_token_headers: dict[str, str], db: Session
) -> None:
username = random_email()
password = random_lower_string()
@@ -107,7 +111,7 @@ def test_get_existing_user_permissions_error(
def test_create_user_existing_username(
client: TestClient, superuser_token_headers: dict, db: Session
client: TestClient, superuser_token_headers: dict[str, str], db: Session
) -> None:
username = random_email()
# username = email
@@ -140,7 +144,7 @@ def test_create_user_by_normal_user(
def test_retrieve_users(
client: TestClient, superuser_token_headers: dict, db: Session
client: TestClient, superuser_token_headers: dict[str, str], db: Session
) -> None:
username = random_email()
password = random_lower_string()
@@ -179,7 +183,7 @@ def test_update_user_me(
def test_update_password_me(
client: TestClient, superuser_token_headers: dict, db: Session
client: TestClient, superuser_token_headers: dict[str, str], db: Session
) -> None:
new_password = random_lower_string()
data = {
@@ -209,7 +213,7 @@ def test_update_password_me(
def test_update_password_me_incorrect_password(
client: TestClient, superuser_token_headers: dict, db: Session
client: TestClient, superuser_token_headers: dict[str, str], db: Session
) -> None:
new_password = random_lower_string()
data = {"current_password": new_password, "new_password": new_password}
@@ -224,7 +228,7 @@ def test_update_password_me_incorrect_password(
def test_update_password_me_same_password_error(
client: TestClient, superuser_token_headers: dict, db: Session
client: TestClient, superuser_token_headers: dict[str, str], db: Session
) -> None:
data = {
"current_password": settings.FIRST_SUPERUSER_PASSWORD,
@@ -242,7 +246,7 @@ def test_update_password_me_same_password_error(
)
def test_create_user_open(client: TestClient, mocker) -> None:
def test_create_user_open(client: TestClient, mocker: MockerFixture) -> None:
mocker.patch("app.core.config.settings.USERS_OPEN_REGISTRATION", True)
username = random_email()
password = random_lower_string()
@@ -258,7 +262,9 @@ def test_create_user_open(client: TestClient, mocker) -> None:
assert created_user["full_name"] == full_name
def test_create_user_open_forbidden_error(client: TestClient, mocker) -> None:
def test_create_user_open_forbidden_error(
client: TestClient, mocker: MockerFixture
) -> None:
mocker.patch("app.core.config.settings.USERS_OPEN_REGISTRATION", False)
username = random_email()
password = random_lower_string()
@@ -272,7 +278,9 @@ def test_create_user_open_forbidden_error(client: TestClient, mocker) -> None:
assert r.json()["detail"] == "Open user registration is forbidden on this server"
def test_create_user_open_already_exists_error(client: TestClient, mocker) -> None:
def test_create_user_open_already_exists_error(
client: TestClient, mocker: MockerFixture
) -> None:
mocker.patch("app.core.config.settings.USERS_OPEN_REGISTRATION", True)
password = random_lower_string()
full_name = random_lower_string()
@@ -382,6 +390,7 @@ def test_delete_user_current_super_user_error(
client: TestClient, superuser_token_headers: dict[str, str], db: Session
) -> None:
super_user = crud.get_user_by_email(session=db, email=settings.FIRST_SUPERUSER)
assert super_user
user_id = super_user.id
r = client.delete(

View File

@@ -13,7 +13,7 @@ from app.tests.utils.utils import get_superuser_token_headers
@pytest.fixture(scope="session", autouse=True)
def db() -> Generator:
def db() -> Generator[Session, None, None]:
with Session(engine) as session:
init_db(session)
yield session
@@ -25,7 +25,7 @@ def db() -> Generator:
@pytest.fixture(scope="module")
def client() -> Generator:
def client() -> Generator[TestClient, None, None]:
with TestClient(app) as c:
yield c

View File

@@ -1,21 +1,22 @@
from unittest.mock import MagicMock
from pytest_mock import MockerFixture
from sqlmodel import select
from app.backend_pre_start import init, logger
def test_init_successful_connection(mocker):
def test_init_successful_connection(mocker: MockerFixture) -> None:
engine_mock = MagicMock()
session_mock = MagicMock()
exec_mock = MagicMock(return_value=True)
session_mock.configure_mock(**{'exec.return_value': exec_mock})
mocker.patch('sqlmodel.Session', return_value=session_mock)
session_mock.configure_mock(**{"exec.return_value": exec_mock})
mocker.patch("sqlmodel.Session", return_value=session_mock)
mocker.patch.object(logger, 'info')
mocker.patch.object(logger, 'error')
mocker.patch.object(logger, 'warn')
mocker.patch.object(logger, "info")
mocker.patch.object(logger, "error")
mocker.patch.object(logger, "warn")
try:
init(engine_mock)
@@ -23,6 +24,10 @@ def test_init_successful_connection(mocker):
except Exception:
connection_successful = False
assert connection_successful, "The database connection should be successful and not raise an exception."
assert (
connection_successful
), "The database connection should be successful and not raise an exception."
assert session_mock.exec.called_once_with(select(1)), "The session should execute a select statement once."
assert session_mock.exec.called_once_with(
select(1)
), "The session should execute a select statement once."

View File

@@ -1,21 +1,22 @@
from unittest.mock import MagicMock
from pytest_mock import MockerFixture
from sqlmodel import select
from app.celeryworker_pre_start import init, logger
def test_init_successful_connection(mocker):
def test_init_successful_connection(mocker: MockerFixture) -> None:
engine_mock = MagicMock()
session_mock = MagicMock()
exec_mock = MagicMock(return_value=True)
session_mock.configure_mock(**{'exec.return_value': exec_mock})
mocker.patch('sqlmodel.Session', return_value=session_mock)
session_mock.configure_mock(**{"exec.return_value": exec_mock})
mocker.patch("sqlmodel.Session", return_value=session_mock)
mocker.patch.object(logger, 'info')
mocker.patch.object(logger, 'error')
mocker.patch.object(logger, 'warn')
mocker.patch.object(logger, "info")
mocker.patch.object(logger, "error")
mocker.patch.object(logger, "warn")
try:
init(engine_mock)
@@ -23,6 +24,10 @@ def test_init_successful_connection(mocker):
except Exception:
connection_successful = False
assert connection_successful, "The database connection should be successful and not raise an exception."
assert (
connection_successful
), "The database connection should be successful and not raise an exception."
assert session_mock.exec.called_once_with(select(1)), "The session should execute a select statement once."
assert session_mock.exec.called_once_with(
select(1)
), "The session should execute a select statement once."

View File

@@ -1,21 +1,22 @@
from unittest.mock import MagicMock
from pytest_mock import MockerFixture
from sqlmodel import select
from app.tests_pre_start import init, logger
def test_init_successful_connection(mocker):
def test_init_successful_connection(mocker: MockerFixture) -> None:
engine_mock = MagicMock()
session_mock = MagicMock()
exec_mock = MagicMock(return_value=True)
session_mock.configure_mock(**{'exec.return_value': exec_mock})
mocker.patch('sqlmodel.Session', return_value=session_mock)
session_mock.configure_mock(**{"exec.return_value": exec_mock})
mocker.patch("sqlmodel.Session", return_value=session_mock)
mocker.patch.object(logger, 'info')
mocker.patch.object(logger, 'error')
mocker.patch.object(logger, 'warn')
mocker.patch.object(logger, "info")
mocker.patch.object(logger, "error")
mocker.patch.object(logger, "warn")
try:
init(engine_mock)
@@ -23,6 +24,10 @@ def test_init_successful_connection(mocker):
except Exception:
connection_successful = False
assert connection_successful, "The database connection should be successful and not raise an exception."
assert (
connection_successful
), "The database connection should be successful and not raise an exception."
assert session_mock.exec.called_once_with(select(1)), "The session should execute a select statement once."
assert session_mock.exec.called_once_with(
select(1)
), "The session should execute a select statement once."

View File

@@ -42,6 +42,8 @@ def authentication_token_from_email(
user = crud.create_user(session=db, user_create=user_in_create)
else:
user_in_update = UserUpdate(password=password)
if not user.id:
raise Exception("User id not set")
user = crud.update_user(session=db, user_id=user.id, user_in=user_in_update)
return user_authentication_headers(client=client, email=email, password=password)

View File

@@ -4,7 +4,7 @@ from datetime import datetime, timedelta
from pathlib import Path
from typing import Any
import emails
import emails # type: ignore
from jinja2 import Template
from jose import JWTError, jwt
@@ -109,6 +109,6 @@ def generate_password_reset_token(email: str) -> str:
def verify_password_reset_token(token: str) -> str | None:
try:
decoded_token = jwt.decode(token, settings.SECRET_KEY, algorithms=["HS256"])
return decoded_token["sub"]
return str(decoded_token["sub"])
except JWTError:
return None

View File

@@ -3,7 +3,7 @@ import sentry_sdk
from app.core.celery_app import celery_app
from app.core.config import settings
sentry_sdk.init(dsn=settings.SENTRY_DSN)
sentry_sdk.init(dsn=str(settings.SENTRY_DSN))
@celery_app.task(acks_late=True)

View File

@@ -35,6 +35,9 @@ mypy = "^1.8.0"
ruff = "^0.2.2"
pre-commit = "^3.6.2"
pytest-mock = "^3.12.0"
types-python-jose = "^3.3.4.20240106"
types-passlib = "^1.7.7.20240106"
celery-types = "^0.22.0"
[tool.isort]
multi_line_output = 3
@@ -47,6 +50,7 @@ build-backend = "poetry.masonry.api"
[tool.mypy]
strict = true
exclude = ["venv", "alembic"]
[tool.ruff]
target-version = "py310"

View File

@@ -3,6 +3,5 @@
set -x
mypy app
black app --check
isort --recursive --check-only app
flake8
ruff app
ruff format app --check

View File

@@ -3,4 +3,5 @@ set -e
python /app/app/tests_pre_start.py
bash ./scripts/lint.sh
bash ./scripts/test.sh "$@"