from typing import List, Optional from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.models.transaction import Transaction from app.models.categories import Category from app.schemas.transaction import ( TransactionCreate, TransactionRead, TransactionUpdate, ) from app.services.db import get_async_session from app.services.user_service import current_active_user from app.models.user import User router = APIRouter(prefix="/transactions", tags=["transactions"]) def _to_read_model(tx: Transaction) -> TransactionRead: return TransactionRead( id=tx.id, amount=tx.amount, description=tx.description, category_ids=[c.id for c in (tx.categories or [])], ) @router.post("/", response_model=TransactionRead, status_code=status.HTTP_201_CREATED) async def create_transaction( payload: TransactionCreate, session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): tx = Transaction(amount=payload.amount, description=payload.description, user_id=user.id) # Attach categories if provided (and owned by user) if payload.category_ids: res = await session.execute( select(Category).where( Category.user_id == user.id, Category.id.in_(payload.category_ids) ) ) categories = list(res.scalars()) if len(categories) != len(set(payload.category_ids)): raise HTTPException( status_code=400, detail="Duplicate category IDs provided or one or more categories not found" ) tx.categories = categories session.add(tx) await session.commit() await session.refresh(tx) # Ensure categories are loaded await session.refresh(tx, attribute_names=["categories"]) return _to_read_model(tx) @router.get("/", response_model=List[TransactionRead]) async def list_transactions( session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): res = await session.execute( select(Transaction).where(Transaction.user_id == user.id).order_by(Transaction.id) ) txs = list(res.scalars()) # Eagerly load categories for each transaction for tx in txs: await session.refresh(tx, attribute_names=["categories"]) return [_to_read_model(tx) for tx in txs] @router.get("/{transaction_id}", response_model=TransactionRead) async def get_transaction( transaction_id: int, session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): res = await session.execute( select(Transaction).where( Transaction.id == transaction_id, Transaction.user_id == user.id ) ) tx: Optional[Transaction] = res.scalar_one_or_none() if not tx: raise HTTPException(status_code=404, detail="Transaction not found") await session.refresh(tx, attribute_names=["categories"]) return _to_read_model(tx) @router.patch("/{transaction_id}", response_model=TransactionRead) async def update_transaction( transaction_id: int, payload: TransactionUpdate, session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): res = await session.execute( select(Transaction).where( Transaction.id == transaction_id, Transaction.user_id == user.id ) ) tx: Optional[Transaction] = res.scalar_one_or_none() if not tx: raise HTTPException(status_code=404, detail="Transaction not found") if payload.amount is not None: tx.amount = payload.amount if payload.description is not None: tx.description = payload.description if payload.category_ids is not None: # Preload categories to avoid async lazy-load during assignment await session.refresh(tx, attribute_names=["categories"]) if payload.category_ids: res = await session.execute( select(Category).where( Category.user_id == user.id, Category.id.in_(payload.category_ids) ) ) categories = list(res.scalars()) if len(categories) != len(set(payload.category_ids)): raise HTTPException(status_code=400, detail="One or more categories not found") tx.categories = categories else: tx.categories = [] await session.commit() await session.refresh(tx, attribute_names=["categories"]) return _to_read_model(tx) @router.delete("/{transaction_id}", status_code=status.HTTP_204_NO_CONTENT) async def delete_transaction( transaction_id: int, session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): res = await session.execute( select(Transaction).where( Transaction.id == transaction_id, Transaction.user_id == user.id ) ) tx = res.scalar_one_or_none() if not tx: raise HTTPException(status_code=404, detail="Transaction not found") await session.delete(tx) await session.commit() return None @router.post("/{transaction_id}/categories/{category_id}", response_model=TransactionRead) async def assign_category( transaction_id: int, category_id: int, session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): # Load transaction and category ensuring ownership res_tx = await session.execute( select(Transaction).where( Transaction.id == transaction_id, Transaction.user_id == user.id ) ) tx: Optional[Transaction] = res_tx.scalar_one_or_none() if not tx: raise HTTPException(status_code=404, detail="Transaction not found") res_cat = await session.execute( select(Category).where(Category.id == category_id, Category.user_id == user.id) ) cat: Optional[Category] = res_cat.scalar_one_or_none() if not cat: raise HTTPException(status_code=404, detail="Category not found") await session.refresh(tx, attribute_names=["categories"]) if cat not in tx.categories: tx.categories.append(cat) await session.commit() await session.refresh(tx, attribute_names=["categories"]) return _to_read_model(tx) @router.delete("/{transaction_id}/categories/{category_id}", response_model=TransactionRead) async def unassign_category( transaction_id: int, category_id: int, session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): res_tx = await session.execute( select(Transaction).where( Transaction.id == transaction_id, Transaction.user_id == user.id ) ) tx: Optional[Transaction] = res_tx.scalar_one_or_none() if not tx: raise HTTPException(status_code=404, detail="Transaction not found") res_cat = await session.execute( select(Category).where(Category.id == category_id, Category.user_id == user.id) ) cat: Optional[Category] = res_cat.scalar_one_or_none() if not cat: raise HTTPException(status_code=404, detail="Category not found") await session.refresh(tx, attribute_names=["categories"]) if cat in tx.categories: tx.categories.remove(cat) await session.commit() await session.refresh(tx, attribute_names=["categories"]) return _to_read_model(tx)