From 1da927dc07d81d28f885f743854b27b9b941505a Mon Sep 17 00:00:00 2001 From: ribardej Date: Tue, 11 Nov 2025 14:50:43 +0100 Subject: [PATCH] fix(tests): fixed test runtime errors regarding database connection --- 7project/backend/app/app.py | 11 ++++++++++- 7project/backend/app/core/db.py | 24 ++++++++++++++++-------- 7project/backend/app/services/db.py | 11 ++++++++++- 3 files changed, 36 insertions(+), 10 deletions(-) diff --git a/7project/backend/app/app.py b/7project/backend/app/app.py index 6bb0e5d..84414d7 100644 --- a/7project/backend/app/app.py +++ b/7project/backend/app/app.py @@ -28,7 +28,8 @@ from app.services.user_service import SECRET from fastapi import FastAPI import sentry_sdk from fastapi_users.db import SQLAlchemyUserDatabase -from app.core.db import async_session_maker +from app.core.db import async_session_maker, engine +from app.core.base import Base sentry_sdk.init( dsn=os.getenv("SENTRY_DSN"), @@ -37,6 +38,14 @@ sentry_sdk.init( fastApi = FastAPI() +@fastApi.on_event("startup") +async def on_startup(): + # Ensure DB schema is created for tests/dev + from sqlalchemy.ext.asyncio import AsyncEngine + from sqlalchemy import text + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + # CORS for frontend dev server fastApi.add_middleware( CORSMiddleware, diff --git a/7project/backend/app/core/db.py b/7project/backend/app/core/db.py index 1186352..5a20cb1 100644 --- a/7project/backend/app/core/db.py +++ b/7project/backend/app/core/db.py @@ -1,27 +1,35 @@ import os +from urllib.parse import urlparse from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker from app.core.base import Base +# Determine DATABASE_URL with sensible defaults for local testing DATABASE_URL = os.getenv("DATABASE_URL") if not DATABASE_URL: - mariadb_host = os.getenv("MARIADB_HOST", "localhost") + mariadb_host = os.getenv("MARIADB_HOST") mariadb_port = os.getenv("MARIADB_PORT", "3306") - mariadb_db = os.getenv("MARIADB_DB", "group_project") - mariadb_user = os.getenv("MARIADB_USER", "root") - mariadb_password = os.getenv("MARIADB_PASSWORD", "strongpassword") + mariadb_db = os.getenv("MARIADB_DB") + mariadb_user = os.getenv("MARIADB_USER") + mariadb_password = os.getenv("MARIADB_PASSWORD") if mariadb_host and mariadb_db and mariadb_user and mariadb_password: DATABASE_URL = f"mysql+asyncmy://{mariadb_user}:{mariadb_password}@{mariadb_host}:{mariadb_port}/{mariadb_db}" else: - raise Exception("Only MariaDB is supported. Please set the DATABASE_URL environment variable.") + # Default to local SQLite for tests/development when nothing is configured + DATABASE_URL = os.getenv("SQLITE_URL", "sqlite+aiosqlite:///./test.db") # Load all models to register them from app.models.user import User from app.models.transaction import Transaction from app.models.categories import Category -host_env = os.getenv("MARIADB_HOST", "localhost") -ssl_enabled = host_env not in {"localhost", "127.0.0.1"} -connect_args = {"ssl": {"ssl": True}} if ssl_enabled else {} +# Configure connect args based on backend +parsed = urlparse(DATABASE_URL) +scheme = parsed.scheme +connect_args = {} +if scheme.startswith("mysql"): + host_env = os.getenv("MARIADB_HOST", parsed.hostname or "localhost") + ssl_enabled = host_env not in {"localhost", "127.0.0.1"} + connect_args = {"ssl": {"ssl": True}} if ssl_enabled else {} engine = create_async_engine( DATABASE_URL, diff --git a/7project/backend/app/services/db.py b/7project/backend/app/services/db.py index 606af8d..65d1dae 100644 --- a/7project/backend/app/services/db.py +++ b/7project/backend/app/services/db.py @@ -3,11 +3,20 @@ from fastapi import Depends from sqlalchemy.ext.asyncio import AsyncSession from fastapi_users.db import SQLAlchemyUserDatabase -from ..core.db import async_session_maker +from ..core.db import async_session_maker, engine +from ..core.base import Base from ..models.user import User, OAuthAccount +_initialized = False + async def get_async_session() -> AsyncGenerator[AsyncSession, None]: + global _initialized + if not _initialized: + # Lazily ensure tables exist; helpful for test runs without migrations + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + _initialized = True async with async_session_maker() as session: yield session