diff --git a/backend/backend/idfm_interface/idfm_interface.py b/backend/backend/idfm_interface/idfm_interface.py index 82b4316..b6bf4b7 100644 --- a/backend/backend/idfm_interface/idfm_interface.py +++ b/backend/backend/idfm_interface/idfm_interface.py @@ -1,44 +1,15 @@ from collections import defaultdict -from logging import getLogger from re import compile as re_compile -from time import time -from typing import ( - AsyncIterator, - ByteString, - Callable, - Iterable, - List, - Type, -) +from typing import ByteString from aiofiles import open as async_open from aiohttp import ClientSession from msgspec import ValidationError from msgspec.json import Decoder -from pyproj import Transformer -from shapefile import Reader as ShapeFileReader, ShapeRecord # type: ignore from ..db import Database -from ..models import ConnectionArea, Line, LinePicto, Stop, StopArea, StopShape -from .idfm_types import ( - ConnectionArea as IdfmConnectionArea, - Destinations as IdfmDestinations, - IdfmLineState, - IdfmResponse, - Line as IdfmLine, - LinePicto as IdfmPicto, - IdfmState, - Stop as IdfmStop, - StopArea as IdfmStopArea, - StopAreaStopAssociation, - StopAreaType, - StopLineAsso as IdfmStopLineAsso, - TransportMode, -) -from .ratp_types import Picto as RatpPicto - - -logger = getLogger(__name__) +from ..models import Line, Stop, StopArea +from .idfm_types import Destinations as IdfmDestinations, IdfmResponse, IdfmState class IdfmInterface: @@ -46,18 +17,6 @@ class IdfmInterface: IDFM_ROOT_URL = "https://prim.iledefrance-mobilites.fr/marketplace" IDFM_STOP_MON_URL = f"{IDFM_ROOT_URL}/stop-monitoring" - IDFM_ROOT_URL = "https://data.iledefrance-mobilites.fr/explore/dataset" - IDFM_STOPS_URL = ( - f"{IDFM_ROOT_URL}/arrets/download/?format=json&timezone=Europe/Berlin" - ) - IDFM_PICTO_URL = f"{IDFM_ROOT_URL}/referentiel-des-lignes/files" - - RATP_ROOT_URL = "https://data.ratp.fr/explore/dataset" - RATP_PICTO_URL = ( - f"{RATP_ROOT_URL}" - "/pictogrammes-des-lignes-de-metro-rer-tramway-bus-et-noctilien/files" - ) - OPERATOR_RE = re_compile(r"[^:]+:Operator::([^:]+):") LINE_RE = re_compile(r"[^:]+:Line::C([^:]+):") @@ -67,459 +26,12 @@ class IdfmInterface: self._http_headers = {"Accept": "application/json", "apikey": self._api_key} - self._epsg2154_epsg3857_transformer = Transformer.from_crs(2154, 3857) - - self._json_stops_decoder = Decoder(type=List[IdfmStop]) - self._json_stop_areas_decoder = Decoder(type=List[IdfmStopArea]) - self._json_connection_areas_decoder = Decoder(type=List[IdfmConnectionArea]) - self._json_lines_decoder = Decoder(type=List[IdfmLine]) - self._json_stops_lines_assos_decoder = Decoder(type=List[IdfmStopLineAsso]) - self._json_ratp_pictos_decoder = Decoder(type=List[RatpPicto]) - self._json_stop_area_stop_asso_decoder = Decoder( - type=List[StopAreaStopAssociation] - ) - self._response_json_decoder = Decoder(type=IdfmResponse) async def startup(self) -> None: - BATCH_SIZE = 10000 - STEPS: tuple[ - tuple[ - Type[ConnectionArea] | Type[Stop] | Type[StopArea] | Type[StopShape], - Callable, - Callable, - ], - ..., - ] = ( - ( - StopShape, - self._request_stop_shapes, - self._format_idfm_stop_shapes, - ), - ( - ConnectionArea, - self._request_idfm_connection_areas, - self._format_idfm_connection_areas, - ), - ( - StopArea, - self._request_idfm_stop_areas, - self._format_idfm_stop_areas, - ), - (Stop, self._request_idfm_stops, self._format_idfm_stops), - ) - - for model, get_method, format_method in STEPS: - step_begin_ts = time() - elements = [] - - async for element in get_method(): - elements.append(element) - - if len(elements) == BATCH_SIZE: - await model.add(format_method(*elements)) - elements.clear() - - if elements: - await model.add(format_method(*elements)) - - print(f"Add {model.__name__}s: {time() - step_begin_ts}s") - - begin_ts = time() - await self._load_lines() - print(f"Add Lines and IDFM LinePictos: {time() - begin_ts}s") - - begin_ts = time() - await self._load_ratp_pictos(30) - print(f"Add RATP LinePictos: {time() - begin_ts}s") - - begin_ts = time() - await self._load_lines_stops_assos() - print(f"Link Stops to Lines: {time() - begin_ts}s") - - begin_ts = time() - await self._load_stop_assos() - print(f"Link Stops to StopAreas: {time() - begin_ts}s") - - async def _load_lines(self, batch_size: int = 5000) -> None: - lines, pictos = [], [] - picto_ids = set() - async for line in self._request_idfm_lines(): - if (picto := line.fields.picto) is not None and picto.id_ not in picto_ids: - picto_ids.add(picto.id_) - pictos.append(picto) - - lines.append(line) - if len(lines) == batch_size: - await LinePicto.add(IdfmInterface._format_idfm_pictos(*pictos)) - await Line.add(await self._format_idfm_lines(*lines)) - lines.clear() - pictos.clear() - - if pictos: - await LinePicto.add(IdfmInterface._format_idfm_pictos(*pictos)) - if lines: - await Line.add(await self._format_idfm_lines(*lines)) - - async def _load_ratp_pictos(self, batch_size: int = 5) -> None: - pictos = [] - - async for picto in self._request_ratp_pictos(): - pictos.append(picto) - if len(pictos) == batch_size: - formatted_pictos = IdfmInterface._format_ratp_pictos(*pictos) - await LinePicto.add(map(lambda picto: picto[1], formatted_pictos)) - await Line.add_pictos(formatted_pictos) - pictos.clear() - - if pictos: - formatted_pictos = IdfmInterface._format_ratp_pictos(*pictos) - await LinePicto.add(map(lambda picto: picto[1], formatted_pictos)) - await Line.add_pictos(formatted_pictos) - - async def _load_lines_stops_assos(self, batch_size: int = 5000) -> None: - total_assos_nb = total_found_nb = 0 - assos = [] - async for asso in self._request_idfm_stops_lines_associations(): - fields = asso.fields - try: - stop_id = int(fields.stop_id.rsplit(":", 1)[-1]) - except ValueError as err: - print(err) - print(f"{fields.stop_id = }") - continue - - assos.append((fields.route_long_name, fields.operatorname, stop_id)) - if len(assos) == batch_size: - total_assos_nb += batch_size - total_found_nb += await Line.add_stops(assos) - assos.clear() - - if assos: - total_assos_nb += len(assos) - total_found_nb += await Line.add_stops(assos) - - print(f"{total_found_nb} line <-> stop ({total_assos_nb = } found)") - - async def _load_stop_assos(self, batch_size: int = 5000) -> None: - total_assos_nb = area_stop_assos_nb = conn_stop_assos_nb = 0 - area_stop_assos = [] - connection_stop_assos = [] - - async for asso in self._request_idfm_stop_area_stop_associations(): - fields = asso.fields - - stop_id = int(fields.arrid) - - area_stop_assos.append((int(fields.zdaid), stop_id)) - connection_stop_assos.append((int(fields.zdcid), stop_id)) - - if len(area_stop_assos) == batch_size: - total_assos_nb += batch_size - - if (found_nb := await StopArea.add_stops(area_stop_assos)) is not None: - area_stop_assos_nb += found_nb - area_stop_assos.clear() - - if ( - found_nb := await ConnectionArea.add_stops(connection_stop_assos) - ) is not None: - conn_stop_assos_nb += found_nb - connection_stop_assos.clear() - - if area_stop_assos: - total_assos_nb += len(area_stop_assos) - if (found_nb := await StopArea.add_stops(area_stop_assos)) is not None: - area_stop_assos_nb += found_nb - if ( - found_nb := await ConnectionArea.add_stops(connection_stop_assos) - ) is not None: - conn_stop_assos_nb += found_nb - - print(f"{area_stop_assos_nb} stop area <-> stop ({total_assos_nb = } found)") - print(f"{conn_stop_assos_nb} stop area <-> stop ({total_assos_nb = } found)") - - # TODO: This method is synchronous due to the shapefile library. - # It's not a blocking issue but it could be nice to find an alternative. - async def _request_stop_shapes(self) -> AsyncIterator[ShapeRecord]: - # TODO: Use HTTP - with ShapeFileReader("./tests/datasets/REF_LDA.zip") as reader: - for record in reader.shapeRecords(): - yield record - - async def _request_idfm_stops(self) -> AsyncIterator[IdfmStop]: - # headers = {"Accept": "application/json", "apikey": self._api_key} - # async with ClientSession(headers=headers) as session: - # async with session.get(self.STOPS_URL) as response: - # # print("Status:", response.status) - # if response.status == 200: - # for point in self._json_stops_decoder.decode(await response.read()): - # yield point - # TODO: Use HTTP - async with async_open("./tests/datasets/stops_dataset.json", "rb") as raw: - for element in self._json_stops_decoder.decode(await raw.read()): - yield element - - async def _request_idfm_stop_areas(self) -> AsyncIterator[IdfmStopArea]: - # TODO: Use HTTP - async with async_open("./tests/datasets/zones-d-arrets.json", "rb") as raw: - for element in self._json_stop_areas_decoder.decode(await raw.read()): - yield element - - async def _request_idfm_connection_areas(self) -> AsyncIterator[IdfmConnectionArea]: - async with async_open( - "./tests/datasets/zones-de-correspondance.json", "rb" - ) as raw: - for element in self._json_connection_areas_decoder.decode(await raw.read()): - yield element - - async def _request_idfm_lines(self) -> AsyncIterator[IdfmLine]: - # TODO: Use HTTP - async with async_open("./tests/datasets/lines_dataset.json", "rb") as raw: - for element in self._json_lines_decoder.decode(await raw.read()): - yield element - - async def _request_idfm_stops_lines_associations( - self, - ) -> AsyncIterator[IdfmStopLineAsso]: - # TODO: Use HTTP - async with async_open("./tests/datasets/arrets-lignes.json", "rb") as raw: - for element in self._json_stops_lines_assos_decoder.decode( - await raw.read() - ): - yield element - - async def _request_idfm_stop_area_stop_associations( - self, - ) -> AsyncIterator[StopAreaStopAssociation]: - # TODO: Use HTTP - async with async_open("./tests/datasets/relations.json", "rb") as raw: - for element in self._json_stop_area_stop_asso_decoder.decode( - await raw.read() - ): - yield element - - async def _request_ratp_pictos(self) -> AsyncIterator[RatpPicto]: - # TODO: Use HTTP - async with async_open( - "./tests/datasets/pictogrammes-des-lignes-de-metro-rer-tramway-bus-et-noctilien.json", - "rb", - ) as fd: - for element in self._json_ratp_pictos_decoder.decode(await fd.read()): - yield element - - @classmethod - def _format_idfm_pictos(cls, *pictos: IdfmPicto) -> Iterable[LinePicto]: - ret = [] - - for picto in pictos: - ret.append( - LinePicto( - id=picto.id_, - mime_type=picto.mimetype, - height_px=picto.height, - width_px=picto.width, - filename=picto.filename, - url=f"{cls.IDFM_PICTO_URL}/{picto.id_}/download", - thumbnail=picto.thumbnail, - format=picto.format, - ) - ) - - return ret - - @classmethod - def _format_ratp_pictos(cls, *pictos: RatpPicto) -> Iterable[tuple[str, LinePicto]]: - ret = [] - - for picto in pictos: - if (fields := picto.fields.noms_des_fichiers) is not None: - ret.append( - ( - picto.fields.indices_commerciaux, - LinePicto( - id=fields.id_, - mime_type=f"image/{fields.format.lower()}", - height_px=fields.height, - width_px=fields.width, - filename=fields.filename, - url=f"{cls.RATP_PICTO_URL}/{fields.id_}/download", - thumbnail=fields.thumbnail, - format=fields.format, - ), - ) - ) - - return ret - - async def _format_idfm_lines(self, *lines: IdfmLine) -> Iterable[Line]: - ret = [] - - optional_value = IdfmLine.optional_value - - for line in lines: - fields = line.fields - - picto_id = fields.picto.id_ if fields.picto is not None else None - - line_id = fields.id_line - try: - formatted_line_id = int(line_id[1:] if line_id[0] == "C" else line_id) - except ValueError: - logger.warning("Unable to format %s line id.", line_id) - continue - - try: - operator_id = int(fields.operatorref) # type: ignore - except (ValueError, TypeError): - logger.warning("Unable to format %s operator id.", fields.operatorref) - operator_id = 0 - - ret.append( - Line( - id=formatted_line_id, - short_name=fields.shortname_line, - name=fields.name_line, - status=IdfmLineState(fields.status.value), - transport_mode=TransportMode(fields.transportmode.value), - transport_submode=optional_value(fields.transportsubmode), - network_name=optional_value(fields.networkname), - group_of_lines_id=optional_value(fields.id_groupoflines), - group_of_lines_shortname=optional_value( - fields.shortname_groupoflines - ), - colour_web_hexa=fields.colourweb_hexa, - text_colour_hexa=fields.textcolourprint_hexa, - operator_id=operator_id, - operator_name=optional_value(fields.operatorname), - accessibility=IdfmState(fields.accessibility.value), - visual_signs_available=IdfmState( - fields.visualsigns_available.value - ), - audible_signs_available=IdfmState( - fields.audiblesigns_available.value - ), - picto_id=fields.picto.id_ if fields.picto is not None else None, - record_id=line.recordid, - record_ts=int(line.record_timestamp.timestamp()), - ) - ) - - return ret - - def _format_idfm_stops(self, *stops: IdfmStop) -> Iterable[Stop]: - for stop in stops: - fields = stop.fields - - try: - created_ts = int(fields.arrcreated.timestamp()) # type: ignore - except AttributeError: - created_ts = None - - epsg3857_point = self._epsg2154_epsg3857_transformer.transform( - fields.arrxepsg2154, fields.arryepsg2154 - ) - - try: - postal_region = int(fields.arrpostalregion) - except ValueError: - logger.warning( - "Unable to format %s postal region.", fields.arrpostalregion - ) - continue - - yield Stop( - id=int(fields.arrid), - name=fields.arrname, - epsg3857_x=epsg3857_point[0], - epsg3857_y=epsg3857_point[1], - town_name=fields.arrtown, - postal_region=postal_region, - transport_mode=TransportMode(fields.arrtype.value), - version=fields.arrversion, - created_ts=created_ts, - changed_ts=int(fields.arrchanged.timestamp()), - accessibility=IdfmState(fields.arraccessibility.value), - visual_signs_available=IdfmState(fields.arrvisualsigns.value), - audible_signs_available=IdfmState(fields.arraudiblesignals.value), - record_id=stop.recordid, - record_ts=int(stop.record_timestamp.timestamp()), - ) - - def _format_idfm_stop_areas(self, *stop_areas: IdfmStopArea) -> Iterable[StopArea]: - for stop_area in stop_areas: - fields = stop_area.fields - - try: - created_ts = int(fields.zdacreated.timestamp()) # type: ignore - except AttributeError: - created_ts = None - - epsg3857_point = self._epsg2154_epsg3857_transformer.transform( - fields.zdaxepsg2154, fields.zdayepsg2154 - ) - - yield StopArea( - id=int(fields.zdaid), - name=fields.zdaname, - town_name=fields.zdatown, - postal_region=fields.zdapostalregion, - epsg3857_x=epsg3857_point[0], - epsg3857_y=epsg3857_point[1], - type=StopAreaType(fields.zdatype.value), - version=fields.zdaversion, - created_ts=created_ts, - changed_ts=int(fields.zdachanged.timestamp()), - ) - - def _format_idfm_connection_areas( - self, - *connection_areas: IdfmConnectionArea, - ) -> Iterable[ConnectionArea]: - for connection_area in connection_areas: - - epsg3857_point = self._epsg2154_epsg3857_transformer.transform( - connection_area.zdcxepsg2154, connection_area.zdcyepsg2154 - ) - - yield ConnectionArea( - id=int(connection_area.zdcid), - name=connection_area.zdcname, - town_name=connection_area.zdctown, - postal_region=connection_area.zdcpostalregion, - epsg3857_x=epsg3857_point[0], - epsg3857_y=epsg3857_point[1], - transport_mode=StopAreaType(connection_area.zdctype.value), - version=connection_area.zdcversion, - created_ts=int(connection_area.zdccreated.timestamp()), - changed_ts=int(connection_area.zdcchanged.timestamp()), - ) - - def _format_idfm_stop_shapes( - self, *shape_records: ShapeRecord - ) -> Iterable[StopShape]: - for shape_record in shape_records: - - epsg3857_points = [ - self._epsg2154_epsg3857_transformer.transform(*point) - for point in shape_record.shape.points - ] - - bbox_it = iter(shape_record.shape.bbox) - epsg3857_bbox = [ - self._epsg2154_epsg3857_transformer.transform(*point) - for point in zip(bbox_it, bbox_it) - ] - - yield StopShape( - id=shape_record.record[1], - type=shape_record.shape.shapeType, - epsg3857_bbox=epsg3857_bbox, - epsg3857_points=epsg3857_points, - ) + ... async def render_line_picto(self, line: Line) -> tuple[None | str, None | str]: - begin_ts = time() line_picto_path = line_picto_format = None target = f"/tmp/{line.id}_repr" @@ -531,12 +43,9 @@ class IdfmInterface: line_picto_path = target line_picto_format = picto.mime_type - print(f"render_line_picto: {time() - begin_ts}") return (line_picto_path, line_picto_format) async def _get_line_picto(self, line: Line) -> ByteString | None: - print("---------------------------------------------------------------------") - begin_ts = time() data = None picto = line.picto @@ -544,25 +53,20 @@ class IdfmInterface: headers = ( self._http_headers if picto.url.startswith(self.IDFM_ROOT_URL) else None ) - session_begin_ts = time() - async with ClientSession(headers=headers) as session: - session_creation_ts = time() - print(f"Session creation {session_creation_ts - session_begin_ts}") - async with session.get(picto.url) as response: - get_end_ts = time() - print(f"GET {get_end_ts - session_creation_ts}") - data = await response.read() - print(f"read {time() - get_end_ts}") - print(f"render_line_picto: {time() - begin_ts}") - print("---------------------------------------------------------------------") + async with ClientSession(headers=headers) as session: + async with session.get(picto.url) as response: + data = await response.read() + return data async def get_next_passages(self, stop_point_id: int) -> IdfmResponse | None: ret = None params = {"MonitoringRef": f"STIF:StopPoint:Q:{stop_point_id}:"} + async with ClientSession(headers=self._http_headers) as session: async with session.get(self.IDFM_STOP_MON_URL, params=params) as response: + if response.status == 200: data = await response.read() try: @@ -573,8 +77,6 @@ class IdfmInterface: return ret async def get_destinations(self, stop_id: int) -> IdfmDestinations | None: - begin_ts = time() - destinations: IdfmDestinations = defaultdict(set) if (stop := await Stop.get_by_id(stop_id)) is not None: @@ -582,7 +84,6 @@ class IdfmInterface: elif (stop_area := await StopArea.get_by_id(stop_id)) is not None: expected_stop_ids = {stop.id for stop in stop_area.stops} - else: return None @@ -593,6 +94,7 @@ class IdfmInterface: for stop_visit in delivery.MonitoredStopVisit: monitoring_ref = stop_visit.MonitoringRef.value + try: monitored_stop_id = int(monitoring_ref.split(":")[-2]) except (IndexError, ValueError): @@ -603,9 +105,7 @@ class IdfmInterface: if ( dst_names := journey.DestinationName ) and monitored_stop_id in expected_stop_ids: - line_id = journey.LineRef.value.split(":")[-2] destinations[line_id].add(dst_names[0].value) - print(f"get_next_passages: {time() - begin_ts}") return destinations