feat: add require_admin dependency and protect admin endpoints

Add require_admin FastAPI dependency that checks is_admin column on users
table. Apply it to all admin-facing write endpoints (games, pokemon,
evolutions, bosses, routes CRUD). Run-scoped endpoints remain protected
by require_auth only since they manage user's own data.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-03-21 11:14:55 +01:00
parent 1042fff2b8
commit 2e66186fac
7 changed files with 226 additions and 27 deletions

View File

@@ -1,11 +1,11 @@
--- ---
# nuzlocke-tracker-f4d0 # nuzlocke-tracker-f4d0
title: Add require_admin dependency and protect admin endpoints title: Add require_admin dependency and protect admin endpoints
status: todo status: in-progress
type: task type: task
priority: normal priority: normal
created_at: 2026-03-21T10:06:19Z created_at: 2026-03-21T10:06:19Z
updated_at: 2026-03-21T10:06:24Z updated_at: 2026-03-21T10:14:36Z
parent: nuzlocke-tracker-ce4o parent: nuzlocke-tracker-ce4o
blocked_by: blocked_by:
- nuzlocke-tracker-dwah - nuzlocke-tracker-dwah
@@ -15,13 +15,13 @@ Add a `require_admin` FastAPI dependency that checks the `is_admin` column on th
## Checklist ## Checklist
- [ ] Add `require_admin` dependency in `backend/src/app/core/auth.py` that: - [x] Add `require_admin` dependency in `backend/src/app/core/auth.py` that:
- Requires authentication (reuses `require_auth`) - Requires authentication (reuses `require_auth`)
- Looks up the user in the `users` table by `AuthUser.id` - Looks up the user in the `users` table by `AuthUser.id`
- Returns 403 if `is_admin` is not `True` - Returns 403 if `is_admin` is not `True`
- [ ] Apply `require_admin` to write endpoints in: `games.py`, `pokemon.py`, `evolutions.py`, `bosses.py` (all POST/PUT/PATCH/DELETE) - [x] Apply `require_admin` to write endpoints in: `games.py`, `pokemon.py`, `evolutions.py`, `bosses.py` (all POST/PUT/PATCH/DELETE)
- [ ] Keep read endpoints (GET) accessible to all authenticated users - [x] Keep read endpoints (GET) accessible to all authenticated users
- [ ] Add tests for 403 response when non-admin user hits admin endpoints - [x] Add tests for 403 response when non-admin user hits admin endpoints
## Files to change ## Files to change
@@ -30,3 +30,20 @@ Add a `require_admin` FastAPI dependency that checks the `is_admin` column on th
- `backend/src/app/api/pokemon.py` — same - `backend/src/app/api/pokemon.py` — same
- `backend/src/app/api/evolutions.py` — same - `backend/src/app/api/evolutions.py` — same
- `backend/src/app/api/bosses.py` — same - `backend/src/app/api/bosses.py` — same
## Summary of Changes
Added `require_admin` FastAPI dependency to `backend/src/app/core/auth.py`:
- Depends on `require_auth` (returns 401 if not authenticated)
- Looks up user in `users` table by UUID
- Returns 403 if user not found or `is_admin` is not True
Applied `require_admin` to all admin-facing write endpoints:
- `games.py`: POST/PUT/DELETE for games and routes
- `pokemon.py`: POST/PUT/DELETE for pokemon and route encounters
- `evolutions.py`: POST/PUT/DELETE for evolutions
- `bosses.py`: POST/PUT/DELETE for game-scoped boss operations (run-scoped endpoints kept with `require_auth`)
Added tests in `test_auth.py`:
- Unit tests for `require_admin` (admin user, non-admin user, user not in DB)
- Integration tests for admin endpoint access (403 for non-admin, 201 for admin)

View File

