from collections import defaultdict from re import compile as re_compile from typing import ByteString from aiofiles import open as async_open from aiohttp import ClientSession from msgspec import ValidationError from msgspec.json import Decoder from .idfm_types import Destinations as IdfmDestinations, IdfmResponse, IdfmState from ..db import Database from ..models import Line, Stop, StopArea class IdfmInterface: IDFM_ROOT_URL = "https://prim.iledefrance-mobilites.fr/marketplace" IDFM_STOP_MON_URL = f"{IDFM_ROOT_URL}/stop-monitoring" OPERATOR_RE = re_compile(r"[^:]+:Operator::([^:]+):") LINE_RE = re_compile(r"[^:]+:Line::C([^:]+):") def __init__(self, api_key: str, database: Database) -> None: self._api_key = api_key self._database = database self._http_headers = {"Accept": "application/json", "apikey": self._api_key} self._response_json_decoder = Decoder(type=IdfmResponse) async def startup(self) -> None: ... @staticmethod def _format_line_id(line_id: str) -> int: return int(line_id[1:] if line_id[0] == "C" else line_id) async def render_line_picto(self, line: Line) -> tuple[None | str, None | str]: line_picto_path = line_picto_format = None target = f"/tmp/{line.id}_repr" picto = line.picto if picto is not None: if (picto_data := await self._get_line_picto(line)) is not None: async with async_open(target, "wb") as fd: await fd.write(bytes(picto_data)) line_picto_path = target line_picto_format = picto.mime_type return (line_picto_path, line_picto_format) async def _get_line_picto(self, line: Line) -> ByteString | None: data = None picto = line.picto if picto is not None and picto.url is not None: headers = ( self._http_headers if picto.url.startswith(self.IDFM_ROOT_URL) else None ) async with ClientSession(headers=headers) as session: async with session.get(picto.url) as response: data = await response.read() return data async def get_next_passages(self, stop_point_id: int) -> IdfmResponse | None: ret = None params = {"MonitoringRef": f"STIF:StopPoint:Q:{stop_point_id}:"} async with ClientSession(headers=self._http_headers) as session: async with session.get(self.IDFM_STOP_MON_URL, params=params) as response: if response.status == 200: data = await response.read() try: ret = self._response_json_decoder.decode(data) except ValidationError as err: print(err) return ret async def get_destinations(self, stop_id: int) -> IdfmDestinations | None: destinations: IdfmDestinations = defaultdict(set) if (stop := await Stop.get_by_id(stop_id)) is not None: expected_stop_ids = {stop.id} elif (stop_area := await StopArea.get_by_id(stop_id)) is not None: expected_stop_ids = {stop.id for stop in stop_area.stops} else: return None if (res := await self.get_next_passages(stop_id)) is not None: for delivery in res.Siri.ServiceDelivery.StopMonitoringDelivery: if delivery.Status == IdfmState.true: for stop_visit in delivery.MonitoredStopVisit: monitoring_ref = stop_visit.MonitoringRef.value try: monitored_stop_id = int(monitoring_ref.split(":")[-2]) except (IndexError, ValueError): print(f"Unable to get stop id from {monitoring_ref}") continue journey = stop_visit.MonitoredVehicleJourney if ( dst_names := journey.DestinationName ) and monitored_stop_id in expected_stop_ids: raw_line_id = journey.LineRef.value.split(":")[-2] line_id = IdfmInterface._format_line_id(raw_line_id) destinations[line_id].add(dst_names[0].value) return destinations