feat: add require_admin dependency and protect admin endpoints
Add require_admin FastAPI dependency that checks is_admin column on users table. Apply it to all admin-facing write endpoints (games, pokemon, evolutions, bosses, routes CRUD). Run-scoped endpoints remain protected by require_auth only since they manage user's own data. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -5,7 +5,7 @@ from sqlalchemy import or_, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.core.auth import AuthUser, require_auth
|
||||
from app.core.auth import AuthUser, require_admin, require_auth
|
||||
from app.core.database import get_session
|
||||
from app.models.boss_battle import BossBattle
|
||||
from app.models.boss_pokemon import BossPokemon
|
||||
@@ -86,7 +86,7 @@ async def reorder_bosses(
|
||||
game_id: int,
|
||||
data: BossReorderRequest,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
_user: AuthUser = Depends(require_auth),
|
||||
_user: AuthUser = Depends(require_admin),
|
||||
):
|
||||
vg_id = await _get_version_group_id(session, game_id)
|
||||
|
||||
@@ -130,7 +130,7 @@ async def create_boss(
|
||||
game_id: int,
|
||||
data: BossBattleCreate,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
_user: AuthUser = Depends(require_auth),
|
||||
_user: AuthUser = Depends(require_admin),
|
||||
):
|
||||
vg_id = await _get_version_group_id(session, game_id)
|
||||
|
||||
@@ -161,7 +161,7 @@ async def update_boss(
|
||||
boss_id: int,
|
||||
data: BossBattleUpdate,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
_user: AuthUser = Depends(require_auth),
|
||||
_user: AuthUser = Depends(require_admin),
|
||||
):
|
||||
vg_id = await _get_version_group_id(session, game_id)
|
||||
|
||||
@@ -202,7 +202,7 @@ async def delete_boss(
|
||||
game_id: int,
|
||||
boss_id: int,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
_user: AuthUser = Depends(require_auth),
|
||||
_user: AuthUser = Depends(require_admin),
|
||||
):
|
||||
vg_id = await _get_version_group_id(session, game_id)
|
||||
|
||||
@@ -225,7 +225,7 @@ async def bulk_import_bosses(
|
||||
game_id: int,
|
||||
items: list[BulkBossItem],
|
||||
session: AsyncSession = Depends(get_session),
|
||||
_user: AuthUser = Depends(require_auth),
|
||||
_user: AuthUser = Depends(require_admin),
|
||||
):
|
||||
vg_id = await _get_version_group_id(session, game_id)
|
||||
|
||||
@@ -268,7 +268,7 @@ async def set_boss_team(
|
||||
boss_id: int,
|
||||
team: list[BossPokemonInput],
|
||||
session: AsyncSession = Depends(get_session),
|
||||
_user: AuthUser = Depends(require_auth),
|
||||
_user: AuthUser = Depends(require_admin),
|
||||
):
|
||||
vg_id = await _get_version_group_id(session, game_id)
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ from sqlalchemy import func, or_, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
from app.core.auth import AuthUser, require_admin
|
||||
from app.core.database import get_session
|
||||
from app.models.evolution import Evolution
|
||||
from app.models.pokemon import Pokemon
|
||||
@@ -89,7 +90,9 @@ async def list_evolutions(
|
||||
|
||||
@router.post("/evolutions", response_model=EvolutionAdminResponse, status_code=201)
|
||||
async def create_evolution(
|
||||
data: EvolutionCreate, session: AsyncSession = Depends(get_session)
|
||||
data: EvolutionCreate,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
_user: AuthUser = Depends(require_admin),
|
||||
):
|
||||
from_pokemon = await session.get(Pokemon, data.from_pokemon_id)
|
||||
if from_pokemon is None:
|
||||
@@ -117,6 +120,7 @@ async def update_evolution(
|
||||
evolution_id: int,
|
||||
data: EvolutionUpdate,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
_user: AuthUser = Depends(require_admin),
|
||||
):
|
||||
evolution = await session.get(Evolution, evolution_id)
|
||||
if evolution is None:
|
||||
@@ -150,7 +154,9 @@ async def update_evolution(
|
||||
|
||||
@router.delete("/evolutions/{evolution_id}", status_code=204)
|
||||
async def delete_evolution(
|
||||
evolution_id: int, session: AsyncSession = Depends(get_session)
|
||||
evolution_id: int,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
_user: AuthUser = Depends(require_admin),
|
||||
):
|
||||
evolution = await session.get(Evolution, evolution_id)
|
||||
if evolution is None:
|
||||
@@ -164,6 +170,7 @@ async def delete_evolution(
|
||||
async def bulk_import_evolutions(
|
||||
items: list[BulkEvolutionItem],
|
||||
session: AsyncSession = Depends(get_session),
|
||||
_user: AuthUser = Depends(require_admin),
|
||||
):
|
||||
# Build pokeapi_id -> id mapping
|
||||
result = await session.execute(select(Pokemon.pokeapi_id, Pokemon.id))
|
||||
|
||||
@@ -6,7 +6,7 @@ from sqlalchemy import delete, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.core.auth import AuthUser, require_auth
|
||||
from app.core.auth import AuthUser, require_admin
|
||||
from app.core.database import get_session
|
||||
from app.models.boss_battle import BossBattle
|
||||
from app.models.game import Game
|
||||
@@ -232,7 +232,7 @@ async def list_game_routes(
|
||||
async def create_game(
|
||||
data: GameCreate,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
_user: AuthUser = Depends(require_auth),
|
||||
_user: AuthUser = Depends(require_admin),
|
||||
):
|
||||
existing = await session.execute(select(Game).where(Game.slug == data.slug))
|
||||
if existing.scalar_one_or_none() is not None:
|
||||
@@ -252,7 +252,7 @@ async def update_game(
|
||||
game_id: int,
|
||||
data: GameUpdate,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
_user: AuthUser = Depends(require_auth),
|
||||
_user: AuthUser = Depends(require_admin),
|
||||
):
|
||||
game = await session.get(Game, game_id)
|
||||
if game is None:
|
||||
@@ -280,7 +280,7 @@ async def update_game(
|
||||
async def delete_game(
|
||||
game_id: int,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
_user: AuthUser = Depends(require_auth),
|
||||
_user: AuthUser = Depends(require_admin),
|
||||
):
|
||||
result = await session.execute(
|
||||
select(Game).where(Game.id == game_id).options(selectinload(Game.runs))
|
||||
@@ -338,7 +338,7 @@ async def create_route(
|
||||
game_id: int,
|
||||
data: RouteCreate,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
_user: AuthUser = Depends(require_auth),
|
||||
_user: AuthUser = Depends(require_admin),
|
||||
):
|
||||
vg_id = await _get_version_group_id(session, game_id)
|
||||
|
||||
@@ -354,7 +354,7 @@ async def reorder_routes(
|
||||
game_id: int,
|
||||
data: RouteReorderRequest,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
_user: AuthUser = Depends(require_auth),
|
||||
_user: AuthUser = Depends(require_admin),
|
||||
):
|
||||
vg_id = await _get_version_group_id(session, game_id)
|
||||
|
||||
@@ -381,7 +381,7 @@ async def update_route(
|
||||
route_id: int,
|
||||
data: RouteUpdate,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
_user: AuthUser = Depends(require_auth),
|
||||
_user: AuthUser = Depends(require_admin),
|
||||
):
|
||||
vg_id = await _get_version_group_id(session, game_id)
|
||||
|
||||
@@ -402,7 +402,7 @@ async def delete_route(
|
||||
game_id: int,
|
||||
route_id: int,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
_user: AuthUser = Depends(require_auth),
|
||||
_user: AuthUser = Depends(require_admin),
|
||||
):
|
||||
vg_id = await _get_version_group_id(session, game_id)
|
||||
|
||||
@@ -437,7 +437,7 @@ async def bulk_import_routes(
|
||||
game_id: int,
|
||||
items: list[BulkRouteItem],
|
||||
session: AsyncSession = Depends(get_session),
|
||||
_user: AuthUser = Depends(require_auth),
|
||||
_user: AuthUser = Depends(require_admin),
|
||||
):
|
||||
vg_id = await _get_version_group_id(session, game_id)
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ from sqlalchemy import func, or_, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import joinedload, selectinload
|
||||
|
||||
from app.core.auth import AuthUser, require_admin
|
||||
from app.core.database import get_session
|
||||
from app.models.evolution import Evolution
|
||||
from app.models.pokemon import Pokemon
|
||||
@@ -68,6 +69,7 @@ async def list_pokemon(
|
||||
async def bulk_import_pokemon(
|
||||
items: list[BulkImportItem],
|
||||
session: AsyncSession = Depends(get_session),
|
||||
_user: AuthUser = Depends(require_admin),
|
||||
):
|
||||
created = 0
|
||||
updated = 0
|
||||
@@ -100,7 +102,9 @@ async def bulk_import_pokemon(
|
||||
|
||||
@router.post("/pokemon", response_model=PokemonResponse, status_code=201)
|
||||
async def create_pokemon(
|
||||
data: PokemonCreate, session: AsyncSession = Depends(get_session)
|
||||
data: PokemonCreate,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
_user: AuthUser = Depends(require_admin),
|
||||
):
|
||||
existing = await session.execute(
|
||||
select(Pokemon).where(Pokemon.pokeapi_id == data.pokeapi_id)
|
||||
@@ -321,6 +325,7 @@ async def update_pokemon(
|
||||
pokemon_id: int,
|
||||
data: PokemonUpdate,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
_user: AuthUser = Depends(require_admin),
|
||||
):
|
||||
pokemon = await session.get(Pokemon, pokemon_id)
|
||||
if pokemon is None:
|
||||
@@ -349,7 +354,11 @@ async def update_pokemon(
|
||||
|
||||
|
||||
@router.delete("/pokemon/{pokemon_id}", status_code=204)
|
||||
async def delete_pokemon(pokemon_id: int, session: AsyncSession = Depends(get_session)):
|
||||
async def delete_pokemon(
|
||||
pokemon_id: int,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
_user: AuthUser = Depends(require_admin),
|
||||
):
|
||||
result = await session.execute(
|
||||
select(Pokemon)
|
||||
.where(Pokemon.id == pokemon_id)
|
||||
@@ -405,6 +414,7 @@ async def add_route_encounter(
|
||||
route_id: int,
|
||||
data: RouteEncounterCreate,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
_user: AuthUser = Depends(require_admin),
|
||||
):
|
||||
route = await session.get(Route, route_id)
|
||||
if route is None:
|
||||
@@ -436,6 +446,7 @@ async def update_route_encounter(
|
||||
encounter_id: int,
|
||||
data: RouteEncounterUpdate,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
_user: AuthUser = Depends(require_admin),
|
||||
):
|
||||
result = await session.execute(
|
||||
select(RouteEncounter)
|
||||
@@ -466,6 +477,7 @@ async def remove_route_encounter(
|
||||
route_id: int,
|
||||
encounter_id: int,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
_user: AuthUser = Depends(require_admin),
|
||||
):
|
||||
encounter = await session.execute(
|
||||
select(RouteEncounter).where(
|
||||
|
||||
@@ -1,9 +1,14 @@
|
||||
from dataclasses import dataclass
|
||||
from uuid import UUID
|
||||
|
||||
import jwt
|
||||
from fastapi import Depends, HTTPException, Request, status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.database import get_session
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -81,3 +86,22 @@ def require_auth(user: AuthUser | None = Depends(get_current_user)) -> AuthUser:
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
return user
|
||||
|
||||
|
||||
async def require_admin(
|
||||
user: AuthUser = Depends(require_auth),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> AuthUser:
|
||||
"""
|
||||
Dependency that requires admin privileges.
|
||||
Raises 401 if not authenticated, 403 if not an admin.
|
||||
"""
|
||||
result = await session.execute(select(User).where(User.id == UUID(user.id)))
|
||||
db_user = result.scalar_one_or_none()
|
||||
|
||||
if db_user is None or not db_user.is_admin:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Admin access required",
|
||||
)
|
||||
return user
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
import time
|
||||
from uuid import UUID
|
||||
|
||||
import jwt
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from app.core.auth import AuthUser, get_current_user, require_auth
|
||||
from app.core.auth import AuthUser, get_current_user, require_admin, require_auth
|
||||
from app.core.config import settings
|
||||
from app.main import app
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -177,3 +179,140 @@ async def test_read_endpoint_without_token(db_session):
|
||||
) as ac:
|
||||
response = await ac.get("/runs")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
async def test_require_admin_valid_admin_user(db_session):
|
||||
"""Test require_admin passes through for admin user."""
|
||||
user_id = "11111111-1111-1111-1111-111111111111"
|
||||
admin_user = User(
|
||||
id=UUID(user_id),
|
||||
email="admin@example.com",
|
||||
is_admin=True,
|
||||
)
|
||||
db_session.add(admin_user)
|
||||
await db_session.commit()
|
||||
|
||||
auth_user = AuthUser(id=user_id, email="admin@example.com")
|
||||
result = await require_admin(user=auth_user, session=db_session)
|
||||
assert result is auth_user
|
||||
|
||||
|
||||
async def test_require_admin_non_admin_user(db_session):
|
||||
"""Test require_admin raises 403 for non-admin user."""
|
||||
from fastapi import HTTPException
|
||||
|
||||
user_id = "22222222-2222-2222-2222-222222222222"
|
||||
regular_user = User(
|
||||
id=UUID(user_id),
|
||||
email="user@example.com",
|
||||
is_admin=False,
|
||||
)
|
||||
db_session.add(regular_user)
|
||||
await db_session.commit()
|
||||
|
||||
auth_user = AuthUser(id=user_id, email="user@example.com")
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await require_admin(user=auth_user, session=db_session)
|
||||
assert exc_info.value.status_code == 403
|
||||
assert exc_info.value.detail == "Admin access required"
|
||||
|
||||
|
||||
async def test_require_admin_user_not_in_db(db_session):
|
||||
"""Test require_admin raises 403 for user not in database."""
|
||||
from fastapi import HTTPException
|
||||
|
||||
auth_user = AuthUser(
|
||||
id="33333333-3333-3333-3333-333333333333", email="ghost@example.com"
|
||||
)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await require_admin(user=auth_user, session=db_session)
|
||||
assert exc_info.value.status_code == 403
|
||||
assert exc_info.value.detail == "Admin access required"
|
||||
|
||||
|
||||
async def test_admin_endpoint_returns_403_for_non_admin(
|
||||
db_session, jwt_secret, monkeypatch
|
||||
):
|
||||
"""Test that admin endpoint returns 403 for authenticated non-admin user."""
|
||||
user_id = "44444444-4444-4444-4444-444444444444"
|
||||
regular_user = User(
|
||||
id=UUID(user_id),
|
||||
email="nonadmin@example.com",
|
||||
is_admin=False,
|
||||
)
|
||||
db_session.add(regular_user)
|
||||
await db_session.commit()
|
||||
|
||||
monkeypatch.setattr(settings, "supabase_jwt_secret", jwt_secret)
|
||||
token = jwt.encode(
|
||||
{
|
||||
"sub": user_id,
|
||||
"email": "nonadmin@example.com",
|
||||
"role": "authenticated",
|
||||
"aud": "authenticated",
|
||||
"exp": int(time.time()) + 3600,
|
||||
},
|
||||
jwt_secret,
|
||||
algorithm="HS256",
|
||||
)
|
||||
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app),
|
||||
base_url="http://test",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
) as ac:
|
||||
response = await ac.post(
|
||||
"/games",
|
||||
json={
|
||||
"name": "Test Game",
|
||||
"slug": "test-game",
|
||||
"generation": 1,
|
||||
"region": "Kanto",
|
||||
"category": "core",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 403
|
||||
assert response.json()["detail"] == "Admin access required"
|
||||
|
||||
|
||||
async def test_admin_endpoint_succeeds_for_admin(db_session, jwt_secret, monkeypatch):
|
||||
"""Test that admin endpoint succeeds for authenticated admin user."""
|
||||
user_id = "55555555-5555-5555-5555-555555555555"
|
||||
admin_user = User(
|
||||
id=UUID(user_id),
|
||||
email="admin@example.com",
|
||||
is_admin=True,
|
||||
)
|
||||
db_session.add(admin_user)
|
||||
await db_session.commit()
|
||||
|
||||
monkeypatch.setattr(settings, "supabase_jwt_secret", jwt_secret)
|
||||
token = jwt.encode(
|
||||
{
|
||||
"sub": user_id,
|
||||
"email": "admin@example.com",
|
||||
"role": "authenticated",
|
||||
"aud": "authenticated",
|
||||
"exp": int(time.time()) + 3600,
|
||||
},
|
||||
jwt_secret,
|
||||
algorithm="HS256",
|
||||
)
|
||||
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app),
|
||||
base_url="http://test",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
) as ac:
|
||||
response = await ac.post(
|
||||
"/games",
|
||||
json={
|
||||
"name": "Test Game",
|
||||
"slug": "test-game",
|
||||
"generation": 1,
|
||||
"region": "Kanto",
|
||||
"category": "core",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 201
|
||||
assert response.json()["name"] == "Test Game"
|
||||
|
||||
Reference in New Issue
Block a user