@@ -5,7 +5,7 @@ from sqlalchemy import or_, select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload from sqlalchemy.orm import selectinload
from app.core.auth import AuthUser, require_auth from app.core.auth import AuthUser, require_admin, require_auth
from app.core.database import get_session from app.core.database import get_session
from app.models.boss_battle import BossBattle from app.models.boss_battle import BossBattle
from app.models.boss_pokemon import BossPokemon from app.models.boss_pokemon import BossPokemon
@@ -86,7 +86,7 @@ async def reorder_bosses(
game_id: int, game_id: int,
data: BossReorderRequest, data: BossReorderRequest,
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
_user: AuthUser = Depends(require_auth), _user: AuthUser = Depends(require_admin),
): ):
vg_id = await _get_version_group_id(session, game_id) vg_id = await _get_version_group_id(session, game_id)
@@ -130,7 +130,7 @@ async def create_boss(
game_id: int, game_id: int,
data: BossBattleCreate, data: BossBattleCreate,
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
_user: AuthUser = Depends(require_auth), _user: AuthUser = Depends(require_admin),
): ):
vg_id = await _get_version_group_id(session, game_id) vg_id = await _get_version_group_id(session, game_id)
@@ -161,7 +161,7 @@ async def update_boss(
boss_id: int, boss_id: int,
data: BossBattleUpdate, data: BossBattleUpdate,
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
_user: AuthUser = Depends(require_auth), _user: AuthUser = Depends(require_admin),
): ):
vg_id = await _get_version_group_id(session, game_id) vg_id = await _get_version_group_id(session, game_id)
@@ -202,7 +202,7 @@ async def delete_boss(
game_id: int, game_id: int,
boss_id: int, boss_id: int,
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
_user: AuthUser = Depends(require_auth), _user: AuthUser = Depends(require_admin),
): ):
vg_id = await _get_version_group_id(session, game_id) vg_id = await _get_version_group_id(session, game_id)
@@ -225,7 +225,7 @@ async def bulk_import_bosses(
game_id: int, game_id: int,
items: list[BulkBossItem], items: list[BulkBossItem],
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
_user: AuthUser = Depends(require_auth), _user: AuthUser = Depends(require_admin),
): ):
vg_id = await _get_version_group_id(session, game_id) vg_id = await _get_version_group_id(session, game_id)
@@ -268,7 +268,7 @@ async def set_boss_team(
boss_id: int, boss_id: int,
team: list[BossPokemonInput], team: list[BossPokemonInput],
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
_user: AuthUser = Depends(require_auth), _user: AuthUser = Depends(require_admin),
): ):
vg_id = await _get_version_group_id(session, game_id) vg_id = await _get_version_group_id(session, game_id)

View File

