Handle IDFM connection areas

This commit is contained in:
2023-04-13 20:57:15 +02:00
parent 293a1391bc
commit ecfb3c8cb3
2 changed files with 87 additions and 16 deletions

View File

@@ -1,6 +1,14 @@
from collections import defaultdict
from re import compile as re_compile
from time import time
from typing import AsyncIterator, ByteString, Callable, Iterable, List, Type
from typing import (
AsyncIterator,
ByteString,
Callable,
Iterable,
List,
Type,
)
from aiofiles import open as async_open
from aiohttp import ClientSession
@@ -10,8 +18,9 @@ from msgspec.json import Decoder
from rich import print
from ..db import Database
from ..models import Line, LinePicto, Stop, StopArea
from ..models import ConnectionArea, Line, LinePicto, Stop, StopArea, StopShape
from .idfm_types import (
ConnectionArea as IdfmConnectionArea,
IdfmLineState,
IdfmResponse,
Line as IdfmLine,
@@ -55,6 +64,7 @@ class IdfmInterface:
self._json_stops_decoder = Decoder(type=List[IdfmStop])
self._json_stop_areas_decoder = Decoder(type=List[IdfmStopArea])
self._json_connection_areas_decoder = Decoder(type=List[IdfmConnectionArea])
self._json_lines_decoder = Decoder(type=List[IdfmLine])
self._json_stops_lines_assos_decoder = Decoder(type=List[IdfmStopLineAsso])
self._json_ratp_pictos_decoder = Decoder(type=List[RatpPicto])
@@ -67,6 +77,11 @@ class IdfmInterface:
async def startup(self) -> None:
BATCH_SIZE = 10000
STEPS: tuple[tuple[Type[Stop] | Type[StopArea], Callable, Callable], ...] = (
(
ConnectionArea,
self._request_idfm_connection_areas,
IdfmInterface._format_idfm_connection_areas,
),
(
StopArea,
self._request_idfm_stop_areas,
@@ -104,7 +119,7 @@ class IdfmInterface:
print(f"Link Stops to Lines: {time() - begin_ts}s")
begin_ts = time()
await self._load_stop_areas_stops_assos()
await self._load_stop_assos()
print(f"Link Stops to StopAreas: {time() - begin_ts}s")
async def _load_lines(self, batch_size: int = 5000) -> None:
@@ -167,25 +182,43 @@ class IdfmInterface:
print(f"{total_found_nb} line <-> stop ({total_assos_nb = } found)")
async def _load_stop_areas_stops_assos(self, batch_size: int = 5000) -> None:
total_assos_nb = total_found_nb = 0
assos = []
async def _load_stop_assos(self, batch_size: int = 5000) -> None:
total_assos_nb = area_stop_assos_nb = conn_stop_assos_nb = 0
area_stop_assos = []
connection_stop_assos = []
async for asso in self._request_idfm_stop_area_stop_associations():
fields = asso.fields
assos.append((int(fields.zdaid), int(fields.arrid)))
if len(assos) == batch_size:
stop_id = int(fields.arrid)
area_stop_assos.append((int(fields.zdaid), stop_id))
connection_stop_assos.append((int(fields.zdcid), stop_id))
if len(area_stop_assos) == batch_size:
total_assos_nb += batch_size
if (found_nb := await StopArea.add_stops(assos)) is not None:
total_found_nb += found_nb
assos.clear()
if assos:
total_assos_nb += len(assos)
if (found_nb := await StopArea.add_stops(assos)) is not None:
total_found_nb += found_nb
if (found_nb := await StopArea.add_stops(area_stop_assos)) is not None:
area_stop_assos_nb += found_nb
area_stop_assos.clear()
print(f"{total_found_nb} stop area <-> stop ({total_assos_nb = } found)")
if (
found_nb := await ConnectionArea.add_stops(connection_stop_assos)
) is not None:
conn_stop_assos_nb += found_nb
connection_stop_assos.clear()
if area_stop_assos:
total_assos_nb += len(area_stop_assos)
if (found_nb := await StopArea.add_stops(area_stop_assos)) is not None:
area_stop_assos_nb += found_nb
if (
found_nb := await ConnectionArea.add_stops(connection_stop_assos)
) is not None:
conn_stop_assos_nb += found_nb
print(f"{area_stop_assos_nb} stop area <-> stop ({total_assos_nb = } found)")
print(f"{conn_stop_assos_nb} stop area <-> stop ({total_assos_nb = } found)")
async def _request_idfm_stops(self) -> AsyncIterator[IdfmStop]:
# headers = {"Accept": "application/json", "apikey": self._api_key}
@@ -206,6 +239,13 @@ class IdfmInterface:
for element in self._json_stop_areas_decoder.decode(await raw.read()):
yield element
async def _request_idfm_connection_areas(self) -> AsyncIterator[IdfmConnectionArea]:
async with async_open(
"./tests/datasets/zones-de-correspondance.json", "rb"
) as raw:
for element in self._json_connection_areas_decoder.decode(await raw.read()):
yield element
async def _request_idfm_lines(self) -> AsyncIterator[IdfmLine]:
# TODO: Use HTTP
async with async_open("./tests/datasets/lines_dataset.json", "rb") as raw:
@@ -378,6 +418,24 @@ class IdfmInterface:
changed_ts=int(fields.zdachanged.timestamp()),
)
@staticmethod
def _format_idfm_connection_areas(
*connection_areas: IdfmConnectionArea,
) -> Iterable[ConnectionArea]:
for connection_area in connection_areas:
yield ConnectionArea(
id=int(connection_area.zdcid),
name=connection_area.zdcname,
town_name=connection_area.zdctown,
postal_region=connection_area.zdcpostalregion,
xepsg2154=connection_area.zdcxepsg2154,
yepsg2154=connection_area.zdcyepsg2154,
transport_mode=StopAreaType(connection_area.zdctype.value),
version=connection_area.zdcversion,
created_ts=int(connection_area.zdccreated.timestamp()),
changed_ts=int(connection_area.zdcchanged.timestamp()),
)
async def render_line_picto(self, line: Line) -> tuple[None | str, None | str]:
begin_ts = time()
line_picto_path = line_picto_format = None

View File

@@ -116,6 +116,19 @@ class StopArea(Struct):
record_timestamp: datetime
class ConnectionArea(Struct):
zdcid: str
zdcversion: str
zdccreated: datetime
zdcchanged: datetime
zdcname: str
zdcxepsg2154: int
zdcyepsg2154: int
zdctown: str
zdcpostalregion: str
zdctype: StopAreaType
class StopAreaStopAssociationFields(Struct, kw_only=True):
arrid: str # TODO: use int ?
artid: str | None = None