fix(tests): fixed test runtime errors regarding database connection

This commit is contained in:
ribardej
2025-11-11 14:50:43 +01:00
parent f58083870f
commit 1da927dc07
3 changed files with 36 additions and 10 deletions

View File

@@ -28,7 +28,8 @@ from app.services.user_service import SECRET
from fastapi import FastAPI from fastapi import FastAPI
import sentry_sdk import sentry_sdk
from fastapi_users.db import SQLAlchemyUserDatabase from fastapi_users.db import SQLAlchemyUserDatabase
from app.core.db import async_session_maker from app.core.db import async_session_maker, engine
from app.core.base import Base
sentry_sdk.init( sentry_sdk.init(
dsn=os.getenv("SENTRY_DSN"), dsn=os.getenv("SENTRY_DSN"),
@@ -37,6 +38,14 @@ sentry_sdk.init(
fastApi = FastAPI() fastApi = FastAPI()
@fastApi.on_event("startup")
async def on_startup():
# Ensure DB schema is created for tests/dev
from sqlalchemy.ext.asyncio import AsyncEngine
from sqlalchemy import text
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
# CORS for frontend dev server # CORS for frontend dev server
fastApi.add_middleware( fastApi.add_middleware(
CORSMiddleware, CORSMiddleware,

View File

@@ -1,27 +1,35 @@
import os import os
from urllib.parse import urlparse
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
from app.core.base import Base from app.core.base import Base
# Determine DATABASE_URL with sensible defaults for local testing
DATABASE_URL = os.getenv("DATABASE_URL") DATABASE_URL = os.getenv("DATABASE_URL")
if not DATABASE_URL: if not DATABASE_URL:
mariadb_host = os.getenv("MARIADB_HOST", "localhost") mariadb_host = os.getenv("MARIADB_HOST")
mariadb_port = os.getenv("MARIADB_PORT", "3306") mariadb_port = os.getenv("MARIADB_PORT", "3306")
mariadb_db = os.getenv("MARIADB_DB", "group_project") mariadb_db = os.getenv("MARIADB_DB")
mariadb_user = os.getenv("MARIADB_USER", "root") mariadb_user = os.getenv("MARIADB_USER")
mariadb_password = os.getenv("MARIADB_PASSWORD", "strongpassword") mariadb_password = os.getenv("MARIADB_PASSWORD")
if mariadb_host and mariadb_db and mariadb_user and mariadb_password: if mariadb_host and mariadb_db and mariadb_user and mariadb_password:
DATABASE_URL = f"mysql+asyncmy://{mariadb_user}:{mariadb_password}@{mariadb_host}:{mariadb_port}/{mariadb_db}" DATABASE_URL = f"mysql+asyncmy://{mariadb_user}:{mariadb_password}@{mariadb_host}:{mariadb_port}/{mariadb_db}"
else: else:
raise Exception("Only MariaDB is supported. Please set the DATABASE_URL environment variable.") # Default to local SQLite for tests/development when nothing is configured
DATABASE_URL = os.getenv("SQLITE_URL", "sqlite+aiosqlite:///./test.db")
# Load all models to register them # Load all models to register them
from app.models.user import User from app.models.user import User
from app.models.transaction import Transaction from app.models.transaction import Transaction
from app.models.categories import Category from app.models.categories import Category
host_env = os.getenv("MARIADB_HOST", "localhost") # Configure connect args based on backend
ssl_enabled = host_env not in {"localhost", "127.0.0.1"} parsed = urlparse(DATABASE_URL)
connect_args = {"ssl": {"ssl": True}} if ssl_enabled else {} scheme = parsed.scheme
connect_args = {}
if scheme.startswith("mysql"):
host_env = os.getenv("MARIADB_HOST", parsed.hostname or "localhost")
ssl_enabled = host_env not in {"localhost", "127.0.0.1"}
connect_args = {"ssl": {"ssl": True}} if ssl_enabled else {}
engine = create_async_engine( engine = create_async_engine(
DATABASE_URL, DATABASE_URL,

View File

@@ -3,11 +3,20 @@ from fastapi import Depends
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from fastapi_users.db import SQLAlchemyUserDatabase from fastapi_users.db import SQLAlchemyUserDatabase
from ..core.db import async_session_maker from ..core.db import async_session_maker, engine
from ..core.base import Base
from ..models.user import User, OAuthAccount from ..models.user import User, OAuthAccount
_initialized = False
async def get_async_session() -> AsyncGenerator[AsyncSession, None]: async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
global _initialized
if not _initialized:
# Lazily ensure tables exist; helpful for test runs without migrations
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
_initialized = True
async with async_session_maker() as session: async with async_session_maker() as session:
yield session yield session