mirror of
https://github.com/dat515-2025/Group-8.git
synced 2026-03-22 23:20:56 +01:00
281 lines
9.8 KiB
Python
281 lines
9.8 KiB
Python
from typing import List, Optional
|
|
from datetime import date
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, status
|
|
from sqlalchemy import select, and_, func
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.models.transaction import Transaction
|
|
from app.models.categories import Category
|
|
from app.schemas.transaction import (
|
|
TransactionCreate,
|
|
TransactionRead,
|
|
TransactionUpdate,
|
|
)
|
|
from app.services.db import get_async_session
|
|
from app.services.user_service import current_active_user
|
|
from app.models.user import User
|
|
|
|
router = APIRouter(prefix="/transactions", tags=["transactions"])
|
|
|
|
|
|
def _to_read_model(tx: Transaction) -> TransactionRead:
|
|
return TransactionRead(
|
|
id=tx.id,
|
|
amount=tx.amount,
|
|
description=tx.description,
|
|
date=tx.date,
|
|
category_ids=[c.id for c in (tx.categories or [])],
|
|
)
|
|
|
|
|
|
@router.post("/create", response_model=TransactionRead, status_code=status.HTTP_201_CREATED)
|
|
async def create_transaction(
|
|
payload: TransactionCreate,
|
|
session: AsyncSession = Depends(get_async_session),
|
|
user: User = Depends(current_active_user),
|
|
):
|
|
# Build transaction; set `date` only if provided to let DB default apply otherwise
|
|
tx_kwargs = dict(
|
|
amount=payload.amount,
|
|
description=payload.description,
|
|
user_id=user.id,
|
|
)
|
|
if payload.date is not None:
|
|
parsed_date = payload.date
|
|
if isinstance(parsed_date, str):
|
|
try:
|
|
parsed_date = date.fromisoformat(parsed_date)
|
|
except ValueError:
|
|
raise HTTPException(status_code=400, detail="Invalid date format, expected YYYY-MM-DD")
|
|
tx_kwargs["date"] = parsed_date
|
|
tx = Transaction(**tx_kwargs)
|
|
|
|
# Attach categories if provided (and owned by user)
|
|
if payload.category_ids:
|
|
res = await session.execute(
|
|
select(Category).where(
|
|
Category.user_id == user.id, Category.id.in_(payload.category_ids)
|
|
)
|
|
)
|
|
categories = list(res.scalars())
|
|
if len(categories) != len(set(payload.category_ids)):
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Duplicate category IDs provided or one or more categories not found"
|
|
)
|
|
tx.categories = categories
|
|
|
|
session.add(tx)
|
|
await session.commit()
|
|
await session.refresh(tx)
|
|
# Ensure categories are loaded
|
|
await session.refresh(tx, attribute_names=["categories"])
|
|
return _to_read_model(tx)
|
|
|
|
|
|
@router.get("/", response_model=List[TransactionRead])
|
|
async def list_transactions(
|
|
start_date: Optional[date] = None,
|
|
end_date: Optional[date] = None,
|
|
session: AsyncSession = Depends(get_async_session),
|
|
user: User = Depends(current_active_user),
|
|
):
|
|
cond = [Transaction.user_id == user.id]
|
|
if start_date is not None:
|
|
cond.append(Transaction.date >= start_date)
|
|
if end_date is not None:
|
|
cond.append(Transaction.date <= end_date)
|
|
res = await session.execute(
|
|
select(Transaction).where(and_(*cond)).order_by(Transaction.date, Transaction.id)
|
|
)
|
|
txs = list(res.scalars())
|
|
# Eagerly load categories for each transaction
|
|
for tx in txs:
|
|
await session.refresh(tx, attribute_names=["categories"])
|
|
return [_to_read_model(tx) for tx in txs]
|
|
|
|
|
|
@router.get("/balance_series")
|
|
async def get_balance_series(
|
|
start_date: Optional[date] = None,
|
|
end_date: Optional[date] = None,
|
|
session: AsyncSession = Depends(get_async_session),
|
|
user: User = Depends(current_active_user),
|
|
):
|
|
cond = [Transaction.user_id == user.id]
|
|
if start_date is not None:
|
|
cond.append(Transaction.date >= start_date)
|
|
if end_date is not None:
|
|
cond.append(Transaction.date <= end_date)
|
|
res = await session.execute(
|
|
select(Transaction).where(and_(*cond)).order_by(Transaction.date, Transaction.id)
|
|
)
|
|
txs = list(res.scalars())
|
|
# Group by date and accumulate
|
|
daily = {}
|
|
for tx in txs:
|
|
key = tx.date.isoformat() if hasattr(tx.date, 'isoformat') else str(tx.date)
|
|
daily[key] = daily.get(key, 0.0) + float(tx.amount)
|
|
# Build cumulative series sorted by date
|
|
series = []
|
|
running = 0.0
|
|
for d in sorted(daily.keys()):
|
|
running += daily[d]
|
|
series.append({"date": d, "balance": running})
|
|
return series
|
|
|
|
|
|
@router.get("/{transaction_id}", response_model=TransactionRead)
|
|
async def get_transaction(
|
|
transaction_id: int,
|
|
session: AsyncSession = Depends(get_async_session),
|
|
user: User = Depends(current_active_user),
|
|
):
|
|
res = await session.execute(
|
|
select(Transaction).where(
|
|
Transaction.id == transaction_id, Transaction.user_id == user.id
|
|
)
|
|
)
|
|
tx: Optional[Transaction] = res.scalar_one_or_none()
|
|
if not tx:
|
|
raise HTTPException(status_code=404, detail="Transaction not found")
|
|
await session.refresh(tx, attribute_names=["categories"])
|
|
return _to_read_model(tx)
|
|
|
|
|
|
@router.patch("/{transaction_id}/edit", response_model=TransactionRead)
|
|
async def update_transaction(
|
|
transaction_id: int,
|
|
payload: TransactionUpdate,
|
|
session: AsyncSession = Depends(get_async_session),
|
|
user: User = Depends(current_active_user),
|
|
):
|
|
res = await session.execute(
|
|
select(Transaction).where(
|
|
Transaction.id == transaction_id, Transaction.user_id == user.id
|
|
)
|
|
)
|
|
tx: Optional[Transaction] = res.scalar_one_or_none()
|
|
if not tx:
|
|
raise HTTPException(status_code=404, detail="Transaction not found")
|
|
|
|
if payload.amount is not None:
|
|
tx.amount = payload.amount
|
|
if payload.description is not None:
|
|
tx.description = payload.description
|
|
if payload.date is not None:
|
|
new_date = payload.date
|
|
if isinstance(new_date, str):
|
|
try:
|
|
new_date = date.fromisoformat(new_date)
|
|
except ValueError:
|
|
raise HTTPException(status_code=400, detail="Invalid date format, expected YYYY-MM-DD")
|
|
tx.date = new_date
|
|
|
|
if payload.category_ids is not None:
|
|
# Preload categories to avoid async lazy-load during assignment
|
|
await session.refresh(tx, attribute_names=["categories"])
|
|
if payload.category_ids:
|
|
# Check for duplicate category IDs in the payload
|
|
if len(payload.category_ids) != len(set(payload.category_ids)):
|
|
raise HTTPException(status_code=400, detail="Duplicate category IDs in payload")
|
|
res = await session.execute(
|
|
select(Category).where(
|
|
Category.user_id == user.id, Category.id.in_(payload.category_ids)
|
|
)
|
|
)
|
|
categories = list(res.scalars())
|
|
if len(categories) != len(payload.category_ids):
|
|
raise HTTPException(status_code=400, detail="One or more categories not found")
|
|
tx.categories = categories
|
|
else:
|
|
tx.categories = []
|
|
|
|
await session.commit()
|
|
await session.refresh(tx, attribute_names=["categories"])
|
|
return _to_read_model(tx)
|
|
|
|
|
|
@router.delete("/{transaction_id}/delete", status_code=status.HTTP_204_NO_CONTENT)
|
|
async def delete_transaction(
|
|
transaction_id: int,
|
|
session: AsyncSession = Depends(get_async_session),
|
|
user: User = Depends(current_active_user),
|
|
):
|
|
res = await session.execute(
|
|
select(Transaction).where(
|
|
Transaction.id == transaction_id, Transaction.user_id == user.id
|
|
)
|
|
)
|
|
tx = res.scalar_one_or_none()
|
|
if not tx:
|
|
raise HTTPException(status_code=404, detail="Transaction not found")
|
|
|
|
await session.delete(tx)
|
|
await session.commit()
|
|
return None
|
|
|
|
|
|
@router.post("/{transaction_id}/categories/{category_id}", response_model=TransactionRead)
|
|
async def assign_category(
|
|
transaction_id: int,
|
|
category_id: int,
|
|
session: AsyncSession = Depends(get_async_session),
|
|
user: User = Depends(current_active_user),
|
|
):
|
|
# Load transaction and category ensuring ownership
|
|
res_tx = await session.execute(
|
|
select(Transaction).where(
|
|
Transaction.id == transaction_id, Transaction.user_id == user.id
|
|
)
|
|
)
|
|
tx: Optional[Transaction] = res_tx.scalar_one_or_none()
|
|
if not tx:
|
|
raise HTTPException(status_code=404, detail="Transaction not found")
|
|
|
|
res_cat = await session.execute(
|
|
select(Category).where(Category.id == category_id, Category.user_id == user.id)
|
|
)
|
|
cat: Optional[Category] = res_cat.scalar_one_or_none()
|
|
if not cat:
|
|
raise HTTPException(status_code=404, detail="Category not found")
|
|
|
|
await session.refresh(tx, attribute_names=["categories"])
|
|
if cat not in tx.categories:
|
|
tx.categories.append(cat)
|
|
await session.commit()
|
|
await session.refresh(tx, attribute_names=["categories"])
|
|
return _to_read_model(tx)
|
|
|
|
|
|
@router.delete("/{transaction_id}/categories/{category_id}", response_model=TransactionRead)
|
|
async def unassign_category(
|
|
transaction_id: int,
|
|
category_id: int,
|
|
session: AsyncSession = Depends(get_async_session),
|
|
user: User = Depends(current_active_user),
|
|
):
|
|
res_tx = await session.execute(
|
|
select(Transaction).where(
|
|
Transaction.id == transaction_id, Transaction.user_id == user.id
|
|
)
|
|
)
|
|
tx: Optional[Transaction] = res_tx.scalar_one_or_none()
|
|
if not tx:
|
|
raise HTTPException(status_code=404, detail="Transaction not found")
|
|
|
|
res_cat = await session.execute(
|
|
select(Category).where(Category.id == category_id, Category.user_id == user.id)
|
|
)
|
|
cat: Optional[Category] = res_cat.scalar_one_or_none()
|
|
if not cat:
|
|
raise HTTPException(status_code=404, detail="Category not found")
|
|
|
|
await session.refresh(tx, attribute_names=["categories"])
|
|
if cat in tx.categories:
|
|
tx.categories.remove(cat)
|
|
await session.commit()
|
|
await session.refresh(tx, attribute_names=["categories"])
|
|
return _to_read_model(tx)
|