mirror of
https://github.com/dat515-2025/Group-8.git
synced 2026-03-22 15:12:08 +01:00
fix(backend): implemented jwt token invalidation so users cannot use it after expiry
This commit is contained in:
@@ -24,6 +24,23 @@ async def delete_me(
|
|||||||
await user_manager.delete(user)
|
await user_manager.delete(user)
|
||||||
|
|
||||||
# Keep existing paths as-is under /auth/* and /users/*
|
# Keep existing paths as-is under /auth/* and /users/*
|
||||||
|
from fastapi import Request, Response
|
||||||
|
from app.core.security import revoke_token, extract_bearer_token
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/auth/jwt/logout",
|
||||||
|
status_code=status.HTTP_204_NO_CONTENT,
|
||||||
|
tags=["auth"],
|
||||||
|
summary="Log out and revoke current token",
|
||||||
|
)
|
||||||
|
async def custom_logout(request: Request) -> Response:
|
||||||
|
"""Revoke the current bearer token so it cannot be used anymore."""
|
||||||
|
token = extract_bearer_token(request)
|
||||||
|
if token:
|
||||||
|
revoke_token(token)
|
||||||
|
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
||||||
|
|
||||||
router.include_router(
|
router.include_router(
|
||||||
fastapi_users.get_auth_router(auth_backend), prefix="/auth/jwt", tags=["auth"]
|
fastapi_users.get_auth_router(auth_backend), prefix="/auth/jwt", tags=["auth"]
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -16,6 +16,8 @@ from app.api.csas import router as csas_router
|
|||||||
from app.api.categories import router as categories_router
|
from app.api.categories import router as categories_router
|
||||||
from app.api.transactions import router as transactions_router
|
from app.api.transactions import router as transactions_router
|
||||||
from app.services.user_service import auth_backend, current_active_verified_user, fastapi_users, get_oauth_provider, UserManager, get_jwt_strategy
|
from app.services.user_service import auth_backend, current_active_verified_user, fastapi_users, get_oauth_provider, UserManager, get_jwt_strategy
|
||||||
|
from app.core.security import extract_bearer_token, is_token_revoked, decode_and_verify_jwt
|
||||||
|
from app.services.user_service import SECRET
|
||||||
|
|
||||||
|
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
@@ -49,6 +51,24 @@ fastApi.include_router(categories_router)
|
|||||||
fastApi.include_router(transactions_router)
|
fastApi.include_router(transactions_router)
|
||||||
|
|
||||||
logging.basicConfig(filename='app.log', level=logging.INFO, format='%(asctime)s %(message)s')
|
logging.basicConfig(filename='app.log', level=logging.INFO, format='%(asctime)s %(message)s')
|
||||||
|
@fastApi.middleware("http")
|
||||||
|
async def auth_guard(request: Request, call_next):
|
||||||
|
# Enforce revoked/expired JWTs are rejected globally
|
||||||
|
token = extract_bearer_token(request)
|
||||||
|
if token:
|
||||||
|
# Deny if token is revoked
|
||||||
|
if is_token_revoked(token):
|
||||||
|
from fastapi import Response, status as _status
|
||||||
|
return Response(status_code=_status.HTTP_401_UNAUTHORIZED)
|
||||||
|
# Deny if token is expired or invalid
|
||||||
|
try:
|
||||||
|
decode_and_verify_jwt(token, SECRET)
|
||||||
|
except Exception:
|
||||||
|
from fastapi import Response, status as _status
|
||||||
|
return Response(status_code=_status.HTTP_401_UNAUTHORIZED)
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
|
||||||
@fastApi.middleware("http")
|
@fastApi.middleware("http")
|
||||||
async def log_traffic(request: Request, call_next):
|
async def log_traffic(request: Request, call_next):
|
||||||
start_time = datetime.now()
|
start_time = datetime.now()
|
||||||
|
|||||||
45
7project/backend/app/core/security.py
Normal file
45
7project/backend/app/core/security.py
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
from typing import Optional
|
||||||
|
import re
|
||||||
|
import jwt
|
||||||
|
from fastapi import Request
|
||||||
|
|
||||||
|
# Simple in-memory revocation store. In production, consider Redis or database.
|
||||||
|
_REVOKED_TOKENS: set[str] = set()
|
||||||
|
|
||||||
|
# Bearer token regex
|
||||||
|
_BEARER_RE = re.compile(r"^[Bb]earer\s+(.+)$")
|
||||||
|
|
||||||
|
|
||||||
|
def extract_bearer_token(request: Request) -> Optional[str]:
|
||||||
|
auth = request.headers.get("authorization")
|
||||||
|
if not auth:
|
||||||
|
return None
|
||||||
|
m = _BEARER_RE.match(auth)
|
||||||
|
if not m:
|
||||||
|
return None
|
||||||
|
return m.group(1).strip()
|
||||||
|
|
||||||
|
|
||||||
|
def revoke_token(token: str) -> None:
|
||||||
|
if token:
|
||||||
|
_REVOKED_TOKENS.add(token)
|
||||||
|
|
||||||
|
|
||||||
|
def is_token_revoked(token: str) -> bool:
|
||||||
|
return token in _REVOKED_TOKENS
|
||||||
|
|
||||||
|
|
||||||
|
def decode_and_verify_jwt(token: str, secret: str) -> dict:
|
||||||
|
"""
|
||||||
|
Decode the JWT using the shared secret, verifying expiration and signature.
|
||||||
|
Audience is not verified here to be compatible with fastapi-users default tokens.
|
||||||
|
Raises jwt.ExpiredSignatureError if expired.
|
||||||
|
Raises jwt.InvalidTokenError for other issues.
|
||||||
|
Returns the decoded payload dict on success.
|
||||||
|
"""
|
||||||
|
return jwt.decode(
|
||||||
|
token,
|
||||||
|
secret,
|
||||||
|
algorithms=["HS256"],
|
||||||
|
options={"verify_aud": False},
|
||||||
|
) # verify_exp is True by default
|
||||||
@@ -1,2 +1,5 @@
|
|||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
pythonpath = "."
|
pythonpath = "."
|
||||||
|
asyncio_mode = "auto"
|
||||||
|
asyncio_default_fixture_loop_scope = "session"
|
||||||
|
asyncio_default_test_loop_scope = "session"
|
||||||
@@ -1,7 +1,9 @@
|
|||||||
import sys
|
import sys
|
||||||
|
import uuid
|
||||||
import types
|
import types
|
||||||
import pytest
|
import pytest
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
from httpx import AsyncClient, ASGITransport
|
||||||
|
|
||||||
# Stub sentry_sdk to avoid optional dependency issues during import of app
|
# Stub sentry_sdk to avoid optional dependency issues during import of app
|
||||||
stub = types.ModuleType("sentry_sdk")
|
stub = types.ModuleType("sentry_sdk")
|
||||||
@@ -20,3 +22,50 @@ def fastapi_app():
|
|||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def client(fastapi_app):
|
def client(fastapi_app):
|
||||||
return TestClient(fastapi_app, raise_server_exceptions=True)
|
return TestClient(fastapi_app, raise_server_exceptions=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
async def test_user(fastapi_app):
|
||||||
|
"""
|
||||||
|
Creates a new user asynchronously and returns their credentials.
|
||||||
|
Does NOT log them in.
|
||||||
|
Using AsyncClient with ASGITransport avoids event loop conflicts with DB connections.
|
||||||
|
"""
|
||||||
|
unique_email = f"testuser_{uuid.uuid4()}@example.com"
|
||||||
|
password = "a_strong_password"
|
||||||
|
user_payload = {"email": unique_email, "password": password}
|
||||||
|
|
||||||
|
transport = ASGITransport(app=fastapi_app, raise_app_exceptions=True)
|
||||||
|
async with AsyncClient(transport=transport, base_url="http://testserver") as ac:
|
||||||
|
response = await ac.post("/auth/register", json=user_payload)
|
||||||
|
assert response.status_code == 201
|
||||||
|
|
||||||
|
return {"username": unique_email, "password": password}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def authenticated_client(client: TestClient):
|
||||||
|
"""
|
||||||
|
Creates a new user, logs them in, and returns a client
|
||||||
|
with the authorization headers already set.
|
||||||
|
"""
|
||||||
|
# 1. Create a unique user
|
||||||
|
unique_email = f"testuser_{uuid.uuid4()}@example.com"
|
||||||
|
password = "a_strong_password"
|
||||||
|
user_payload = {"email": unique_email, "password": password}
|
||||||
|
|
||||||
|
register_resp = client.post("/auth/register", json=user_payload)
|
||||||
|
assert register_resp.status_code == 201
|
||||||
|
|
||||||
|
# 2. Log in to get the token
|
||||||
|
login_payload = {"username": unique_email, "password": password}
|
||||||
|
login_resp = client.post("/auth/jwt/login", data=login_payload)
|
||||||
|
token = login_resp.json()["access_token"]
|
||||||
|
|
||||||
|
# 3. Set the authorization header for subsequent requests
|
||||||
|
client.headers = {"Authorization": f"Bearer {token}"}
|
||||||
|
|
||||||
|
yield client
|
||||||
|
|
||||||
|
# Teardown: Clear headers after the test
|
||||||
|
client.headers.pop("Authorization", None)
|
||||||
|
|||||||
@@ -1,3 +1,6 @@
|
|||||||
|
import pytest
|
||||||
|
import uuid
|
||||||
|
from httpx import AsyncClient, ASGITransport
|
||||||
from fastapi import status
|
from fastapi import status
|
||||||
|
|
||||||
|
|
||||||
@@ -13,3 +16,83 @@ def test_e2e_minimal_auth_flow(client):
|
|||||||
# 3) Protected endpoint should not be accessible without token
|
# 3) Protected endpoint should not be accessible without token
|
||||||
me = client.get("/users/me")
|
me = client.get("/users/me")
|
||||||
assert me.status_code in (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)
|
assert me.status_code in (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_e2e_full_user_lifecycle(fastapi_app, test_user):
|
||||||
|
# Use an AsyncClient with ASGITransport for async tests
|
||||||
|
transport = ASGITransport(app=fastapi_app, raise_app_exceptions=True)
|
||||||
|
async with AsyncClient(transport=transport, base_url="http://testserver") as ac:
|
||||||
|
login_payload = test_user
|
||||||
|
|
||||||
|
# 1. Log in with the new credentials
|
||||||
|
login_resp = await ac.post("/auth/jwt/login", data=login_payload)
|
||||||
|
assert login_resp.status_code == status.HTTP_200_OK
|
||||||
|
token = login_resp.json()["access_token"]
|
||||||
|
headers = {"Authorization": f"Bearer {token}"}
|
||||||
|
|
||||||
|
# 2. Access a protected endpoint
|
||||||
|
me_resp = await ac.get("/users/me", headers=headers)
|
||||||
|
assert me_resp.status_code == status.HTTP_200_OK
|
||||||
|
assert me_resp.json()["email"] == test_user["username"]
|
||||||
|
|
||||||
|
# 3. Update the user's profile
|
||||||
|
update_payload = {"first_name": "Test"}
|
||||||
|
patch_resp = await ac.patch("/users/me", json=update_payload, headers=headers)
|
||||||
|
assert patch_resp.status_code == status.HTTP_200_OK
|
||||||
|
assert patch_resp.json()["first_name"] == "Test"
|
||||||
|
|
||||||
|
# 4. Log out
|
||||||
|
logout_resp = await ac.post("/auth/jwt/logout", headers=headers)
|
||||||
|
assert logout_resp.status_code in (status.HTTP_200_OK, status.HTTP_204_NO_CONTENT)
|
||||||
|
|
||||||
|
# 5. Verify token is invalid
|
||||||
|
me_again_resp = await ac.get("/users/me", headers=headers)
|
||||||
|
assert me_again_resp.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_e2e_transaction_workflow(fastapi_app, test_user):
|
||||||
|
transport = ASGITransport(app=fastapi_app, raise_app_exceptions=True)
|
||||||
|
async with AsyncClient(transport=transport, base_url="http://testserver") as ac:
|
||||||
|
# 1. Log in to get the token
|
||||||
|
login_resp = await ac.post("/auth/jwt/login", data=test_user)
|
||||||
|
token = login_resp.json()["access_token"]
|
||||||
|
headers = {"Authorization": f"Bearer {token}"}
|
||||||
|
|
||||||
|
# NEW STEP: Create a category first to get a valid ID
|
||||||
|
category_payload = {"name": "Test Category for E2E"}
|
||||||
|
create_category_resp = await ac.post("/categories/create", json=category_payload, headers=headers)
|
||||||
|
assert create_category_resp.status_code == status.HTTP_201_CREATED
|
||||||
|
category_id = create_category_resp.json()["id"]
|
||||||
|
|
||||||
|
# 2. Create a new transaction
|
||||||
|
tx_payload = {"amount": -55.40, "description": "Milk and eggs"}
|
||||||
|
tx_resp = await ac.post("/transactions/create", json=tx_payload, headers=headers)
|
||||||
|
assert tx_resp.status_code == status.HTTP_201_CREATED
|
||||||
|
tx_id = tx_resp.json()["id"]
|
||||||
|
|
||||||
|
# 3. Assign the category
|
||||||
|
assign_resp = await ac.post(f"/transactions/{tx_id}/categories/{category_id}", headers=headers)
|
||||||
|
assert assign_resp.status_code == status.HTTP_200_OK
|
||||||
|
|
||||||
|
# 4. Verify assignment
|
||||||
|
get_tx_resp = await ac.get(f"/transactions/{tx_id}", headers=headers)
|
||||||
|
assert category_id in get_tx_resp.json()["category_ids"]
|
||||||
|
|
||||||
|
# 5. Unassign the category
|
||||||
|
unassign_resp = await ac.delete(f"/transactions/{tx_id}/categories/{category_id}", headers=headers)
|
||||||
|
assert unassign_resp.status_code == status.HTTP_200_OK
|
||||||
|
|
||||||
|
# 6. Get the transaction again and verify the category is gone
|
||||||
|
get_tx_again_resp = await ac.get(f"/transactions/{tx_id}", headers=headers)
|
||||||
|
final_tx_data = get_tx_again_resp.json()
|
||||||
|
assert category_id not in final_tx_data["category_ids"]
|
||||||
|
|
||||||
|
# 7. Delete the transaction for cleanup
|
||||||
|
delete_resp = await ac.delete(f"/transactions/{tx_id}/delete", headers=headers)
|
||||||
|
assert delete_resp.status_code in (status.HTTP_200_OK, status.HTTP_204_NO_CONTENT)
|
||||||
|
|
||||||
|
# NEW STEP: Clean up the created category
|
||||||
|
delete_category_resp = await ac.delete(f"/categories/{category_id}", headers=headers)
|
||||||
|
assert delete_category_resp.status_code in (status.HTTP_200_OK, status.HTTP_204_NO_CONTENT)
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
from fastapi import status
|
from fastapi import status
|
||||||
import pytest
|
import pytest
|
||||||
|
from httpx import AsyncClient, ASGITransport
|
||||||
|
|
||||||
|
|
||||||
def test_root_ok(client):
|
def test_root_ok(client):
|
||||||
@@ -16,3 +17,55 @@ def test_authenticated_route_requires_auth(client):
|
|||||||
def test_sentry_debug_raises_exception(client):
|
def test_sentry_debug_raises_exception(client):
|
||||||
with pytest.raises(ZeroDivisionError):
|
with pytest.raises(ZeroDivisionError):
|
||||||
client.get("/sentry-debug")
|
client.get("/sentry-debug")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_and_get_category(fastapi_app, test_user):
|
||||||
|
# Use AsyncClient for async tests
|
||||||
|
transport = ASGITransport(app=fastapi_app)
|
||||||
|
async with AsyncClient(transport=transport, base_url="http://testserver") as ac:
|
||||||
|
# 1. Log in to get an auth token
|
||||||
|
login_resp = await ac.post("/auth/jwt/login", data=test_user)
|
||||||
|
token = login_resp.json()["access_token"]
|
||||||
|
headers = {"Authorization": f"Bearer {token}"}
|
||||||
|
|
||||||
|
# 2. Define and create the new category
|
||||||
|
category_name = "Async Integration Test"
|
||||||
|
category_payload = {"name": category_name}
|
||||||
|
create_resp = await ac.post("/categories/create", json=category_payload, headers=headers)
|
||||||
|
|
||||||
|
# 3. Assert creation was successful
|
||||||
|
assert create_resp.status_code == status.HTTP_201_CREATED
|
||||||
|
created_data = create_resp.json()
|
||||||
|
category_id = created_data["id"]
|
||||||
|
assert created_data["name"] == category_name
|
||||||
|
|
||||||
|
# 4. GET the list of categories to verify
|
||||||
|
list_resp = await ac.get("/categories/", headers=headers)
|
||||||
|
assert list_resp.status_code == status.HTTP_200_OK
|
||||||
|
|
||||||
|
# 5. Check that our new category is in the list
|
||||||
|
categories_list = list_resp.json()
|
||||||
|
assert any(cat["name"] == category_name for cat in categories_list)
|
||||||
|
|
||||||
|
delete_resp = await ac.delete(f"/categories/{category_id}", headers=headers)
|
||||||
|
assert delete_resp.status_code in (status.HTTP_200_OK, status.HTTP_204_NO_CONTENT)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_transaction_missing_amount_fails(fastapi_app, test_user):
|
||||||
|
transport = ASGITransport(app=fastapi_app)
|
||||||
|
async with AsyncClient(transport=transport, base_url="http://testserver") as ac:
|
||||||
|
# 1. Log in to get an auth token
|
||||||
|
login_resp = await ac.post("/auth/jwt/login", data=test_user)
|
||||||
|
token = login_resp.json()["access_token"]
|
||||||
|
headers = {"Authorization": f"Bearer {token}"}
|
||||||
|
|
||||||
|
# 2. Define an invalid payload
|
||||||
|
invalid_payload = {"description": "This should fail"}
|
||||||
|
|
||||||
|
# 3. Attempt to create the transaction
|
||||||
|
resp = await ac.post("/transactions/create", json=invalid_payload, headers=headers)
|
||||||
|
|
||||||
|
# 4. Assert the expected validation error
|
||||||
|
assert resp.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||||
|
|||||||
Reference in New Issue
Block a user