@@ -3,6 +3,7 @@ from sqlalchemy import func, or_, select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload from sqlalchemy.orm import joinedload
from app.core.auth import AuthUser, require_admin
from app.core.database import get_session from app.core.database import get_session
from app.models.evolution import Evolution from app.models.evolution import Evolution
from app.models.pokemon import Pokemon from app.models.pokemon import Pokemon
@@ -89,7 +90,9 @@ async def list_evolutions(
@router.post("/evolutions", response_model=EvolutionAdminResponse, status_code=201) @router.post("/evolutions", response_model=EvolutionAdminResponse, status_code=201)
async def create_evolution( async def create_evolution(
data: EvolutionCreate, session: AsyncSession = Depends(get_session) data: EvolutionCreate,
session: AsyncSession = Depends(get_session),
_user: AuthUser = Depends(require_admin),
): ):
from_pokemon = await session.get(Pokemon, data.from_pokemon_id) from_pokemon = await session.get(Pokemon, data.from_pokemon_id)
if from_pokemon is None: if from_pokemon is None:
@@ -117,6 +120,7 @@ async def update_evolution(
evolution_id: int, evolution_id: int,
data: EvolutionUpdate, data: EvolutionUpdate,
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
_user: AuthUser = Depends(require_admin),
): ):
evolution = await session.get(Evolution, evolution_id) evolution = await session.get(Evolution, evolution_id)
if evolution is None: if evolution is None:
@@ -150,7 +154,9 @@ async def update_evolution(
@router.delete("/evolutions/{evolution_id}", status_code=204) @router.delete("/evolutions/{evolution_id}", status_code=204)
async def delete_evolution( async def delete_evolution(
evolution_id: int, session: AsyncSession = Depends(get_session) evolution_id: int,
session: AsyncSession = Depends(get_session),
_user: AuthUser = Depends(require_admin),
): ):
evolution = await session.get(Evolution, evolution_id) evolution = await session.get(Evolution, evolution_id)
if evolution is None: if evolution is None:
@@ -164,6 +170,7 @@ async def delete_evolution(
async def bulk_import_evolutions( async def bulk_import_evolutions(
items: list[BulkEvolutionItem], items: list[BulkEvolutionItem],
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
_user: AuthUser = Depends(require_admin),
): ):
# Build pokeapi_id -> id mapping # Build pokeapi_id -> id mapping
result = await session.execute(select(Pokemon.pokeapi_id, Pokemon.id)) result = await session.execute(select(Pokemon.pokeapi_id, Pokemon.id))

View File

@@ -6,7 +6,7 @@ from sqlalchemy import delete, select, update
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload from sqlalchemy.orm import selectinload
from app.core.auth import AuthUser, require_auth from app.core.auth import AuthUser, require_admin
from app.core.database import get_session from app.core.database import get_session
from app.models.boss_battle import BossBattle from app.models.boss_battle import BossBattle
from app.models.game import Game from app.models.game import Game
@@ -232,7 +232,7 @@ async def list_game_routes(
async def create_game( async def create_game(
data: GameCreate, data: GameCreate,
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
_user: AuthUser = Depends(require_auth), _user: AuthUser = Depends(require_admin),
): ):
existing = await session.execute(select(Game).where(Game.slug == data.slug)) existing = await session.execute(select(Game).where(Game.slug == data.slug))
if existing.scalar_one_or_none() is not None: if existing.scalar_one_or_none() is not None:
@@ -252,7 +252,7 @@ async def update_game(
game_id: int, game_id: int,
data: GameUpdate, data: GameUpdate,
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
_user: AuthUser = Depends(require_auth), _user: AuthUser = Depends(require_admin),
): ):
game = await session.get(Game, game_id) game = await session.get(Game, game_id)
if game is None: if game is None:
@@ -280,7 +280,7 @@ async def update_game(
async def delete_game( async def delete_game(
game_id: int, game_id: int,
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
_user: AuthUser = Depends(require_auth), _user: AuthUser = Depends(require_admin),
): ):
result = await session.execute( result = await session.execute(
select(Game).where(Game.id == game_id).options(selectinload(Game.runs)) select(Game).where(Game.id == game_id).options(selectinload(Game.runs))
@@ -338,7 +338,7 @@ async def create_route(
game_id: int, game_id: int,
data: RouteCreate, data: RouteCreate,
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
_user: AuthUser = Depends(require_auth), _user: AuthUser = Depends(require_admin),
): ):
vg_id = await _get_version_group_id(session, game_id) vg_id = await _get_version_group_id(session, game_id)
@@ -354,7 +354,7 @@ async def reorder_routes(
game_id: int, game_id: int,
data: RouteReorderRequest, data: RouteReorderRequest,
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
_user: AuthUser = Depends(require_auth), _user: AuthUser = Depends(require_admin),
): ):
vg_id = await _get_version_group_id(session, game_id) vg_id = await _get_version_group_id(session, game_id)
@@ -381,7 +381,7 @@ async def update_route(
route_id: int, route_id: int,
data: RouteUpdate, data: RouteUpdate,
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
_user: AuthUser = Depends(require_auth), _user: AuthUser = Depends(require_admin),
): ):
vg_id = await _get_version_group_id(session, game_id) vg_id = await _get_version_group_id(session, game_id)
@@ -402,7 +402,7 @@ async def delete_route(
game_id: int, game_id: int,
route_id: int, route_id: int,
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
_user: AuthUser = Depends(require_auth), _user: AuthUser = Depends(require_admin),
): ):
vg_id = await _get_version_group_id(session, game_id) vg_id = await _get_version_group_id(session, game_id)
@@ -437,7 +437,7 @@ async def bulk_import_routes(
game_id: int, game_id: int,
items: list[BulkRouteItem], items: list[BulkRouteItem],
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
_user: AuthUser = Depends(require_auth), _user: AuthUser = Depends(require_admin),
): ):
vg_id = await _get_version_group_id(session, game_id) vg_id = await _get_version_group_id(session, game_id)

View File

