✨ Handle IDFM connection areas
This commit is contained in:
@@ -1,6 +1,14 @@
|
|||||||
|
from collections import defaultdict
|
||||||
from re import compile as re_compile
|
from re import compile as re_compile
|
||||||
from time import time
|
from time import time
|
||||||
from typing import AsyncIterator, ByteString, Callable, Iterable, List, Type
|
from typing import (
|
||||||
|
AsyncIterator,
|
||||||
|
ByteString,
|
||||||
|
Callable,
|
||||||
|
Iterable,
|
||||||
|
List,
|
||||||
|
Type,
|
||||||
|
)
|
||||||
|
|
||||||
from aiofiles import open as async_open
|
from aiofiles import open as async_open
|
||||||
from aiohttp import ClientSession
|
from aiohttp import ClientSession
|
||||||
@@ -10,8 +18,9 @@ from msgspec.json import Decoder
|
|||||||
from rich import print
|
from rich import print
|
||||||
|
|
||||||
from ..db import Database
|
from ..db import Database
|
||||||
from ..models import Line, LinePicto, Stop, StopArea
|
from ..models import ConnectionArea, Line, LinePicto, Stop, StopArea, StopShape
|
||||||
from .idfm_types import (
|
from .idfm_types import (
|
||||||
|
ConnectionArea as IdfmConnectionArea,
|
||||||
IdfmLineState,
|
IdfmLineState,
|
||||||
IdfmResponse,
|
IdfmResponse,
|
||||||
Line as IdfmLine,
|
Line as IdfmLine,
|
||||||
@@ -55,6 +64,7 @@ class IdfmInterface:
|
|||||||
|
|
||||||
self._json_stops_decoder = Decoder(type=List[IdfmStop])
|
self._json_stops_decoder = Decoder(type=List[IdfmStop])
|
||||||
self._json_stop_areas_decoder = Decoder(type=List[IdfmStopArea])
|
self._json_stop_areas_decoder = Decoder(type=List[IdfmStopArea])
|
||||||
|
self._json_connection_areas_decoder = Decoder(type=List[IdfmConnectionArea])
|
||||||
self._json_lines_decoder = Decoder(type=List[IdfmLine])
|
self._json_lines_decoder = Decoder(type=List[IdfmLine])
|
||||||
self._json_stops_lines_assos_decoder = Decoder(type=List[IdfmStopLineAsso])
|
self._json_stops_lines_assos_decoder = Decoder(type=List[IdfmStopLineAsso])
|
||||||
self._json_ratp_pictos_decoder = Decoder(type=List[RatpPicto])
|
self._json_ratp_pictos_decoder = Decoder(type=List[RatpPicto])
|
||||||
@@ -67,6 +77,11 @@ class IdfmInterface:
|
|||||||
async def startup(self) -> None:
|
async def startup(self) -> None:
|
||||||
BATCH_SIZE = 10000
|
BATCH_SIZE = 10000
|
||||||
STEPS: tuple[tuple[Type[Stop] | Type[StopArea], Callable, Callable], ...] = (
|
STEPS: tuple[tuple[Type[Stop] | Type[StopArea], Callable, Callable], ...] = (
|
||||||
|
(
|
||||||
|
ConnectionArea,
|
||||||
|
self._request_idfm_connection_areas,
|
||||||
|
IdfmInterface._format_idfm_connection_areas,
|
||||||
|
),
|
||||||
(
|
(
|
||||||
StopArea,
|
StopArea,
|
||||||
self._request_idfm_stop_areas,
|
self._request_idfm_stop_areas,
|
||||||
@@ -104,7 +119,7 @@ class IdfmInterface:
|
|||||||
print(f"Link Stops to Lines: {time() - begin_ts}s")
|
print(f"Link Stops to Lines: {time() - begin_ts}s")
|
||||||
|
|
||||||
begin_ts = time()
|
begin_ts = time()
|
||||||
await self._load_stop_areas_stops_assos()
|
await self._load_stop_assos()
|
||||||
print(f"Link Stops to StopAreas: {time() - begin_ts}s")
|
print(f"Link Stops to StopAreas: {time() - begin_ts}s")
|
||||||
|
|
||||||
async def _load_lines(self, batch_size: int = 5000) -> None:
|
async def _load_lines(self, batch_size: int = 5000) -> None:
|
||||||
@@ -167,25 +182,43 @@ class IdfmInterface:
|
|||||||
|
|
||||||
print(f"{total_found_nb} line <-> stop ({total_assos_nb = } found)")
|
print(f"{total_found_nb} line <-> stop ({total_assos_nb = } found)")
|
||||||
|
|
||||||
async def _load_stop_areas_stops_assos(self, batch_size: int = 5000) -> None:
|
async def _load_stop_assos(self, batch_size: int = 5000) -> None:
|
||||||
total_assos_nb = total_found_nb = 0
|
total_assos_nb = area_stop_assos_nb = conn_stop_assos_nb = 0
|
||||||
assos = []
|
area_stop_assos = []
|
||||||
|
connection_stop_assos = []
|
||||||
|
|
||||||
async for asso in self._request_idfm_stop_area_stop_associations():
|
async for asso in self._request_idfm_stop_area_stop_associations():
|
||||||
fields = asso.fields
|
fields = asso.fields
|
||||||
|
|
||||||
assos.append((int(fields.zdaid), int(fields.arrid)))
|
stop_id = int(fields.arrid)
|
||||||
if len(assos) == batch_size:
|
|
||||||
|
area_stop_assos.append((int(fields.zdaid), stop_id))
|
||||||
|
connection_stop_assos.append((int(fields.zdcid), stop_id))
|
||||||
|
|
||||||
|
if len(area_stop_assos) == batch_size:
|
||||||
total_assos_nb += batch_size
|
total_assos_nb += batch_size
|
||||||
if (found_nb := await StopArea.add_stops(assos)) is not None:
|
|
||||||
total_found_nb += found_nb
|
|
||||||
assos.clear()
|
|
||||||
|
|
||||||
if assos:
|
if (found_nb := await StopArea.add_stops(area_stop_assos)) is not None:
|
||||||
total_assos_nb += len(assos)
|
area_stop_assos_nb += found_nb
|
||||||
if (found_nb := await StopArea.add_stops(assos)) is not None:
|
area_stop_assos.clear()
|
||||||
total_found_nb += found_nb
|
|
||||||
|
|
||||||
print(f"{total_found_nb} stop area <-> stop ({total_assos_nb = } found)")
|
if (
|
||||||
|
found_nb := await ConnectionArea.add_stops(connection_stop_assos)
|
||||||
|
) is not None:
|
||||||
|
conn_stop_assos_nb += found_nb
|
||||||
|
connection_stop_assos.clear()
|
||||||
|
|
||||||
|
if area_stop_assos:
|
||||||
|
total_assos_nb += len(area_stop_assos)
|
||||||
|
if (found_nb := await StopArea.add_stops(area_stop_assos)) is not None:
|
||||||
|
area_stop_assos_nb += found_nb
|
||||||
|
if (
|
||||||
|
found_nb := await ConnectionArea.add_stops(connection_stop_assos)
|
||||||
|
) is not None:
|
||||||
|
conn_stop_assos_nb += found_nb
|
||||||
|
|
||||||
|
print(f"{area_stop_assos_nb} stop area <-> stop ({total_assos_nb = } found)")
|
||||||
|
print(f"{conn_stop_assos_nb} stop area <-> stop ({total_assos_nb = } found)")
|
||||||
|
|
||||||
async def _request_idfm_stops(self) -> AsyncIterator[IdfmStop]:
|
async def _request_idfm_stops(self) -> AsyncIterator[IdfmStop]:
|
||||||
# headers = {"Accept": "application/json", "apikey": self._api_key}
|
# headers = {"Accept": "application/json", "apikey": self._api_key}
|
||||||
@@ -206,6 +239,13 @@ class IdfmInterface:
|
|||||||
for element in self._json_stop_areas_decoder.decode(await raw.read()):
|
for element in self._json_stop_areas_decoder.decode(await raw.read()):
|
||||||
yield element
|
yield element
|
||||||
|
|
||||||
|
async def _request_idfm_connection_areas(self) -> AsyncIterator[IdfmConnectionArea]:
|
||||||
|
async with async_open(
|
||||||
|
"./tests/datasets/zones-de-correspondance.json", "rb"
|
||||||
|
) as raw:
|
||||||
|
for element in self._json_connection_areas_decoder.decode(await raw.read()):
|
||||||
|
yield element
|
||||||
|
|
||||||
async def _request_idfm_lines(self) -> AsyncIterator[IdfmLine]:
|
async def _request_idfm_lines(self) -> AsyncIterator[IdfmLine]:
|
||||||
# TODO: Use HTTP
|
# TODO: Use HTTP
|
||||||
async with async_open("./tests/datasets/lines_dataset.json", "rb") as raw:
|
async with async_open("./tests/datasets/lines_dataset.json", "rb") as raw:
|
||||||
@@ -378,6 +418,24 @@ class IdfmInterface:
|
|||||||
changed_ts=int(fields.zdachanged.timestamp()),
|
changed_ts=int(fields.zdachanged.timestamp()),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _format_idfm_connection_areas(
|
||||||
|
*connection_areas: IdfmConnectionArea,
|
||||||
|
) -> Iterable[ConnectionArea]:
|
||||||
|
for connection_area in connection_areas:
|
||||||
|
yield ConnectionArea(
|
||||||
|
id=int(connection_area.zdcid),
|
||||||
|
name=connection_area.zdcname,
|
||||||
|
town_name=connection_area.zdctown,
|
||||||
|
postal_region=connection_area.zdcpostalregion,
|
||||||
|
xepsg2154=connection_area.zdcxepsg2154,
|
||||||
|
yepsg2154=connection_area.zdcyepsg2154,
|
||||||
|
transport_mode=StopAreaType(connection_area.zdctype.value),
|
||||||
|
version=connection_area.zdcversion,
|
||||||
|
created_ts=int(connection_area.zdccreated.timestamp()),
|
||||||
|
changed_ts=int(connection_area.zdcchanged.timestamp()),
|
||||||
|
)
|
||||||
|
|
||||||
async def render_line_picto(self, line: Line) -> tuple[None | str, None | str]:
|
async def render_line_picto(self, line: Line) -> tuple[None | str, None | str]:
|
||||||
begin_ts = time()
|
begin_ts = time()
|
||||||
line_picto_path = line_picto_format = None
|
line_picto_path = line_picto_format = None
|
||||||
|
@@ -116,6 +116,19 @@ class StopArea(Struct):
|
|||||||
record_timestamp: datetime
|
record_timestamp: datetime
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectionArea(Struct):
|
||||||
|
zdcid: str
|
||||||
|
zdcversion: str
|
||||||
|
zdccreated: datetime
|
||||||
|
zdcchanged: datetime
|
||||||
|
zdcname: str
|
||||||
|
zdcxepsg2154: int
|
||||||
|
zdcyepsg2154: int
|
||||||
|
zdctown: str
|
||||||
|
zdcpostalregion: str
|
||||||
|
zdctype: StopAreaType
|
||||||
|
|
||||||
|
|
||||||
class StopAreaStopAssociationFields(Struct, kw_only=True):
|
class StopAreaStopAssociationFields(Struct, kw_only=True):
|
||||||
arrid: str # TODO: use int ?
|
arrid: str # TODO: use int ?
|
||||||
artid: str | None = None
|
artid: str | None = None
|
||||||
|
Reference in New Issue
Block a user