116 lines
4.3 KiB
Python
116 lines
4.3 KiB
Python
from collections import defaultdict
|
|
from re import compile as re_compile
|
|
from typing import ByteString
|
|
|
|
from aiofiles import open as async_open
|
|
from aiohttp import ClientSession
|
|
from msgspec import ValidationError
|
|
from msgspec.json import Decoder
|
|
|
|
from .idfm_types import Destinations as IdfmDestinations, IdfmResponse, IdfmState
|
|
from ..db import Database
|
|
from ..models import Line, Stop, StopArea
|
|
|
|
|
|
class IdfmInterface:
|
|
|
|
IDFM_ROOT_URL = "https://prim.iledefrance-mobilites.fr/marketplace"
|
|
IDFM_STOP_MON_URL = f"{IDFM_ROOT_URL}/stop-monitoring"
|
|
|
|
OPERATOR_RE = re_compile(r"[^:]+:Operator::([^:]+):")
|
|
LINE_RE = re_compile(r"[^:]+:Line::C([^:]+):")
|
|
|
|
def __init__(self, api_key: str, database: Database) -> None:
|
|
self._api_key = api_key
|
|
self._database = database
|
|
|
|
self._http_headers = {"Accept": "application/json", "apikey": self._api_key}
|
|
|
|
self._response_json_decoder = Decoder(type=IdfmResponse)
|
|
|
|
async def startup(self) -> None:
|
|
...
|
|
|
|
@staticmethod
|
|
def _format_line_id(line_id: str) -> int:
|
|
return int(line_id[1:] if line_id[0] == "C" else line_id)
|
|
|
|
async def render_line_picto(self, line: Line) -> tuple[None | str, None | str]:
|
|
line_picto_path = line_picto_format = None
|
|
target = f"/tmp/{line.id}_repr"
|
|
|
|
picto = line.picto
|
|
if picto is not None:
|
|
if (picto_data := await self._get_line_picto(line)) is not None:
|
|
async with async_open(target, "wb") as fd:
|
|
await fd.write(bytes(picto_data))
|
|
line_picto_path = target
|
|
line_picto_format = picto.mime_type
|
|
|
|
return (line_picto_path, line_picto_format)
|
|
|
|
async def _get_line_picto(self, line: Line) -> ByteString | None:
|
|
data = None
|
|
|
|
picto = line.picto
|
|
if picto is not None and picto.url is not None:
|
|
headers = (
|
|
self._http_headers if picto.url.startswith(self.IDFM_ROOT_URL) else None
|
|
)
|
|
|
|
async with ClientSession(headers=headers) as session:
|
|
async with session.get(picto.url) as response:
|
|
data = await response.read()
|
|
|
|
return data
|
|
|
|
async def get_next_passages(self, stop_point_id: int) -> IdfmResponse | None:
|
|
ret = None
|
|
params = {"MonitoringRef": f"STIF:StopPoint:Q:{stop_point_id}:"}
|
|
|
|
async with ClientSession(headers=self._http_headers) as session:
|
|
async with session.get(self.IDFM_STOP_MON_URL, params=params) as response:
|
|
|
|
if response.status == 200:
|
|
data = await response.read()
|
|
try:
|
|
ret = self._response_json_decoder.decode(data)
|
|
except ValidationError as err:
|
|
print(err)
|
|
|
|
return ret
|
|
|
|
async def get_destinations(self, stop_id: int) -> IdfmDestinations | None:
|
|
destinations: IdfmDestinations = defaultdict(set)
|
|
|
|
if (stop := await Stop.get_by_id(stop_id)) is not None:
|
|
expected_stop_ids = {stop.id}
|
|
elif (stop_area := await StopArea.get_by_id(stop_id)) is not None:
|
|
expected_stop_ids = {stop.id for stop in stop_area.stops}
|
|
else:
|
|
return None
|
|
|
|
if (res := await self.get_next_passages(stop_id)) is not None:
|
|
|
|
for delivery in res.Siri.ServiceDelivery.StopMonitoringDelivery:
|
|
if delivery.Status == IdfmState.true:
|
|
for stop_visit in delivery.MonitoredStopVisit:
|
|
|
|
monitoring_ref = stop_visit.MonitoringRef.value
|
|
|
|
try:
|
|
monitored_stop_id = int(monitoring_ref.split(":")[-2])
|
|
except (IndexError, ValueError):
|
|
print(f"Unable to get stop id from {monitoring_ref}")
|
|
continue
|
|
|
|
journey = stop_visit.MonitoredVehicleJourney
|
|
if (
|
|
dst_names := journey.DestinationName
|
|
) and monitored_stop_id in expected_stop_ids:
|
|
raw_line_id = journey.LineRef.value.split(":")[-2]
|
|
line_id = IdfmInterface._format_line_id(raw_line_id)
|
|
destinations[line_id].add(dst_names[0].value)
|
|
|
|
return destinations
|