from typing import List, Optional from datetime import date from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy import select, and_, func from sqlalchemy.ext.asyncio import AsyncSession from app.models.transaction import Transaction from app.models.categories import Category from app.schemas.transaction import ( TransactionCreate, TransactionRead, TransactionUpdate, ) from app.services.db import get_async_session from app.services.user_service import current_active_user from app.models.user import User router = APIRouter(prefix="/transactions", tags=["transactions"]) def _to_read_model(tx: Transaction) -> TransactionRead: return TransactionRead( id=tx.id, amount=tx.amount, description=tx.description, date=tx.date, category_ids=[c.id for c in (tx.categories or [])], ) @router.post("/create", response_model=TransactionRead, status_code=status.HTTP_201_CREATED) async def create_transaction( payload: TransactionCreate, session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): # Build transaction; set `date` only if provided to let DB default apply otherwise tx_kwargs = dict( amount=payload.amount, description=payload.description, user_id=user.id, ) if payload.date is not None: parsed_date = payload.date if isinstance(parsed_date, str): try: parsed_date = date.fromisoformat(parsed_date) except ValueError: raise HTTPException(status_code=400, detail="Invalid date format, expected YYYY-MM-DD") tx_kwargs["date"] = parsed_date tx = Transaction(**tx_kwargs) # Attach categories if provided (and owned by user) if payload.category_ids: res = await session.execute( select(Category).where( Category.user_id == user.id, Category.id.in_(payload.category_ids) ) ) categories = list(res.scalars()) if len(categories) != len(set(payload.category_ids)): raise HTTPException( status_code=400, detail="Duplicate category IDs provided or one or more categories not found" ) tx.categories = categories session.add(tx) await session.commit() await session.refresh(tx) # Ensure categories are loaded await session.refresh(tx, attribute_names=["categories"]) return _to_read_model(tx) @router.get("/", response_model=List[TransactionRead]) async def list_transactions( start_date: Optional[date] = None, end_date: Optional[date] = None, session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): cond = [Transaction.user_id == user.id] if start_date is not None: cond.append(Transaction.date >= start_date) if end_date is not None: cond.append(Transaction.date <= end_date) res = await session.execute( select(Transaction).where(and_(*cond)).order_by(Transaction.date, Transaction.id) ) txs = list(res.scalars()) # Eagerly load categories for each transaction for tx in txs: await session.refresh(tx, attribute_names=["categories"]) return [_to_read_model(tx) for tx in txs] @router.get("/balance_series") async def get_balance_series( start_date: Optional[date] = None, end_date: Optional[date] = None, session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): cond = [Transaction.user_id == user.id] if start_date is not None: cond.append(Transaction.date >= start_date) if end_date is not None: cond.append(Transaction.date <= end_date) res = await session.execute( select(Transaction).where(and_(*cond)).order_by(Transaction.date, Transaction.id) ) txs = list(res.scalars()) # Group by date and accumulate daily = {} for tx in txs: key = tx.date.isoformat() if hasattr(tx.date, 'isoformat') else str(tx.date) daily[key] = daily.get(key, 0.0) + float(tx.amount) # Build cumulative series sorted by date series = [] running = 0.0 for d in sorted(daily.keys()): running += daily[d] series.append({"date": d, "balance": running}) return series @router.get("/{transaction_id}", response_model=TransactionRead) async def get_transaction( transaction_id: int, session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): res = await session.execute( select(Transaction).where( Transaction.id == transaction_id, Transaction.user_id == user.id ) ) tx: Optional[Transaction] = res.scalar_one_or_none() if not tx: raise HTTPException(status_code=404, detail="Transaction not found") await session.refresh(tx, attribute_names=["categories"]) return _to_read_model(tx) @router.patch("/{transaction_id}/edit", response_model=TransactionRead) async def update_transaction( transaction_id: int, payload: TransactionUpdate, session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): res = await session.execute( select(Transaction).where( Transaction.id == transaction_id, Transaction.user_id == user.id ) ) tx: Optional[Transaction] = res.scalar_one_or_none() if not tx: raise HTTPException(status_code=404, detail="Transaction not found") if payload.amount is not None: tx.amount = payload.amount if payload.description is not None: tx.description = payload.description if payload.date is not None: new_date = payload.date if isinstance(new_date, str): try: new_date = date.fromisoformat(new_date) except ValueError: raise HTTPException(status_code=400, detail="Invalid date format, expected YYYY-MM-DD") tx.date = new_date if payload.category_ids is not None: # Preload categories to avoid async lazy-load during assignment await session.refresh(tx, attribute_names=["categories"]) if payload.category_ids: # Check for duplicate category IDs in the payload if len(payload.category_ids) != len(set(payload.category_ids)): raise HTTPException(status_code=400, detail="Duplicate category IDs in payload") res = await session.execute( select(Category).where( Category.user_id == user.id, Category.id.in_(payload.category_ids) ) ) categories = list(res.scalars()) if len(categories) != len(payload.category_ids): raise HTTPException(status_code=400, detail="One or more categories not found") tx.categories = categories else: tx.categories = [] await session.commit() await session.refresh(tx, attribute_names=["categories"]) return _to_read_model(tx) @router.delete("/{transaction_id}/delete", status_code=status.HTTP_204_NO_CONTENT) async def delete_transaction( transaction_id: int, session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): res = await session.execute( select(Transaction).where( Transaction.id == transaction_id, Transaction.user_id == user.id ) ) tx = res.scalar_one_or_none() if not tx: raise HTTPException(status_code=404, detail="Transaction not found") await session.delete(tx) await session.commit() return None @router.post("/{transaction_id}/categories/{category_id}", response_model=TransactionRead) async def assign_category( transaction_id: int, category_id: int, session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): # Load transaction and category ensuring ownership res_tx = await session.execute( select(Transaction).where( Transaction.id == transaction_id, Transaction.user_id == user.id ) ) tx: Optional[Transaction] = res_tx.scalar_one_or_none() if not tx: raise HTTPException(status_code=404, detail="Transaction not found") res_cat = await session.execute( select(Category).where(Category.id == category_id, Category.user_id == user.id) ) cat: Optional[Category] = res_cat.scalar_one_or_none() if not cat: raise HTTPException(status_code=404, detail="Category not found") await session.refresh(tx, attribute_names=["categories"]) if cat not in tx.categories: tx.categories.append(cat) await session.commit() await session.refresh(tx, attribute_names=["categories"]) return _to_read_model(tx) @router.delete("/{transaction_id}/categories/{category_id}", response_model=TransactionRead) async def unassign_category( transaction_id: int, category_id: int, session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): res_tx = await session.execute( select(Transaction).where( Transaction.id == transaction_id, Transaction.user_id == user.id ) ) tx: Optional[Transaction] = res_tx.scalar_one_or_none() if not tx: raise HTTPException(status_code=404, detail="Transaction not found") res_cat = await session.execute( select(Category).where(Category.id == category_id, Category.user_id == user.id) ) cat: Optional[Category] = res_cat.scalar_one_or_none() if not cat: raise HTTPException(status_code=404, detail="Category not found") await session.refresh(tx, attribute_names=["categories"]) if cat in tx.categories: tx.categories.remove(cat) await session.commit() await session.refresh(tx, attribute_names=["categories"]) return _to_read_model(tx)