170 lines
5.1 KiB
Python
170 lines
5.1 KiB
Python
import logging
|
|
from dataclasses import dataclass
|
|
from uuid import UUID
|
|
|
|
import jwt
|
|
from fastapi import Depends, HTTPException, Request, status
|
|
from jwt import PyJWKClient, PyJWKClientError, PyJWKSetError
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.core.config import settings
|
|
from app.core.database import get_session
|
|
from app.models.nuzlocke_run import NuzlockeRun
|
|
from app.models.user import User
|
|
|
|
logger = logging.getLogger(__name__)
|
|
_jwks_client: PyJWKClient | None = None
|
|
|
|
|
|
@dataclass
|
|
class AuthUser:
|
|
"""Authenticated user info extracted from JWT."""
|
|
|
|
id: str # Supabase user UUID
|
|
email: str | None = None
|
|
role: str | None = None
|
|
|
|
|
|
def _build_jwks_url(base_url: str) -> str:
|
|
"""Build the JWKS URL, adding /auth/v1 prefix for Supabase Cloud."""
|
|
base = base_url.rstrip("/")
|
|
if "/auth/v1" in base:
|
|
return f"{base}/.well-known/jwks.json"
|
|
# Supabase Cloud URLs need the /auth/v1 prefix;
|
|
# local GoTrue serves JWKS at root but uses HS256 fallback anyway.
|
|
return f"{base}/auth/v1/.well-known/jwks.json"
|
|
|
|
|
|
def _get_jwks_client() -> PyJWKClient | None:
|
|
"""Get or create a cached JWKS client."""
|
|
global _jwks_client
|
|
if _jwks_client is None and settings.supabase_url:
|
|
jwks_url = _build_jwks_url(settings.supabase_url)
|
|
_jwks_client = PyJWKClient(jwks_url, cache_jwk_set=True, lifespan=300)
|
|
return _jwks_client
|
|
|
|
|
|
def _extract_token(request: Request) -> str | None:
|
|
"""Extract Bearer token from Authorization header."""
|
|
auth_header = request.headers.get("Authorization")
|
|
if not auth_header:
|
|
return None
|
|
parts = auth_header.split()
|
|
if len(parts) != 2 or parts[0].lower() != "bearer":
|
|
return None
|
|
return parts[1]
|
|
|
|
|
|
def _verify_jwt_hs256(token: str) -> dict | None:
|
|
"""Verify JWT using HS256 shared secret. Returns payload or None."""
|
|
if not settings.supabase_jwt_secret:
|
|
return None
|
|
try:
|
|
return jwt.decode(
|
|
token,
|
|
settings.supabase_jwt_secret,
|
|
algorithms=["HS256"],
|
|
audience="authenticated",
|
|
)
|
|
except jwt.InvalidTokenError:
|
|
return None
|
|
|
|
|
|
def _verify_jwt(token: str) -> dict | None:
|
|
"""Verify JWT using JWKS (RS256/ES256), falling back to HS256 shared secret."""
|
|
client = _get_jwks_client()
|
|
if client:
|
|
try:
|
|
signing_key = client.get_signing_key_from_jwt(token)
|
|
return jwt.decode(
|
|
token,
|
|
signing_key.key,
|
|
algorithms=["RS256", "ES256"],
|
|
audience="authenticated",
|
|
)
|
|
except jwt.InvalidTokenError as e:
|
|
logger.warning("JWKS JWT validation failed: %s", e)
|
|
except PyJWKClientError as e:
|
|
logger.warning("JWKS client error: %s", e)
|
|
except PyJWKSetError as e:
|
|
logger.warning("JWKS set error: %s", e)
|
|
else:
|
|
logger.debug("No JWKS client available (SUPABASE_URL not set?)")
|
|
return _verify_jwt_hs256(token)
|
|
|
|
|
|
def get_current_user(request: Request) -> AuthUser | None:
|
|
"""
|
|
Extract and verify the current user from the request.
|
|
Returns AuthUser if valid token, None otherwise.
|
|
"""
|
|
token = _extract_token(request)
|
|
if not token:
|
|
return None
|
|
|
|
payload = _verify_jwt(token)
|
|
if not payload:
|
|
return None
|
|
|
|
# Supabase JWT has 'sub' as user ID
|
|
user_id = payload.get("sub")
|
|
if not user_id:
|
|
return None
|
|
|
|
return AuthUser(
|
|
id=user_id,
|
|
email=payload.get("email"),
|
|
role=payload.get("role"),
|
|
)
|
|
|
|
|
|
def require_auth(user: AuthUser | None = Depends(get_current_user)) -> AuthUser:
|
|
"""
|
|
Dependency that requires authentication.
|
|
Raises 401 if no valid token is present.
|
|
"""
|
|
if user is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Authentication required",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
return user
|
|
|
|
|
|
async def require_admin(
|
|
user: AuthUser = Depends(require_auth),
|
|
session: AsyncSession = Depends(get_session),
|
|
) -> AuthUser:
|
|
"""
|
|
Dependency that requires admin privileges.
|
|
Raises 401 if not authenticated, 403 if not an admin.
|
|
"""
|
|
result = await session.execute(select(User).where(User.id == UUID(user.id)))
|
|
db_user = result.scalar_one_or_none()
|
|
|
|
if db_user is None or not db_user.is_admin:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="Admin access required",
|
|
)
|
|
return user
|
|
|
|
|
|
def require_run_owner(run: NuzlockeRun, user: AuthUser) -> None:
|
|
"""
|
|
Verify user owns the run. Raises 403 if not owner.
|
|
Unowned (legacy) runs are read-only and reject all mutations.
|
|
"""
|
|
if run.owner_id is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="This run has no owner and cannot be modified",
|
|
)
|
|
if UUID(user.id) != run.owner_id:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="Only the run owner can perform this action",
|
|
)
|