🏷️ Add mypy to the GitHub Action for tests and fixed types in the whole project (#655)
Co-authored-by: Sebastián Ramírez <tiangolo@gmail.com>
This commit is contained in:
@@ -17,7 +17,7 @@ reusable_oauth2 = OAuth2PasswordBearer(
|
||||
)
|
||||
|
||||
|
||||
def get_db() -> Generator:
|
||||
def get_db() -> Generator[Session, None, None]:
|
||||
with Session(engine) as session:
|
||||
yield session
|
||||
|
||||
|
@@ -1,7 +1,7 @@
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlmodel import delete, func, select
|
||||
from sqlmodel import col, delete, func, select
|
||||
|
||||
from app import crud
|
||||
from app.api.deps import (
|
||||
@@ -189,16 +189,17 @@ def delete_user(
|
||||
user = session.get(User, user_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
if (user == current_user and not current_user.is_superuser) or (
|
||||
user != current_user and current_user.is_superuser
|
||||
):
|
||||
statement = delete(Item).where(Item.owner_id == user_id)
|
||||
session.exec(statement)
|
||||
session.delete(user)
|
||||
session.commit()
|
||||
return Message(message="User deleted successfully")
|
||||
elif user != current_user and not current_user.is_superuser:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="The user doesn't have enough privileges"
|
||||
)
|
||||
elif user == current_user and current_user.is_superuser:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="Super users are not allowed to delete themselves"
|
||||
)
|
||||
|
||||
statement = delete(Item).where(col(Item.owner_id) == user_id)
|
||||
session.exec(statement) # type: ignore
|
||||
session.delete(user)
|
||||
session.commit()
|
||||
return Message(message="User deleted successfully")
|
||||
|
@@ -68,7 +68,7 @@ class Settings(BaseSettings):
|
||||
@field_validator("EMAILS_FROM_NAME")
|
||||
def get_project_name(cls, v: str | None, info: ValidationInfo) -> str:
|
||||
if not v:
|
||||
return info.data["PROJECT_NAME"]
|
||||
return str(info.data["PROJECT_NAME"])
|
||||
return v
|
||||
|
||||
EMAIL_RESET_TOKEN_EXPIRE_HOURS: int = 48
|
||||
@@ -89,7 +89,7 @@ class Settings(BaseSettings):
|
||||
FIRST_SUPERUSER: str
|
||||
FIRST_SUPERUSER_PASSWORD: str
|
||||
USERS_OPEN_REGISTRATION: bool = False
|
||||
model_config = SettingsConfigDict(case_sensitive=True)
|
||||
model_config = SettingsConfigDict(env_file=".env")
|
||||
|
||||
|
||||
settings = Settings()
|
||||
settings = Settings() # type: ignore
|
||||
|
@@ -12,13 +12,8 @@ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
ALGORITHM = "HS256"
|
||||
|
||||
|
||||
def create_access_token(subject: str | Any, expires_delta: timedelta = None) -> str:
|
||||
if expires_delta:
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
else:
|
||||
expire = datetime.utcnow() + timedelta(
|
||||
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
)
|
||||
def create_access_token(subject: str | Any, expires_delta: timedelta) -> str:
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
to_encode = {"exp": expire, "sub": str(subject)}
|
||||
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
@@ -6,7 +6,7 @@ from app.api.main import api_router
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
def custom_generate_unique_id(route: APIRoute):
|
||||
def custom_generate_unique_id(route: APIRoute) -> str:
|
||||
return f"{route.tags[0]}-{route.name}"
|
||||
|
||||
|
||||
|
@@ -25,7 +25,7 @@ class UserCreateOpen(SQLModel):
|
||||
# Properties to receive via API on update, all are optional
|
||||
# TODO replace email str with EmailStr when sqlmodel supports it
|
||||
class UserUpdate(UserBase):
|
||||
email: str | None = None
|
||||
email: str | None = None # type: ignore
|
||||
password: str | None = None
|
||||
|
||||
|
||||
@@ -70,7 +70,7 @@ class ItemCreate(ItemBase):
|
||||
|
||||
# Properties to receive on item update
|
||||
class ItemUpdate(ItemBase):
|
||||
title: str | None = None
|
||||
title: str | None = None # type: ignore
|
||||
|
||||
|
||||
# Database model, database table inferred from class name
|
||||
|
@@ -6,7 +6,7 @@ from app.tests.utils.item import create_random_item
|
||||
|
||||
|
||||
def test_create_item(
|
||||
client: TestClient, superuser_token_headers: dict, db: Session
|
||||
client: TestClient, superuser_token_headers: dict[str, str], db: Session
|
||||
) -> None:
|
||||
data = {"title": "Foo", "description": "Fighters"}
|
||||
response = client.post(
|
||||
@@ -23,7 +23,7 @@ def test_create_item(
|
||||
|
||||
|
||||
def test_read_item(
|
||||
client: TestClient, superuser_token_headers: dict, db: Session
|
||||
client: TestClient, superuser_token_headers: dict[str, str], db: Session
|
||||
) -> None:
|
||||
item = create_random_item(db)
|
||||
response = client.get(
|
||||
@@ -39,7 +39,7 @@ def test_read_item(
|
||||
|
||||
|
||||
def test_read_item_not_found(
|
||||
client: TestClient, superuser_token_headers: dict, db: Session
|
||||
client: TestClient, superuser_token_headers: dict[str, str], db: Session
|
||||
) -> None:
|
||||
response = client.get(
|
||||
f"{settings.API_V1_STR}/items/999",
|
||||
@@ -51,7 +51,7 @@ def test_read_item_not_found(
|
||||
|
||||
|
||||
def test_read_item_not_enough_permissions(
|
||||
client: TestClient, normal_user_token_headers: dict, db: Session
|
||||
client: TestClient, normal_user_token_headers: dict[str, str], db: Session
|
||||
) -> None:
|
||||
item = create_random_item(db)
|
||||
response = client.get(
|
||||
@@ -64,7 +64,7 @@ def test_read_item_not_enough_permissions(
|
||||
|
||||
|
||||
def test_read_items(
|
||||
client: TestClient, superuser_token_headers: dict, db: Session
|
||||
client: TestClient, superuser_token_headers: dict[str, str], db: Session
|
||||
) -> None:
|
||||
create_random_item(db)
|
||||
create_random_item(db)
|
||||
@@ -78,7 +78,7 @@ def test_read_items(
|
||||
|
||||
|
||||
def test_update_item(
|
||||
client: TestClient, superuser_token_headers: dict, db: Session
|
||||
client: TestClient, superuser_token_headers: dict[str, str], db: Session
|
||||
) -> None:
|
||||
item = create_random_item(db)
|
||||
data = {"title": "Updated title", "description": "Updated description"}
|
||||
@@ -96,7 +96,7 @@ def test_update_item(
|
||||
|
||||
|
||||
def test_update_item_not_found(
|
||||
client: TestClient, superuser_token_headers: dict, db: Session
|
||||
client: TestClient, superuser_token_headers: dict[str, str], db: Session
|
||||
) -> None:
|
||||
data = {"title": "Updated title", "description": "Updated description"}
|
||||
response = client.put(
|
||||
@@ -110,7 +110,7 @@ def test_update_item_not_found(
|
||||
|
||||
|
||||
def test_update_item_not_enough_permissions(
|
||||
client: TestClient, normal_user_token_headers: dict, db: Session
|
||||
client: TestClient, normal_user_token_headers: dict[str, str], db: Session
|
||||
) -> None:
|
||||
item = create_random_item(db)
|
||||
data = {"title": "Updated title", "description": "Updated description"}
|
||||
@@ -125,7 +125,7 @@ def test_update_item_not_enough_permissions(
|
||||
|
||||
|
||||
def test_delete_item(
|
||||
client: TestClient, superuser_token_headers: dict, db: Session
|
||||
client: TestClient, superuser_token_headers: dict[str, str], db: Session
|
||||
) -> None:
|
||||
item = create_random_item(db)
|
||||
response = client.delete(
|
||||
@@ -138,7 +138,7 @@ def test_delete_item(
|
||||
|
||||
|
||||
def test_delete_item_not_found(
|
||||
client: TestClient, superuser_token_headers: dict, db: Session
|
||||
client: TestClient, superuser_token_headers: dict[str, str], db: Session
|
||||
) -> None:
|
||||
response = client.delete(
|
||||
f"{settings.API_V1_STR}/items/999",
|
||||
@@ -150,7 +150,7 @@ def test_delete_item_not_found(
|
||||
|
||||
|
||||
def test_delete_item_not_enough_permissions(
|
||||
client: TestClient, normal_user_token_headers: dict, db: Session
|
||||
client: TestClient, normal_user_token_headers: dict[str, str], db: Session
|
||||
) -> None:
|
||||
item = create_random_item(db)
|
||||
response = client.delete(
|
||||
|
@@ -1,7 +1,8 @@
|
||||
from app.utils import generate_password_reset_token
|
||||
from fastapi.testclient import TestClient
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from app.core.config import settings
|
||||
from app.utils import generate_password_reset_token
|
||||
|
||||
|
||||
def test_get_access_token(client: TestClient) -> None:
|
||||
@@ -38,7 +39,7 @@ def test_use_access_token(
|
||||
|
||||
|
||||
def test_recovery_password(
|
||||
client: TestClient, normal_user_token_headers: dict[str, str], mocker
|
||||
client: TestClient, normal_user_token_headers: dict[str, str], mocker: MockerFixture
|
||||
) -> None:
|
||||
mocker.patch("app.utils.send_email", return_value=None)
|
||||
mocker.patch("app.core.config.settings.EMAILS_ENABLED", True)
|
||||
|
@@ -1,4 +1,5 @@
|
||||
from fastapi.testclient import TestClient
|
||||
from pytest_mock import MockerFixture
|
||||
from sqlmodel import Session
|
||||
|
||||
from app import crud
|
||||
@@ -30,7 +31,10 @@ def test_get_users_normal_user_me(
|
||||
|
||||
|
||||
def test_create_user_new_email(
|
||||
client: TestClient, superuser_token_headers: dict, db: Session, mocker
|
||||
client: TestClient,
|
||||
superuser_token_headers: dict[str, str],
|
||||
db: Session,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch("app.utils.send_email")
|
||||
mocker.patch("app.core.config.settings.EMAILS_ENABLED", True)
|
||||
@@ -50,7 +54,7 @@ def test_create_user_new_email(
|
||||
|
||||
|
||||
def test_get_existing_user(
|
||||
client: TestClient, superuser_token_headers: dict, db: Session
|
||||
client: TestClient, superuser_token_headers: dict[str, str], db: Session
|
||||
) -> None:
|
||||
username = random_email()
|
||||
password = random_lower_string()
|
||||
@@ -107,7 +111,7 @@ def test_get_existing_user_permissions_error(
|
||||
|
||||
|
||||
def test_create_user_existing_username(
|
||||
client: TestClient, superuser_token_headers: dict, db: Session
|
||||
client: TestClient, superuser_token_headers: dict[str, str], db: Session
|
||||
) -> None:
|
||||
username = random_email()
|
||||
# username = email
|
||||
@@ -140,7 +144,7 @@ def test_create_user_by_normal_user(
|
||||
|
||||
|
||||
def test_retrieve_users(
|
||||
client: TestClient, superuser_token_headers: dict, db: Session
|
||||
client: TestClient, superuser_token_headers: dict[str, str], db: Session
|
||||
) -> None:
|
||||
username = random_email()
|
||||
password = random_lower_string()
|
||||
@@ -179,7 +183,7 @@ def test_update_user_me(
|
||||
|
||||
|
||||
def test_update_password_me(
|
||||
client: TestClient, superuser_token_headers: dict, db: Session
|
||||
client: TestClient, superuser_token_headers: dict[str, str], db: Session
|
||||
) -> None:
|
||||
new_password = random_lower_string()
|
||||
data = {
|
||||
@@ -209,7 +213,7 @@ def test_update_password_me(
|
||||
|
||||
|
||||
def test_update_password_me_incorrect_password(
|
||||
client: TestClient, superuser_token_headers: dict, db: Session
|
||||
client: TestClient, superuser_token_headers: dict[str, str], db: Session
|
||||
) -> None:
|
||||
new_password = random_lower_string()
|
||||
data = {"current_password": new_password, "new_password": new_password}
|
||||
@@ -224,7 +228,7 @@ def test_update_password_me_incorrect_password(
|
||||
|
||||
|
||||
def test_update_password_me_same_password_error(
|
||||
client: TestClient, superuser_token_headers: dict, db: Session
|
||||
client: TestClient, superuser_token_headers: dict[str, str], db: Session
|
||||
) -> None:
|
||||
data = {
|
||||
"current_password": settings.FIRST_SUPERUSER_PASSWORD,
|
||||
@@ -242,7 +246,7 @@ def test_update_password_me_same_password_error(
|
||||
)
|
||||
|
||||
|
||||
def test_create_user_open(client: TestClient, mocker) -> None:
|
||||
def test_create_user_open(client: TestClient, mocker: MockerFixture) -> None:
|
||||
mocker.patch("app.core.config.settings.USERS_OPEN_REGISTRATION", True)
|
||||
username = random_email()
|
||||
password = random_lower_string()
|
||||
@@ -258,7 +262,9 @@ def test_create_user_open(client: TestClient, mocker) -> None:
|
||||
assert created_user["full_name"] == full_name
|
||||
|
||||
|
||||
def test_create_user_open_forbidden_error(client: TestClient, mocker) -> None:
|
||||
def test_create_user_open_forbidden_error(
|
||||
client: TestClient, mocker: MockerFixture
|
||||
) -> None:
|
||||
mocker.patch("app.core.config.settings.USERS_OPEN_REGISTRATION", False)
|
||||
username = random_email()
|
||||
password = random_lower_string()
|
||||
@@ -272,7 +278,9 @@ def test_create_user_open_forbidden_error(client: TestClient, mocker) -> None:
|
||||
assert r.json()["detail"] == "Open user registration is forbidden on this server"
|
||||
|
||||
|
||||
def test_create_user_open_already_exists_error(client: TestClient, mocker) -> None:
|
||||
def test_create_user_open_already_exists_error(
|
||||
client: TestClient, mocker: MockerFixture
|
||||
) -> None:
|
||||
mocker.patch("app.core.config.settings.USERS_OPEN_REGISTRATION", True)
|
||||
password = random_lower_string()
|
||||
full_name = random_lower_string()
|
||||
@@ -382,6 +390,7 @@ def test_delete_user_current_super_user_error(
|
||||
client: TestClient, superuser_token_headers: dict[str, str], db: Session
|
||||
) -> None:
|
||||
super_user = crud.get_user_by_email(session=db, email=settings.FIRST_SUPERUSER)
|
||||
assert super_user
|
||||
user_id = super_user.id
|
||||
|
||||
r = client.delete(
|
||||
|
@@ -13,7 +13,7 @@ from app.tests.utils.utils import get_superuser_token_headers
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def db() -> Generator:
|
||||
def db() -> Generator[Session, None, None]:
|
||||
with Session(engine) as session:
|
||||
init_db(session)
|
||||
yield session
|
||||
@@ -25,7 +25,7 @@ def db() -> Generator:
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def client() -> Generator:
|
||||
def client() -> Generator[TestClient, None, None]:
|
||||
with TestClient(app) as c:
|
||||
yield c
|
||||
|
||||
|
@@ -1,21 +1,22 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from pytest_mock import MockerFixture
|
||||
from sqlmodel import select
|
||||
|
||||
from app.backend_pre_start import init, logger
|
||||
|
||||
|
||||
def test_init_successful_connection(mocker):
|
||||
def test_init_successful_connection(mocker: MockerFixture) -> None:
|
||||
engine_mock = MagicMock()
|
||||
|
||||
session_mock = MagicMock()
|
||||
exec_mock = MagicMock(return_value=True)
|
||||
session_mock.configure_mock(**{'exec.return_value': exec_mock})
|
||||
mocker.patch('sqlmodel.Session', return_value=session_mock)
|
||||
session_mock.configure_mock(**{"exec.return_value": exec_mock})
|
||||
mocker.patch("sqlmodel.Session", return_value=session_mock)
|
||||
|
||||
mocker.patch.object(logger, 'info')
|
||||
mocker.patch.object(logger, 'error')
|
||||
mocker.patch.object(logger, 'warn')
|
||||
mocker.patch.object(logger, "info")
|
||||
mocker.patch.object(logger, "error")
|
||||
mocker.patch.object(logger, "warn")
|
||||
|
||||
try:
|
||||
init(engine_mock)
|
||||
@@ -23,6 +24,10 @@ def test_init_successful_connection(mocker):
|
||||
except Exception:
|
||||
connection_successful = False
|
||||
|
||||
assert connection_successful, "The database connection should be successful and not raise an exception."
|
||||
assert (
|
||||
connection_successful
|
||||
), "The database connection should be successful and not raise an exception."
|
||||
|
||||
assert session_mock.exec.called_once_with(select(1)), "The session should execute a select statement once."
|
||||
assert session_mock.exec.called_once_with(
|
||||
select(1)
|
||||
), "The session should execute a select statement once."
|
||||
|
@@ -1,21 +1,22 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from pytest_mock import MockerFixture
|
||||
from sqlmodel import select
|
||||
|
||||
from app.celeryworker_pre_start import init, logger
|
||||
|
||||
|
||||
def test_init_successful_connection(mocker):
|
||||
def test_init_successful_connection(mocker: MockerFixture) -> None:
|
||||
engine_mock = MagicMock()
|
||||
|
||||
session_mock = MagicMock()
|
||||
exec_mock = MagicMock(return_value=True)
|
||||
session_mock.configure_mock(**{'exec.return_value': exec_mock})
|
||||
mocker.patch('sqlmodel.Session', return_value=session_mock)
|
||||
session_mock.configure_mock(**{"exec.return_value": exec_mock})
|
||||
mocker.patch("sqlmodel.Session", return_value=session_mock)
|
||||
|
||||
mocker.patch.object(logger, 'info')
|
||||
mocker.patch.object(logger, 'error')
|
||||
mocker.patch.object(logger, 'warn')
|
||||
mocker.patch.object(logger, "info")
|
||||
mocker.patch.object(logger, "error")
|
||||
mocker.patch.object(logger, "warn")
|
||||
|
||||
try:
|
||||
init(engine_mock)
|
||||
@@ -23,6 +24,10 @@ def test_init_successful_connection(mocker):
|
||||
except Exception:
|
||||
connection_successful = False
|
||||
|
||||
assert connection_successful, "The database connection should be successful and not raise an exception."
|
||||
assert (
|
||||
connection_successful
|
||||
), "The database connection should be successful and not raise an exception."
|
||||
|
||||
assert session_mock.exec.called_once_with(select(1)), "The session should execute a select statement once."
|
||||
assert session_mock.exec.called_once_with(
|
||||
select(1)
|
||||
), "The session should execute a select statement once."
|
||||
|
@@ -1,21 +1,22 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from pytest_mock import MockerFixture
|
||||
from sqlmodel import select
|
||||
|
||||
from app.tests_pre_start import init, logger
|
||||
|
||||
|
||||
def test_init_successful_connection(mocker):
|
||||
def test_init_successful_connection(mocker: MockerFixture) -> None:
|
||||
engine_mock = MagicMock()
|
||||
|
||||
session_mock = MagicMock()
|
||||
exec_mock = MagicMock(return_value=True)
|
||||
session_mock.configure_mock(**{'exec.return_value': exec_mock})
|
||||
mocker.patch('sqlmodel.Session', return_value=session_mock)
|
||||
session_mock.configure_mock(**{"exec.return_value": exec_mock})
|
||||
mocker.patch("sqlmodel.Session", return_value=session_mock)
|
||||
|
||||
mocker.patch.object(logger, 'info')
|
||||
mocker.patch.object(logger, 'error')
|
||||
mocker.patch.object(logger, 'warn')
|
||||
mocker.patch.object(logger, "info")
|
||||
mocker.patch.object(logger, "error")
|
||||
mocker.patch.object(logger, "warn")
|
||||
|
||||
try:
|
||||
init(engine_mock)
|
||||
@@ -23,6 +24,10 @@ def test_init_successful_connection(mocker):
|
||||
except Exception:
|
||||
connection_successful = False
|
||||
|
||||
assert connection_successful, "The database connection should be successful and not raise an exception."
|
||||
assert (
|
||||
connection_successful
|
||||
), "The database connection should be successful and not raise an exception."
|
||||
|
||||
assert session_mock.exec.called_once_with(select(1)), "The session should execute a select statement once."
|
||||
assert session_mock.exec.called_once_with(
|
||||
select(1)
|
||||
), "The session should execute a select statement once."
|
||||
|
@@ -42,6 +42,8 @@ def authentication_token_from_email(
|
||||
user = crud.create_user(session=db, user_create=user_in_create)
|
||||
else:
|
||||
user_in_update = UserUpdate(password=password)
|
||||
if not user.id:
|
||||
raise Exception("User id not set")
|
||||
user = crud.update_user(session=db, user_id=user.id, user_in=user_in_update)
|
||||
|
||||
return user_authentication_headers(client=client, email=email, password=password)
|
||||
|
@@ -4,7 +4,7 @@ from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import emails
|
||||
import emails # type: ignore
|
||||
from jinja2 import Template
|
||||
from jose import JWTError, jwt
|
||||
|
||||
@@ -109,6 +109,6 @@ def generate_password_reset_token(email: str) -> str:
|
||||
def verify_password_reset_token(token: str) -> str | None:
|
||||
try:
|
||||
decoded_token = jwt.decode(token, settings.SECRET_KEY, algorithms=["HS256"])
|
||||
return decoded_token["sub"]
|
||||
return str(decoded_token["sub"])
|
||||
except JWTError:
|
||||
return None
|
||||
|
@@ -3,7 +3,7 @@ import sentry_sdk
|
||||
from app.core.celery_app import celery_app
|
||||
from app.core.config import settings
|
||||
|
||||
sentry_sdk.init(dsn=settings.SENTRY_DSN)
|
||||
sentry_sdk.init(dsn=str(settings.SENTRY_DSN))
|
||||
|
||||
|
||||
@celery_app.task(acks_late=True)
|
||||
|
Reference in New Issue
Block a user