Supabase JWT key was switched to ECC P-256, but the JWKS verification only accepted RS256. Add ES256 to the accepted algorithms list. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
372 lines
12 KiB
Python
372 lines
12 KiB
Python
import time
|
|
from unittest.mock import MagicMock, patch
|
|
from uuid import UUID
|
|
|
|
import jwt
|
|
import pytest
|
|
from cryptography.hazmat.primitives.asymmetric import ec, rsa
|
|
from httpx import ASGITransport, AsyncClient
|
|
|
|
from app.core.auth import AuthUser, get_current_user, require_admin, require_auth
|
|
from app.main import app
|
|
from app.models.user import User
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def rsa_key_pair():
|
|
"""Generate RSA key pair for testing."""
|
|
private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
|
|
public_key = private_key.public_key()
|
|
return private_key, public_key
|
|
|
|
|
|
@pytest.fixture
|
|
def valid_token(rsa_key_pair):
|
|
"""Generate a valid RS256 JWT token."""
|
|
private_key, _ = rsa_key_pair
|
|
payload = {
|
|
"sub": "user-123",
|
|
"email": "test@example.com",
|
|
"role": "authenticated",
|
|
"aud": "authenticated",
|
|
"exp": int(time.time()) + 3600,
|
|
}
|
|
return jwt.encode(payload, private_key, algorithm="RS256")
|
|
|
|
|
|
@pytest.fixture
|
|
def expired_token(rsa_key_pair):
|
|
"""Generate an expired RS256 JWT token."""
|
|
private_key, _ = rsa_key_pair
|
|
payload = {
|
|
"sub": "user-123",
|
|
"email": "test@example.com",
|
|
"role": "authenticated",
|
|
"aud": "authenticated",
|
|
"exp": int(time.time()) - 3600, # Expired 1 hour ago
|
|
}
|
|
return jwt.encode(payload, private_key, algorithm="RS256")
|
|
|
|
|
|
@pytest.fixture
|
|
def invalid_token():
|
|
"""Generate a token signed with wrong key."""
|
|
wrong_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
|
|
payload = {
|
|
"sub": "user-123",
|
|
"email": "test@example.com",
|
|
"role": "authenticated",
|
|
"aud": "authenticated",
|
|
"exp": int(time.time()) + 3600,
|
|
}
|
|
return jwt.encode(payload, wrong_key, algorithm="RS256")
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_jwks_client(rsa_key_pair):
|
|
"""Create a mock JWKS client that returns our test public key."""
|
|
_, public_key = rsa_key_pair
|
|
mock_client = MagicMock()
|
|
mock_signing_key = MagicMock()
|
|
mock_signing_key.key = public_key
|
|
mock_client.get_signing_key_from_jwt.return_value = mock_signing_key
|
|
return mock_client
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def ec_key_pair():
|
|
"""Generate EC P-256 key pair for testing."""
|
|
private_key = ec.generate_private_key(ec.SECP256R1())
|
|
public_key = private_key.public_key()
|
|
return private_key, public_key
|
|
|
|
|
|
@pytest.fixture
|
|
def valid_es256_token(ec_key_pair):
|
|
"""Generate a valid ES256 JWT token."""
|
|
private_key, _ = ec_key_pair
|
|
payload = {
|
|
"sub": "user-456",
|
|
"email": "ec-user@example.com",
|
|
"role": "authenticated",
|
|
"aud": "authenticated",
|
|
"exp": int(time.time()) + 3600,
|
|
}
|
|
return jwt.encode(payload, private_key, algorithm="ES256")
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_jwks_client_ec(ec_key_pair):
|
|
"""Create a mock JWKS client that returns our test EC public key."""
|
|
_, public_key = ec_key_pair
|
|
mock_client = MagicMock()
|
|
mock_signing_key = MagicMock()
|
|
mock_signing_key.key = public_key
|
|
mock_client.get_signing_key_from_jwt.return_value = mock_signing_key
|
|
return mock_client
|
|
|
|
|
|
async def test_get_current_user_valid_es256_token(
|
|
valid_es256_token, mock_jwks_client_ec
|
|
):
|
|
"""Test get_current_user works with ES256 (ECC P-256) tokens."""
|
|
with patch("app.core.auth._get_jwks_client", return_value=mock_jwks_client_ec):
|
|
|
|
class MockRequest:
|
|
headers = {"Authorization": f"Bearer {valid_es256_token}"}
|
|
|
|
user = get_current_user(MockRequest())
|
|
assert user is not None
|
|
assert user.id == "user-456"
|
|
assert user.email == "ec-user@example.com"
|
|
assert user.role == "authenticated"
|
|
|
|
|
|
async def test_get_current_user_valid_token(valid_token, mock_jwks_client):
|
|
"""Test get_current_user returns user for valid token."""
|
|
with patch("app.core.auth._get_jwks_client", return_value=mock_jwks_client):
|
|
|
|
class MockRequest:
|
|
headers = {"Authorization": f"Bearer {valid_token}"}
|
|
|
|
user = get_current_user(MockRequest())
|
|
assert user is not None
|
|
assert user.id == "user-123"
|
|
assert user.email == "test@example.com"
|
|
assert user.role == "authenticated"
|
|
|
|
|
|
async def test_get_current_user_no_token(mock_jwks_client):
|
|
"""Test get_current_user returns None when no token."""
|
|
with patch("app.core.auth._get_jwks_client", return_value=mock_jwks_client):
|
|
|
|
class MockRequest:
|
|
headers = {}
|
|
|
|
user = get_current_user(MockRequest())
|
|
assert user is None
|
|
|
|
|
|
async def test_get_current_user_expired_token(expired_token, mock_jwks_client):
|
|
"""Test get_current_user returns None for expired token."""
|
|
with patch("app.core.auth._get_jwks_client", return_value=mock_jwks_client):
|
|
|
|
class MockRequest:
|
|
headers = {"Authorization": f"Bearer {expired_token}"}
|
|
|
|
user = get_current_user(MockRequest())
|
|
assert user is None
|
|
|
|
|
|
async def test_get_current_user_invalid_token(invalid_token, mock_jwks_client):
|
|
"""Test get_current_user returns None for invalid token."""
|
|
with patch("app.core.auth._get_jwks_client", return_value=mock_jwks_client):
|
|
|
|
class MockRequest:
|
|
headers = {"Authorization": f"Bearer {invalid_token}"}
|
|
|
|
user = get_current_user(MockRequest())
|
|
assert user is None
|
|
|
|
|
|
async def test_get_current_user_malformed_header(mock_jwks_client):
|
|
"""Test get_current_user returns None for malformed auth header."""
|
|
with patch("app.core.auth._get_jwks_client", return_value=mock_jwks_client):
|
|
|
|
class MockRequest:
|
|
headers = {"Authorization": "NotBearer token"}
|
|
|
|
user = get_current_user(MockRequest())
|
|
assert user is None
|
|
|
|
|
|
async def test_require_auth_valid_user():
|
|
"""Test require_auth passes through valid user."""
|
|
user = AuthUser(id="user-123", email="test@example.com")
|
|
result = require_auth(user)
|
|
assert result is user
|
|
|
|
|
|
async def test_require_auth_no_user():
|
|
"""Test require_auth raises 401 for no user."""
|
|
from fastapi import HTTPException
|
|
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
require_auth(None)
|
|
assert exc_info.value.status_code == 401
|
|
assert exc_info.value.detail == "Authentication required"
|
|
|
|
|
|
async def test_protected_endpoint_without_token(db_session):
|
|
"""Test that write endpoint returns 401 without token."""
|
|
async with AsyncClient(
|
|
transport=ASGITransport(app=app), base_url="http://test"
|
|
) as ac:
|
|
response = await ac.post("/runs", json={"game_id": 1, "name": "Test Run"})
|
|
assert response.status_code == 401
|
|
assert response.json()["detail"] == "Authentication required"
|
|
|
|
|
|
async def test_protected_endpoint_with_expired_token(
|
|
db_session, expired_token, mock_jwks_client
|
|
):
|
|
"""Test that write endpoint returns 401 with expired token."""
|
|
with patch("app.core.auth._get_jwks_client", return_value=mock_jwks_client):
|
|
async with AsyncClient(
|
|
transport=ASGITransport(app=app),
|
|
base_url="http://test",
|
|
headers={"Authorization": f"Bearer {expired_token}"},
|
|
) as ac:
|
|
response = await ac.post("/runs", json={"game_id": 1, "name": "Test Run"})
|
|
assert response.status_code == 401
|
|
|
|
|
|
async def test_read_endpoint_without_token(db_session):
|
|
"""Test that read endpoints work without authentication."""
|
|
async with AsyncClient(
|
|
transport=ASGITransport(app=app), base_url="http://test"
|
|
) as ac:
|
|
response = await ac.get("/runs")
|
|
assert response.status_code == 200
|
|
|
|
|
|
async def test_require_admin_valid_admin_user(db_session):
|
|
"""Test require_admin passes through for admin user."""
|
|
user_id = "11111111-1111-1111-1111-111111111111"
|
|
admin_user = User(
|
|
id=UUID(user_id),
|
|
email="admin@example.com",
|
|
is_admin=True,
|
|
)
|
|
db_session.add(admin_user)
|
|
await db_session.commit()
|
|
|
|
auth_user = AuthUser(id=user_id, email="admin@example.com")
|
|
result = await require_admin(user=auth_user, session=db_session)
|
|
assert result is auth_user
|
|
|
|
|
|
async def test_require_admin_non_admin_user(db_session):
|
|
"""Test require_admin raises 403 for non-admin user."""
|
|
from fastapi import HTTPException
|
|
|
|
user_id = "22222222-2222-2222-2222-222222222222"
|
|
regular_user = User(
|
|
id=UUID(user_id),
|
|
email="user@example.com",
|
|
is_admin=False,
|
|
)
|
|
db_session.add(regular_user)
|
|
await db_session.commit()
|
|
|
|
auth_user = AuthUser(id=user_id, email="user@example.com")
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
await require_admin(user=auth_user, session=db_session)
|
|
assert exc_info.value.status_code == 403
|
|
assert exc_info.value.detail == "Admin access required"
|
|
|
|
|
|
async def test_require_admin_user_not_in_db(db_session):
|
|
"""Test require_admin raises 403 for user not in database."""
|
|
from fastapi import HTTPException
|
|
|
|
auth_user = AuthUser(
|
|
id="33333333-3333-3333-3333-333333333333", email="ghost@example.com"
|
|
)
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
await require_admin(user=auth_user, session=db_session)
|
|
assert exc_info.value.status_code == 403
|
|
assert exc_info.value.detail == "Admin access required"
|
|
|
|
|
|
async def test_admin_endpoint_returns_403_for_non_admin(
|
|
db_session, rsa_key_pair, mock_jwks_client
|
|
):
|
|
"""Test that admin endpoint returns 403 for authenticated non-admin user."""
|
|
user_id = "44444444-4444-4444-4444-444444444444"
|
|
regular_user = User(
|
|
id=UUID(user_id),
|
|
email="nonadmin@example.com",
|
|
is_admin=False,
|
|
)
|
|
db_session.add(regular_user)
|
|
await db_session.commit()
|
|
|
|
private_key, _ = rsa_key_pair
|
|
token = jwt.encode(
|
|
{
|
|
"sub": user_id,
|
|
"email": "nonadmin@example.com",
|
|
"role": "authenticated",
|
|
"aud": "authenticated",
|
|
"exp": int(time.time()) + 3600,
|
|
},
|
|
private_key,
|
|
algorithm="RS256",
|
|
)
|
|
|
|
with patch("app.core.auth._get_jwks_client", return_value=mock_jwks_client):
|
|
async with AsyncClient(
|
|
transport=ASGITransport(app=app),
|
|
base_url="http://test",
|
|
headers={"Authorization": f"Bearer {token}"},
|
|
) as ac:
|
|
response = await ac.post(
|
|
"/games",
|
|
json={
|
|
"name": "Test Game",
|
|
"slug": "test-game",
|
|
"generation": 1,
|
|
"region": "Kanto",
|
|
"category": "core",
|
|
},
|
|
)
|
|
assert response.status_code == 403
|
|
assert response.json()["detail"] == "Admin access required"
|
|
|
|
|
|
async def test_admin_endpoint_succeeds_for_admin(
|
|
db_session, rsa_key_pair, mock_jwks_client
|
|
):
|
|
"""Test that admin endpoint succeeds for authenticated admin user."""
|
|
user_id = "55555555-5555-5555-5555-555555555555"
|
|
admin_user = User(
|
|
id=UUID(user_id),
|
|
email="admin@example.com",
|
|
is_admin=True,
|
|
)
|
|
db_session.add(admin_user)
|
|
await db_session.commit()
|
|
|
|
private_key, _ = rsa_key_pair
|
|
token = jwt.encode(
|
|
{
|
|
"sub": user_id,
|
|
"email": "admin@example.com",
|
|
"role": "authenticated",
|
|
"aud": "authenticated",
|
|
"exp": int(time.time()) + 3600,
|
|
},
|
|
private_key,
|
|
algorithm="RS256",
|
|
)
|
|
|
|
with patch("app.core.auth._get_jwks_client", return_value=mock_jwks_client):
|
|
async with AsyncClient(
|
|
transport=ASGITransport(app=app),
|
|
base_url="http://test",
|
|
headers={"Authorization": f"Bearer {token}"},
|
|
) as ac:
|
|
response = await ac.post(
|
|
"/games",
|
|
json={
|
|
"name": "Test Game",
|
|
"slug": "test-game",
|
|
"generation": 1,
|
|
"region": "Kanto",
|
|
"category": "core",
|
|
},
|
|
)
|
|
assert response.status_code == 201
|
|
assert response.json()["name"] == "Test Game"
|