mirror of
https://github.com/dat515-2025/Group-8.git
synced 2026-03-22 23:20:56 +01:00
144 lines
5.1 KiB
Python
144 lines
5.1 KiB
Python
from datetime import datetime, timedelta
|
|
from typing import List, Optional
|
|
import random
|
|
|
|
from fastapi import APIRouter, Depends, Response, status
|
|
from pydantic import BaseModel, Field, conint, confloat, validator
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.services.db import get_async_session
|
|
from app.services.user_service import current_active_user
|
|
from app.models.user import User
|
|
from app.models.transaction import Transaction
|
|
from app.models.categories import Category
|
|
from app.schemas.transaction import TransactionRead
|
|
|
|
router = APIRouter(prefix="/mock-bank", tags=["mock-bank"])
|
|
|
|
|
|
class GenerateOptions(BaseModel):
|
|
count: conint(strict=True, gt=0) = Field(default=10, description="Number of transactions to generate")
|
|
minAmount: confloat(strict=True) = Field(default=-200.0, description="Minimum transaction amount")
|
|
maxAmount: confloat(strict=True) = Field(default=200.0, description="Maximum transaction amount")
|
|
startDate: Optional[str] = Field(None, description="Earliest date (YYYY-MM-DD)")
|
|
endDate: Optional[str] = Field(None, description="Latest date (YYYY-MM-DD)")
|
|
categoryIds: List[int] = Field(default_factory=list, description="Optional category IDs to assign randomly")
|
|
|
|
@validator("maxAmount")
|
|
def _validate_amounts(cls, v, values):
|
|
min_amt = values.get("minAmount")
|
|
if min_amt is not None and v < min_amt:
|
|
raise ValueError("maxAmount must be greater than or equal to minAmount")
|
|
return v
|
|
|
|
@validator("endDate")
|
|
def _validate_dates(cls, v, values):
|
|
sd = values.get("startDate")
|
|
if v and sd:
|
|
try:
|
|
ed = datetime.strptime(v, "%Y-%m-%d").date()
|
|
st = datetime.strptime(sd, "%Y-%m-%d").date()
|
|
except ValueError:
|
|
raise ValueError("Invalid date format, expected YYYY-MM-DD")
|
|
if ed < st:
|
|
raise ValueError("endDate must be greater than or equal to startDate")
|
|
return v
|
|
|
|
|
|
class GeneratedTransaction(BaseModel):
|
|
amount: float
|
|
date: str # YYYY-MM-DD
|
|
category_ids: List[int] = []
|
|
description: Optional[str] = None
|
|
|
|
|
|
@router.post("/generate", response_model=List[GeneratedTransaction])
|
|
async def generate_mock_transactions(
|
|
options: GenerateOptions,
|
|
user: User = Depends(current_active_user),
|
|
):
|
|
# Seed randomness per user to make results less erratic across multiple calls in quick succession
|
|
seed = int(datetime.utcnow().timestamp()) ^ int(user.id)
|
|
rnd = random.Random(seed)
|
|
|
|
# Determine date range
|
|
if options.startDate:
|
|
start_date = datetime.strptime(options.startDate, "%Y-%m-%d").date()
|
|
else:
|
|
start_date = (datetime.utcnow() - timedelta(days=365)).date()
|
|
if options.endDate:
|
|
end_date = datetime.strptime(options.endDate, "%Y-%m-%d").date()
|
|
else:
|
|
end_date = datetime.utcnow().date()
|
|
|
|
span_days = max(0, (end_date - start_date).days)
|
|
|
|
results: List[GeneratedTransaction] = []
|
|
for _ in range(options.count):
|
|
amount = round(rnd.uniform(options.minAmount, options.maxAmount), 2)
|
|
# Pick a random date in the inclusive range
|
|
rand_day = rnd.randint(0, span_days) if span_days > 0 else 0
|
|
tx_date = start_date + timedelta(days=rand_day)
|
|
# Pick category randomly from provided list, or empty
|
|
if options.categoryIds:
|
|
cat = [rnd.choice(options.categoryIds)]
|
|
else:
|
|
cat = []
|
|
# Optional simple description for flavor
|
|
desc = None
|
|
# Assemble
|
|
results.append(GeneratedTransaction(
|
|
amount=amount,
|
|
date=tx_date.isoformat(),
|
|
category_ids=cat,
|
|
description=desc,
|
|
))
|
|
|
|
return results
|
|
|
|
|
|
|
|
@router.get("/scrape", response_model=Optional[TransactionRead])
|
|
async def scrape_mock_bank(
|
|
session: AsyncSession = Depends(get_async_session),
|
|
user: User = Depends(current_active_user),
|
|
):
|
|
# 95% of the time: nothing to scrape
|
|
if random.random() < 0.95:
|
|
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
|
|
|
# 5% chance: create a new transaction and return it
|
|
amount = round(random.uniform(-200.0, 200.0), 2)
|
|
tx_date = datetime.utcnow().date()
|
|
|
|
# Optionally attach a random category owned by this user (if any)
|
|
res = await session.execute(select(Category).where(Category.user_id == user.id))
|
|
user_categories = list(res.scalars())
|
|
chosen_categories = []
|
|
if user_categories:
|
|
chosen_categories = [random.choice(user_categories)]
|
|
|
|
# Build and persist transaction
|
|
tx = Transaction(
|
|
amount=amount,
|
|
description="Mock bank scrape",
|
|
user_id=user.id,
|
|
date=tx_date,
|
|
)
|
|
if chosen_categories:
|
|
tx.categories = chosen_categories
|
|
|
|
session.add(tx)
|
|
await session.commit()
|
|
await session.refresh(tx)
|
|
await session.refresh(tx, attribute_names=["categories"]) # ensure categories are loaded
|
|
|
|
return TransactionRead(
|
|
id=tx.id,
|
|
amount=float(tx.amount),
|
|
description=tx.description,
|
|
date=tx.date,
|
|
category_ids=[c.id for c in (tx.categories or [])],
|
|
)
|