feat: add require_admin dependency and protect admin endpoints
Add require_admin FastAPI dependency that checks is_admin column on users table. Apply it to all admin-facing write endpoints (games, pokemon, evolutions, bosses, routes CRUD). Run-scoped endpoints remain protected by require_auth only since they manage user's own data. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1,11 +1,11 @@
|
|||||||
---
|
---
|
||||||
# nuzlocke-tracker-f4d0
|
# nuzlocke-tracker-f4d0
|
||||||
title: Add require_admin dependency and protect admin endpoints
|
title: Add require_admin dependency and protect admin endpoints
|
||||||
status: todo
|
status: in-progress
|
||||||
type: task
|
type: task
|
||||||
priority: normal
|
priority: normal
|
||||||
created_at: 2026-03-21T10:06:19Z
|
created_at: 2026-03-21T10:06:19Z
|
||||||
updated_at: 2026-03-21T10:06:24Z
|
updated_at: 2026-03-21T10:14:36Z
|
||||||
parent: nuzlocke-tracker-ce4o
|
parent: nuzlocke-tracker-ce4o
|
||||||
blocked_by:
|
blocked_by:
|
||||||
- nuzlocke-tracker-dwah
|
- nuzlocke-tracker-dwah
|
||||||
@@ -15,13 +15,13 @@ Add a `require_admin` FastAPI dependency that checks the `is_admin` column on th
|
|||||||
|
|
||||||
## Checklist
|
## Checklist
|
||||||
|
|
||||||
- [ ] Add `require_admin` dependency in `backend/src/app/core/auth.py` that:
|
- [x] Add `require_admin` dependency in `backend/src/app/core/auth.py` that:
|
||||||
- Requires authentication (reuses `require_auth`)
|
- Requires authentication (reuses `require_auth`)
|
||||||
- Looks up the user in the `users` table by `AuthUser.id`
|
- Looks up the user in the `users` table by `AuthUser.id`
|
||||||
- Returns 403 if `is_admin` is not `True`
|
- Returns 403 if `is_admin` is not `True`
|
||||||
- [ ] Apply `require_admin` to write endpoints in: `games.py`, `pokemon.py`, `evolutions.py`, `bosses.py` (all POST/PUT/PATCH/DELETE)
|
- [x] Apply `require_admin` to write endpoints in: `games.py`, `pokemon.py`, `evolutions.py`, `bosses.py` (all POST/PUT/PATCH/DELETE)
|
||||||
- [ ] Keep read endpoints (GET) accessible to all authenticated users
|
- [x] Keep read endpoints (GET) accessible to all authenticated users
|
||||||
- [ ] Add tests for 403 response when non-admin user hits admin endpoints
|
- [x] Add tests for 403 response when non-admin user hits admin endpoints
|
||||||
|
|
||||||
## Files to change
|
## Files to change
|
||||||
|
|
||||||
@@ -30,3 +30,20 @@ Add a `require_admin` FastAPI dependency that checks the `is_admin` column on th
|
|||||||
- `backend/src/app/api/pokemon.py` — same
|
- `backend/src/app/api/pokemon.py` — same
|
||||||
- `backend/src/app/api/evolutions.py` — same
|
- `backend/src/app/api/evolutions.py` — same
|
||||||
- `backend/src/app/api/bosses.py` — same
|
- `backend/src/app/api/bosses.py` — same
|
||||||
|
|
||||||
|
## Summary of Changes
|
||||||
|
|
||||||
|
Added `require_admin` FastAPI dependency to `backend/src/app/core/auth.py`:
|
||||||
|
- Depends on `require_auth` (returns 401 if not authenticated)
|
||||||
|
- Looks up user in `users` table by UUID
|
||||||
|
- Returns 403 if user not found or `is_admin` is not True
|
||||||
|
|
||||||
|
Applied `require_admin` to all admin-facing write endpoints:
|
||||||
|
- `games.py`: POST/PUT/DELETE for games and routes
|
||||||
|
- `pokemon.py`: POST/PUT/DELETE for pokemon and route encounters
|
||||||
|
- `evolutions.py`: POST/PUT/DELETE for evolutions
|
||||||
|
- `bosses.py`: POST/PUT/DELETE for game-scoped boss operations (run-scoped endpoints kept with `require_auth`)
|
||||||
|
|
||||||
|
Added tests in `test_auth.py`:
|
||||||
|
- Unit tests for `require_admin` (admin user, non-admin user, user not in DB)
|
||||||
|
- Integration tests for admin endpoint access (403 for non-admin, 201 for admin)
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from sqlalchemy import or_, select
|
|||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.orm import selectinload
|
from sqlalchemy.orm import selectinload
|
||||||
|
|
||||||
from app.core.auth import AuthUser, require_auth
|
from app.core.auth import AuthUser, require_admin, require_auth
|
||||||
from app.core.database import get_session
|
from app.core.database import get_session
|
||||||
from app.models.boss_battle import BossBattle
|
from app.models.boss_battle import BossBattle
|
||||||
from app.models.boss_pokemon import BossPokemon
|
from app.models.boss_pokemon import BossPokemon
|
||||||
@@ -86,7 +86,7 @@ async def reorder_bosses(
|
|||||||
game_id: int,
|
game_id: int,
|
||||||
data: BossReorderRequest,
|
data: BossReorderRequest,
|
||||||
session: AsyncSession = Depends(get_session),
|
session: AsyncSession = Depends(get_session),
|
||||||
_user: AuthUser = Depends(require_auth),
|
_user: AuthUser = Depends(require_admin),
|
||||||
):
|
):
|
||||||
vg_id = await _get_version_group_id(session, game_id)
|
vg_id = await _get_version_group_id(session, game_id)
|
||||||
|
|
||||||
@@ -130,7 +130,7 @@ async def create_boss(
|
|||||||
game_id: int,
|
game_id: int,
|
||||||
data: BossBattleCreate,
|
data: BossBattleCreate,
|
||||||
session: AsyncSession = Depends(get_session),
|
session: AsyncSession = Depends(get_session),
|
||||||
_user: AuthUser = Depends(require_auth),
|
_user: AuthUser = Depends(require_admin),
|
||||||
):
|
):
|
||||||
vg_id = await _get_version_group_id(session, game_id)
|
vg_id = await _get_version_group_id(session, game_id)
|
||||||
|
|
||||||
@@ -161,7 +161,7 @@ async def update_boss(
|
|||||||
boss_id: int,
|
boss_id: int,
|
||||||
data: BossBattleUpdate,
|
data: BossBattleUpdate,
|
||||||
session: AsyncSession = Depends(get_session),
|
session: AsyncSession = Depends(get_session),
|
||||||
_user: AuthUser = Depends(require_auth),
|
_user: AuthUser = Depends(require_admin),
|
||||||
):
|
):
|
||||||
vg_id = await _get_version_group_id(session, game_id)
|
vg_id = await _get_version_group_id(session, game_id)
|
||||||
|
|
||||||
@@ -202,7 +202,7 @@ async def delete_boss(
|
|||||||
game_id: int,
|
game_id: int,
|
||||||
boss_id: int,
|
boss_id: int,
|
||||||
session: AsyncSession = Depends(get_session),
|
session: AsyncSession = Depends(get_session),
|
||||||
_user: AuthUser = Depends(require_auth),
|
_user: AuthUser = Depends(require_admin),
|
||||||
):
|
):
|
||||||
vg_id = await _get_version_group_id(session, game_id)
|
vg_id = await _get_version_group_id(session, game_id)
|
||||||
|
|
||||||
@@ -225,7 +225,7 @@ async def bulk_import_bosses(
|
|||||||
game_id: int,
|
game_id: int,
|
||||||
items: list[BulkBossItem],
|
items: list[BulkBossItem],
|
||||||
session: AsyncSession = Depends(get_session),
|
session: AsyncSession = Depends(get_session),
|
||||||
_user: AuthUser = Depends(require_auth),
|
_user: AuthUser = Depends(require_admin),
|
||||||
):
|
):
|
||||||
vg_id = await _get_version_group_id(session, game_id)
|
vg_id = await _get_version_group_id(session, game_id)
|
||||||
|
|
||||||
@@ -268,7 +268,7 @@ async def set_boss_team(
|
|||||||
boss_id: int,
|
boss_id: int,
|
||||||
team: list[BossPokemonInput],
|
team: list[BossPokemonInput],
|
||||||
session: AsyncSession = Depends(get_session),
|
session: AsyncSession = Depends(get_session),
|
||||||
_user: AuthUser = Depends(require_auth),
|
_user: AuthUser = Depends(require_admin),
|
||||||
):
|
):
|
||||||
vg_id = await _get_version_group_id(session, game_id)
|
vg_id = await _get_version_group_id(session, game_id)
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ from sqlalchemy import func, or_, select
|
|||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.orm import joinedload
|
from sqlalchemy.orm import joinedload
|
||||||
|
|
||||||
|
from app.core.auth import AuthUser, require_admin
|
||||||
from app.core.database import get_session
|
from app.core.database import get_session
|
||||||
from app.models.evolution import Evolution
|
from app.models.evolution import Evolution
|
||||||
from app.models.pokemon import Pokemon
|
from app.models.pokemon import Pokemon
|
||||||
@@ -89,7 +90,9 @@ async def list_evolutions(
|
|||||||
|
|
||||||
@router.post("/evolutions", response_model=EvolutionAdminResponse, status_code=201)
|
@router.post("/evolutions", response_model=EvolutionAdminResponse, status_code=201)
|
||||||
async def create_evolution(
|
async def create_evolution(
|
||||||
data: EvolutionCreate, session: AsyncSession = Depends(get_session)
|
data: EvolutionCreate,
|
||||||
|
session: AsyncSession = Depends(get_session),
|
||||||
|
_user: AuthUser = Depends(require_admin),
|
||||||
):
|
):
|
||||||
from_pokemon = await session.get(Pokemon, data.from_pokemon_id)
|
from_pokemon = await session.get(Pokemon, data.from_pokemon_id)
|
||||||
if from_pokemon is None:
|
if from_pokemon is None:
|
||||||
@@ -117,6 +120,7 @@ async def update_evolution(
|
|||||||
evolution_id: int,
|
evolution_id: int,
|
||||||
data: EvolutionUpdate,
|
data: EvolutionUpdate,
|
||||||
session: AsyncSession = Depends(get_session),
|
session: AsyncSession = Depends(get_session),
|
||||||
|
_user: AuthUser = Depends(require_admin),
|
||||||
):
|
):
|
||||||
evolution = await session.get(Evolution, evolution_id)
|
evolution = await session.get(Evolution, evolution_id)
|
||||||
if evolution is None:
|
if evolution is None:
|
||||||
@@ -150,7 +154,9 @@ async def update_evolution(
|
|||||||
|
|
||||||
@router.delete("/evolutions/{evolution_id}", status_code=204)
|
@router.delete("/evolutions/{evolution_id}", status_code=204)
|
||||||
async def delete_evolution(
|
async def delete_evolution(
|
||||||
evolution_id: int, session: AsyncSession = Depends(get_session)
|
evolution_id: int,
|
||||||
|
session: AsyncSession = Depends(get_session),
|
||||||
|
_user: AuthUser = Depends(require_admin),
|
||||||
):
|
):
|
||||||
evolution = await session.get(Evolution, evolution_id)
|
evolution = await session.get(Evolution, evolution_id)
|
||||||
if evolution is None:
|
if evolution is None:
|
||||||
@@ -164,6 +170,7 @@ async def delete_evolution(
|
|||||||
async def bulk_import_evolutions(
|
async def bulk_import_evolutions(
|
||||||
items: list[BulkEvolutionItem],
|
items: list[BulkEvolutionItem],
|
||||||
session: AsyncSession = Depends(get_session),
|
session: AsyncSession = Depends(get_session),
|
||||||
|
_user: AuthUser = Depends(require_admin),
|
||||||
):
|
):
|
||||||
# Build pokeapi_id -> id mapping
|
# Build pokeapi_id -> id mapping
|
||||||
result = await session.execute(select(Pokemon.pokeapi_id, Pokemon.id))
|
result = await session.execute(select(Pokemon.pokeapi_id, Pokemon.id))
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from sqlalchemy import delete, select, update
|
|||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.orm import selectinload
|
from sqlalchemy.orm import selectinload
|
||||||
|
|
||||||
from app.core.auth import AuthUser, require_auth
|
from app.core.auth import AuthUser, require_admin
|
||||||
from app.core.database import get_session
|
from app.core.database import get_session
|
||||||
from app.models.boss_battle import BossBattle
|
from app.models.boss_battle import BossBattle
|
||||||
from app.models.game import Game
|
from app.models.game import Game
|
||||||
@@ -232,7 +232,7 @@ async def list_game_routes(
|
|||||||
async def create_game(
|
async def create_game(
|
||||||
data: GameCreate,
|
data: GameCreate,
|
||||||
session: AsyncSession = Depends(get_session),
|
session: AsyncSession = Depends(get_session),
|
||||||
_user: AuthUser = Depends(require_auth),
|
_user: AuthUser = Depends(require_admin),
|
||||||
):
|
):
|
||||||
existing = await session.execute(select(Game).where(Game.slug == data.slug))
|
existing = await session.execute(select(Game).where(Game.slug == data.slug))
|
||||||
if existing.scalar_one_or_none() is not None:
|
if existing.scalar_one_or_none() is not None:
|
||||||
@@ -252,7 +252,7 @@ async def update_game(
|
|||||||
game_id: int,
|
game_id: int,
|
||||||
data: GameUpdate,
|
data: GameUpdate,
|
||||||
session: AsyncSession = Depends(get_session),
|
session: AsyncSession = Depends(get_session),
|
||||||
_user: AuthUser = Depends(require_auth),
|
_user: AuthUser = Depends(require_admin),
|
||||||
):
|
):
|
||||||
game = await session.get(Game, game_id)
|
game = await session.get(Game, game_id)
|
||||||
if game is None:
|
if game is None:
|
||||||
@@ -280,7 +280,7 @@ async def update_game(
|
|||||||
async def delete_game(
|
async def delete_game(
|
||||||
game_id: int,
|
game_id: int,
|
||||||
session: AsyncSession = Depends(get_session),
|
session: AsyncSession = Depends(get_session),
|
||||||
_user: AuthUser = Depends(require_auth),
|
_user: AuthUser = Depends(require_admin),
|
||||||
):
|
):
|
||||||
result = await session.execute(
|
result = await session.execute(
|
||||||
select(Game).where(Game.id == game_id).options(selectinload(Game.runs))
|
select(Game).where(Game.id == game_id).options(selectinload(Game.runs))
|
||||||
@@ -338,7 +338,7 @@ async def create_route(
|
|||||||
game_id: int,
|
game_id: int,
|
||||||
data: RouteCreate,
|
data: RouteCreate,
|
||||||
session: AsyncSession = Depends(get_session),
|
session: AsyncSession = Depends(get_session),
|
||||||
_user: AuthUser = Depends(require_auth),
|
_user: AuthUser = Depends(require_admin),
|
||||||
):
|
):
|
||||||
vg_id = await _get_version_group_id(session, game_id)
|
vg_id = await _get_version_group_id(session, game_id)
|
||||||
|
|
||||||
@@ -354,7 +354,7 @@ async def reorder_routes(
|
|||||||
game_id: int,
|
game_id: int,
|
||||||
data: RouteReorderRequest,
|
data: RouteReorderRequest,
|
||||||
session: AsyncSession = Depends(get_session),
|
session: AsyncSession = Depends(get_session),
|
||||||
_user: AuthUser = Depends(require_auth),
|
_user: AuthUser = Depends(require_admin),
|
||||||
):
|
):
|
||||||
vg_id = await _get_version_group_id(session, game_id)
|
vg_id = await _get_version_group_id(session, game_id)
|
||||||
|
|
||||||
@@ -381,7 +381,7 @@ async def update_route(
|
|||||||
route_id: int,
|
route_id: int,
|
||||||
data: RouteUpdate,
|
data: RouteUpdate,
|
||||||
session: AsyncSession = Depends(get_session),
|
session: AsyncSession = Depends(get_session),
|
||||||
_user: AuthUser = Depends(require_auth),
|
_user: AuthUser = Depends(require_admin),
|
||||||
):
|
):
|
||||||
vg_id = await _get_version_group_id(session, game_id)
|
vg_id = await _get_version_group_id(session, game_id)
|
||||||
|
|
||||||
@@ -402,7 +402,7 @@ async def delete_route(
|
|||||||
game_id: int,
|
game_id: int,
|
||||||
route_id: int,
|
route_id: int,
|
||||||
session: AsyncSession = Depends(get_session),
|
session: AsyncSession = Depends(get_session),
|
||||||
_user: AuthUser = Depends(require_auth),
|
_user: AuthUser = Depends(require_admin),
|
||||||
):
|
):
|
||||||
vg_id = await _get_version_group_id(session, game_id)
|
vg_id = await _get_version_group_id(session, game_id)
|
||||||
|
|
||||||
@@ -437,7 +437,7 @@ async def bulk_import_routes(
|
|||||||
game_id: int,
|
game_id: int,
|
||||||
items: list[BulkRouteItem],
|
items: list[BulkRouteItem],
|
||||||
session: AsyncSession = Depends(get_session),
|
session: AsyncSession = Depends(get_session),
|
||||||
_user: AuthUser = Depends(require_auth),
|
_user: AuthUser = Depends(require_admin),
|
||||||
):
|
):
|
||||||
vg_id = await _get_version_group_id(session, game_id)
|
vg_id = await _get_version_group_id(session, game_id)
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ from sqlalchemy import func, or_, select
|
|||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.orm import joinedload, selectinload
|
from sqlalchemy.orm import joinedload, selectinload
|
||||||
|
|
||||||
|
from app.core.auth import AuthUser, require_admin
|
||||||
from app.core.database import get_session
|
from app.core.database import get_session
|
||||||
from app.models.evolution import Evolution
|
from app.models.evolution import Evolution
|
||||||
from app.models.pokemon import Pokemon
|
from app.models.pokemon import Pokemon
|
||||||
@@ -68,6 +69,7 @@ async def list_pokemon(
|
|||||||
async def bulk_import_pokemon(
|
async def bulk_import_pokemon(
|
||||||
items: list[BulkImportItem],
|
items: list[BulkImportItem],
|
||||||
session: AsyncSession = Depends(get_session),
|
session: AsyncSession = Depends(get_session),
|
||||||
|
_user: AuthUser = Depends(require_admin),
|
||||||
):
|
):
|
||||||
created = 0
|
created = 0
|
||||||
updated = 0
|
updated = 0
|
||||||
@@ -100,7 +102,9 @@ async def bulk_import_pokemon(
|
|||||||
|
|
||||||
@router.post("/pokemon", response_model=PokemonResponse, status_code=201)
|
@router.post("/pokemon", response_model=PokemonResponse, status_code=201)
|
||||||
async def create_pokemon(
|
async def create_pokemon(
|
||||||
data: PokemonCreate, session: AsyncSession = Depends(get_session)
|
data: PokemonCreate,
|
||||||
|
session: AsyncSession = Depends(get_session),
|
||||||
|
_user: AuthUser = Depends(require_admin),
|
||||||
):
|
):
|
||||||
existing = await session.execute(
|
existing = await session.execute(
|
||||||
select(Pokemon).where(Pokemon.pokeapi_id == data.pokeapi_id)
|
select(Pokemon).where(Pokemon.pokeapi_id == data.pokeapi_id)
|
||||||
@@ -321,6 +325,7 @@ async def update_pokemon(
|
|||||||
pokemon_id: int,
|
pokemon_id: int,
|
||||||
data: PokemonUpdate,
|
data: PokemonUpdate,
|
||||||
session: AsyncSession = Depends(get_session),
|
session: AsyncSession = Depends(get_session),
|
||||||
|
_user: AuthUser = Depends(require_admin),
|
||||||
):
|
):
|
||||||
pokemon = await session.get(Pokemon, pokemon_id)
|
pokemon = await session.get(Pokemon, pokemon_id)
|
||||||
if pokemon is None:
|
if pokemon is None:
|
||||||
@@ -349,7 +354,11 @@ async def update_pokemon(
|
|||||||
|
|
||||||
|
|
||||||
@router.delete("/pokemon/{pokemon_id}", status_code=204)
|
@router.delete("/pokemon/{pokemon_id}", status_code=204)
|
||||||
async def delete_pokemon(pokemon_id: int, session: AsyncSession = Depends(get_session)):
|
async def delete_pokemon(
|
||||||
|
pokemon_id: int,
|
||||||
|
session: AsyncSession = Depends(get_session),
|
||||||
|
_user: AuthUser = Depends(require_admin),
|
||||||
|
):
|
||||||
result = await session.execute(
|
result = await session.execute(
|
||||||
select(Pokemon)
|
select(Pokemon)
|
||||||
.where(Pokemon.id == pokemon_id)
|
.where(Pokemon.id == pokemon_id)
|
||||||
@@ -405,6 +414,7 @@ async def add_route_encounter(
|
|||||||
route_id: int,
|
route_id: int,
|
||||||
data: RouteEncounterCreate,
|
data: RouteEncounterCreate,
|
||||||
session: AsyncSession = Depends(get_session),
|
session: AsyncSession = Depends(get_session),
|
||||||
|
_user: AuthUser = Depends(require_admin),
|
||||||
):
|
):
|
||||||
route = await session.get(Route, route_id)
|
route = await session.get(Route, route_id)
|
||||||
if route is None:
|
if route is None:
|
||||||
@@ -436,6 +446,7 @@ async def update_route_encounter(
|
|||||||
encounter_id: int,
|
encounter_id: int,
|
||||||
data: RouteEncounterUpdate,
|
data: RouteEncounterUpdate,
|
||||||
session: AsyncSession = Depends(get_session),
|
session: AsyncSession = Depends(get_session),
|
||||||
|
_user: AuthUser = Depends(require_admin),
|
||||||
):
|
):
|
||||||
result = await session.execute(
|
result = await session.execute(
|
||||||
select(RouteEncounter)
|
select(RouteEncounter)
|
||||||
@@ -466,6 +477,7 @@ async def remove_route_encounter(
|
|||||||
route_id: int,
|
route_id: int,
|
||||||
encounter_id: int,
|
encounter_id: int,
|
||||||
session: AsyncSession = Depends(get_session),
|
session: AsyncSession = Depends(get_session),
|
||||||
|
_user: AuthUser = Depends(require_admin),
|
||||||
):
|
):
|
||||||
encounter = await session.execute(
|
encounter = await session.execute(
|
||||||
select(RouteEncounter).where(
|
select(RouteEncounter).where(
|
||||||
|
|||||||
@@ -1,9 +1,14 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
import jwt
|
import jwt
|
||||||
from fastapi import Depends, HTTPException, Request, status
|
from fastapi import Depends, HTTPException, Request, status
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
|
from app.core.database import get_session
|
||||||
|
from app.models.user import User
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -81,3 +86,22 @@ def require_auth(user: AuthUser | None = Depends(get_current_user)) -> AuthUser:
|
|||||||
headers={"WWW-Authenticate": "Bearer"},
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
)
|
)
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
async def require_admin(
|
||||||
|
user: AuthUser = Depends(require_auth),
|
||||||
|
session: AsyncSession = Depends(get_session),
|
||||||
|
) -> AuthUser:
|
||||||
|
"""
|
||||||
|
Dependency that requires admin privileges.
|
||||||
|
Raises 401 if not authenticated, 403 if not an admin.
|
||||||
|
"""
|
||||||
|
result = await session.execute(select(User).where(User.id == UUID(user.id)))
|
||||||
|
db_user = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if db_user is None or not db_user.is_admin:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Admin access required",
|
||||||
|
)
|
||||||
|
return user
|
||||||
|
|||||||
@@ -1,12 +1,14 @@
|
|||||||
import time
|
import time
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
import jwt
|
import jwt
|
||||||
import pytest
|
import pytest
|
||||||
from httpx import ASGITransport, AsyncClient
|
from httpx import ASGITransport, AsyncClient
|
||||||
|
|
||||||
from app.core.auth import AuthUser, get_current_user, require_auth
|
from app.core.auth import AuthUser, get_current_user, require_admin, require_auth
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.main import app
|
from app.main import app
|
||||||
|
from app.models.user import User
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -177,3 +179,140 @@ async def test_read_endpoint_without_token(db_session):
|
|||||||
) as ac:
|
) as ac:
|
||||||
response = await ac.get("/runs")
|
response = await ac.get("/runs")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
async def test_require_admin_valid_admin_user(db_session):
|
||||||
|
"""Test require_admin passes through for admin user."""
|
||||||
|
user_id = "11111111-1111-1111-1111-111111111111"
|
||||||
|
admin_user = User(
|
||||||
|
id=UUID(user_id),
|
||||||
|
email="admin@example.com",
|
||||||
|
is_admin=True,
|
||||||
|
)
|
||||||
|
db_session.add(admin_user)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
auth_user = AuthUser(id=user_id, email="admin@example.com")
|
||||||
|
result = await require_admin(user=auth_user, session=db_session)
|
||||||
|
assert result is auth_user
|
||||||
|
|
||||||
|
|
||||||
|
async def test_require_admin_non_admin_user(db_session):
|
||||||
|
"""Test require_admin raises 403 for non-admin user."""
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
user_id = "22222222-2222-2222-2222-222222222222"
|
||||||
|
regular_user = User(
|
||||||
|
id=UUID(user_id),
|
||||||
|
email="user@example.com",
|
||||||
|
is_admin=False,
|
||||||
|
)
|
||||||
|
db_session.add(regular_user)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
auth_user = AuthUser(id=user_id, email="user@example.com")
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
await require_admin(user=auth_user, session=db_session)
|
||||||
|
assert exc_info.value.status_code == 403
|
||||||
|
assert exc_info.value.detail == "Admin access required"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_require_admin_user_not_in_db(db_session):
|
||||||
|
"""Test require_admin raises 403 for user not in database."""
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
auth_user = AuthUser(
|
||||||
|
id="33333333-3333-3333-3333-333333333333", email="ghost@example.com"
|
||||||
|
)
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
await require_admin(user=auth_user, session=db_session)
|
||||||
|
assert exc_info.value.status_code == 403
|
||||||
|
assert exc_info.value.detail == "Admin access required"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_admin_endpoint_returns_403_for_non_admin(
|
||||||
|
db_session, jwt_secret, monkeypatch
|
||||||
|
):
|
||||||
|
"""Test that admin endpoint returns 403 for authenticated non-admin user."""
|
||||||
|
user_id = "44444444-4444-4444-4444-444444444444"
|
||||||
|
regular_user = User(
|
||||||
|
id=UUID(user_id),
|
||||||
|
email="nonadmin@example.com",
|
||||||
|
is_admin=False,
|
||||||
|
)
|
||||||
|
db_session.add(regular_user)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
monkeypatch.setattr(settings, "supabase_jwt_secret", jwt_secret)
|
||||||
|
token = jwt.encode(
|
||||||
|
{
|
||||||
|
"sub": user_id,
|
||||||
|
"email": "nonadmin@example.com",
|
||||||
|
"role": "authenticated",
|
||||||
|
"aud": "authenticated",
|
||||||
|
"exp": int(time.time()) + 3600,
|
||||||
|
},
|
||||||
|
jwt_secret,
|
||||||
|
algorithm="HS256",
|
||||||
|
)
|
||||||
|
|
||||||
|
async with AsyncClient(
|
||||||
|
transport=ASGITransport(app=app),
|
||||||
|
base_url="http://test",
|
||||||
|
headers={"Authorization": f"Bearer {token}"},
|
||||||
|
) as ac:
|
||||||
|
response = await ac.post(
|
||||||
|
"/games",
|
||||||
|
json={
|
||||||
|
"name": "Test Game",
|
||||||
|
"slug": "test-game",
|
||||||
|
"generation": 1,
|
||||||
|
"region": "Kanto",
|
||||||
|
"category": "core",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response.status_code == 403
|
||||||
|
assert response.json()["detail"] == "Admin access required"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_admin_endpoint_succeeds_for_admin(db_session, jwt_secret, monkeypatch):
|
||||||
|
"""Test that admin endpoint succeeds for authenticated admin user."""
|
||||||
|
user_id = "55555555-5555-5555-5555-555555555555"
|
||||||
|
admin_user = User(
|
||||||
|
id=UUID(user_id),
|
||||||
|
email="admin@example.com",
|
||||||
|
is_admin=True,
|
||||||
|
)
|
||||||
|
db_session.add(admin_user)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
monkeypatch.setattr(settings, "supabase_jwt_secret", jwt_secret)
|
||||||
|
token = jwt.encode(
|
||||||
|
{
|
||||||
|
"sub": user_id,
|
||||||
|
"email": "admin@example.com",
|
||||||
|
"role": "authenticated",
|
||||||
|
"aud": "authenticated",
|
||||||
|
"exp": int(time.time()) + 3600,
|
||||||
|
},
|
||||||
|
jwt_secret,
|
||||||
|
algorithm="HS256",
|
||||||
|
)
|
||||||
|
|
||||||
|
async with AsyncClient(
|
||||||
|
transport=ASGITransport(app=app),
|
||||||
|
base_url="http://test",
|
||||||
|
headers={"Authorization": f"Bearer {token}"},
|
||||||
|
) as ac:
|
||||||
|
response = await ac.post(
|
||||||
|
"/games",
|
||||||
|
json={
|
||||||
|
"name": "Test Game",
|
||||||
|
"slug": "test-game",
|
||||||
|
"generation": 1,
|
||||||
|
"region": "Kanto",
|
||||||
|
"category": "core",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response.status_code == 201
|
||||||
|
assert response.json()["name"] == "Test Game"
|
||||||
|
|||||||
Reference in New Issue
Block a user