Handle IDFM connection areas

This commit is contained in:
2023-04-13 20:57:15 +02:00
parent 293a1391bc
commit ecfb3c8cb3
2 changed files with 87 additions and 16 deletions

View File

@@ -1,6 +1,14 @@
from collections import defaultdict
from re import compile as re_compile from re import compile as re_compile
from time import time from time import time
from typing import AsyncIterator, ByteString, Callable, Iterable, List, Type from typing import (
AsyncIterator,
ByteString,
Callable,
Iterable,
List,
Type,
)
from aiofiles import open as async_open from aiofiles import open as async_open
from aiohttp import ClientSession from aiohttp import ClientSession
@@ -10,8 +18,9 @@ from msgspec.json import Decoder
from rich import print from rich import print
from ..db import Database from ..db import Database
from ..models import Line, LinePicto, Stop, StopArea from ..models import ConnectionArea, Line, LinePicto, Stop, StopArea, StopShape
from .idfm_types import ( from .idfm_types import (
ConnectionArea as IdfmConnectionArea,
IdfmLineState, IdfmLineState,
IdfmResponse, IdfmResponse,
Line as IdfmLine, Line as IdfmLine,
@@ -55,6 +64,7 @@ class IdfmInterface:
self._json_stops_decoder = Decoder(type=List[IdfmStop]) self._json_stops_decoder = Decoder(type=List[IdfmStop])
self._json_stop_areas_decoder = Decoder(type=List[IdfmStopArea]) self._json_stop_areas_decoder = Decoder(type=List[IdfmStopArea])
self._json_connection_areas_decoder = Decoder(type=List[IdfmConnectionArea])
self._json_lines_decoder = Decoder(type=List[IdfmLine]) self._json_lines_decoder = Decoder(type=List[IdfmLine])
self._json_stops_lines_assos_decoder = Decoder(type=List[IdfmStopLineAsso]) self._json_stops_lines_assos_decoder = Decoder(type=List[IdfmStopLineAsso])
self._json_ratp_pictos_decoder = Decoder(type=List[RatpPicto]) self._json_ratp_pictos_decoder = Decoder(type=List[RatpPicto])
@@ -67,6 +77,11 @@ class IdfmInterface:
async def startup(self) -> None: async def startup(self) -> None:
BATCH_SIZE = 10000 BATCH_SIZE = 10000
STEPS: tuple[tuple[Type[Stop] | Type[StopArea], Callable, Callable], ...] = ( STEPS: tuple[tuple[Type[Stop] | Type[StopArea], Callable, Callable], ...] = (
(
ConnectionArea,
self._request_idfm_connection_areas,
IdfmInterface._format_idfm_connection_areas,
),
( (
StopArea, StopArea,
self._request_idfm_stop_areas, self._request_idfm_stop_areas,
@@ -104,7 +119,7 @@ class IdfmInterface:
print(f"Link Stops to Lines: {time() - begin_ts}s") print(f"Link Stops to Lines: {time() - begin_ts}s")
begin_ts = time() begin_ts = time()
await self._load_stop_areas_stops_assos() await self._load_stop_assos()
print(f"Link Stops to StopAreas: {time() - begin_ts}s") print(f"Link Stops to StopAreas: {time() - begin_ts}s")
async def _load_lines(self, batch_size: int = 5000) -> None: async def _load_lines(self, batch_size: int = 5000) -> None:
@@ -167,25 +182,43 @@ class IdfmInterface:
print(f"{total_found_nb} line <-> stop ({total_assos_nb = } found)") print(f"{total_found_nb} line <-> stop ({total_assos_nb = } found)")
async def _load_stop_areas_stops_assos(self, batch_size: int = 5000) -> None: async def _load_stop_assos(self, batch_size: int = 5000) -> None:
total_assos_nb = total_found_nb = 0 total_assos_nb = area_stop_assos_nb = conn_stop_assos_nb = 0
assos = [] area_stop_assos = []
connection_stop_assos = []
async for asso in self._request_idfm_stop_area_stop_associations(): async for asso in self._request_idfm_stop_area_stop_associations():
fields = asso.fields fields = asso.fields
assos.append((int(fields.zdaid), int(fields.arrid))) stop_id = int(fields.arrid)
if len(assos) == batch_size:
area_stop_assos.append((int(fields.zdaid), stop_id))
connection_stop_assos.append((int(fields.zdcid), stop_id))
if len(area_stop_assos) == batch_size:
total_assos_nb += batch_size total_assos_nb += batch_size
if (found_nb := await StopArea.add_stops(assos)) is not None:
total_found_nb += found_nb
assos.clear()
if assos: if (found_nb := await StopArea.add_stops(area_stop_assos)) is not None:
total_assos_nb += len(assos) area_stop_assos_nb += found_nb
if (found_nb := await StopArea.add_stops(assos)) is not None: area_stop_assos.clear()
total_found_nb += found_nb
print(f"{total_found_nb} stop area <-> stop ({total_assos_nb = } found)") if (
found_nb := await ConnectionArea.add_stops(connection_stop_assos)
) is not None:
conn_stop_assos_nb += found_nb
connection_stop_assos.clear()
if area_stop_assos:
total_assos_nb += len(area_stop_assos)
if (found_nb := await StopArea.add_stops(area_stop_assos)) is not None:
area_stop_assos_nb += found_nb
if (
found_nb := await ConnectionArea.add_stops(connection_stop_assos)
) is not None:
conn_stop_assos_nb += found_nb
print(f"{area_stop_assos_nb} stop area <-> stop ({total_assos_nb = } found)")
print(f"{conn_stop_assos_nb} stop area <-> stop ({total_assos_nb = } found)")
async def _request_idfm_stops(self) -> AsyncIterator[IdfmStop]: async def _request_idfm_stops(self) -> AsyncIterator[IdfmStop]:
# headers = {"Accept": "application/json", "apikey": self._api_key} # headers = {"Accept": "application/json", "apikey": self._api_key}
@@ -206,6 +239,13 @@ class IdfmInterface:
for element in self._json_stop_areas_decoder.decode(await raw.read()): for element in self._json_stop_areas_decoder.decode(await raw.read()):
yield element yield element
async def _request_idfm_connection_areas(self) -> AsyncIterator[IdfmConnectionArea]:
async with async_open(
"./tests/datasets/zones-de-correspondance.json", "rb"
) as raw:
for element in self._json_connection_areas_decoder.decode(await raw.read()):
yield element
async def _request_idfm_lines(self) -> AsyncIterator[IdfmLine]: async def _request_idfm_lines(self) -> AsyncIterator[IdfmLine]:
# TODO: Use HTTP # TODO: Use HTTP
async with async_open("./tests/datasets/lines_dataset.json", "rb") as raw: async with async_open("./tests/datasets/lines_dataset.json", "rb") as raw:
@@ -378,6 +418,24 @@ class IdfmInterface:
changed_ts=int(fields.zdachanged.timestamp()), changed_ts=int(fields.zdachanged.timestamp()),
) )
@staticmethod
def _format_idfm_connection_areas(
*connection_areas: IdfmConnectionArea,
) -> Iterable[ConnectionArea]:
for connection_area in connection_areas:
yield ConnectionArea(
id=int(connection_area.zdcid),
name=connection_area.zdcname,
town_name=connection_area.zdctown,
postal_region=connection_area.zdcpostalregion,
xepsg2154=connection_area.zdcxepsg2154,
yepsg2154=connection_area.zdcyepsg2154,
transport_mode=StopAreaType(connection_area.zdctype.value),
version=connection_area.zdcversion,
created_ts=int(connection_area.zdccreated.timestamp()),
changed_ts=int(connection_area.zdcchanged.timestamp()),
)
async def render_line_picto(self, line: Line) -> tuple[None | str, None | str]: async def render_line_picto(self, line: Line) -> tuple[None | str, None | str]:
begin_ts = time() begin_ts = time()
line_picto_path = line_picto_format = None line_picto_path = line_picto_format = None

View File

@@ -116,6 +116,19 @@ class StopArea(Struct):
record_timestamp: datetime record_timestamp: datetime
class ConnectionArea(Struct):
zdcid: str
zdcversion: str
zdccreated: datetime
zdcchanged: datetime
zdcname: str
zdcxepsg2154: int
zdcyepsg2154: int
zdctown: str
zdcpostalregion: str
zdctype: StopAreaType
class StopAreaStopAssociationFields(Struct, kw_only=True): class StopAreaStopAssociationFields(Struct, kw_only=True):
arrid: str # TODO: use int ? arrid: str # TODO: use int ?
artid: str | None = None artid: str | None = None