From 19abbd18ccd28743d4eefb8279a96f9dd8d7862f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luk=C3=A1=C5=A1=20Trkan?= Date: Wed, 24 Sep 2025 00:21:04 +0200 Subject: [PATCH] feat(infrastructure): allow ssl connection to database --- backend/app/db.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/backend/app/db.py b/backend/app/db.py index c25d356..c2b28fa 100644 --- a/backend/app/db.py +++ b/backend/app/db.py @@ -3,9 +3,8 @@ from typing import AsyncGenerator from fastapi import Depends from fastapi_users.db import SQLAlchemyBaseUserTableUUID, SQLAlchemyUserDatabase -from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker from sqlalchemy.ext.declarative import DeclarativeMeta, declarative_base -from sqlalchemy.orm import sessionmaker DATABASE_URL = os.getenv("DATABASE_URL") if not DATABASE_URL: @@ -15,11 +14,9 @@ if not DATABASE_URL: mariadb_user = os.getenv("MARIADB_USER", "root") mariadb_password = os.getenv("MARIADB_PASSWORD", "strongpassword") #always use SSL except for localhost - i dont want to include certs - ssl_param = "?ssl=true" if mariadb_host != "localhost" else "" - if mariadb_host and mariadb_db and mariadb_user and mariadb_password: # Use MariaDB/MySQL over async driver - DATABASE_URL = f"mysql+asyncmy://{mariadb_user}:{mariadb_password}@{mariadb_host}:{mariadb_port}/{mariadb_db}{ssl_param}" + DATABASE_URL = f"mysql+asyncmy://{mariadb_user}:{mariadb_password}@{mariadb_host}:{mariadb_port}/{mariadb_db}" else: raise Exception("Only MariaDB is supported. Please set the DATABASE_URL environment variable.") @@ -30,12 +27,17 @@ class User(SQLAlchemyBaseUserTableUUID, Base): pass +# Nastavení connect_args pro SSL pouze pokud není localhost +ssl_enabled = os.getenv("MARIADB_HOST", "localhost") != "localhost" +connect_args = {"ssl": {"ssl": True}} if ssl_enabled else {} + engine = create_async_engine( DATABASE_URL, pool_pre_ping=True, echo=os.getenv("SQL_ECHO", "0") == "1", + connect_args=connect_args, ) -async_session_maker = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) +async_session_maker = async_sessionmaker(engine, expire_on_commit=False) async def create_db_and_tables():