"""Database upsert helpers for seed data.""" from sqlalchemy import delete, select from sqlalchemy.dialects.postgresql import insert from sqlalchemy.ext.asyncio import AsyncSession from app.models.boss_battle import BossBattle from app.models.boss_pokemon import BossPokemon from app.models.evolution import Evolution from app.models.game import Game from app.models.pokemon import Pokemon from app.models.route import Route from app.models.route_encounter import RouteEncounter from app.models.version_group import VersionGroup async def upsert_version_groups( session: AsyncSession, vg_data: dict[str, dict], ) -> dict[str, int]: """Upsert version group records, return {slug: id} mapping.""" for vg_slug, vg_info in vg_data.items(): vg_name = " / ".join( g["name"].replace("Pokemon ", "") for g in vg_info["games"].values() ) stmt = insert(VersionGroup).values( name=vg_name, slug=vg_slug, ).on_conflict_do_update( index_elements=["slug"], set_={"name": vg_name}, ) await session.execute(stmt) await session.flush() result = await session.execute(select(VersionGroup.slug, VersionGroup.id)) return {row.slug: row.id for row in result} async def upsert_games( session: AsyncSession, games: list[dict], slug_to_vg_id: dict[str, int] | None = None, ) -> dict[str, int]: """Upsert game records, return {slug: id} mapping.""" for game in games: values = { "name": game["name"], "slug": game["slug"], "generation": game["generation"], "region": game["region"], "release_year": game.get("release_year"), "color": game.get("color"), } update_set = { "name": game["name"], "generation": game["generation"], "region": game["region"], "release_year": game.get("release_year"), "color": game.get("color"), } if slug_to_vg_id is not None: vg_id = slug_to_vg_id.get(game["slug"]) if vg_id is not None: values["version_group_id"] = vg_id update_set["version_group_id"] = vg_id stmt = insert(Game).values(**values).on_conflict_do_update( index_elements=["slug"], set_=update_set, ) await session.execute(stmt) await session.flush() result = await session.execute(select(Game.slug, Game.id)) return {row.slug: row.id for row in result} async def upsert_pokemon(session: AsyncSession, pokemon_list: list[dict]) -> dict[int, int]: """Upsert pokemon records, return {pokeapi_id: id} mapping.""" for poke in pokemon_list: stmt = insert(Pokemon).values( pokeapi_id=poke["pokeapi_id"], national_dex=poke["national_dex"], name=poke["name"], types=poke["types"], sprite_url=poke.get("sprite_url"), ).on_conflict_do_update( index_elements=["pokeapi_id"], set_={ "national_dex": poke["national_dex"], "name": poke["name"], "types": poke["types"], "sprite_url": poke.get("sprite_url"), }, ) await session.execute(stmt) await session.flush() result = await session.execute(select(Pokemon.pokeapi_id, Pokemon.id)) return {row.pokeapi_id: row.id for row in result} async def upsert_routes( session: AsyncSession, version_group_id: int, routes: list[dict], ) -> dict[str, int]: """Upsert route records for a version group, return {name: id} mapping. Handles hierarchical routes: routes with 'children' are parent routes, and their children get parent_route_id set accordingly. """ # First pass: upsert all parent routes (without parent_route_id) for route in routes: stmt = insert(Route).values( name=route["name"], version_group_id=version_group_id, order=route["order"], parent_route_id=None, # Parent routes have no parent ).on_conflict_do_update( constraint="uq_routes_version_group_name", set_={"order": route["order"], "parent_route_id": None}, ) await session.execute(stmt) await session.flush() # Get mapping of parent routes result = await session.execute( select(Route.name, Route.id).where(Route.version_group_id == version_group_id) ) name_to_id = {row.name: row.id for row in result} # Second pass: upsert child routes with parent_route_id for route in routes: children = route.get("children", []) if not children: continue parent_id = name_to_id[route["name"]] for child in children: stmt = insert(Route).values( name=child["name"], version_group_id=version_group_id, order=child["order"], parent_route_id=parent_id, pinwheel_zone=child.get("pinwheel_zone"), ).on_conflict_do_update( constraint="uq_routes_version_group_name", set_={ "order": child["order"], "parent_route_id": parent_id, "pinwheel_zone": child.get("pinwheel_zone"), }, ) await session.execute(stmt) await session.flush() # Return full mapping including children result = await session.execute( select(Route.name, Route.id).where(Route.version_group_id == version_group_id) ) return {row.name: row.id for row in result} async def upsert_route_encounters( session: AsyncSession, route_id: int, encounters: list[dict], dex_to_id: dict[int, int], game_id: int, ) -> int: """Upsert encounters for a route and game, return count of upserted rows.""" count = 0 for enc in encounters: pokemon_id = dex_to_id.get(enc["pokeapi_id"]) if pokemon_id is None: print(f" Warning: no pokemon_id for pokeapi_id {enc['pokeapi_id']}") continue stmt = insert(RouteEncounter).values( route_id=route_id, pokemon_id=pokemon_id, game_id=game_id, encounter_method=enc["method"], encounter_rate=enc["encounter_rate"], min_level=enc["min_level"], max_level=enc["max_level"], ).on_conflict_do_update( constraint="uq_route_pokemon_method_game", set_={ "encounter_rate": enc["encounter_rate"], "min_level": enc["min_level"], "max_level": enc["max_level"], }, ) await session.execute(stmt) count += 1 return count async def upsert_bosses( session: AsyncSession, version_group_id: int, bosses: list[dict], dex_to_id: dict[int, int], route_name_to_id: dict[str, int] | None = None, ) -> int: """Upsert boss battles for a version group, return count of bosses upserted.""" count = 0 for boss in bosses: # Resolve after_route_name to an ID after_route_id = None after_route_name = boss.get("after_route_name") if after_route_name and route_name_to_id: after_route_id = route_name_to_id.get(after_route_name) if after_route_id is None: print(f" Warning: route '{after_route_name}' not found for boss '{boss['name']}'") # Upsert the boss battle on (version_group_id, order) conflict stmt = insert(BossBattle).values( version_group_id=version_group_id, name=boss["name"], boss_type=boss["boss_type"], specialty_type=boss.get("specialty_type"), badge_name=boss.get("badge_name"), badge_image_url=boss.get("badge_image_url"), level_cap=boss["level_cap"], order=boss["order"], after_route_id=after_route_id, location=boss["location"], section=boss.get("section"), sprite_url=boss.get("sprite_url"), ).on_conflict_do_update( constraint="uq_boss_battles_version_group_order", set_={ "name": boss["name"], "boss_type": boss["boss_type"], "specialty_type": boss.get("specialty_type"), "badge_name": boss.get("badge_name"), "badge_image_url": boss.get("badge_image_url"), "level_cap": boss["level_cap"], "after_route_id": after_route_id, "location": boss["location"], "section": boss.get("section"), "sprite_url": boss.get("sprite_url"), }, ).returning(BossBattle.id) result = await session.execute(stmt) boss_id = result.scalar_one() # Delete existing boss_pokemon for this boss, then re-insert await session.execute( delete(BossPokemon).where(BossPokemon.boss_battle_id == boss_id) ) for bp in boss.get("pokemon", []): pokemon_id = dex_to_id.get(bp["pokeapi_id"]) if pokemon_id is None: print(f" Warning: no pokemon_id for pokeapi_id {bp['pokeapi_id']}") continue session.add(BossPokemon( boss_battle_id=boss_id, pokemon_id=pokemon_id, level=bp["level"], order=bp["order"], condition_label=bp.get("condition_label"), )) count += 1 await session.flush() return count async def upsert_evolutions( session: AsyncSession, evolutions: list[dict], dex_to_id: dict[int, int], ) -> int: """Upsert evolution pairs, return count of upserted rows.""" await session.execute(delete(Evolution)) count = 0 for evo in evolutions: from_id = dex_to_id.get(evo["from_pokeapi_id"]) to_id = dex_to_id.get(evo["to_pokeapi_id"]) if from_id is None or to_id is None: continue evolution = Evolution( from_pokemon_id=from_id, to_pokemon_id=to_id, trigger=evo["trigger"], min_level=evo.get("min_level"), item=evo.get("item"), held_item=evo.get("held_item"), condition=evo.get("condition"), region=evo.get("region"), ) session.add(evolution) count += 1 await session.flush() return count