Merge pull request #50 from dat515-2025/merge/update_workers

feat(workers): update workers
This commit is contained in:
2025-11-12 00:42:16 +01:00
committed by GitHub
4 changed files with 125 additions and 147 deletions

View File

@@ -2,7 +2,7 @@ from datetime import datetime, timedelta
from typing import List, Optional from typing import List, Optional
import random import random
from fastapi import APIRouter, Depends, Response, status from fastapi import APIRouter, Depends
from pydantic import BaseModel, Field, conint, confloat, validator from pydantic import BaseModel, Field, conint, confloat, validator
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
@@ -98,46 +98,19 @@ async def generate_mock_transactions(
return results return results
@router.get("/scrape")
async def scrape_mock_bank():
# 80% of the time: nothing to scrape
if random.random() < 0.8:
return []
@router.get("/scrape", response_model=Optional[TransactionRead]) transactions = []
async def scrape_mock_bank( count = random.randint(1, 10)
session: AsyncSession = Depends(get_async_session), for _ in range(count):
user: User = Depends(current_active_user), transactions.append({
): "amount": round(random.uniform(-200.0, 200.0), 2),
# 95% of the time: nothing to scrape "date": (datetime.utcnow().date() - timedelta(days=random.randint(0, 30))).isoformat(),
if random.random() < 0.95: "description": "Mock transaction",
return Response(status_code=status.HTTP_204_NO_CONTENT) })
# 5% chance: create a new transaction and return it return transactions
amount = round(random.uniform(-200.0, 200.0), 2)
tx_date = datetime.utcnow().date()
# Optionally attach a random category owned by this user (if any)
res = await session.execute(select(Category).where(Category.user_id == user.id))
user_categories = list(res.scalars())
chosen_categories = []
if user_categories:
chosen_categories = [random.choice(user_categories)]
# Build and persist transaction
tx = Transaction(
amount=amount,
description="Mock bank scrape",
user_id=user.id,
date=tx_date,
)
if chosen_categories:
tx.categories = chosen_categories
session.add(tx)
await session.commit()
await session.refresh(tx)
await session.refresh(tx, attribute_names=["categories"]) # ensure categories are loaded
return TransactionRead(
id=tx.id,
amount=float(tx.amount),
description=tx.description,
date=tx.date,
category_ids=[c.id for c in (tx.categories or [])],
)

View File

@@ -1,5 +1,7 @@
import os import os
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from app.core.base import Base from app.core.base import Base
DATABASE_URL = os.getenv("DATABASE_URL") DATABASE_URL = os.getenv("DATABASE_URL")
@@ -23,6 +25,7 @@ host_env = os.getenv("MARIADB_HOST", "localhost")
ssl_enabled = host_env not in {"localhost", "127.0.0.1"} ssl_enabled = host_env not in {"localhost", "127.0.0.1"}
connect_args = {"ssl": {"ssl": True}} if ssl_enabled else {} connect_args = {"ssl": {"ssl": True}} if ssl_enabled else {}
# Async engine/session for the async parts of the app
engine = create_async_engine( engine = create_async_engine(
DATABASE_URL, DATABASE_URL,
pool_pre_ping=True, pool_pre_ping=True,
@@ -30,3 +33,13 @@ engine = create_async_engine(
connect_args=connect_args, connect_args=connect_args,
) )
async_session_maker = async_sessionmaker(engine, expire_on_commit=False) async_session_maker = async_sessionmaker(engine, expire_on_commit=False)
# Synchronous engine/session for sync utilities (e.g., bank_scraper)
SYNC_DATABASE_URL = DATABASE_URL.replace("+asyncmy", "+pymysql")
engine_sync = create_engine(
SYNC_DATABASE_URL,
pool_pre_ping=True,
echo=os.getenv("SQL_ECHO", "0") == "1",
connect_args=connect_args,
)
sync_session_maker = sessionmaker(bind=engine_sync, expire_on_commit=False)

View File

@@ -7,7 +7,7 @@ from uuid import UUID
import httpx import httpx
from sqlalchemy import select from sqlalchemy import select
from app.core.db import async_session_maker from app.core.db import sync_session_maker
from app.models.transaction import Transaction from app.models.transaction import Transaction
from app.models.user import User from app.models.user import User
@@ -20,26 +20,78 @@ CERTS = (
) )
async def aload_ceska_sporitelna_transactions(user_id: str) -> None: def load_mock_bank_transactions(user_id: str) -> None:
try: try:
uid = UUID(str(user_id)) uid = UUID(str(user_id))
except Exception: except Exception:
logger.error("Invalid user_id provided to bank_scraper (async): %r", user_id) logger.error("Invalid user_id provided to bank_scraper (sync): %r", user_id)
return return
await _aload_ceska_sporitelna_transactions(uid) _load_mock_bank_transactions(uid)
async def aload_all_ceska_sporitelna_transactions() -> None: def load_all_mock_bank_transactions() -> None:
async with async_session_maker() as session: with sync_session_maker() as session:
result = await session.execute(select(User)) users = session.execute(select(User)).unique().scalars().all()
users = result.unique().scalars().all() logger.info("[BankScraper] Starting Mock Bank scrape for all users | count=%d", len(users))
processed = 0
for user in users:
try:
_load_mock_bank_transactions(user.id)
processed += 1
except Exception:
logger.exception("[BankScraper] Error scraping for user id=%s email=%s", user.id,
getattr(user, 'email', None))
logger.info("[BankScraper] Finished Mock Bank scrape for all users | processed=%d", processed)
def _load_mock_bank_transactions(user_id: UUID) -> None:
with sync_session_maker() as session:
user: User | None = session.execute(select(User).where(User.id == user_id)).unique().scalar_one_or_none()
if user is None:
logger.warning("User not found for id=%s", user_id)
return
transactions = []
with httpx.Client() as client:
response = client.get("http://127.0.0.1:8000/mock-bank/scrape")
if response.status_code != httpx.codes.OK:
return
for transaction in response.json():
transactions.append(
Transaction(
amount=transaction["amount"],
description=transaction.get("description"),
date=strptime(transaction["date"], "%Y-%m-%d"),
user_id=user_id,
)
)
for transaction in transactions:
session.add(transaction)
session.commit()
def load_ceska_sporitelna_transactions(user_id: str) -> None:
try:
uid = UUID(str(user_id))
except Exception:
logger.error("Invalid user_id provided to bank_scraper (sync): %r", user_id)
return
_load_ceska_sporitelna_transactions(uid)
def load_all_ceska_sporitelna_transactions() -> None:
with sync_session_maker() as session:
users = session.execute(select(User)).unique().scalars().all()
logger.info("[BankScraper] Starting CSAS scrape for all users | count=%d", len(users)) logger.info("[BankScraper] Starting CSAS scrape for all users | count=%d", len(users))
processed = 0 processed = 0
for user in users: for user in users:
try: try:
await _aload_ceska_sporitelna_transactions(user.id) _load_ceska_sporitelna_transactions(user.id)
processed += 1 processed += 1
except Exception: except Exception:
logger.exception("[BankScraper] Error scraping for user id=%s email=%s", user.id, logger.exception("[BankScraper] Error scraping for user id=%s email=%s", user.id,
@@ -47,10 +99,9 @@ async def aload_all_ceska_sporitelna_transactions() -> None:
logger.info("[BankScraper] Finished CSAS scrape for all users | processed=%d", processed) logger.info("[BankScraper] Finished CSAS scrape for all users | processed=%d", processed)
async def _aload_ceska_sporitelna_transactions(user_id: UUID) -> None: def _load_ceska_sporitelna_transactions(user_id: UUID) -> None:
async with (async_session_maker() as session): with sync_session_maker() as session:
result = await session.execute(select(User).where(User.id == user_id)) user: User | None = session.execute(select(User).where(User.id == user_id)).unique().scalar_one_or_none()
user: User = result.unique().scalar_one_or_none()
if user is None: if user is None:
logger.warning("User not found for id=%s", user_id) logger.warning("User not found for id=%s", user_id)
return return
@@ -65,8 +116,8 @@ async def _aload_ceska_sporitelna_transactions(user_id: UUID) -> None:
accounts = [] accounts = []
try: try:
async with httpx.AsyncClient(cert=CERTS, timeout=httpx.Timeout(20.0)) as client: with httpx.Client(cert=CERTS, timeout=httpx.Timeout(20.0)) as client:
response = await client.get( response = client.get(
"https://webapi.developers.erstegroup.com/api/csas/sandbox/v4/account-information/my/accounts?size=10&page=0&sort=iban&order=desc", "https://webapi.developers.erstegroup.com/api/csas/sandbox/v4/account-information/my/accounts?size=10&page=0&sort=iban&order=desc",
headers={ headers={
"Authorization": f"Bearer {cfg['access_token']}", "Authorization": f"Bearer {cfg['access_token']}",
@@ -77,7 +128,7 @@ async def _aload_ceska_sporitelna_transactions(user_id: UUID) -> None:
if response.status_code != httpx.codes.OK: if response.status_code != httpx.codes.OK:
return return
for account in response.json()["accounts"]: for account in response.json().get("accounts", []):
accounts.append(account) accounts.append(account)
except (httpx.HTTPError,) as e: except (httpx.HTTPError,) as e:
@@ -85,11 +136,13 @@ async def _aload_ceska_sporitelna_transactions(user_id: UUID) -> None:
return return
for account in accounts: for account in accounts:
id = account["id"] acc_id = account.get("id")
if not acc_id:
continue
url = f"https://webapi.developers.erstegroup.com/api/csas/sandbox/v4/account-information/my/accounts/{id}/transactions?size=100&page=0&sort=bookingdate&order=desc" url = f"https://webapi.developers.erstegroup.com/api/csas/sandbox/v4/account-information/my/accounts/{acc_id}/transactions?size=100&page=0&sort=bookingdate&order=desc"
async with httpx.AsyncClient(cert=CERTS) as client: with httpx.Client(cert=CERTS) as client:
response = await client.get( response = client.get(
url, url,
headers={ headers={
"Authorization": f"Bearer {cfg['access_token']}", "Authorization": f"Bearer {cfg['access_token']}",
@@ -100,7 +153,7 @@ async def _aload_ceska_sporitelna_transactions(user_id: UUID) -> None:
if response.status_code != httpx.codes.OK: if response.status_code != httpx.codes.OK:
continue continue
transactions = response.json()["transactions"] transactions = response.json().get("transactions", [])
for transaction in transactions: for transaction in transactions:
description = transaction.get("entryDetails", {}).get("transactionDetails", {}).get( description = transaction.get("entryDetails", {}).get("transactionDetails", {}).get(
@@ -108,9 +161,12 @@ async def _aload_ceska_sporitelna_transactions(user_id: UUID) -> None:
date_str = transaction.get("bookingDate", {}).get("date") date_str = transaction.get("bookingDate", {}).get("date")
date = strptime(date_str, "%Y-%m-%d") if date_str else None date = strptime(date_str, "%Y-%m-%d") if date_str else None
amount = transaction.get("amount", {}).get("value") amount = transaction.get("amount", {}).get("value")
if transaction.get("creditDebitIndicator") == "DBIT": if transaction.get("creditDebitIndicator") == "DBIT" and amount is not None:
amount = -abs(amount) amount = -abs(amount)
if amount is None:
continue
obj = Transaction( obj = Transaction(
amount=amount, amount=amount,
description=description, description=description,
@@ -118,7 +174,4 @@ async def _aload_ceska_sporitelna_transactions(user_id: UUID) -> None:
user_id=user_id, user_id=user_id,
) )
session.add(obj) session.add(obj)
await session.commit() session.commit()
pass
pass

View File

@@ -1,12 +1,10 @@
import logging import logging
import asyncio
import os import os
import smtplib import smtplib
from email.message import EmailMessage from email.message import EmailMessage
from celery import shared_task
import app.services.bank_scraper import app.services.bank_scraper
from app.celery_app import celery_app
logger = logging.getLogger("celery_tasks") logger = logging.getLogger("celery_tasks")
if not logger.handlers: if not logger.handlers:
@@ -15,73 +13,7 @@ if not logger.handlers:
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
def run_coro(coro) -> None: @celery_app.task(name="workers.send_email")
"""Run an async coroutine in a fresh event loop without using run_until_complete.
Primary strategy runs in a new loop in the current thread. If that fails due to
debugger patches (e.g., Bad file descriptor from pydevd_nest_asyncio), fall back
to running in a dedicated thread with its own event loop.
"""
import threading
def _cleanup_loop(loop):
try:
pending = [t for t in asyncio.all_tasks(loop) if not t.done()]
for t in pending:
t.cancel()
if pending:
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
except Exception:
pass
finally:
try:
loop.close()
finally:
asyncio.set_event_loop(None)
# First attempt: Run in current thread with a fresh event loop
try:
loop = asyncio.get_event_loop_policy().new_event_loop()
try:
asyncio.set_event_loop(loop)
task = loop.create_task(coro)
task.add_done_callback(lambda _t: loop.stop())
loop.run_forever()
exc = task.exception()
if exc:
raise exc
return
finally:
_cleanup_loop(loop)
except OSError as e:
logger.warning("run_coro primary strategy failed (%s). Falling back to thread runner.", e)
except Exception:
# For any other unexpected errors, try thread fallback as well
logger.exception("run_coro primary strategy raised; attempting thread fallback")
# Fallback: Run in a dedicated thread with its own event loop
error = {"exc": None}
def _thread_target():
loop = asyncio.new_event_loop()
try:
asyncio.set_event_loop(loop)
task = loop.create_task(coro)
task.add_done_callback(lambda _t: loop.stop())
loop.run_forever()
exc = task.exception()
if exc:
error["exc"] = exc
finally:
_cleanup_loop(loop)
th = threading.Thread(target=_thread_target, name="celery-async-runner", daemon=True)
th.start()
th.join()
if error["exc"] is not None:
raise error["exc"]
@shared_task(name="workers.send_email")
def send_email(to: str, subject: str, body: str) -> None: def send_email(to: str, subject: str, body: str) -> None:
if not (to and subject and body): if not (to and subject and body):
logger.error("Email task missing fields. to=%r subject=%r body_len=%r", to, subject, len(body) if body else 0) logger.error("Email task missing fields. to=%r subject=%r body_len=%r", to, subject, len(body) if body else 0)
@@ -128,20 +60,27 @@ def send_email(to: str, subject: str, body: str) -> None:
host, port, use_tls, use_ssl) host, port, use_tls, use_ssl)
@shared_task(name="workers.load_transactions") @celery_app.task(name="workers.load_transactions")
def load_transactions(user_id: str) -> None: def load_transactions(user_id: str) -> None:
if not user_id: if not user_id:
logger.error("Load transactions task missing user_id.") logger.error("Load transactions task missing user_id.")
return return
run_coro(app.services.bank_scraper.aload_ceska_sporitelna_transactions(user_id)) logger.info("[Celery] Starting load_transactions | user_id=%s", user_id)
try:
# Placeholder for real transaction loading logic # Use synchronous bank scraper functions directly, mirroring load_all_transactions
logger.info("[Celery] Transactions loaded for user_id=%s", user_id) app.services.bank_scraper.load_mock_bank_transactions(user_id)
app.services.bank_scraper.load_ceska_sporitelna_transactions(user_id)
except Exception:
logger.exception("Failed to load transactions for user_id=%s", user_id)
else:
logger.info("[Celery] Finished load_transactions | user_id=%s", user_id)
@shared_task(name="workers.load_all_transactions") @celery_app.task(name="workers.load_all_transactions")
def load_all_transactions() -> None: def load_all_transactions() -> None:
logger.info("[Celery] Starting load_all_transactions") logger.info("[Celery] Starting load_all_transactions")
run_coro(app.services.bank_scraper.aload_all_ceska_sporitelna_transactions()) # Now use synchronous bank scraper functions directly
app.services.bank_scraper.load_all_mock_bank_transactions()
app.services.bank_scraper.load_all_ceska_sporitelna_transactions()
logger.info("[Celery] Finished load_all_transactions") logger.info("[Celery] Finished load_all_transactions")