From ecfb3c8cb38599a946e8d6bce1a93c494b19935e Mon Sep 17 00:00:00 2001 From: Adrien Date: Thu, 13 Apr 2023 20:57:15 +0200 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20Handle=20IDFM=20connection=20areas?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../backend/idfm_interface/idfm_interface.py | 90 +++++++++++++++---- backend/backend/idfm_interface/idfm_types.py | 13 +++ 2 files changed, 87 insertions(+), 16 deletions(-) diff --git a/backend/backend/idfm_interface/idfm_interface.py b/backend/backend/idfm_interface/idfm_interface.py index faf2769..5d16edf 100644 --- a/backend/backend/idfm_interface/idfm_interface.py +++ b/backend/backend/idfm_interface/idfm_interface.py @@ -1,6 +1,14 @@ +from collections import defaultdict from re import compile as re_compile from time import time -from typing import AsyncIterator, ByteString, Callable, Iterable, List, Type +from typing import ( + AsyncIterator, + ByteString, + Callable, + Iterable, + List, + Type, +) from aiofiles import open as async_open from aiohttp import ClientSession @@ -10,8 +18,9 @@ from msgspec.json import Decoder from rich import print from ..db import Database -from ..models import Line, LinePicto, Stop, StopArea +from ..models import ConnectionArea, Line, LinePicto, Stop, StopArea, StopShape from .idfm_types import ( + ConnectionArea as IdfmConnectionArea, IdfmLineState, IdfmResponse, Line as IdfmLine, @@ -55,6 +64,7 @@ class IdfmInterface: self._json_stops_decoder = Decoder(type=List[IdfmStop]) self._json_stop_areas_decoder = Decoder(type=List[IdfmStopArea]) + self._json_connection_areas_decoder = Decoder(type=List[IdfmConnectionArea]) self._json_lines_decoder = Decoder(type=List[IdfmLine]) self._json_stops_lines_assos_decoder = Decoder(type=List[IdfmStopLineAsso]) self._json_ratp_pictos_decoder = Decoder(type=List[RatpPicto]) @@ -67,6 +77,11 @@ class IdfmInterface: async def startup(self) -> None: BATCH_SIZE = 10000 STEPS: tuple[tuple[Type[Stop] | Type[StopArea], Callable, Callable], ...] = ( + ( + ConnectionArea, + self._request_idfm_connection_areas, + IdfmInterface._format_idfm_connection_areas, + ), ( StopArea, self._request_idfm_stop_areas, @@ -104,7 +119,7 @@ class IdfmInterface: print(f"Link Stops to Lines: {time() - begin_ts}s") begin_ts = time() - await self._load_stop_areas_stops_assos() + await self._load_stop_assos() print(f"Link Stops to StopAreas: {time() - begin_ts}s") async def _load_lines(self, batch_size: int = 5000) -> None: @@ -167,25 +182,43 @@ class IdfmInterface: print(f"{total_found_nb} line <-> stop ({total_assos_nb = } found)") - async def _load_stop_areas_stops_assos(self, batch_size: int = 5000) -> None: - total_assos_nb = total_found_nb = 0 - assos = [] + async def _load_stop_assos(self, batch_size: int = 5000) -> None: + total_assos_nb = area_stop_assos_nb = conn_stop_assos_nb = 0 + area_stop_assos = [] + connection_stop_assos = [] + async for asso in self._request_idfm_stop_area_stop_associations(): fields = asso.fields - assos.append((int(fields.zdaid), int(fields.arrid))) - if len(assos) == batch_size: + stop_id = int(fields.arrid) + + area_stop_assos.append((int(fields.zdaid), stop_id)) + connection_stop_assos.append((int(fields.zdcid), stop_id)) + + if len(area_stop_assos) == batch_size: total_assos_nb += batch_size - if (found_nb := await StopArea.add_stops(assos)) is not None: - total_found_nb += found_nb - assos.clear() - if assos: - total_assos_nb += len(assos) - if (found_nb := await StopArea.add_stops(assos)) is not None: - total_found_nb += found_nb + if (found_nb := await StopArea.add_stops(area_stop_assos)) is not None: + area_stop_assos_nb += found_nb + area_stop_assos.clear() - print(f"{total_found_nb} stop area <-> stop ({total_assos_nb = } found)") + if ( + found_nb := await ConnectionArea.add_stops(connection_stop_assos) + ) is not None: + conn_stop_assos_nb += found_nb + connection_stop_assos.clear() + + if area_stop_assos: + total_assos_nb += len(area_stop_assos) + if (found_nb := await StopArea.add_stops(area_stop_assos)) is not None: + area_stop_assos_nb += found_nb + if ( + found_nb := await ConnectionArea.add_stops(connection_stop_assos) + ) is not None: + conn_stop_assos_nb += found_nb + + print(f"{area_stop_assos_nb} stop area <-> stop ({total_assos_nb = } found)") + print(f"{conn_stop_assos_nb} stop area <-> stop ({total_assos_nb = } found)") async def _request_idfm_stops(self) -> AsyncIterator[IdfmStop]: # headers = {"Accept": "application/json", "apikey": self._api_key} @@ -206,6 +239,13 @@ class IdfmInterface: for element in self._json_stop_areas_decoder.decode(await raw.read()): yield element + async def _request_idfm_connection_areas(self) -> AsyncIterator[IdfmConnectionArea]: + async with async_open( + "./tests/datasets/zones-de-correspondance.json", "rb" + ) as raw: + for element in self._json_connection_areas_decoder.decode(await raw.read()): + yield element + async def _request_idfm_lines(self) -> AsyncIterator[IdfmLine]: # TODO: Use HTTP async with async_open("./tests/datasets/lines_dataset.json", "rb") as raw: @@ -378,6 +418,24 @@ class IdfmInterface: changed_ts=int(fields.zdachanged.timestamp()), ) + @staticmethod + def _format_idfm_connection_areas( + *connection_areas: IdfmConnectionArea, + ) -> Iterable[ConnectionArea]: + for connection_area in connection_areas: + yield ConnectionArea( + id=int(connection_area.zdcid), + name=connection_area.zdcname, + town_name=connection_area.zdctown, + postal_region=connection_area.zdcpostalregion, + xepsg2154=connection_area.zdcxepsg2154, + yepsg2154=connection_area.zdcyepsg2154, + transport_mode=StopAreaType(connection_area.zdctype.value), + version=connection_area.zdcversion, + created_ts=int(connection_area.zdccreated.timestamp()), + changed_ts=int(connection_area.zdcchanged.timestamp()), + ) + async def render_line_picto(self, line: Line) -> tuple[None | str, None | str]: begin_ts = time() line_picto_path = line_picto_format = None diff --git a/backend/backend/idfm_interface/idfm_types.py b/backend/backend/idfm_interface/idfm_types.py index 82b9cbb..efb2c71 100644 --- a/backend/backend/idfm_interface/idfm_types.py +++ b/backend/backend/idfm_interface/idfm_types.py @@ -116,6 +116,19 @@ class StopArea(Struct): record_timestamp: datetime +class ConnectionArea(Struct): + zdcid: str + zdcversion: str + zdccreated: datetime + zdcchanged: datetime + zdcname: str + zdcxepsg2154: int + zdcyepsg2154: int + zdctown: str + zdcpostalregion: str + zdctype: StopAreaType + + class StopAreaStopAssociationFields(Struct, kw_only=True): arrid: str # TODO: use int ? artid: str | None = None