🏷️ Add mypy to the GitHub Action for tests and fixed types in the whole project (#655)
Co-authored-by: Sebastián Ramírez <tiangolo@gmail.com>
This commit is contained in:
@@ -17,7 +17,7 @@ reusable_oauth2 = OAuth2PasswordBearer(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_db() -> Generator:
|
def get_db() -> Generator[Session, None, None]:
|
||||||
with Session(engine) as session:
|
with Session(engine) as session:
|
||||||
yield session
|
yield session
|
||||||
|
|
||||||
|
@@ -1,7 +1,7 @@
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
from sqlmodel import delete, func, select
|
from sqlmodel import col, delete, func, select
|
||||||
|
|
||||||
from app import crud
|
from app import crud
|
||||||
from app.api.deps import (
|
from app.api.deps import (
|
||||||
@@ -189,16 +189,17 @@ def delete_user(
|
|||||||
user = session.get(User, user_id)
|
user = session.get(User, user_id)
|
||||||
if not user:
|
if not user:
|
||||||
raise HTTPException(status_code=404, detail="User not found")
|
raise HTTPException(status_code=404, detail="User not found")
|
||||||
|
elif user != current_user and not current_user.is_superuser:
|
||||||
if (user == current_user and not current_user.is_superuser) or (
|
raise HTTPException(
|
||||||
user != current_user and current_user.is_superuser
|
status_code=403, detail="The user doesn't have enough privileges"
|
||||||
):
|
)
|
||||||
statement = delete(Item).where(Item.owner_id == user_id)
|
|
||||||
session.exec(statement)
|
|
||||||
session.delete(user)
|
|
||||||
session.commit()
|
|
||||||
return Message(message="User deleted successfully")
|
|
||||||
elif user == current_user and current_user.is_superuser:
|
elif user == current_user and current_user.is_superuser:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=403, detail="Super users are not allowed to delete themselves"
|
status_code=403, detail="Super users are not allowed to delete themselves"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
statement = delete(Item).where(col(Item.owner_id) == user_id)
|
||||||
|
session.exec(statement) # type: ignore
|
||||||
|
session.delete(user)
|
||||||
|
session.commit()
|
||||||
|
return Message(message="User deleted successfully")
|
||||||
|
@@ -68,7 +68,7 @@ class Settings(BaseSettings):
|
|||||||
@field_validator("EMAILS_FROM_NAME")
|
@field_validator("EMAILS_FROM_NAME")
|
||||||
def get_project_name(cls, v: str | None, info: ValidationInfo) -> str:
|
def get_project_name(cls, v: str | None, info: ValidationInfo) -> str:
|
||||||
if not v:
|
if not v:
|
||||||
return info.data["PROJECT_NAME"]
|
return str(info.data["PROJECT_NAME"])
|
||||||
return v
|
return v
|
||||||
|
|
||||||
EMAIL_RESET_TOKEN_EXPIRE_HOURS: int = 48
|
EMAIL_RESET_TOKEN_EXPIRE_HOURS: int = 48
|
||||||
@@ -89,7 +89,7 @@ class Settings(BaseSettings):
|
|||||||
FIRST_SUPERUSER: str
|
FIRST_SUPERUSER: str
|
||||||
FIRST_SUPERUSER_PASSWORD: str
|
FIRST_SUPERUSER_PASSWORD: str
|
||||||
USERS_OPEN_REGISTRATION: bool = False
|
USERS_OPEN_REGISTRATION: bool = False
|
||||||
model_config = SettingsConfigDict(case_sensitive=True)
|
model_config = SettingsConfigDict(env_file=".env")
|
||||||
|
|
||||||
|
|
||||||
settings = Settings()
|
settings = Settings() # type: ignore
|
||||||
|
@@ -12,13 +12,8 @@ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
|||||||
ALGORITHM = "HS256"
|
ALGORITHM = "HS256"
|
||||||
|
|
||||||
|
|
||||||
def create_access_token(subject: str | Any, expires_delta: timedelta = None) -> str:
|
def create_access_token(subject: str | Any, expires_delta: timedelta) -> str:
|
||||||
if expires_delta:
|
expire = datetime.utcnow() + expires_delta
|
||||||
expire = datetime.utcnow() + expires_delta
|
|
||||||
else:
|
|
||||||
expire = datetime.utcnow() + timedelta(
|
|
||||||
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
|
|
||||||
)
|
|
||||||
to_encode = {"exp": expire, "sub": str(subject)}
|
to_encode = {"exp": expire, "sub": str(subject)}
|
||||||
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM)
|
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM)
|
||||||
return encoded_jwt
|
return encoded_jwt
|
||||||
|
@@ -6,7 +6,7 @@ from app.api.main import api_router
|
|||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
|
|
||||||
|
|
||||||
def custom_generate_unique_id(route: APIRoute):
|
def custom_generate_unique_id(route: APIRoute) -> str:
|
||||||
return f"{route.tags[0]}-{route.name}"
|
return f"{route.tags[0]}-{route.name}"
|
||||||
|
|
||||||
|
|
||||||
|
@@ -25,7 +25,7 @@ class UserCreateOpen(SQLModel):
|
|||||||
# Properties to receive via API on update, all are optional
|
# Properties to receive via API on update, all are optional
|
||||||
# TODO replace email str with EmailStr when sqlmodel supports it
|
# TODO replace email str with EmailStr when sqlmodel supports it
|
||||||
class UserUpdate(UserBase):
|
class UserUpdate(UserBase):
|
||||||
email: str | None = None
|
email: str | None = None # type: ignore
|
||||||
password: str | None = None
|
password: str | None = None
|
||||||
|
|
||||||
|
|
||||||
@@ -70,7 +70,7 @@ class ItemCreate(ItemBase):
|
|||||||
|
|
||||||
# Properties to receive on item update
|
# Properties to receive on item update
|
||||||
class ItemUpdate(ItemBase):
|
class ItemUpdate(ItemBase):
|
||||||
title: str | None = None
|
title: str | None = None # type: ignore
|
||||||
|
|
||||||
|
|
||||||
# Database model, database table inferred from class name
|
# Database model, database table inferred from class name
|
||||||
|
@@ -6,7 +6,7 @@ from app.tests.utils.item import create_random_item
|
|||||||
|
|
||||||
|
|
||||||
def test_create_item(
|
def test_create_item(
|
||||||
client: TestClient, superuser_token_headers: dict, db: Session
|
client: TestClient, superuser_token_headers: dict[str, str], db: Session
|
||||||
) -> None:
|
) -> None:
|
||||||
data = {"title": "Foo", "description": "Fighters"}
|
data = {"title": "Foo", "description": "Fighters"}
|
||||||
response = client.post(
|
response = client.post(
|
||||||
@@ -23,7 +23,7 @@ def test_create_item(
|
|||||||
|
|
||||||
|
|
||||||
def test_read_item(
|
def test_read_item(
|
||||||
client: TestClient, superuser_token_headers: dict, db: Session
|
client: TestClient, superuser_token_headers: dict[str, str], db: Session
|
||||||
) -> None:
|
) -> None:
|
||||||
item = create_random_item(db)
|
item = create_random_item(db)
|
||||||
response = client.get(
|
response = client.get(
|
||||||
@@ -39,7 +39,7 @@ def test_read_item(
|
|||||||
|
|
||||||
|
|
||||||
def test_read_item_not_found(
|
def test_read_item_not_found(
|
||||||
client: TestClient, superuser_token_headers: dict, db: Session
|
client: TestClient, superuser_token_headers: dict[str, str], db: Session
|
||||||
) -> None:
|
) -> None:
|
||||||
response = client.get(
|
response = client.get(
|
||||||
f"{settings.API_V1_STR}/items/999",
|
f"{settings.API_V1_STR}/items/999",
|
||||||
@@ -51,7 +51,7 @@ def test_read_item_not_found(
|
|||||||
|
|
||||||
|
|
||||||
def test_read_item_not_enough_permissions(
|
def test_read_item_not_enough_permissions(
|
||||||
client: TestClient, normal_user_token_headers: dict, db: Session
|
client: TestClient, normal_user_token_headers: dict[str, str], db: Session
|
||||||
) -> None:
|
) -> None:
|
||||||
item = create_random_item(db)
|
item = create_random_item(db)
|
||||||
response = client.get(
|
response = client.get(
|
||||||
@@ -64,7 +64,7 @@ def test_read_item_not_enough_permissions(
|
|||||||
|
|
||||||
|
|
||||||
def test_read_items(
|
def test_read_items(
|
||||||
client: TestClient, superuser_token_headers: dict, db: Session
|
client: TestClient, superuser_token_headers: dict[str, str], db: Session
|
||||||
) -> None:
|
) -> None:
|
||||||
create_random_item(db)
|
create_random_item(db)
|
||||||
create_random_item(db)
|
create_random_item(db)
|
||||||
@@ -78,7 +78,7 @@ def test_read_items(
|
|||||||
|
|
||||||
|
|
||||||
def test_update_item(
|
def test_update_item(
|
||||||
client: TestClient, superuser_token_headers: dict, db: Session
|
client: TestClient, superuser_token_headers: dict[str, str], db: Session
|
||||||
) -> None:
|
) -> None:
|
||||||
item = create_random_item(db)
|
item = create_random_item(db)
|
||||||
data = {"title": "Updated title", "description": "Updated description"}
|
data = {"title": "Updated title", "description": "Updated description"}
|
||||||
@@ -96,7 +96,7 @@ def test_update_item(
|
|||||||
|
|
||||||
|
|
||||||
def test_update_item_not_found(
|
def test_update_item_not_found(
|
||||||
client: TestClient, superuser_token_headers: dict, db: Session
|
client: TestClient, superuser_token_headers: dict[str, str], db: Session
|
||||||
) -> None:
|
) -> None:
|
||||||
data = {"title": "Updated title", "description": "Updated description"}
|
data = {"title": "Updated title", "description": "Updated description"}
|
||||||
response = client.put(
|
response = client.put(
|
||||||
@@ -110,7 +110,7 @@ def test_update_item_not_found(
|
|||||||
|
|
||||||
|
|
||||||
def test_update_item_not_enough_permissions(
|
def test_update_item_not_enough_permissions(
|
||||||
client: TestClient, normal_user_token_headers: dict, db: Session
|
client: TestClient, normal_user_token_headers: dict[str, str], db: Session
|
||||||
) -> None:
|
) -> None:
|
||||||
item = create_random_item(db)
|
item = create_random_item(db)
|
||||||
data = {"title": "Updated title", "description": "Updated description"}
|
data = {"title": "Updated title", "description": "Updated description"}
|
||||||
@@ -125,7 +125,7 @@ def test_update_item_not_enough_permissions(
|
|||||||
|
|
||||||
|
|
||||||
def test_delete_item(
|
def test_delete_item(
|
||||||
client: TestClient, superuser_token_headers: dict, db: Session
|
client: TestClient, superuser_token_headers: dict[str, str], db: Session
|
||||||
) -> None:
|
) -> None:
|
||||||
item = create_random_item(db)
|
item = create_random_item(db)
|
||||||
response = client.delete(
|
response = client.delete(
|
||||||
@@ -138,7 +138,7 @@ def test_delete_item(
|
|||||||
|
|
||||||
|
|
||||||
def test_delete_item_not_found(
|
def test_delete_item_not_found(
|
||||||
client: TestClient, superuser_token_headers: dict, db: Session
|
client: TestClient, superuser_token_headers: dict[str, str], db: Session
|
||||||
) -> None:
|
) -> None:
|
||||||
response = client.delete(
|
response = client.delete(
|
||||||
f"{settings.API_V1_STR}/items/999",
|
f"{settings.API_V1_STR}/items/999",
|
||||||
@@ -150,7 +150,7 @@ def test_delete_item_not_found(
|
|||||||
|
|
||||||
|
|
||||||
def test_delete_item_not_enough_permissions(
|
def test_delete_item_not_enough_permissions(
|
||||||
client: TestClient, normal_user_token_headers: dict, db: Session
|
client: TestClient, normal_user_token_headers: dict[str, str], db: Session
|
||||||
) -> None:
|
) -> None:
|
||||||
item = create_random_item(db)
|
item = create_random_item(db)
|
||||||
response = client.delete(
|
response = client.delete(
|
||||||
|
@@ -1,7 +1,8 @@
|
|||||||
from app.utils import generate_password_reset_token
|
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
from pytest_mock import MockerFixture
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
|
from app.utils import generate_password_reset_token
|
||||||
|
|
||||||
|
|
||||||
def test_get_access_token(client: TestClient) -> None:
|
def test_get_access_token(client: TestClient) -> None:
|
||||||
@@ -38,7 +39,7 @@ def test_use_access_token(
|
|||||||
|
|
||||||
|
|
||||||
def test_recovery_password(
|
def test_recovery_password(
|
||||||
client: TestClient, normal_user_token_headers: dict[str, str], mocker
|
client: TestClient, normal_user_token_headers: dict[str, str], mocker: MockerFixture
|
||||||
) -> None:
|
) -> None:
|
||||||
mocker.patch("app.utils.send_email", return_value=None)
|
mocker.patch("app.utils.send_email", return_value=None)
|
||||||
mocker.patch("app.core.config.settings.EMAILS_ENABLED", True)
|
mocker.patch("app.core.config.settings.EMAILS_ENABLED", True)
|
||||||
|
@@ -1,4 +1,5 @@
|
|||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
from pytest_mock import MockerFixture
|
||||||
from sqlmodel import Session
|
from sqlmodel import Session
|
||||||
|
|
||||||
from app import crud
|
from app import crud
|
||||||
@@ -30,7 +31,10 @@ def test_get_users_normal_user_me(
|
|||||||
|
|
||||||
|
|
||||||
def test_create_user_new_email(
|
def test_create_user_new_email(
|
||||||
client: TestClient, superuser_token_headers: dict, db: Session, mocker
|
client: TestClient,
|
||||||
|
superuser_token_headers: dict[str, str],
|
||||||
|
db: Session,
|
||||||
|
mocker: MockerFixture,
|
||||||
) -> None:
|
) -> None:
|
||||||
mocker.patch("app.utils.send_email")
|
mocker.patch("app.utils.send_email")
|
||||||
mocker.patch("app.core.config.settings.EMAILS_ENABLED", True)
|
mocker.patch("app.core.config.settings.EMAILS_ENABLED", True)
|
||||||
@@ -50,7 +54,7 @@ def test_create_user_new_email(
|
|||||||
|
|
||||||
|
|
||||||
def test_get_existing_user(
|
def test_get_existing_user(
|
||||||
client: TestClient, superuser_token_headers: dict, db: Session
|
client: TestClient, superuser_token_headers: dict[str, str], db: Session
|
||||||
) -> None:
|
) -> None:
|
||||||
username = random_email()
|
username = random_email()
|
||||||
password = random_lower_string()
|
password = random_lower_string()
|
||||||
@@ -107,7 +111,7 @@ def test_get_existing_user_permissions_error(
|
|||||||
|
|
||||||
|
|
||||||
def test_create_user_existing_username(
|
def test_create_user_existing_username(
|
||||||
client: TestClient, superuser_token_headers: dict, db: Session
|
client: TestClient, superuser_token_headers: dict[str, str], db: Session
|
||||||
) -> None:
|
) -> None:
|
||||||
username = random_email()
|
username = random_email()
|
||||||
# username = email
|
# username = email
|
||||||
@@ -140,7 +144,7 @@ def test_create_user_by_normal_user(
|
|||||||
|
|
||||||
|
|
||||||
def test_retrieve_users(
|
def test_retrieve_users(
|
||||||
client: TestClient, superuser_token_headers: dict, db: Session
|
client: TestClient, superuser_token_headers: dict[str, str], db: Session
|
||||||
) -> None:
|
) -> None:
|
||||||
username = random_email()
|
username = random_email()
|
||||||
password = random_lower_string()
|
password = random_lower_string()
|
||||||
@@ -179,7 +183,7 @@ def test_update_user_me(
|
|||||||
|
|
||||||
|
|
||||||
def test_update_password_me(
|
def test_update_password_me(
|
||||||
client: TestClient, superuser_token_headers: dict, db: Session
|
client: TestClient, superuser_token_headers: dict[str, str], db: Session
|
||||||
) -> None:
|
) -> None:
|
||||||
new_password = random_lower_string()
|
new_password = random_lower_string()
|
||||||
data = {
|
data = {
|
||||||
@@ -209,7 +213,7 @@ def test_update_password_me(
|
|||||||
|
|
||||||
|
|
||||||
def test_update_password_me_incorrect_password(
|
def test_update_password_me_incorrect_password(
|
||||||
client: TestClient, superuser_token_headers: dict, db: Session
|
client: TestClient, superuser_token_headers: dict[str, str], db: Session
|
||||||
) -> None:
|
) -> None:
|
||||||
new_password = random_lower_string()
|
new_password = random_lower_string()
|
||||||
data = {"current_password": new_password, "new_password": new_password}
|
data = {"current_password": new_password, "new_password": new_password}
|
||||||
@@ -224,7 +228,7 @@ def test_update_password_me_incorrect_password(
|
|||||||
|
|
||||||
|
|
||||||
def test_update_password_me_same_password_error(
|
def test_update_password_me_same_password_error(
|
||||||
client: TestClient, superuser_token_headers: dict, db: Session
|
client: TestClient, superuser_token_headers: dict[str, str], db: Session
|
||||||
) -> None:
|
) -> None:
|
||||||
data = {
|
data = {
|
||||||
"current_password": settings.FIRST_SUPERUSER_PASSWORD,
|
"current_password": settings.FIRST_SUPERUSER_PASSWORD,
|
||||||
@@ -242,7 +246,7 @@ def test_update_password_me_same_password_error(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_create_user_open(client: TestClient, mocker) -> None:
|
def test_create_user_open(client: TestClient, mocker: MockerFixture) -> None:
|
||||||
mocker.patch("app.core.config.settings.USERS_OPEN_REGISTRATION", True)
|
mocker.patch("app.core.config.settings.USERS_OPEN_REGISTRATION", True)
|
||||||
username = random_email()
|
username = random_email()
|
||||||
password = random_lower_string()
|
password = random_lower_string()
|
||||||
@@ -258,7 +262,9 @@ def test_create_user_open(client: TestClient, mocker) -> None:
|
|||||||
assert created_user["full_name"] == full_name
|
assert created_user["full_name"] == full_name
|
||||||
|
|
||||||
|
|
||||||
def test_create_user_open_forbidden_error(client: TestClient, mocker) -> None:
|
def test_create_user_open_forbidden_error(
|
||||||
|
client: TestClient, mocker: MockerFixture
|
||||||
|
) -> None:
|
||||||
mocker.patch("app.core.config.settings.USERS_OPEN_REGISTRATION", False)
|
mocker.patch("app.core.config.settings.USERS_OPEN_REGISTRATION", False)
|
||||||
username = random_email()
|
username = random_email()
|
||||||
password = random_lower_string()
|
password = random_lower_string()
|
||||||
@@ -272,7 +278,9 @@ def test_create_user_open_forbidden_error(client: TestClient, mocker) -> None:
|
|||||||
assert r.json()["detail"] == "Open user registration is forbidden on this server"
|
assert r.json()["detail"] == "Open user registration is forbidden on this server"
|
||||||
|
|
||||||
|
|
||||||
def test_create_user_open_already_exists_error(client: TestClient, mocker) -> None:
|
def test_create_user_open_already_exists_error(
|
||||||
|
client: TestClient, mocker: MockerFixture
|
||||||
|
) -> None:
|
||||||
mocker.patch("app.core.config.settings.USERS_OPEN_REGISTRATION", True)
|
mocker.patch("app.core.config.settings.USERS_OPEN_REGISTRATION", True)
|
||||||
password = random_lower_string()
|
password = random_lower_string()
|
||||||
full_name = random_lower_string()
|
full_name = random_lower_string()
|
||||||
@@ -382,6 +390,7 @@ def test_delete_user_current_super_user_error(
|
|||||||
client: TestClient, superuser_token_headers: dict[str, str], db: Session
|
client: TestClient, superuser_token_headers: dict[str, str], db: Session
|
||||||
) -> None:
|
) -> None:
|
||||||
super_user = crud.get_user_by_email(session=db, email=settings.FIRST_SUPERUSER)
|
super_user = crud.get_user_by_email(session=db, email=settings.FIRST_SUPERUSER)
|
||||||
|
assert super_user
|
||||||
user_id = super_user.id
|
user_id = super_user.id
|
||||||
|
|
||||||
r = client.delete(
|
r = client.delete(
|
||||||
|
@@ -13,7 +13,7 @@ from app.tests.utils.utils import get_superuser_token_headers
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
def db() -> Generator:
|
def db() -> Generator[Session, None, None]:
|
||||||
with Session(engine) as session:
|
with Session(engine) as session:
|
||||||
init_db(session)
|
init_db(session)
|
||||||
yield session
|
yield session
|
||||||
@@ -25,7 +25,7 @@ def db() -> Generator:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def client() -> Generator:
|
def client() -> Generator[TestClient, None, None]:
|
||||||
with TestClient(app) as c:
|
with TestClient(app) as c:
|
||||||
yield c
|
yield c
|
||||||
|
|
||||||
|
@@ -1,21 +1,22 @@
|
|||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
from pytest_mock import MockerFixture
|
||||||
from sqlmodel import select
|
from sqlmodel import select
|
||||||
|
|
||||||
from app.backend_pre_start import init, logger
|
from app.backend_pre_start import init, logger
|
||||||
|
|
||||||
|
|
||||||
def test_init_successful_connection(mocker):
|
def test_init_successful_connection(mocker: MockerFixture) -> None:
|
||||||
engine_mock = MagicMock()
|
engine_mock = MagicMock()
|
||||||
|
|
||||||
session_mock = MagicMock()
|
session_mock = MagicMock()
|
||||||
exec_mock = MagicMock(return_value=True)
|
exec_mock = MagicMock(return_value=True)
|
||||||
session_mock.configure_mock(**{'exec.return_value': exec_mock})
|
session_mock.configure_mock(**{"exec.return_value": exec_mock})
|
||||||
mocker.patch('sqlmodel.Session', return_value=session_mock)
|
mocker.patch("sqlmodel.Session", return_value=session_mock)
|
||||||
|
|
||||||
mocker.patch.object(logger, 'info')
|
mocker.patch.object(logger, "info")
|
||||||
mocker.patch.object(logger, 'error')
|
mocker.patch.object(logger, "error")
|
||||||
mocker.patch.object(logger, 'warn')
|
mocker.patch.object(logger, "warn")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
init(engine_mock)
|
init(engine_mock)
|
||||||
@@ -23,6 +24,10 @@ def test_init_successful_connection(mocker):
|
|||||||
except Exception:
|
except Exception:
|
||||||
connection_successful = False
|
connection_successful = False
|
||||||
|
|
||||||
assert connection_successful, "The database connection should be successful and not raise an exception."
|
assert (
|
||||||
|
connection_successful
|
||||||
|
), "The database connection should be successful and not raise an exception."
|
||||||
|
|
||||||
assert session_mock.exec.called_once_with(select(1)), "The session should execute a select statement once."
|
assert session_mock.exec.called_once_with(
|
||||||
|
select(1)
|
||||||
|
), "The session should execute a select statement once."
|
||||||
|
@@ -1,21 +1,22 @@
|
|||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
from pytest_mock import MockerFixture
|
||||||
from sqlmodel import select
|
from sqlmodel import select
|
||||||
|
|
||||||
from app.celeryworker_pre_start import init, logger
|
from app.celeryworker_pre_start import init, logger
|
||||||
|
|
||||||
|
|
||||||
def test_init_successful_connection(mocker):
|
def test_init_successful_connection(mocker: MockerFixture) -> None:
|
||||||
engine_mock = MagicMock()
|
engine_mock = MagicMock()
|
||||||
|
|
||||||
session_mock = MagicMock()
|
session_mock = MagicMock()
|
||||||
exec_mock = MagicMock(return_value=True)
|
exec_mock = MagicMock(return_value=True)
|
||||||
session_mock.configure_mock(**{'exec.return_value': exec_mock})
|
session_mock.configure_mock(**{"exec.return_value": exec_mock})
|
||||||
mocker.patch('sqlmodel.Session', return_value=session_mock)
|
mocker.patch("sqlmodel.Session", return_value=session_mock)
|
||||||
|
|
||||||
mocker.patch.object(logger, 'info')
|
mocker.patch.object(logger, "info")
|
||||||
mocker.patch.object(logger, 'error')
|
mocker.patch.object(logger, "error")
|
||||||
mocker.patch.object(logger, 'warn')
|
mocker.patch.object(logger, "warn")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
init(engine_mock)
|
init(engine_mock)
|
||||||
@@ -23,6 +24,10 @@ def test_init_successful_connection(mocker):
|
|||||||
except Exception:
|
except Exception:
|
||||||
connection_successful = False
|
connection_successful = False
|
||||||
|
|
||||||
assert connection_successful, "The database connection should be successful and not raise an exception."
|
assert (
|
||||||
|
connection_successful
|
||||||
|
), "The database connection should be successful and not raise an exception."
|
||||||
|
|
||||||
assert session_mock.exec.called_once_with(select(1)), "The session should execute a select statement once."
|
assert session_mock.exec.called_once_with(
|
||||||
|
select(1)
|
||||||
|
), "The session should execute a select statement once."
|
||||||
|
@@ -1,21 +1,22 @@
|
|||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
from pytest_mock import MockerFixture
|
||||||
from sqlmodel import select
|
from sqlmodel import select
|
||||||
|
|
||||||
from app.tests_pre_start import init, logger
|
from app.tests_pre_start import init, logger
|
||||||
|
|
||||||
|
|
||||||
def test_init_successful_connection(mocker):
|
def test_init_successful_connection(mocker: MockerFixture) -> None:
|
||||||
engine_mock = MagicMock()
|
engine_mock = MagicMock()
|
||||||
|
|
||||||
session_mock = MagicMock()
|
session_mock = MagicMock()
|
||||||
exec_mock = MagicMock(return_value=True)
|
exec_mock = MagicMock(return_value=True)
|
||||||
session_mock.configure_mock(**{'exec.return_value': exec_mock})
|
session_mock.configure_mock(**{"exec.return_value": exec_mock})
|
||||||
mocker.patch('sqlmodel.Session', return_value=session_mock)
|
mocker.patch("sqlmodel.Session", return_value=session_mock)
|
||||||
|
|
||||||
mocker.patch.object(logger, 'info')
|
mocker.patch.object(logger, "info")
|
||||||
mocker.patch.object(logger, 'error')
|
mocker.patch.object(logger, "error")
|
||||||
mocker.patch.object(logger, 'warn')
|
mocker.patch.object(logger, "warn")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
init(engine_mock)
|
init(engine_mock)
|
||||||
@@ -23,6 +24,10 @@ def test_init_successful_connection(mocker):
|
|||||||
except Exception:
|
except Exception:
|
||||||
connection_successful = False
|
connection_successful = False
|
||||||
|
|
||||||
assert connection_successful, "The database connection should be successful and not raise an exception."
|
assert (
|
||||||
|
connection_successful
|
||||||
|
), "The database connection should be successful and not raise an exception."
|
||||||
|
|
||||||
assert session_mock.exec.called_once_with(select(1)), "The session should execute a select statement once."
|
assert session_mock.exec.called_once_with(
|
||||||
|
select(1)
|
||||||
|
), "The session should execute a select statement once."
|
||||||
|
@@ -42,6 +42,8 @@ def authentication_token_from_email(
|
|||||||
user = crud.create_user(session=db, user_create=user_in_create)
|
user = crud.create_user(session=db, user_create=user_in_create)
|
||||||
else:
|
else:
|
||||||
user_in_update = UserUpdate(password=password)
|
user_in_update = UserUpdate(password=password)
|
||||||
|
if not user.id:
|
||||||
|
raise Exception("User id not set")
|
||||||
user = crud.update_user(session=db, user_id=user.id, user_in=user_in_update)
|
user = crud.update_user(session=db, user_id=user.id, user_in=user_in_update)
|
||||||
|
|
||||||
return user_authentication_headers(client=client, email=email, password=password)
|
return user_authentication_headers(client=client, email=email, password=password)
|
||||||
|
@@ -4,7 +4,7 @@ from datetime import datetime, timedelta
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import emails
|
import emails # type: ignore
|
||||||
from jinja2 import Template
|
from jinja2 import Template
|
||||||
from jose import JWTError, jwt
|
from jose import JWTError, jwt
|
||||||
|
|
||||||
@@ -109,6 +109,6 @@ def generate_password_reset_token(email: str) -> str:
|
|||||||
def verify_password_reset_token(token: str) -> str | None:
|
def verify_password_reset_token(token: str) -> str | None:
|
||||||
try:
|
try:
|
||||||
decoded_token = jwt.decode(token, settings.SECRET_KEY, algorithms=["HS256"])
|
decoded_token = jwt.decode(token, settings.SECRET_KEY, algorithms=["HS256"])
|
||||||
return decoded_token["sub"]
|
return str(decoded_token["sub"])
|
||||||
except JWTError:
|
except JWTError:
|
||||||
return None
|
return None
|
||||||
|
@@ -3,7 +3,7 @@ import sentry_sdk
|
|||||||
from app.core.celery_app import celery_app
|
from app.core.celery_app import celery_app
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
|
|
||||||
sentry_sdk.init(dsn=settings.SENTRY_DSN)
|
sentry_sdk.init(dsn=str(settings.SENTRY_DSN))
|
||||||
|
|
||||||
|
|
||||||
@celery_app.task(acks_late=True)
|
@celery_app.task(acks_late=True)
|
||||||
|
@@ -35,6 +35,9 @@ mypy = "^1.8.0"
|
|||||||
ruff = "^0.2.2"
|
ruff = "^0.2.2"
|
||||||
pre-commit = "^3.6.2"
|
pre-commit = "^3.6.2"
|
||||||
pytest-mock = "^3.12.0"
|
pytest-mock = "^3.12.0"
|
||||||
|
types-python-jose = "^3.3.4.20240106"
|
||||||
|
types-passlib = "^1.7.7.20240106"
|
||||||
|
celery-types = "^0.22.0"
|
||||||
|
|
||||||
[tool.isort]
|
[tool.isort]
|
||||||
multi_line_output = 3
|
multi_line_output = 3
|
||||||
@@ -47,6 +50,7 @@ build-backend = "poetry.masonry.api"
|
|||||||
|
|
||||||
[tool.mypy]
|
[tool.mypy]
|
||||||
strict = true
|
strict = true
|
||||||
|
exclude = ["venv", "alembic"]
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
target-version = "py310"
|
target-version = "py310"
|
||||||
|
@@ -3,6 +3,5 @@
|
|||||||
set -x
|
set -x
|
||||||
|
|
||||||
mypy app
|
mypy app
|
||||||
black app --check
|
ruff app
|
||||||
isort --recursive --check-only app
|
ruff format app --check
|
||||||
flake8
|
|
||||||
|
@@ -3,4 +3,5 @@ set -e
|
|||||||
|
|
||||||
python /app/app/tests_pre_start.py
|
python /app/app/tests_pre_start.py
|
||||||
|
|
||||||
|
bash ./scripts/lint.sh
|
||||||
bash ./scripts/test.sh "$@"
|
bash ./scripts/test.sh "$@"
|
||||||
|
Reference in New Issue
Block a user