✅ Add tests to raise coverage to at least 90% and fix recover password logic (#632)
Co-authored-by: Sebastián Ramírez <tiangolo@gmail.com>
This commit is contained in:
@@ -9,7 +9,7 @@ from app.api.deps import CurrentUser, SessionDep
|
||||
from app.core import security
|
||||
from app.core.config import settings
|
||||
from app.core.security import get_password_hash
|
||||
from app.models import Message, NewPassword, Token, UserOut
|
||||
from app.models import Message, NewPassword, Token, User, UserOut
|
||||
from app.utils import (
|
||||
generate_password_reset_token,
|
||||
send_reset_password_email,
|
||||
@@ -73,10 +73,10 @@ def reset_password(session: SessionDep, body: NewPassword) -> Message:
|
||||
"""
|
||||
Reset password
|
||||
"""
|
||||
email = verify_password_reset_token(token=body.token)
|
||||
if not email:
|
||||
user_id = verify_password_reset_token(token=body.token)
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=400, detail="Invalid token")
|
||||
user = crud.get_user_by_email(session=session, email=email)
|
||||
user = session.get(User, int(user_id))
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
|
@@ -144,8 +144,7 @@ def read_user_by_id(
|
||||
return user
|
||||
if not current_user.is_superuser:
|
||||
raise HTTPException(
|
||||
# TODO: Review status code
|
||||
status_code=400,
|
||||
status_code=403,
|
||||
detail="The user doesn't have enough privileges",
|
||||
)
|
||||
return user
|
||||
@@ -194,5 +193,5 @@ def delete_user(
|
||||
return Message(message="User deleted successfully")
|
||||
elif user == current_user and current_user.is_superuser:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Super users are not allowed to delete themselves"
|
||||
status_code=403, detail="Super users are not allowed to delete themselves"
|
||||
)
|
||||
|
@@ -1,5 +1,6 @@
|
||||
import logging
|
||||
|
||||
from sqlalchemy import Engine
|
||||
from sqlmodel import Session, select
|
||||
from tenacity import after_log, before_log, retry, stop_after_attempt, wait_fixed
|
||||
|
||||
@@ -18,9 +19,9 @@ wait_seconds = 1
|
||||
before=before_log(logger, logging.INFO),
|
||||
after=after_log(logger, logging.WARN),
|
||||
)
|
||||
def init() -> None:
|
||||
def init(db_engine: Engine) -> None:
|
||||
try:
|
||||
with Session(engine) as session:
|
||||
with Session(db_engine) as session:
|
||||
# Try to create session to check if DB is awake
|
||||
session.exec(select(1))
|
||||
except Exception as e:
|
||||
@@ -30,7 +31,7 @@ def init() -> None:
|
||||
|
||||
def main() -> None:
|
||||
logger.info("Initializing service")
|
||||
init()
|
||||
init(engine)
|
||||
logger.info("Service finished initializing")
|
||||
|
||||
|
||||
|
@@ -1,5 +1,6 @@
|
||||
import logging
|
||||
|
||||
from sqlalchemy import Engine
|
||||
from sqlmodel import Session, select
|
||||
from tenacity import after_log, before_log, retry, stop_after_attempt, wait_fixed
|
||||
|
||||
@@ -18,10 +19,10 @@ wait_seconds = 1
|
||||
before=before_log(logger, logging.INFO),
|
||||
after=after_log(logger, logging.WARN),
|
||||
)
|
||||
def init() -> None:
|
||||
def init(db_engine: Engine) -> None:
|
||||
try:
|
||||
# Try to create session to check if DB is awake
|
||||
with Session(engine) as session:
|
||||
with Session(db_engine) as session:
|
||||
session.exec(select(1))
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
@@ -30,7 +31,7 @@ def init() -> None:
|
||||
|
||||
def main() -> None:
|
||||
logger.info("Initializing service")
|
||||
init()
|
||||
init(engine)
|
||||
logger.info("Service finished initializing")
|
||||
|
||||
|
||||
|
@@ -47,6 +47,7 @@ def authenticate(*, session: Session, email: str, password: str) -> User | None:
|
||||
return None
|
||||
return db_user
|
||||
|
||||
|
||||
def create_item(*, session: Session, item_in: ItemCreate, owner_id: int) -> Item:
|
||||
db_item = Item.model_validate(item_in, update={"owner_id": owner_id})
|
||||
session.add(db_item)
|
||||
|
@@ -36,3 +36,127 @@ def test_read_item(
|
||||
assert content["description"] == item.description
|
||||
assert content["id"] == item.id
|
||||
assert content["owner_id"] == item.owner_id
|
||||
|
||||
|
||||
def test_read_item_not_found(
|
||||
client: TestClient, superuser_token_headers: dict, db: Session
|
||||
) -> None:
|
||||
response = client.get(
|
||||
f"{settings.API_V1_STR}/items/999",
|
||||
headers=superuser_token_headers,
|
||||
)
|
||||
assert response.status_code == 404
|
||||
content = response.json()
|
||||
assert content["detail"] == "Item not found"
|
||||
|
||||
|
||||
def test_read_item_not_enough_permissions(
|
||||
client: TestClient, normal_user_token_headers: dict, db: Session
|
||||
) -> None:
|
||||
item = create_random_item(db)
|
||||
response = client.get(
|
||||
f"{settings.API_V1_STR}/items/{item.id}",
|
||||
headers=normal_user_token_headers,
|
||||
)
|
||||
assert response.status_code == 400
|
||||
content = response.json()
|
||||
assert content["detail"] == "Not enough permissions"
|
||||
|
||||
|
||||
def test_read_items(
|
||||
client: TestClient, superuser_token_headers: dict, db: Session
|
||||
) -> None:
|
||||
create_random_item(db)
|
||||
create_random_item(db)
|
||||
response = client.get(
|
||||
f"{settings.API_V1_STR}/items/",
|
||||
headers=superuser_token_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
content = response.json()
|
||||
assert len(content["data"]) >= 2
|
||||
|
||||
|
||||
def test_update_item(
|
||||
client: TestClient, superuser_token_headers: dict, db: Session
|
||||
) -> None:
|
||||
item = create_random_item(db)
|
||||
data = {"title": "Updated title", "description": "Updated description"}
|
||||
response = client.put(
|
||||
f"{settings.API_V1_STR}/items/{item.id}",
|
||||
headers=superuser_token_headers,
|
||||
json=data,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
content = response.json()
|
||||
assert content["title"] == data["title"]
|
||||
assert content["description"] == data["description"]
|
||||
assert content["id"] == item.id
|
||||
assert content["owner_id"] == item.owner_id
|
||||
|
||||
|
||||
def test_update_item_not_found(
|
||||
client: TestClient, superuser_token_headers: dict, db: Session
|
||||
) -> None:
|
||||
data = {"title": "Updated title", "description": "Updated description"}
|
||||
response = client.put(
|
||||
f"{settings.API_V1_STR}/items/999",
|
||||
headers=superuser_token_headers,
|
||||
json=data,
|
||||
)
|
||||
assert response.status_code == 404
|
||||
content = response.json()
|
||||
assert content["detail"] == "Item not found"
|
||||
|
||||
|
||||
def test_update_item_not_enough_permissions(
|
||||
client: TestClient, normal_user_token_headers: dict, db: Session
|
||||
) -> None:
|
||||
item = create_random_item(db)
|
||||
data = {"title": "Updated title", "description": "Updated description"}
|
||||
response = client.put(
|
||||
f"{settings.API_V1_STR}/items/{item.id}",
|
||||
headers=normal_user_token_headers,
|
||||
json=data,
|
||||
)
|
||||
assert response.status_code == 400
|
||||
content = response.json()
|
||||
assert content["detail"] == "Not enough permissions"
|
||||
|
||||
|
||||
def test_delete_item(
|
||||
client: TestClient, superuser_token_headers: dict, db: Session
|
||||
) -> None:
|
||||
item = create_random_item(db)
|
||||
response = client.delete(
|
||||
f"{settings.API_V1_STR}/items/{item.id}",
|
||||
headers=superuser_token_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
content = response.json()
|
||||
assert content["message"] == "Item deleted successfully"
|
||||
|
||||
|
||||
def test_delete_item_not_found(
|
||||
client: TestClient, superuser_token_headers: dict, db: Session
|
||||
) -> None:
|
||||
response = client.delete(
|
||||
f"{settings.API_V1_STR}/items/999",
|
||||
headers=superuser_token_headers,
|
||||
)
|
||||
assert response.status_code == 404
|
||||
content = response.json()
|
||||
assert content["detail"] == "Item not found"
|
||||
|
||||
|
||||
def test_delete_item_not_enough_permissions(
|
||||
client: TestClient, normal_user_token_headers: dict, db: Session
|
||||
) -> None:
|
||||
item = create_random_item(db)
|
||||
response = client.delete(
|
||||
f"{settings.API_V1_STR}/items/{item.id}",
|
||||
headers=normal_user_token_headers,
|
||||
)
|
||||
assert response.status_code == 400
|
||||
content = response.json()
|
||||
assert content["detail"] == "Not enough permissions"
|
||||
|
@@ -15,6 +15,15 @@ def test_get_access_token(client: TestClient) -> None:
|
||||
assert tokens["access_token"]
|
||||
|
||||
|
||||
def test_get_access_token_incorrect_password(client: TestClient) -> None:
|
||||
login_data = {
|
||||
"username": settings.FIRST_SUPERUSER,
|
||||
"password": "incorrect",
|
||||
}
|
||||
r = client.post(f"{settings.API_V1_STR}/login/access-token", data=login_data)
|
||||
assert r.status_code == 400
|
||||
|
||||
|
||||
def test_use_access_token(
|
||||
client: TestClient, superuser_token_headers: dict[str, str]
|
||||
) -> None:
|
||||
@@ -25,3 +34,64 @@ def test_use_access_token(
|
||||
result = r.json()
|
||||
assert r.status_code == 200
|
||||
assert "email" in result
|
||||
|
||||
|
||||
def test_recovery_password(
|
||||
client: TestClient, normal_user_token_headers: dict[str, str], mocker
|
||||
) -> None:
|
||||
mocker.patch("app.utils.send_reset_password_email", return_value=None)
|
||||
mocker.patch("app.utils.send_email", return_value=None)
|
||||
email = "test@example.com"
|
||||
r = client.post(
|
||||
f"{settings.API_V1_STR}/password-recovery/{email}",
|
||||
headers=normal_user_token_headers,
|
||||
)
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {"message": "Password recovery email sent"}
|
||||
|
||||
|
||||
def test_recovery_password_user_not_exits(
|
||||
client: TestClient, normal_user_token_headers: dict[str, str]
|
||||
) -> None:
|
||||
email = "jVgQr@example.com"
|
||||
r = client.post(
|
||||
f"{settings.API_V1_STR}/password-recovery/{email}",
|
||||
headers=normal_user_token_headers,
|
||||
)
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
def test_reset_password(
|
||||
client: TestClient, superuser_token_headers: dict[str, str]
|
||||
) -> None:
|
||||
login_data = {
|
||||
"username": settings.FIRST_SUPERUSER,
|
||||
"password": settings.FIRST_SUPERUSER_PASSWORD,
|
||||
}
|
||||
r = client.post(f"{settings.API_V1_STR}/login/access-token", data=login_data)
|
||||
token = r.json().get("access_token")
|
||||
|
||||
data = {"new_password": "changethis", "token": token}
|
||||
r = client.post(
|
||||
f"{settings.API_V1_STR}/reset-password/",
|
||||
headers=superuser_token_headers,
|
||||
json=data
|
||||
)
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {"message": "Password updated successfully"}
|
||||
|
||||
|
||||
def test_reset_password_invalid_token(
|
||||
client: TestClient, superuser_token_headers: dict[str, str]
|
||||
) -> None:
|
||||
data = {"new_password": "changethis", "token": "invalid"}
|
||||
r = client.post(
|
||||
f"{settings.API_V1_STR}/reset-password/",
|
||||
headers=superuser_token_headers,
|
||||
json=data
|
||||
)
|
||||
response = r.json()
|
||||
|
||||
assert "detail" in response
|
||||
assert r.status_code == 400
|
||||
assert response["detail"] == "Invalid token"
|
||||
|
@@ -30,8 +30,10 @@ def test_get_users_normal_user_me(
|
||||
|
||||
|
||||
def test_create_user_new_email(
|
||||
client: TestClient, superuser_token_headers: dict, db: Session
|
||||
client: TestClient, superuser_token_headers: dict, db: Session, mocker
|
||||
) -> None:
|
||||
mocker.patch("app.utils.send_new_account_email")
|
||||
mocker.patch("app.core.config.settings.EMAILS_ENABLED", True)
|
||||
username = random_email()
|
||||
password = random_lower_string()
|
||||
data = {"email": username, "password": password}
|
||||
@@ -66,6 +68,46 @@ def test_get_existing_user(
|
||||
assert existing_user.email == api_user["email"]
|
||||
|
||||
|
||||
def test_get_existing_user_current_user(
|
||||
client: TestClient, db: Session
|
||||
) -> None:
|
||||
username = random_email()
|
||||
password = random_lower_string()
|
||||
user_in = UserCreate(email=username, password=password)
|
||||
user = crud.create_user(session=db, user_create=user_in)
|
||||
user_id = user.id
|
||||
|
||||
login_data = {
|
||||
"username": username,
|
||||
"password": password,
|
||||
}
|
||||
r = client.post(f"{settings.API_V1_STR}/login/access-token", data=login_data)
|
||||
tokens = r.json()
|
||||
a_token = tokens["access_token"]
|
||||
headers = {"Authorization": f"Bearer {a_token}"}
|
||||
|
||||
r = client.get(
|
||||
f"{settings.API_V1_STR}/users/{user_id}",
|
||||
headers=headers,
|
||||
)
|
||||
assert 200 <= r.status_code < 300
|
||||
api_user = r.json()
|
||||
existing_user = crud.get_user_by_email(session=db, email=username)
|
||||
assert existing_user
|
||||
assert existing_user.email == api_user["email"]
|
||||
|
||||
|
||||
def test_get_existing_user_permissions_error(
|
||||
client: TestClient, normal_user_token_headers: dict[str, str], db: Session
|
||||
) -> None:
|
||||
r = client.get(
|
||||
f"{settings.API_V1_STR}/users/999999",
|
||||
headers=normal_user_token_headers,
|
||||
)
|
||||
assert r.status_code == 403
|
||||
assert r.json() == {"detail": "The user doesn't have enough privileges"}
|
||||
|
||||
|
||||
def test_create_user_existing_username(
|
||||
client: TestClient, superuser_token_headers: dict, db: Session
|
||||
) -> None:
|
||||
@@ -119,3 +161,223 @@ def test_retrieve_users(
|
||||
assert "count" in all_users
|
||||
for item in all_users["data"]:
|
||||
assert "email" in item
|
||||
|
||||
|
||||
def test_update_user_me(
|
||||
client: TestClient, normal_user_token_headers: dict[str, str], db: Session
|
||||
) -> None:
|
||||
full_name = "Updated Name"
|
||||
email = "updated email"
|
||||
data = {"full_name": full_name, "email": email}
|
||||
r = client.patch(
|
||||
f"{settings.API_V1_STR}/users/me",
|
||||
headers=normal_user_token_headers,
|
||||
json=data,
|
||||
)
|
||||
assert r.status_code == 200
|
||||
updated_user = r.json()
|
||||
assert updated_user["email"] == email
|
||||
assert updated_user["full_name"] == full_name
|
||||
|
||||
|
||||
def test_update_password_me(
|
||||
client: TestClient, superuser_token_headers: dict, db: Session
|
||||
) -> None:
|
||||
new_password = random_lower_string()
|
||||
data = {"current_password": settings.FIRST_SUPERUSER_PASSWORD, "new_password": new_password}
|
||||
r = client.patch(
|
||||
f"{settings.API_V1_STR}/users/me/password",
|
||||
headers=superuser_token_headers,
|
||||
json=data,
|
||||
)
|
||||
assert r.status_code == 200
|
||||
updated_user = r.json()
|
||||
assert updated_user["message"] == "Password updated successfully"
|
||||
|
||||
# Revert to the old password to keep consistency in test
|
||||
old_data = {"current_password": new_password, "new_password": settings.FIRST_SUPERUSER_PASSWORD}
|
||||
r = client.patch(
|
||||
f"{settings.API_V1_STR}/users/me/password",
|
||||
headers=superuser_token_headers,
|
||||
json=old_data,
|
||||
)
|
||||
assert r.status_code == 200
|
||||
|
||||
|
||||
def test_update_password_me_incorrect_password(
|
||||
client: TestClient, superuser_token_headers: dict, db: Session
|
||||
) -> None:
|
||||
new_password = random_lower_string()
|
||||
data = {"current_password": new_password, "new_password": new_password}
|
||||
r = client.patch(
|
||||
f"{settings.API_V1_STR}/users/me/password",
|
||||
headers=superuser_token_headers,
|
||||
json=data,
|
||||
)
|
||||
assert r.status_code == 400
|
||||
updated_user = r.json()
|
||||
assert updated_user["detail"] == "Incorrect password"
|
||||
|
||||
|
||||
def test_update_password_me_same_password_error(
|
||||
client: TestClient, superuser_token_headers: dict, db: Session
|
||||
) -> None:
|
||||
data = {"current_password": settings.FIRST_SUPERUSER_PASSWORD, "new_password": settings.FIRST_SUPERUSER_PASSWORD}
|
||||
r = client.patch(
|
||||
f"{settings.API_V1_STR}/users/me/password",
|
||||
headers=superuser_token_headers,
|
||||
json=data,
|
||||
)
|
||||
assert r.status_code == 400
|
||||
updated_user = r.json()
|
||||
assert updated_user["detail"] == "New password cannot be the same as the current one"
|
||||
|
||||
|
||||
def test_create_user_open(
|
||||
client: TestClient, mocker
|
||||
) -> None:
|
||||
mocker.patch("app.core.config.settings.USERS_OPEN_REGISTRATION", True)
|
||||
username = random_email()
|
||||
password = random_lower_string()
|
||||
full_name = random_lower_string()
|
||||
data = {"email": username, "password": password, "full_name": full_name}
|
||||
r = client.post(
|
||||
f"{settings.API_V1_STR}/users/open",
|
||||
json=data,
|
||||
)
|
||||
assert r.status_code == 200
|
||||
created_user = r.json()
|
||||
assert created_user["email"] == username
|
||||
assert created_user["full_name"] == full_name
|
||||
|
||||
|
||||
def test_create_user_open_forbidden_error(
|
||||
client: TestClient, mocker
|
||||
) -> None:
|
||||
mocker.patch("app.core.config.settings.USERS_OPEN_REGISTRATION", False)
|
||||
username = random_email()
|
||||
password = random_lower_string()
|
||||
full_name = random_lower_string()
|
||||
data = {"email": username, "password": password, "full_name": full_name}
|
||||
r = client.post(
|
||||
f"{settings.API_V1_STR}/users/open",
|
||||
json=data,
|
||||
)
|
||||
assert r.status_code == 403
|
||||
assert r.json()["detail"] == "Open user registration is forbidden on this server"
|
||||
|
||||
|
||||
def test_create_user_open_already_exists_error(
|
||||
client: TestClient, mocker
|
||||
) -> None:
|
||||
mocker.patch("app.core.config.settings.USERS_OPEN_REGISTRATION", True)
|
||||
password = random_lower_string()
|
||||
full_name = random_lower_string()
|
||||
data = {"email": settings.FIRST_SUPERUSER, "password": password, "full_name": full_name}
|
||||
r = client.post(
|
||||
f"{settings.API_V1_STR}/users/open",
|
||||
json=data,
|
||||
)
|
||||
assert r.status_code == 400
|
||||
assert r.json()["detail"] == "The user with this username already exists in the system"
|
||||
|
||||
|
||||
def test_update_user(
|
||||
client: TestClient, superuser_token_headers: dict[str, str], db: Session
|
||||
) -> None:
|
||||
username = random_email()
|
||||
password = random_lower_string()
|
||||
user_in = UserCreate(email=username, password=password)
|
||||
user = crud.create_user(session=db, user_create=user_in)
|
||||
|
||||
data = {"full_name": "Updated_full_name"}
|
||||
r = client.patch(
|
||||
f"{settings.API_V1_STR}/users/{user.id}",
|
||||
headers=superuser_token_headers,
|
||||
json=data,
|
||||
)
|
||||
assert r.status_code == 200
|
||||
updated_user = r.json()
|
||||
assert updated_user["full_name"] == "Updated_full_name"
|
||||
|
||||
|
||||
def test_update_user_not_exists(
|
||||
client: TestClient, superuser_token_headers: dict[str, str], db: Session
|
||||
) -> None:
|
||||
data = {"full_name": "Updated_full_name"}
|
||||
r = client.patch(
|
||||
f"{settings.API_V1_STR}/users/99999999",
|
||||
headers=superuser_token_headers,
|
||||
json=data,
|
||||
)
|
||||
assert r.status_code == 404
|
||||
assert r.json()["detail"] == "The user with this username does not exist in the system"
|
||||
|
||||
|
||||
def test_delete_user_super_user(
|
||||
client: TestClient, superuser_token_headers: dict[str, str], db: Session
|
||||
) -> None:
|
||||
username = random_email()
|
||||
password = random_lower_string()
|
||||
user_in = UserCreate(email=username, password=password)
|
||||
user = crud.create_user(session=db, user_create=user_in)
|
||||
user_id = user.id
|
||||
r = client.delete(
|
||||
f"{settings.API_V1_STR}/users/{user_id}",
|
||||
headers=superuser_token_headers,
|
||||
)
|
||||
assert r.status_code == 200
|
||||
deleted_user = r.json()
|
||||
assert deleted_user["message"] == "User deleted successfully"
|
||||
|
||||
|
||||
def test_delete_user_current_user(
|
||||
client: TestClient, db: Session
|
||||
) -> None:
|
||||
username = random_email()
|
||||
password = random_lower_string()
|
||||
user_in = UserCreate(email=username, password=password)
|
||||
user = crud.create_user(session=db, user_create=user_in)
|
||||
user_id = user.id
|
||||
|
||||
login_data = {
|
||||
"username": username,
|
||||
"password": password,
|
||||
}
|
||||
r = client.post(f"{settings.API_V1_STR}/login/access-token", data=login_data)
|
||||
tokens = r.json()
|
||||
a_token = tokens["access_token"]
|
||||
headers = {"Authorization": f"Bearer {a_token}"}
|
||||
|
||||
r = client.delete(
|
||||
f"{settings.API_V1_STR}/users/{user_id}",
|
||||
headers=headers,
|
||||
)
|
||||
assert r.status_code == 200
|
||||
deleted_user = r.json()
|
||||
assert deleted_user["message"] == "User deleted successfully"
|
||||
|
||||
|
||||
def test_delete_user_not_found(
|
||||
client: TestClient, superuser_token_headers: dict[str, str], db: Session
|
||||
) -> None:
|
||||
r = client.delete(
|
||||
f"{settings.API_V1_STR}/users/99999999",
|
||||
headers=superuser_token_headers,
|
||||
)
|
||||
assert r.status_code == 404
|
||||
assert r.json()["detail"] == "User not found"
|
||||
|
||||
|
||||
def test_delete_user_current_super_user_error(
|
||||
client: TestClient, superuser_token_headers: dict[str, str], db: Session
|
||||
) -> None:
|
||||
super_user = crud.get_user_by_email(session=db, email=settings.FIRST_SUPERUSER)
|
||||
user_id = super_user.id
|
||||
|
||||
r = client.delete(
|
||||
f"{settings.API_V1_STR}/users/{user_id}",
|
||||
headers=superuser_token_headers,
|
||||
)
|
||||
assert r.status_code == 403
|
||||
assert r.json()["detail"] == "Super users are not allowed to delete themselves"
|
||||
|
0
backend/app/tests/scripts/__init__.py
Normal file
0
backend/app/tests/scripts/__init__.py
Normal file
28
backend/app/tests/scripts/test_backend_pre_start.py
Normal file
28
backend/app/tests/scripts/test_backend_pre_start.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from sqlmodel import select
|
||||
|
||||
from app.backend_pre_start import init, logger
|
||||
|
||||
|
||||
def test_init_successful_connection(mocker):
|
||||
engine_mock = MagicMock()
|
||||
|
||||
session_mock = MagicMock()
|
||||
exec_mock = MagicMock(return_value=True)
|
||||
session_mock.configure_mock(**{'exec.return_value': exec_mock})
|
||||
mocker.patch('sqlmodel.Session', return_value=session_mock)
|
||||
|
||||
mocker.patch.object(logger, 'info')
|
||||
mocker.patch.object(logger, 'error')
|
||||
mocker.patch.object(logger, 'warn')
|
||||
|
||||
try:
|
||||
init(engine_mock)
|
||||
connection_successful = True
|
||||
except Exception:
|
||||
connection_successful = False
|
||||
|
||||
assert connection_successful, "The database connection should be successful and not raise an exception."
|
||||
|
||||
assert session_mock.exec.called_once_with(select(1)), "The session should execute a select statement once."
|
28
backend/app/tests/scripts/test_celery_pre_start.py
Normal file
28
backend/app/tests/scripts/test_celery_pre_start.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from sqlmodel import select
|
||||
|
||||
from app.celeryworker_pre_start import init, logger
|
||||
|
||||
|
||||
def test_init_successful_connection(mocker):
|
||||
engine_mock = MagicMock()
|
||||
|
||||
session_mock = MagicMock()
|
||||
exec_mock = MagicMock(return_value=True)
|
||||
session_mock.configure_mock(**{'exec.return_value': exec_mock})
|
||||
mocker.patch('sqlmodel.Session', return_value=session_mock)
|
||||
|
||||
mocker.patch.object(logger, 'info')
|
||||
mocker.patch.object(logger, 'error')
|
||||
mocker.patch.object(logger, 'warn')
|
||||
|
||||
try:
|
||||
init(engine_mock)
|
||||
connection_successful = True
|
||||
except Exception:
|
||||
connection_successful = False
|
||||
|
||||
assert connection_successful, "The database connection should be successful and not raise an exception."
|
||||
|
||||
assert session_mock.exec.called_once_with(select(1)), "The session should execute a select statement once."
|
28
backend/app/tests/scripts/test_test_pre_start.py
Normal file
28
backend/app/tests/scripts/test_test_pre_start.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from sqlmodel import select
|
||||
|
||||
from app.tests_pre_start import init, logger
|
||||
|
||||
|
||||
def test_init_successful_connection(mocker):
|
||||
engine_mock = MagicMock()
|
||||
|
||||
session_mock = MagicMock()
|
||||
exec_mock = MagicMock(return_value=True)
|
||||
session_mock.configure_mock(**{'exec.return_value': exec_mock})
|
||||
mocker.patch('sqlmodel.Session', return_value=session_mock)
|
||||
|
||||
mocker.patch.object(logger, 'info')
|
||||
mocker.patch.object(logger, 'error')
|
||||
mocker.patch.object(logger, 'warn')
|
||||
|
||||
try:
|
||||
init(engine_mock)
|
||||
connection_successful = True
|
||||
except Exception:
|
||||
connection_successful = False
|
||||
|
||||
assert connection_successful, "The database connection should be successful and not raise an exception."
|
||||
|
||||
assert session_mock.exec.called_once_with(select(1)), "The session should execute a select statement once."
|
@@ -1,5 +1,6 @@
|
||||
import logging
|
||||
|
||||
from sqlalchemy import Engine
|
||||
from sqlmodel import Session, select
|
||||
from tenacity import after_log, before_log, retry, stop_after_attempt, wait_fixed
|
||||
|
||||
@@ -18,10 +19,10 @@ wait_seconds = 1
|
||||
before=before_log(logger, logging.INFO),
|
||||
after=after_log(logger, logging.WARN),
|
||||
)
|
||||
def init() -> None:
|
||||
def init(db_engine: Engine) -> None:
|
||||
try:
|
||||
# Try to create session to check if DB is awake
|
||||
with Session(engine) as session:
|
||||
with Session(db_engine) as session:
|
||||
session.exec(select(1))
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
@@ -30,7 +31,7 @@ def init() -> None:
|
||||
|
||||
def main() -> None:
|
||||
logger.info("Initializing service")
|
||||
init()
|
||||
init(engine)
|
||||
logger.info("Service finished initializing")
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user