diff --git a/7project/backend/app/app.py b/7project/backend/app/app.py index 56ef23c..c9e5ded 100644 --- a/7project/backend/app/app.py +++ b/7project/backend/app/app.py @@ -47,10 +47,10 @@ fastApi.include_router( fastApi.include_router( fastapi_users.get_oauth_router( - app.services.user_service.mojeid_oauth_service, + app.services.user_service.get_oauth_provider("MojeID"), auth_backend, "SECRET", - associate_by_email=True + associate_by_email=True, ), prefix="/auth/mojeid", tags=["auth"], diff --git a/7project/backend/app/oauth/moje_id.py b/7project/backend/app/oauth/moje_id.py index 0c118e3..b199f63 100644 --- a/7project/backend/app/oauth/moje_id.py +++ b/7project/backend/app/oauth/moje_id.py @@ -1,11 +1,10 @@ import json -from typing import Optional, Literal +from typing import Optional, Literal, Any from httpx_oauth.clients.openid import OpenID -from httpx_oauth.oauth2 import OAuth2Token, GetAccessTokenError, T +from httpx_oauth.oauth2 import T -# claims=%7B%22id_token%22%3A%7B%22birthdate%22%3A%7B%22essential%22%3Atrue%7D%2C%22name%22%3A%7B%22essential%22%3Atrue%7D%2C%22given_name%22%3A%7B%22essential%22%3Atrue%7D%2C%22family_name%22%3A%7B%22essential%22%3Atrue%7D%2C%22email%22%3A%7B%22essential%22%3Atrue%7D%2C%22address%22%3A%7B%22essential%22%3Afalse%7D%2C%22mojeid_valid%22%3A%7B%22essential%22%3Atrue%7D%7D%7D class MojeIDOAuth(OpenID): def __init__(self, client_id: str, client_secret: str): super().__init__( @@ -16,6 +15,14 @@ class MojeIDOAuth(OpenID): base_scopes=["openid", "email", "profile"], ) + async def get_user_info(self, token: str) -> Optional[Any]: + info = await self.get_profile(token) + + return { + "first_name": info.get("given_name"), + "last_name": info.get("family_name"), + } + async def get_authorization_url( self, redirect_uri: str, diff --git a/7project/backend/app/services/user_service.py b/7project/backend/app/services/user_service.py index dcaf0d9..feee712 100644 --- a/7project/backend/app/services/user_service.py +++ b/7project/backend/app/services/user_service.py @@ -3,7 +3,7 @@ import uuid from typing import Optional from fastapi import Depends, Request -from fastapi_users import BaseUserManager, FastAPIUsers, UUIDIDMixin +from fastapi_users import BaseUserManager, FastAPIUsers, UUIDIDMixin, models from fastapi_users.authentication import ( AuthenticationBackend, BearerTransport, @@ -20,15 +20,41 @@ SECRET = os.getenv("SECRET", "CHANGE_ME_SECRET") FRONTEND_URL = os.getenv("FRONTEND_URL", "http://localhost:5173") BACKEND_URL = os.getenv("BACKEND_URL", "http://localhost:8000") -mojeid_oauth_service = MojeIDOAuth( - os.getenv("MOJEID_CLIENT_ID", "CHANGE_ME_CLIENT_ID"), - os.getenv("MOJEID_CLIENT_SECRET", "CHANGE_ME_CLIENT_SECRET"), -) +providers = { + "MojeID": MojeIDOAuth( + os.getenv("MOJEID_CLIENT_ID", "CHANGE_ME_CLIENT_ID"), + os.getenv("MOJEID_CLIENT_SECRET", "CHANGE_ME_CLIENT_SECRET"), + ) +} + + +def get_oauth_provider(name: str) -> Optional[MojeIDOAuth]: + if name not in providers: + return None + return providers[name] + class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): reset_password_token_secret = SECRET verification_token_secret = SECRET + async def oauth_callback(self: "BaseUserManager[models.UOAP, models.ID]", oauth_name: str, access_token: str, + account_id: str, account_email: str, expires_at: Optional[int] = None, + refresh_token: Optional[str] = None, request: Optional[Request] = None, *, + associate_by_email: bool = False, is_verified_by_default: bool = False) -> models.UOAP: + + user = await super().oauth_callback(oauth_name, access_token, account_id, account_email, expires_at, + refresh_token, request, associate_by_email=associate_by_email, + is_verified_by_default=is_verified_by_default) + + # set additional user info from the OAuth provider + provider = get_oauth_provider(oauth_name) + if provider is not None and hasattr(provider, "get_user_info"): + update_dict = await provider.get_user_info(access_token) + await self.user_db.update(user, update_dict) + + return user + async def on_after_register(self, user: User, request: Optional[Request] = None): await self.request_verify(user, request) @@ -58,14 +84,18 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): print("[Email Fallback] Subject:", subject) print("[Email Fallback] Body:\n", body) + async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db)): yield UserManager(user_db) + bearer_transport = BearerTransport(tokenUrl="auth/jwt/login") + def get_jwt_strategy() -> JWTStrategy: return JWTStrategy(secret=SECRET, lifetime_seconds=3600) + auth_backend = AuthenticationBackend( name="jwt", transport=bearer_transport, @@ -76,4 +106,3 @@ fastapi_users = FastAPIUsers[User, uuid.UUID](get_user_manager, [auth_backend]) current_active_user = fastapi_users.current_user(active=True) current_active_verified_user = fastapi_users.current_user(active=True, verified=True) -