@@ -3,6 +3,7 @@ from sqlalchemy import func, or_, select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload, selectinload from sqlalchemy.orm import joinedload, selectinload
from app.core.auth import AuthUser, require_admin
from app.core.database import get_session from app.core.database import get_session
from app.models.evolution import Evolution from app.models.evolution import Evolution
from app.models.pokemon import Pokemon from app.models.pokemon import Pokemon
@@ -68,6 +69,7 @@ async def list_pokemon(
async def bulk_import_pokemon( async def bulk_import_pokemon(
items: list[BulkImportItem], items: list[BulkImportItem],
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
_user: AuthUser = Depends(require_admin),
): ):
created = 0 created = 0
updated = 0 updated = 0
@@ -100,7 +102,9 @@ async def bulk_import_pokemon(
@router.post("/pokemon", response_model=PokemonResponse, status_code=201) @router.post("/pokemon", response_model=PokemonResponse, status_code=201)
async def create_pokemon( async def create_pokemon(
data: PokemonCreate, session: AsyncSession = Depends(get_session) data: PokemonCreate,
session: AsyncSession = Depends(get_session),
_user: AuthUser = Depends(require_admin),
): ):
existing = await session.execute( existing = await session.execute(
select(Pokemon).where(Pokemon.pokeapi_id == data.pokeapi_id) select(Pokemon).where(Pokemon.pokeapi_id == data.pokeapi_id)
@@ -321,6 +325,7 @@ async def update_pokemon(
pokemon_id: int, pokemon_id: int,
data: PokemonUpdate, data: PokemonUpdate,
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
_user: AuthUser = Depends(require_admin),
): ):
pokemon = await session.get(Pokemon, pokemon_id) pokemon = await session.get(Pokemon, pokemon_id)
if pokemon is None: if pokemon is None:
@@ -349,7 +354,11 @@ async def update_pokemon(
@router.delete("/pokemon/{pokemon_id}", status_code=204) @router.delete("/pokemon/{pokemon_id}", status_code=204)
async def delete_pokemon(pokemon_id: int, session: AsyncSession = Depends(get_session)): async def delete_pokemon(
pokemon_id: int,
session: AsyncSession = Depends(get_session),
_user: AuthUser = Depends(require_admin),
):
result = await session.execute( result = await session.execute(
select(Pokemon) select(Pokemon)
.where(Pokemon.id == pokemon_id) .where(Pokemon.id == pokemon_id)
@@ -405,6 +414,7 @@ async def add_route_encounter(
route_id: int, route_id: int,
data: RouteEncounterCreate, data: RouteEncounterCreate,
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
_user: AuthUser = Depends(require_admin),
): ):
route = await session.get(Route, route_id) route = await session.get(Route, route_id)
if route is None: if route is None:
@@ -436,6 +446,7 @@ async def update_route_encounter(
encounter_id: int, encounter_id: int,
data: RouteEncounterUpdate, data: RouteEncounterUpdate,
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
_user: AuthUser = Depends(require_admin),
): ):
result = await session.execute( result = await session.execute(
select(RouteEncounter) select(RouteEncounter)
@@ -466,6 +477,7 @@ async def remove_route_encounter(
route_id: int, route_id: int,
encounter_id: int, encounter_id: int,
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
_user: AuthUser = Depends(require_admin),
): ):
encounter = await session.execute( encounter = await session.execute(
select(RouteEncounter).where( select(RouteEncounter).where(

View File

@@ -1,9 +1,14 @@
from dataclasses import dataclass from dataclasses import dataclass
from uuid import UUID
import jwt import jwt
from fastapi import Depends, HTTPException, Request, status from fastapi import Depends, HTTPException, Request, status
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import settings from app.core.config import settings
from app.core.database import get_session
from app.models.user import User
@dataclass @dataclass
@@ -81,3 +86,22 @@ def require_auth(user: AuthUser | None = Depends(get_current_user)) -> AuthUser:
headers={"WWW-Authenticate": "Bearer"}, headers={"WWW-Authenticate": "Bearer"},
) )
return user return user
async def require_admin(
user: AuthUser = Depends(require_auth),
session: AsyncSession = Depends(get_session),
) -> AuthUser:
"""
Dependency that requires admin privileges.
Raises 401 if not authenticated, 403 if not an admin.
"""
result = await session.execute(select(User).where(User.id == UUID(user.id)))
db_user = result.scalar_one_or_none()
if db_user is None or not db_user.is_admin:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Admin access required",
)
return user

View File

@@ -1,12 +1,14 @@
import time import time
from uuid import UUID
import jwt import jwt
import pytest import pytest
from httpx import ASGITransport, AsyncClient from httpx import ASGITransport, AsyncClient
from app.core.auth import AuthUser, get_current_user, require_auth from app.core.auth import AuthUser, get_current_user, require_admin, require_auth
from app.core.config import settings from app.core.config import settings
from app.main import app from app.main import app
from app.models.user import User
@pytest.fixture @pytest.fixture
@@ -177,3 +179,140 @@ async def test_read_endpoint_without_token(db_session):
) as ac: ) as ac:
response = await ac.get("/runs") response = await ac.get("/runs")
assert response.status_code == 200 assert response.status_code == 200
async def test_require_admin_valid_admin_user(db_session):
"""Test require_admin passes through for admin user."""
user_id = "11111111-1111-1111-1111-111111111111"
admin_user = User(
id=UUID(user_id),
email="admin@example.com",
is_admin=True,
)
db_session.add(admin_user)
await db_session.commit()
auth_user = AuthUser(id=user_id, email="admin@example.com")
result = await require_admin(user=auth_user, session=db_session)
assert result is auth_user
async def test_require_admin_non_admin_user(db_session):
"""Test require_admin raises 403 for non-admin user."""
from fastapi import HTTPException
user_id = "22222222-2222-2222-2222-222222222222"
regular_user = User(
id=UUID(user_id),
email="user@example.com",
is_admin=False,
)
db_session.add(regular_user)
await db_session.commit()
auth_user = AuthUser(id=user_id, email="user@example.com")
with pytest.raises(HTTPException) as exc_info:
await require_admin(user=auth_user, session=db_session)
assert exc_info.value.status_code == 403
assert exc_info.value.detail == "Admin access required"
async def test_require_admin_user_not_in_db(db_session):
"""Test require_admin raises 403 for user not in database."""
from fastapi import HTTPException
auth_user = AuthUser(
id="33333333-3333-3333-3333-333333333333", email="ghost@example.com"
)
with pytest.raises(HTTPException) as exc_info:
await require_admin(user=auth_user, session=db_session)
assert exc_info.value.status_code == 403
assert exc_info.value.detail == "Admin access required"
async def test_admin_endpoint_returns_403_for_non_admin(
db_session, jwt_secret, monkeypatch
):
"""Test that admin endpoint returns 403 for authenticated non-admin user."""
user_id = "44444444-4444-4444-4444-444444444444"
regular_user = User(
id=UUID(user_id),
email="nonadmin@example.com",
is_admin=False,
)
db_session.add(regular_user)
await db_session.commit()
monkeypatch.setattr(settings, "supabase_jwt_secret", jwt_secret)
token = jwt.encode(
{
"sub": user_id,
"email": "nonadmin@example.com",
"role": "authenticated",
"aud": "authenticated",
"exp": int(time.time()) + 3600,
},
jwt_secret,
algorithm="HS256",
)
async with AsyncClient(
transport=ASGITransport(app=app),
base_url="http://test",
headers={"Authorization": f"Bearer {token}"},
) as ac:
response = await ac.post(
"/games",
json={
"name": "Test Game",
"slug": "test-game",
"generation": 1,
"region": "Kanto",
"category": "core",
},
)
assert response.status_code == 403
assert response.json()["detail"] == "Admin access required"
async def test_admin_endpoint_succeeds_for_admin(db_session, jwt_secret, monkeypatch):
"""Test that admin endpoint succeeds for authenticated admin user."""
user_id = "55555555-5555-5555-5555-555555555555"
admin_user = User(
id=UUID(user_id),
email="admin@example.com",
is_admin=True,
)
db_session.add(admin_user)
await db_session.commit()
monkeypatch.setattr(settings, "supabase_jwt_secret", jwt_secret)
token = jwt.encode(
{
"sub": user_id,
"email": "admin@example.com",
"role": "authenticated",
"aud": "authenticated",
"exp": int(time.time()) + 3600,
},
jwt_secret,
algorithm="HS256",
)
async with AsyncClient(
transport=ASGITransport(app=app),
base_url="http://test",
headers={"Authorization": f"Bearer {token}"},
) as ac:
response = await ac.post(
"/games",
json={
"name": "Test Game",
"slug": "test-game",
"generation": 1,
"region": "Kanto",
"category": "core",
},
)
assert response.status_code == 201
assert response.json()["name"] == "Test Game"