Files
carrramba-encore-rate/backend/backend/idfm_interface/idfm_interface.py

562 lines
22 KiB
Python

from collections import defaultdict
from re import compile as re_compile
from time import time
from typing import (
AsyncIterator,
ByteString,
Callable,
Iterable,
List,
Type,
)
from aiofiles import open as async_open
from aiohttp import ClientSession
from msgspec import ValidationError
from msgspec.json import Decoder
from rich import print
from shapefile import Reader as ShapeFileReader, ShapeRecord
from ..db import Database
from ..models import ConnectionArea, Line, LinePicto, Stop, StopArea, StopShape
from .idfm_types import (
ConnectionArea as IdfmConnectionArea,
Destinations as IdfmDestinations,
IdfmLineState,
IdfmResponse,
Line as IdfmLine,
LinePicto as IdfmPicto,
IdfmState,
Stop as IdfmStop,
StopArea as IdfmStopArea,
StopAreaStopAssociation,
StopAreaType,
StopLineAsso as IdfmStopLineAsso,
TransportMode,
)
from .ratp_types import Picto as RatpPicto
class IdfmInterface:
IDFM_ROOT_URL = "https://prim.iledefrance-mobilites.fr/marketplace"
IDFM_STOP_MON_URL = f"{IDFM_ROOT_URL}/stop-monitoring"
IDFM_ROOT_URL = "https://data.iledefrance-mobilites.fr/explore/dataset"
IDFM_STOPS_URL = (
f"{IDFM_ROOT_URL}/arrets/download/?format=json&timezone=Europe/Berlin"
)
IDFM_PICTO_URL = f"{IDFM_ROOT_URL}/referentiel-des-lignes/files"
RATP_ROOT_URL = "https://data.ratp.fr/explore/dataset"
RATP_PICTO_URL = (
f"{RATP_ROOT_URL}"
"/pictogrammes-des-lignes-de-metro-rer-tramway-bus-et-noctilien/files"
)
OPERATOR_RE = re_compile(r"[^:]+:Operator::([^:]+):")
LINE_RE = re_compile(r"[^:]+:Line::([^:]+):")
def __init__(self, api_key: str, database: Database) -> None:
self._api_key = api_key
self._database = database
self._http_headers = {"Accept": "application/json", "apikey": self._api_key}
self._json_stops_decoder = Decoder(type=List[IdfmStop])
self._json_stop_areas_decoder = Decoder(type=List[IdfmStopArea])
self._json_connection_areas_decoder = Decoder(type=List[IdfmConnectionArea])
self._json_lines_decoder = Decoder(type=List[IdfmLine])
self._json_stops_lines_assos_decoder = Decoder(type=List[IdfmStopLineAsso])
self._json_ratp_pictos_decoder = Decoder(type=List[RatpPicto])
self._json_stop_area_stop_asso_decoder = Decoder(
type=List[StopAreaStopAssociation]
)
self._response_json_decoder = Decoder(type=IdfmResponse)
async def startup(self) -> None:
BATCH_SIZE = 10000
STEPS: tuple[
tuple[
Type[ConnectionArea] | Type[Stop] | Type[StopArea] | Type[StopShape],
Callable,
Callable,
],
...,
] = (
(
StopShape,
self._request_stop_shapes,
IdfmInterface._format_idfm_stop_shapes,
),
(
ConnectionArea,
self._request_idfm_connection_areas,
IdfmInterface._format_idfm_connection_areas,
),
(
StopArea,
self._request_idfm_stop_areas,
IdfmInterface._format_idfm_stop_areas,
),
(Stop, self._request_idfm_stops, IdfmInterface._format_idfm_stops),
)
for model, get_method, format_method in STEPS:
step_begin_ts = time()
elements = []
async for element in get_method():
elements.append(element)
if len(elements) == BATCH_SIZE:
await model.add(format_method(*elements))
elements.clear()
if elements:
await model.add(format_method(*elements))
print(f"Add {model.__name__}s: {time() - step_begin_ts}s")
begin_ts = time()
await self._load_lines()
print(f"Add Lines and IDFM LinePictos: {time() - begin_ts}s")
begin_ts = time()
await self._load_ratp_pictos(30)
print(f"Add RATP LinePictos: {time() - begin_ts}s")
begin_ts = time()
await self._load_lines_stops_assos()
print(f"Link Stops to Lines: {time() - begin_ts}s")
begin_ts = time()
await self._load_stop_assos()
print(f"Link Stops to StopAreas: {time() - begin_ts}s")
async def _load_lines(self, batch_size: int = 5000) -> None:
lines, pictos = [], []
picto_ids = set()
async for line in self._request_idfm_lines():
if (picto := line.fields.picto) is not None and picto.id_ not in picto_ids:
picto_ids.add(picto.id_)
pictos.append(picto)
lines.append(line)
if len(lines) == batch_size:
await LinePicto.add(IdfmInterface._format_idfm_pictos(*pictos))
await Line.add(await self._format_idfm_lines(*lines))
lines.clear()
pictos.clear()
if pictos:
await LinePicto.add(IdfmInterface._format_idfm_pictos(*pictos))
if lines:
await Line.add(await self._format_idfm_lines(*lines))
async def _load_ratp_pictos(self, batch_size: int = 5) -> None:
pictos = []
async for picto in self._request_ratp_pictos():
pictos.append(picto)
if len(pictos) == batch_size:
formatted_pictos = IdfmInterface._format_ratp_pictos(*pictos)
await LinePicto.add(map(lambda picto: picto[1], formatted_pictos))
await Line.add_pictos(formatted_pictos)
pictos.clear()
if pictos:
formatted_pictos = IdfmInterface._format_ratp_pictos(*pictos)
await LinePicto.add(map(lambda picto: picto[1], formatted_pictos))
await Line.add_pictos(formatted_pictos)
async def _load_lines_stops_assos(self, batch_size: int = 5000) -> None:
total_assos_nb = total_found_nb = 0
assos = []
async for asso in self._request_idfm_stops_lines_associations():
fields = asso.fields
try:
stop_id = int(fields.stop_id.rsplit(":", 1)[-1])
except ValueError as err:
print(err)
print(f"{fields.stop_id = }")
continue
assos.append((fields.route_long_name, fields.operatorname, stop_id))
if len(assos) == batch_size:
total_assos_nb += batch_size
total_found_nb += await Line.add_stops(assos)
assos.clear()
if assos:
total_assos_nb += len(assos)
total_found_nb += await Line.add_stops(assos)
print(f"{total_found_nb} line <-> stop ({total_assos_nb = } found)")
async def _load_stop_assos(self, batch_size: int = 5000) -> None:
total_assos_nb = area_stop_assos_nb = conn_stop_assos_nb = 0
area_stop_assos = []
connection_stop_assos = []
async for asso in self._request_idfm_stop_area_stop_associations():
fields = asso.fields
stop_id = int(fields.arrid)
area_stop_assos.append((int(fields.zdaid), stop_id))
connection_stop_assos.append((int(fields.zdcid), stop_id))
if len(area_stop_assos) == batch_size:
total_assos_nb += batch_size
if (found_nb := await StopArea.add_stops(area_stop_assos)) is not None:
area_stop_assos_nb += found_nb
area_stop_assos.clear()
if (
found_nb := await ConnectionArea.add_stops(connection_stop_assos)
) is not None:
conn_stop_assos_nb += found_nb
connection_stop_assos.clear()
if area_stop_assos:
total_assos_nb += len(area_stop_assos)
if (found_nb := await StopArea.add_stops(area_stop_assos)) is not None:
area_stop_assos_nb += found_nb
if (
found_nb := await ConnectionArea.add_stops(connection_stop_assos)
) is not None:
conn_stop_assos_nb += found_nb
print(f"{area_stop_assos_nb} stop area <-> stop ({total_assos_nb = } found)")
print(f"{conn_stop_assos_nb} stop area <-> stop ({total_assos_nb = } found)")
# TODO: This method is synchronous due to the shapefile library.
# It's not a blocking issue but it could be nice to find an alternative.
async def _request_stop_shapes(self) -> AsyncIterator[ShapeRecord]:
# TODO: Use HTTP
with ShapeFileReader("./tests/datasets/REF_LDA.zip") as reader:
for record in reader.shapeRecords():
yield record
async def _request_idfm_stops(self) -> AsyncIterator[IdfmStop]:
# headers = {"Accept": "application/json", "apikey": self._api_key}
# async with ClientSession(headers=headers) as session:
# async with session.get(self.STOPS_URL) as response:
# # print("Status:", response.status)
# if response.status == 200:
# for point in self._json_stops_decoder.decode(await response.read()):
# yield point
# TODO: Use HTTP
async with async_open("./tests/datasets/stops_dataset.json", "rb") as raw:
for element in self._json_stops_decoder.decode(await raw.read()):
yield element
async def _request_idfm_stop_areas(self) -> AsyncIterator[IdfmStopArea]:
# TODO: Use HTTP
async with async_open("./tests/datasets/zones-d-arrets.json", "rb") as raw:
for element in self._json_stop_areas_decoder.decode(await raw.read()):
yield element
async def _request_idfm_connection_areas(self) -> AsyncIterator[IdfmConnectionArea]:
async with async_open(
"./tests/datasets/zones-de-correspondance.json", "rb"
) as raw:
for element in self._json_connection_areas_decoder.decode(await raw.read()):
yield element
async def _request_idfm_lines(self) -> AsyncIterator[IdfmLine]:
# TODO: Use HTTP
async with async_open("./tests/datasets/lines_dataset.json", "rb") as raw:
for element in self._json_lines_decoder.decode(await raw.read()):
yield element
async def _request_idfm_stops_lines_associations(
self,
) -> AsyncIterator[IdfmStopLineAsso]:
# TODO: Use HTTP
async with async_open("./tests/datasets/arrets-lignes.json", "rb") as raw:
for element in self._json_stops_lines_assos_decoder.decode(
await raw.read()
):
yield element
async def _request_idfm_stop_area_stop_associations(
self,
) -> AsyncIterator[StopAreaStopAssociation]:
# TODO: Use HTTP
async with async_open("./tests/datasets/relations.json", "rb") as raw:
for element in self._json_stop_area_stop_asso_decoder.decode(
await raw.read()
):
yield element
async def _request_ratp_pictos(self) -> AsyncIterator[RatpPicto]:
# TODO: Use HTTP
async with async_open(
"./tests/datasets/pictogrammes-des-lignes-de-metro-rer-tramway-bus-et-noctilien.json",
"rb",
) as fd:
for element in self._json_ratp_pictos_decoder.decode(await fd.read()):
yield element
@classmethod
def _format_idfm_pictos(cls, *pictos: IdfmPicto) -> Iterable[LinePicto]:
ret = []
for picto in pictos:
ret.append(
LinePicto(
id=picto.id_,
mime_type=picto.mimetype,
height_px=picto.height,
width_px=picto.width,
filename=picto.filename,
url=f"{cls.IDFM_PICTO_URL}/{picto.id_}/download",
thumbnail=picto.thumbnail,
format=picto.format,
)
)
return ret
@classmethod
def _format_ratp_pictos(cls, *pictos: RatpPicto) -> Iterable[tuple[str, LinePicto]]:
ret = []
for picto in pictos:
if (fields := picto.fields.noms_des_fichiers) is not None:
ret.append(
(
picto.fields.indices_commerciaux,
LinePicto(
id=fields.id_,
mime_type=f"image/{fields.format.lower()}",
height_px=fields.height,
width_px=fields.width,
filename=fields.filename,
url=f"{cls.RATP_PICTO_URL}/{fields.id_}/download",
thumbnail=fields.thumbnail,
format=fields.format,
),
)
)
return ret
async def _format_idfm_lines(self, *lines: IdfmLine) -> Iterable[Line]:
ret = []
optional_value = IdfmLine.optional_value
for line in lines:
fields = line.fields
picto_id = fields.picto.id_ if fields.picto is not None else None
picto = await LinePicto.get_by_id(picto_id) if picto_id else None
ret.append(
Line(
id=fields.id_line,
short_name=fields.shortname_line,
name=fields.name_line,
status=IdfmLineState(fields.status.value),
transport_mode=TransportMode(fields.transportmode.value),
transport_submode=optional_value(fields.transportsubmode),
network_name=optional_value(fields.networkname),
group_of_lines_id=optional_value(fields.id_groupoflines),
group_of_lines_shortname=optional_value(
fields.shortname_groupoflines
),
colour_web_hexa=fields.colourweb_hexa,
text_colour_hexa=fields.textcolourprint_hexa,
operator_id=optional_value(fields.operatorref),
operator_name=optional_value(fields.operatorname),
accessibility=IdfmState(fields.accessibility.value),
visual_signs_available=IdfmState(
fields.visualsigns_available.value
),
audible_signs_available=IdfmState(
fields.audiblesigns_available.value
),
picto_id=fields.picto.id_ if fields.picto is not None else None,
picto=picto,
record_id=line.recordid,
record_ts=int(line.record_timestamp.timestamp()),
)
)
return ret
@staticmethod
def _format_idfm_stops(*stops: IdfmStop) -> Iterable[Stop]:
for stop in stops:
fields = stop.fields
try:
created_ts = int(fields.arrcreated.timestamp()) # type: ignore
except AttributeError:
created_ts = None
yield Stop(
id=int(fields.arrid),
name=fields.arrname,
latitude=fields.arrgeopoint.lat,
longitude=fields.arrgeopoint.lon,
town_name=fields.arrtown,
postal_region=fields.arrpostalregion,
xepsg2154=fields.arrxepsg2154,
yepsg2154=fields.arryepsg2154,
transport_mode=TransportMode(fields.arrtype.value),
version=fields.arrversion,
created_ts=created_ts,
changed_ts=int(fields.arrchanged.timestamp()),
accessibility=IdfmState(fields.arraccessibility.value),
visual_signs_available=IdfmState(fields.arrvisualsigns.value),
audible_signs_available=IdfmState(fields.arraudiblesignals.value),
record_id=stop.recordid,
record_ts=int(stop.record_timestamp.timestamp()),
)
@staticmethod
def _format_idfm_stop_areas(*stop_areas: IdfmStopArea) -> Iterable[StopArea]:
for stop_area in stop_areas:
fields = stop_area.fields
try:
created_ts = int(fields.zdacreated.timestamp()) # type: ignore
except AttributeError:
created_ts = None
yield StopArea(
id=int(fields.zdaid),
name=fields.zdaname,
town_name=fields.zdatown,
postal_region=fields.zdapostalregion,
xepsg2154=fields.zdaxepsg2154,
yepsg2154=fields.zdayepsg2154,
type=StopAreaType(fields.zdatype.value),
version=fields.zdaversion,
created_ts=created_ts,
changed_ts=int(fields.zdachanged.timestamp()),
)
@staticmethod
def _format_idfm_connection_areas(
*connection_areas: IdfmConnectionArea,
) -> Iterable[ConnectionArea]:
for connection_area in connection_areas:
yield ConnectionArea(
id=int(connection_area.zdcid),
name=connection_area.zdcname,
town_name=connection_area.zdctown,
postal_region=connection_area.zdcpostalregion,
xepsg2154=connection_area.zdcxepsg2154,
yepsg2154=connection_area.zdcyepsg2154,
transport_mode=StopAreaType(connection_area.zdctype.value),
version=connection_area.zdcversion,
created_ts=int(connection_area.zdccreated.timestamp()),
changed_ts=int(connection_area.zdcchanged.timestamp()),
)
@staticmethod
def _format_idfm_stop_shapes(*shape_records: ShapeRecord) -> Iterable[StopShape]:
for shape_record in shape_records:
yield StopShape(
id=shape_record.record[1],
type=shape_record.shape.shapeType,
bounding_box=list(shape_record.shape.bbox),
points=shape_record.shape.points,
)
async def render_line_picto(self, line: Line) -> tuple[None | str, None | str]:
begin_ts = time()
line_picto_path = line_picto_format = None
target = f"/tmp/{line.id}_repr"
picto = line.picto
if picto is not None:
if (picto_data := await self._get_line_picto(line)) is not None:
async with async_open(target, "wb") as fd:
await fd.write(bytes(picto_data))
line_picto_path = target
line_picto_format = picto.mime_type
print(f"render_line_picto: {time() - begin_ts}")
return (line_picto_path, line_picto_format)
async def _get_line_picto(self, line: Line) -> ByteString | None:
print("---------------------------------------------------------------------")
begin_ts = time()
data = None
picto = line.picto
if picto is not None and picto.url is not None:
headers = (
self._http_headers if picto.url.startswith(self.IDFM_ROOT_URL) else None
)
session_begin_ts = time()
async with ClientSession(headers=headers) as session:
session_creation_ts = time()
print(f"Session creation {session_creation_ts - session_begin_ts}")
async with session.get(picto.url) as response:
get_end_ts = time()
print(f"GET {get_end_ts - session_creation_ts}")
data = await response.read()
print(f"read {time() - get_end_ts}")
print(f"render_line_picto: {time() - begin_ts}")
print("---------------------------------------------------------------------")
return data
async def get_next_passages(self, stop_point_id: str) -> IdfmResponse | None:
ret = None
params = {"MonitoringRef": f"STIF:StopPoint:Q:{stop_point_id}:"}
async with ClientSession(headers=self._http_headers) as session:
async with session.get(self.IDFM_STOP_MON_URL, params=params) as response:
if response.status == 200:
data = await response.read()
try:
ret = self._response_json_decoder.decode(data)
except ValidationError as err:
print(err)
return ret
async def get_destinations(self, stop_id: int) -> IdfmDestinations | None:
begin_ts = time()
destinations: IdfmDestinations = defaultdict(set)
if (stop := await Stop.get_by_id(stop_id)) is not None:
expected_stop_ids = {stop.id}
elif (stop_area := await StopArea.get_by_id(stop_id)) is not None:
expected_stop_ids = {stop.id for stop in stop_area.stops}
else:
return None
if (res := await self.get_next_passages(stop_id)) is not None:
for delivery in res.Siri.ServiceDelivery.StopMonitoringDelivery:
if delivery.Status == IdfmState.true:
for stop_visit in delivery.MonitoredStopVisit:
monitoring_ref = stop_visit.MonitoringRef.value
try:
monitored_stop_id = int(monitoring_ref.split(":")[-2])
except (IndexError, ValueError):
print(f"Unable to get stop id from {monitoring_ref}")
continue
journey = stop_visit.MonitoredVehicleJourney
if (
dst_names := journey.DestinationName
) and monitored_stop_id in expected_stop_ids:
line_id = journey.LineRef.value.split(":")[-2]
destinations[line_id].add(dst_names[0].value)
print(f"get_next_passages: {time() - begin_ts}")
return destinations