mirror of
https://github.com/dat515-2025/Group-8.git
synced 2026-03-22 06:57:47 +01:00
Merge pull request #50 from dat515-2025/merge/update_workers
feat(workers): update workers
This commit is contained in:
@@ -2,7 +2,7 @@ from datetime import datetime, timedelta
|
||||
from typing import List, Optional
|
||||
import random
|
||||
|
||||
from fastapi import APIRouter, Depends, Response, status
|
||||
from fastapi import APIRouter, Depends
|
||||
from pydantic import BaseModel, Field, conint, confloat, validator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
@@ -55,8 +55,8 @@ class GeneratedTransaction(BaseModel):
|
||||
|
||||
@router.post("/generate", response_model=List[GeneratedTransaction])
|
||||
async def generate_mock_transactions(
|
||||
options: GenerateOptions,
|
||||
user: User = Depends(current_active_user),
|
||||
options: GenerateOptions,
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
# Seed randomness per user to make results less erratic across multiple calls in quick succession
|
||||
seed = int(datetime.utcnow().timestamp()) ^ int(user.id)
|
||||
@@ -98,46 +98,19 @@ async def generate_mock_transactions(
|
||||
return results
|
||||
|
||||
|
||||
@router.get("/scrape")
|
||||
async def scrape_mock_bank():
|
||||
# 80% of the time: nothing to scrape
|
||||
if random.random() < 0.8:
|
||||
return []
|
||||
|
||||
@router.get("/scrape", response_model=Optional[TransactionRead])
|
||||
async def scrape_mock_bank(
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
# 95% of the time: nothing to scrape
|
||||
if random.random() < 0.95:
|
||||
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
||||
transactions = []
|
||||
count = random.randint(1, 10)
|
||||
for _ in range(count):
|
||||
transactions.append({
|
||||
"amount": round(random.uniform(-200.0, 200.0), 2),
|
||||
"date": (datetime.utcnow().date() - timedelta(days=random.randint(0, 30))).isoformat(),
|
||||
"description": "Mock transaction",
|
||||
})
|
||||
|
||||
# 5% chance: create a new transaction and return it
|
||||
amount = round(random.uniform(-200.0, 200.0), 2)
|
||||
tx_date = datetime.utcnow().date()
|
||||
|
||||
# Optionally attach a random category owned by this user (if any)
|
||||
res = await session.execute(select(Category).where(Category.user_id == user.id))
|
||||
user_categories = list(res.scalars())
|
||||
chosen_categories = []
|
||||
if user_categories:
|
||||
chosen_categories = [random.choice(user_categories)]
|
||||
|
||||
# Build and persist transaction
|
||||
tx = Transaction(
|
||||
amount=amount,
|
||||
description="Mock bank scrape",
|
||||
user_id=user.id,
|
||||
date=tx_date,
|
||||
)
|
||||
if chosen_categories:
|
||||
tx.categories = chosen_categories
|
||||
|
||||
session.add(tx)
|
||||
await session.commit()
|
||||
await session.refresh(tx)
|
||||
await session.refresh(tx, attribute_names=["categories"]) # ensure categories are loaded
|
||||
|
||||
return TransactionRead(
|
||||
id=tx.id,
|
||||
amount=float(tx.amount),
|
||||
description=tx.description,
|
||||
date=tx.date,
|
||||
category_ids=[c.id for c in (tx.categories or [])],
|
||||
)
|
||||
return transactions
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import os
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from app.core.base import Base
|
||||
|
||||
DATABASE_URL = os.getenv("DATABASE_URL")
|
||||
@@ -23,6 +25,7 @@ host_env = os.getenv("MARIADB_HOST", "localhost")
|
||||
ssl_enabled = host_env not in {"localhost", "127.0.0.1"}
|
||||
connect_args = {"ssl": {"ssl": True}} if ssl_enabled else {}
|
||||
|
||||
# Async engine/session for the async parts of the app
|
||||
engine = create_async_engine(
|
||||
DATABASE_URL,
|
||||
pool_pre_ping=True,
|
||||
@@ -30,3 +33,13 @@ engine = create_async_engine(
|
||||
connect_args=connect_args,
|
||||
)
|
||||
async_session_maker = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
# Synchronous engine/session for sync utilities (e.g., bank_scraper)
|
||||
SYNC_DATABASE_URL = DATABASE_URL.replace("+asyncmy", "+pymysql")
|
||||
engine_sync = create_engine(
|
||||
SYNC_DATABASE_URL,
|
||||
pool_pre_ping=True,
|
||||
echo=os.getenv("SQL_ECHO", "0") == "1",
|
||||
connect_args=connect_args,
|
||||
)
|
||||
sync_session_maker = sessionmaker(bind=engine_sync, expire_on_commit=False)
|
||||
|
||||
@@ -7,7 +7,7 @@ from uuid import UUID
|
||||
import httpx
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.core.db import async_session_maker
|
||||
from app.core.db import sync_session_maker
|
||||
from app.models.transaction import Transaction
|
||||
from app.models.user import User
|
||||
|
||||
@@ -20,26 +20,78 @@ CERTS = (
|
||||
)
|
||||
|
||||
|
||||
async def aload_ceska_sporitelna_transactions(user_id: str) -> None:
|
||||
def load_mock_bank_transactions(user_id: str) -> None:
|
||||
try:
|
||||
uid = UUID(str(user_id))
|
||||
except Exception:
|
||||
logger.error("Invalid user_id provided to bank_scraper (async): %r", user_id)
|
||||
logger.error("Invalid user_id provided to bank_scraper (sync): %r", user_id)
|
||||
return
|
||||
|
||||
await _aload_ceska_sporitelna_transactions(uid)
|
||||
_load_mock_bank_transactions(uid)
|
||||
|
||||
|
||||
async def aload_all_ceska_sporitelna_transactions() -> None:
|
||||
async with async_session_maker() as session:
|
||||
result = await session.execute(select(User))
|
||||
users = result.unique().scalars().all()
|
||||
def load_all_mock_bank_transactions() -> None:
|
||||
with sync_session_maker() as session:
|
||||
users = session.execute(select(User)).unique().scalars().all()
|
||||
logger.info("[BankScraper] Starting Mock Bank scrape for all users | count=%d", len(users))
|
||||
|
||||
processed = 0
|
||||
for user in users:
|
||||
try:
|
||||
_load_mock_bank_transactions(user.id)
|
||||
processed += 1
|
||||
except Exception:
|
||||
logger.exception("[BankScraper] Error scraping for user id=%s email=%s", user.id,
|
||||
getattr(user, 'email', None))
|
||||
logger.info("[BankScraper] Finished Mock Bank scrape for all users | processed=%d", processed)
|
||||
|
||||
|
||||
def _load_mock_bank_transactions(user_id: UUID) -> None:
|
||||
with sync_session_maker() as session:
|
||||
user: User | None = session.execute(select(User).where(User.id == user_id)).unique().scalar_one_or_none()
|
||||
if user is None:
|
||||
logger.warning("User not found for id=%s", user_id)
|
||||
return
|
||||
|
||||
transactions = []
|
||||
with httpx.Client() as client:
|
||||
response = client.get("http://127.0.0.1:8000/mock-bank/scrape")
|
||||
if response.status_code != httpx.codes.OK:
|
||||
return
|
||||
for transaction in response.json():
|
||||
transactions.append(
|
||||
Transaction(
|
||||
amount=transaction["amount"],
|
||||
description=transaction.get("description"),
|
||||
date=strptime(transaction["date"], "%Y-%m-%d"),
|
||||
user_id=user_id,
|
||||
)
|
||||
)
|
||||
|
||||
for transaction in transactions:
|
||||
session.add(transaction)
|
||||
session.commit()
|
||||
|
||||
|
||||
def load_ceska_sporitelna_transactions(user_id: str) -> None:
|
||||
try:
|
||||
uid = UUID(str(user_id))
|
||||
except Exception:
|
||||
logger.error("Invalid user_id provided to bank_scraper (sync): %r", user_id)
|
||||
return
|
||||
|
||||
_load_ceska_sporitelna_transactions(uid)
|
||||
|
||||
|
||||
def load_all_ceska_sporitelna_transactions() -> None:
|
||||
with sync_session_maker() as session:
|
||||
users = session.execute(select(User)).unique().scalars().all()
|
||||
logger.info("[BankScraper] Starting CSAS scrape for all users | count=%d", len(users))
|
||||
|
||||
processed = 0
|
||||
for user in users:
|
||||
try:
|
||||
await _aload_ceska_sporitelna_transactions(user.id)
|
||||
_load_ceska_sporitelna_transactions(user.id)
|
||||
processed += 1
|
||||
except Exception:
|
||||
logger.exception("[BankScraper] Error scraping for user id=%s email=%s", user.id,
|
||||
@@ -47,10 +99,9 @@ async def aload_all_ceska_sporitelna_transactions() -> None:
|
||||
logger.info("[BankScraper] Finished CSAS scrape for all users | processed=%d", processed)
|
||||
|
||||
|
||||
async def _aload_ceska_sporitelna_transactions(user_id: UUID) -> None:
|
||||
async with (async_session_maker() as session):
|
||||
result = await session.execute(select(User).where(User.id == user_id))
|
||||
user: User = result.unique().scalar_one_or_none()
|
||||
def _load_ceska_sporitelna_transactions(user_id: UUID) -> None:
|
||||
with sync_session_maker() as session:
|
||||
user: User | None = session.execute(select(User).where(User.id == user_id)).unique().scalar_one_or_none()
|
||||
if user is None:
|
||||
logger.warning("User not found for id=%s", user_id)
|
||||
return
|
||||
@@ -65,8 +116,8 @@ async def _aload_ceska_sporitelna_transactions(user_id: UUID) -> None:
|
||||
|
||||
accounts = []
|
||||
try:
|
||||
async with httpx.AsyncClient(cert=CERTS, timeout=httpx.Timeout(20.0)) as client:
|
||||
response = await client.get(
|
||||
with httpx.Client(cert=CERTS, timeout=httpx.Timeout(20.0)) as client:
|
||||
response = client.get(
|
||||
"https://webapi.developers.erstegroup.com/api/csas/sandbox/v4/account-information/my/accounts?size=10&page=0&sort=iban&order=desc",
|
||||
headers={
|
||||
"Authorization": f"Bearer {cfg['access_token']}",
|
||||
@@ -77,7 +128,7 @@ async def _aload_ceska_sporitelna_transactions(user_id: UUID) -> None:
|
||||
if response.status_code != httpx.codes.OK:
|
||||
return
|
||||
|
||||
for account in response.json()["accounts"]:
|
||||
for account in response.json().get("accounts", []):
|
||||
accounts.append(account)
|
||||
|
||||
except (httpx.HTTPError,) as e:
|
||||
@@ -85,11 +136,13 @@ async def _aload_ceska_sporitelna_transactions(user_id: UUID) -> None:
|
||||
return
|
||||
|
||||
for account in accounts:
|
||||
id = account["id"]
|
||||
acc_id = account.get("id")
|
||||
if not acc_id:
|
||||
continue
|
||||
|
||||
url = f"https://webapi.developers.erstegroup.com/api/csas/sandbox/v4/account-information/my/accounts/{id}/transactions?size=100&page=0&sort=bookingdate&order=desc"
|
||||
async with httpx.AsyncClient(cert=CERTS) as client:
|
||||
response = await client.get(
|
||||
url = f"https://webapi.developers.erstegroup.com/api/csas/sandbox/v4/account-information/my/accounts/{acc_id}/transactions?size=100&page=0&sort=bookingdate&order=desc"
|
||||
with httpx.Client(cert=CERTS) as client:
|
||||
response = client.get(
|
||||
url,
|
||||
headers={
|
||||
"Authorization": f"Bearer {cfg['access_token']}",
|
||||
@@ -100,7 +153,7 @@ async def _aload_ceska_sporitelna_transactions(user_id: UUID) -> None:
|
||||
if response.status_code != httpx.codes.OK:
|
||||
continue
|
||||
|
||||
transactions = response.json()["transactions"]
|
||||
transactions = response.json().get("transactions", [])
|
||||
|
||||
for transaction in transactions:
|
||||
description = transaction.get("entryDetails", {}).get("transactionDetails", {}).get(
|
||||
@@ -108,9 +161,12 @@ async def _aload_ceska_sporitelna_transactions(user_id: UUID) -> None:
|
||||
date_str = transaction.get("bookingDate", {}).get("date")
|
||||
date = strptime(date_str, "%Y-%m-%d") if date_str else None
|
||||
amount = transaction.get("amount", {}).get("value")
|
||||
if transaction.get("creditDebitIndicator") == "DBIT":
|
||||
if transaction.get("creditDebitIndicator") == "DBIT" and amount is not None:
|
||||
amount = -abs(amount)
|
||||
|
||||
if amount is None:
|
||||
continue
|
||||
|
||||
obj = Transaction(
|
||||
amount=amount,
|
||||
description=description,
|
||||
@@ -118,7 +174,4 @@ async def _aload_ceska_sporitelna_transactions(user_id: UUID) -> None:
|
||||
user_id=user_id,
|
||||
)
|
||||
session.add(obj)
|
||||
await session.commit()
|
||||
|
||||
pass
|
||||
pass
|
||||
session.commit()
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
import logging
|
||||
import asyncio
|
||||
import os
|
||||
import smtplib
|
||||
from email.message import EmailMessage
|
||||
|
||||
from celery import shared_task
|
||||
|
||||
import app.services.bank_scraper
|
||||
from app.celery_app import celery_app
|
||||
|
||||
logger = logging.getLogger("celery_tasks")
|
||||
if not logger.handlers:
|
||||
@@ -15,73 +13,7 @@ if not logger.handlers:
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
def run_coro(coro) -> None:
|
||||
"""Run an async coroutine in a fresh event loop without using run_until_complete.
|
||||
Primary strategy runs in a new loop in the current thread. If that fails due to
|
||||
debugger patches (e.g., Bad file descriptor from pydevd_nest_asyncio), fall back
|
||||
to running in a dedicated thread with its own event loop.
|
||||
"""
|
||||
import threading
|
||||
|
||||
def _cleanup_loop(loop):
|
||||
try:
|
||||
pending = [t for t in asyncio.all_tasks(loop) if not t.done()]
|
||||
for t in pending:
|
||||
t.cancel()
|
||||
if pending:
|
||||
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
try:
|
||||
loop.close()
|
||||
finally:
|
||||
asyncio.set_event_loop(None)
|
||||
|
||||
# First attempt: Run in current thread with a fresh event loop
|
||||
try:
|
||||
loop = asyncio.get_event_loop_policy().new_event_loop()
|
||||
try:
|
||||
asyncio.set_event_loop(loop)
|
||||
task = loop.create_task(coro)
|
||||
task.add_done_callback(lambda _t: loop.stop())
|
||||
loop.run_forever()
|
||||
exc = task.exception()
|
||||
if exc:
|
||||
raise exc
|
||||
return
|
||||
finally:
|
||||
_cleanup_loop(loop)
|
||||
except OSError as e:
|
||||
logger.warning("run_coro primary strategy failed (%s). Falling back to thread runner.", e)
|
||||
except Exception:
|
||||
# For any other unexpected errors, try thread fallback as well
|
||||
logger.exception("run_coro primary strategy raised; attempting thread fallback")
|
||||
|
||||
# Fallback: Run in a dedicated thread with its own event loop
|
||||
error = {"exc": None}
|
||||
|
||||
def _thread_target():
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
asyncio.set_event_loop(loop)
|
||||
task = loop.create_task(coro)
|
||||
task.add_done_callback(lambda _t: loop.stop())
|
||||
loop.run_forever()
|
||||
exc = task.exception()
|
||||
if exc:
|
||||
error["exc"] = exc
|
||||
finally:
|
||||
_cleanup_loop(loop)
|
||||
|
||||
th = threading.Thread(target=_thread_target, name="celery-async-runner", daemon=True)
|
||||
th.start()
|
||||
th.join()
|
||||
if error["exc"] is not None:
|
||||
raise error["exc"]
|
||||
|
||||
|
||||
@shared_task(name="workers.send_email")
|
||||
@celery_app.task(name="workers.send_email")
|
||||
def send_email(to: str, subject: str, body: str) -> None:
|
||||
if not (to and subject and body):
|
||||
logger.error("Email task missing fields. to=%r subject=%r body_len=%r", to, subject, len(body) if body else 0)
|
||||
@@ -128,20 +60,27 @@ def send_email(to: str, subject: str, body: str) -> None:
|
||||
host, port, use_tls, use_ssl)
|
||||
|
||||
|
||||
@shared_task(name="workers.load_transactions")
|
||||
@celery_app.task(name="workers.load_transactions")
|
||||
def load_transactions(user_id: str) -> None:
|
||||
if not user_id:
|
||||
logger.error("Load transactions task missing user_id.")
|
||||
return
|
||||
|
||||
run_coro(app.services.bank_scraper.aload_ceska_sporitelna_transactions(user_id))
|
||||
|
||||
# Placeholder for real transaction loading logic
|
||||
logger.info("[Celery] Transactions loaded for user_id=%s", user_id)
|
||||
logger.info("[Celery] Starting load_transactions | user_id=%s", user_id)
|
||||
try:
|
||||
# Use synchronous bank scraper functions directly, mirroring load_all_transactions
|
||||
app.services.bank_scraper.load_mock_bank_transactions(user_id)
|
||||
app.services.bank_scraper.load_ceska_sporitelna_transactions(user_id)
|
||||
except Exception:
|
||||
logger.exception("Failed to load transactions for user_id=%s", user_id)
|
||||
else:
|
||||
logger.info("[Celery] Finished load_transactions | user_id=%s", user_id)
|
||||
|
||||
|
||||
@shared_task(name="workers.load_all_transactions")
|
||||
@celery_app.task(name="workers.load_all_transactions")
|
||||
def load_all_transactions() -> None:
|
||||
logger.info("[Celery] Starting load_all_transactions")
|
||||
run_coro(app.services.bank_scraper.aload_all_ceska_sporitelna_transactions())
|
||||
# Now use synchronous bank scraper functions directly
|
||||
app.services.bank_scraper.load_all_mock_bank_transactions()
|
||||
app.services.bank_scraper.load_all_ceska_sporitelna_transactions()
|
||||
logger.info("[Celery] Finished load_all_transactions")
|
||||
|
||||
Reference in New Issue
Block a user