🎉 First commit !!!

This commit is contained in:
2023-01-22 16:53:45 +01:00
commit dde835760a
68 changed files with 3250 additions and 0 deletions

View File

@@ -0,0 +1,4 @@
from .db import Database
from .base_class import Base
db = Database()

View File

@@ -0,0 +1,34 @@
from collections.abc import Iterable
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import declarative_base
from typing import Iterable, Self
Base = declarative_base()
Base.db = None
async def base_add(cls, stops: Self | Iterable[Self]) -> bool:
try:
method = (
cls.db.session.add_all
if isinstance(stops, Iterable)
else cls.db.session.add
)
method(stops)
await cls.db.session.commit()
except IntegrityError as err:
print(err)
Base.add = classmethod(base_add)
async def base_get_by_id(cls, id_: int | str) -> None | Base:
res = await cls.db.session.execute(select(cls).where(cls.id == id_))
element = res.scalar_one_or_none()
return element
Base.get_by_id = classmethod(base_get_by_id)

View File

@@ -0,0 +1,80 @@
from asyncio import gather as asyncio_gather
from functools import wraps
from pathlib import Path
from time import time
from typing import Callable, Iterable, Optional
from rich import print
from sqlalchemy import event, select, tuple_
from sqlalchemy.engine import Engine
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import (
selectinload,
sessionmaker,
with_polymorphic,
)
from sqlalchemy.orm.attributes import set_committed_value
from .base_class import Base
# import logging
# logging.basicConfig()
# logger = logging.getLogger("bot.sqltime")
# logger.setLevel(logging.DEBUG)
# @event.listens_for(Engine, "before_cursor_execute")
# def before_cursor_execute(conn, cursor, statement, parameters, context, executemany):
# conn.info.setdefault("query_start_time", []).append(time())
# logger.debug("Start Query: %s", statement)
# @event.listens_for(Engine, "after_cursor_execute")
# def after_cursor_execute(conn, cursor, statement, parameters, context, executemany):
# total = time() - conn.info["query_start_time"].pop(-1)
# logger.debug("Query Complete!")
# logger.debug("Total Time: %f", total)
class Database:
def __init__(self) -> None:
self._engine = None
self._session_maker = None
self._session = None
@property
def session(self) -> None:
if self._session is None:
self._session = self._session_maker()
return self._session
def use_session(func: Callable):
@wraps(func)
async def wrapper(self, *args, **kwargs):
if self._check_session() is not None:
return await func(self, *args, **kwargs)
# TODO: Raise an exception ?
return wrapper
async def connect(self, db_path: str, clear_static_data: bool = False) -> None:
# TODO: Preserve UserLastStopSearchResults table from drop.
self._engine = create_async_engine(db_path)
self._session_maker = sessionmaker(
self._engine, expire_on_commit=False, class_=AsyncSession
)
await self.session.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm;")
async with self._engine.begin() as conn:
if clear_static_data:
await conn.run_sync(Base.metadata.drop_all)
await conn.run_sync(Base.metadata.create_all)
async def disconnect(self) -> None:
if self._session is not None:
await self._session.close()
self._session = None
await self._engine.dispose()

View File

@@ -0,0 +1,2 @@
from .idfm_interface import IdfmInterface
from .idfm_types import *

View File

@@ -0,0 +1,447 @@
from pathlib import Path
from re import compile as re_compile
from time import time
from typing import ByteString, Iterable, List, Optional
from aiofiles import open as async_open
from aiohttp import ClientSession
from msgspec import ValidationError
from msgspec.json import Decoder
from rich import print
from ..db import Database
from ..models import Line, LinePicto, Stop, StopArea
from .idfm_types import (
IdfmLineState,
IdfmResponse,
Line as IdfmLine,
MonitoredVehicleJourney,
LinePicto as IdfmPicto,
IdfmState,
Stop as IdfmStop,
StopArea as IdfmStopArea,
StopAreaStopAssociation,
StopLineAsso as IdfmStopLineAsso,
Stops,
)
from .ratp_types import Picto as RatpPicto
class IdfmInterface:
IDFM_ROOT_URL = "https://prim.iledefrance-mobilites.fr/marketplace"
IDFM_STOP_MON_URL = f"{IDFM_ROOT_URL}/stop-monitoring"
IDFM_ROOT_URL = "https://data.iledefrance-mobilites.fr/explore/dataset"
IDFM_STOPS_URL = (
f"{IDFM_ROOT_URL}/arrets/download/?format=json&timezone=Europe/Berlin"
)
IDFM_PICTO_URL = f"{IDFM_ROOT_URL}/referentiel-des-lignes/files"
RATP_ROOT_URL = "https://data.ratp.fr/explore/dataset"
RATP_PICTO_URL = f"{RATP_ROOT_URL}/pictogrammes-des-lignes-de-metro-rer-tramway-bus-et-noctilien/files"
OPERATOR_RE = re_compile(r"[^:]+:Operator::([^:]+):")
LINE_RE = re_compile(r"[^:]+:Line::([^:]+):")
def __init__(self, api_key: str, database: Database) -> None:
self._api_key = api_key
self._database = database
self._http_headers = {"Accept": "application/json", "apikey": self._api_key}
self._json_stops_decoder = Decoder(type=List[IdfmStop])
self._json_stop_areas_decoder = Decoder(type=List[IdfmStopArea])
self._json_lines_decoder = Decoder(type=List[IdfmLine])
self._json_stops_lines_assos_decoder = Decoder(type=List[IdfmStopLineAsso])
self._json_ratp_pictos_decoder = Decoder(type=List[RatpPicto])
self._json_stop_area_stop_asso_decoder = Decoder(
type=List[StopAreaStopAssociation]
)
self._response_json_decoder = Decoder(type=IdfmResponse)
async def startup(self) -> None:
BATCH_SIZE = 10000
STEPS = (
(
StopArea,
self._request_idfm_stop_areas,
IdfmInterface._format_idfm_stop_areas,
),
(Stop, self._request_idfm_stops, IdfmInterface._format_idfm_stops),
)
for model, get_method, format_method in STEPS:
step_begin_ts = time()
elements = []
async for element in get_method():
elements.append(element)
if len(elements) == BATCH_SIZE:
await model.add(format_method(*elements))
elements.clear()
if elements:
await model.add(format_method(*elements))
print(f"Add {model.__name__}s: {time() - step_begin_ts}s")
begin_ts = time()
await self._load_lines()
print(f"Add Lines and IDFM LinePictos: {time() - begin_ts}s")
begin_ts = time()
await self._load_ratp_pictos(30)
print(f"Add RATP LinePictos: {time() - begin_ts}s")
begin_ts = time()
await self._load_lines_stops_assos()
print(f"Link Stops to Lines: {time() - begin_ts}s")
begin_ts = time()
await self._load_stop_areas_stops_assos()
print(f"Link Stops to StopAreas: {time() - begin_ts}s")
async def _load_lines(self, batch_size: int = 5000) -> None:
lines, pictos = [], []
picto_ids = set()
async for line in self._request_idfm_lines():
if (picto := line.fields.picto) is not None and picto.id_ not in picto_ids:
picto_ids.add(picto.id_)
pictos.append(picto)
lines.append(line)
if len(lines) == batch_size:
await LinePicto.add(IdfmInterface._format_idfm_pictos(*pictos))
await Line.add(await self._format_idfm_lines(*lines))
lines.clear()
pictos.clear()
if pictos:
await LinePicto.add(IdfmInterface._format_idfm_pictos(*pictos))
if lines:
await Line.add(await self._format_idfm_lines(*lines))
async def _load_ratp_pictos(self, batch_size: int = 5) -> None:
pictos = []
async for picto in self._request_ratp_pictos():
pictos.append(picto)
if len(pictos) == batch_size:
formatted_pictos = IdfmInterface._format_ratp_pictos(*pictos)
await LinePicto.add(formatted_pictos.values())
await Line.add_pictos(formatted_pictos)
pictos.clear()
if pictos:
formatted_pictos = IdfmInterface._format_ratp_pictos(*pictos)
await LinePicto.add(formatted_pictos.values())
await Line.add_pictos(formatted_pictos)
async def _load_lines_stops_assos(self, batch_size: int = 5000) -> None:
total_assos_nb = total_found_nb = 0
assos = []
async for asso in self._request_idfm_stops_lines_associations():
fields = asso.fields
try:
stop_id = int(fields.stop_id.rsplit(":", 1)[-1])
except ValueError as err:
print(err)
print(f"{fields.stop_id = }")
continue
assos.append((fields.route_long_name, fields.operatorname, stop_id))
if len(assos) == batch_size:
total_assos_nb += batch_size
total_found_nb += await Line.add_stops(assos)
assos.clear()
if assos:
total_assos_nb += len(assos)
total_found_nb += await Line.add_stops(assos)
print(f"{total_found_nb} line <-> stop ({total_assos_nb = } found)")
async def _load_stop_areas_stops_assos(self, batch_size: int = 5000) -> None:
total_assos_nb = total_found_nb = 0
assos = []
async for asso in self._request_idfm_stop_area_stop_associations():
fields = asso.fields
assos.append((int(fields.zdaid), int(fields.arrid)))
if len(assos) == batch_size:
total_assos_nb += batch_size
total_found_nb += await StopArea.add_stops(assos)
assos.clear()
if assos:
total_assos_nb += len(assos)
total_found_nb += await StopArea.add_stops(assos)
print(f"{total_found_nb} stop area <-> stop ({total_assos_nb = } found)")
async def _request_idfm_stops(self):
# headers = {"Accept": "application/json", "apikey": self._api_key}
# async with ClientSession(headers=headers) as session:
# async with session.get(self.STOPS_URL) as response:
# # print("Status:", response.status)
# if response.status == 200:
# for point in self._json_stops_decoder.decode(await response.read()):
# yield point
# TODO: Use HTTP
async with async_open("./tests/datasets/stops_dataset.json", "rb") as raw:
for element in self._json_stops_decoder.decode(await raw.read()):
yield element
async def _request_idfm_stop_areas(self):
# TODO: Use HTTP
async with async_open("./tests/datasets/zones-d-arrets.json", "rb") as raw:
for element in self._json_stop_areas_decoder.decode(await raw.read()):
yield element
async def _request_idfm_lines(self):
# TODO: Use HTTP
async with async_open("./tests/datasets/lines_dataset.json", "rb") as raw:
for element in self._json_lines_decoder.decode(await raw.read()):
yield element
async def _request_idfm_stops_lines_associations(self):
# TODO: Use HTTP
async with async_open("./tests/datasets/arrets-lignes.json", "rb") as raw:
for element in self._json_stops_lines_assos_decoder.decode(
await raw.read()
):
yield element
async def _request_idfm_stop_area_stop_associations(self):
# TODO: Use HTTP
async with async_open("./tests/datasets/relations.json", "rb") as raw:
for element in self._json_stop_area_stop_asso_decoder.decode(
await raw.read()
):
yield element
async def _request_ratp_pictos(self):
# TODO: Use HTTP
async with async_open(
"./tests/datasets/pictogrammes-des-lignes-de-metro-rer-tramway-bus-et-noctilien.json",
"rb",
) as fd:
for element in self._json_ratp_pictos_decoder.decode(await fd.read()):
yield element
@classmethod
def _format_idfm_pictos(cls, *pictos: IdfmPicto) -> Iterable[LinePicto]:
ret = []
for picto in pictos:
ret.append(
LinePicto(
id=picto.id_,
mime_type=picto.mimetype,
height_px=picto.height,
width_px=picto.width,
filename=picto.filename,
url=f"{cls.IDFM_PICTO_URL}/{picto.id_}/download",
thumbnail=picto.thumbnail,
format=picto.format,
)
)
return ret
@classmethod
def _format_ratp_pictos(cls, *pictos: RatpPicto) -> dict[str, None | LinePicto]:
ret = {}
for picto in pictos:
if (fields := picto.fields.noms_des_fichiers) is not None:
ret[picto.fields.indices_commerciaux] = LinePicto(
id=fields.id_,
mime_type=f"image/{fields.format.lower()}",
height_px=fields.height,
width_px=fields.width,
filename=fields.filename,
url=f"{cls.RATP_PICTO_URL}/{fields.id_}/download",
thumbnail=fields.thumbnail,
format=fields.format,
)
return ret
async def _format_idfm_lines(self, *lines: IdfmLine) -> Iterable[Line]:
ret = []
optional_value = IdfmLine.optional_value
for line in lines:
fields = line.fields
picto_id = fields.picto.id_ if fields.picto is not None else None
picto = await LinePicto.get_by_id(picto_id) if picto_id else None
ret.append(
Line(
id=fields.id_line,
short_name=fields.shortname_line,
name=fields.name_line,
status=IdfmLineState(fields.status.value),
transport_mode=fields.transportmode.value,
transport_submode=optional_value(fields.transportsubmode),
network_name=optional_value(fields.networkname),
group_of_lines_id=optional_value(fields.id_groupoflines),
group_of_lines_shortname=optional_value(
fields.shortname_groupoflines
),
colour_web_hexa=fields.colourweb_hexa,
text_colour_hexa=fields.textcolourprint_hexa,
operator_id=optional_value(fields.operatorref),
operator_name=optional_value(fields.operatorname),
accessibility=fields.accessibility.value,
visual_signs_available=fields.visualsigns_available.value,
audible_signs_available=fields.audiblesigns_available.value,
picto_id=fields.picto.id_ if fields.picto is not None else None,
picto=picto,
record_id=line.recordid,
record_ts=int(line.record_timestamp.timestamp()),
)
)
return ret
@staticmethod
def _format_idfm_stops(*stops: IdfmStop) -> Iterable[Stop]:
for stop in stops:
fields = stop.fields
try:
created_ts = int(fields.arrcreated.timestamp())
except AttributeError:
created_ts = None
yield Stop(
id=int(fields.arrid),
name=fields.arrname,
latitude=fields.arrgeopoint.lat,
longitude=fields.arrgeopoint.lon,
town_name=fields.arrtown,
postal_region=fields.arrpostalregion,
xepsg2154=fields.arrxepsg2154,
yepsg2154=fields.arryepsg2154,
transport_mode=fields.arrtype.value,
version=fields.arrversion,
created_ts=created_ts,
changed_ts=int(fields.arrchanged.timestamp()),
accessibility=fields.arraccessibility.value,
visual_signs_available=fields.arrvisualsigns.value,
audible_signs_available=fields.arraudiblesignals.value,
record_id=stop.recordid,
record_ts=int(stop.record_timestamp.timestamp()),
)
@staticmethod
def _format_idfm_stop_areas(*stop_areas: IdfmStopArea) -> Iterable[StopArea]:
for stop_area in stop_areas:
fields = stop_area.fields
try:
created_ts = int(fields.arrcreated.timestamp())
except AttributeError:
created_ts = None
yield StopArea(
id=int(fields.zdaid),
name=fields.zdaname,
town_name=fields.zdatown,
postal_region=fields.zdapostalregion,
xepsg2154=fields.zdaxepsg2154,
yepsg2154=fields.zdayepsg2154,
type=fields.zdatype.value,
version=fields.zdaversion,
created_ts=created_ts,
changed_ts=int(fields.zdachanged.timestamp()),
)
async def render_line_picto(self, line: Line) -> tuple[None | str, None | str]:
begin_ts = time()
line_picto_path = line_picto_format = None
target = f"/tmp/{line.id}_repr"
picto = line.picto
if picto is not None:
picto_data = await self._get_line_picto(line)
async with async_open(target, "wb") as fd:
await fd.write(picto_data)
line_picto_path = target
line_picto_format = picto.mime_type
print(f"render_line_picto: {time() - begin_ts}")
return (line_picto_path, line_picto_format)
async def _get_line_picto(self, line: Line) -> Optional[ByteString]:
print("---------------------------------------------------------------------")
begin_ts = time()
data = None
picto = line.picto
if picto is not None:
headers = (
self._http_headers if picto.url.startswith(self.IDFM_ROOT_URL) else None
)
session_begin_ts = time()
async with ClientSession(headers=headers) as session:
session_creation_ts = time()
print(f"Session creation {session_creation_ts - session_begin_ts}")
async with session.get(picto.url) as response:
get_end_ts = time()
print(f"GET {get_end_ts - session_creation_ts}")
data = await response.read()
print(f"read {time() - get_end_ts}")
print(f"render_line_picto: {time() - begin_ts}")
print("---------------------------------------------------------------------")
return data
async def get_next_passages(self, stop_point_id: str) -> Optional[IdfmResponse]:
# print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
begin_ts = time()
ret = None
params = {"MonitoringRef": f"STIF:StopPoint:Q:{stop_point_id}:"}
session_begin_ts = time()
async with ClientSession(headers=self._http_headers) as session:
session_creation_ts = time()
# print(f"Session creation {session_creation_ts - session_begin_ts}")
async with session.get(self.IDFM_STOP_MON_URL, params=params) as response:
get_end_ts = time()
# print(f"GET {get_end_ts - session_creation_ts}")
if response.status == 200:
get_end_ts = time()
# print(f"GET {get_end_ts - session_creation_ts}")
data = await response.read()
# print(data)
try:
ret = self._response_json_decoder.decode(data)
except ValidationError as err:
print(err)
# print(f"read {time() - get_end_ts}")
# print(f"get_next_passages: {time() - begin_ts}")
# print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
return ret
async def get_destinations(self, stop_point_id: str) -> Iterable[str]:
# TODO: Store in database the destination for the given stop and line id.
begin_ts = time()
destinations: dict[str, str] = {}
if (res := await self.get_next_passages(stop_point_id)) is not None:
for delivery in res.Siri.ServiceDelivery.StopMonitoringDelivery:
if delivery.Status == IdfmState.true:
for stop_visit in delivery.MonitoredStopVisit:
journey = stop_visit.MonitoredVehicleJourney
if (destination_name := journey.DestinationName) and (
line_ref := journey.LineRef
):
line_id = line_ref.value.replace("STIF:Line::", "")[:-1]
print(f"{line_id = }")
destinations[line_id] = destination_name[0].value
print(f"get_next_passages: {time() - begin_ts}")
return destinations

View File

@@ -0,0 +1,277 @@
from __future__ import annotations
from datetime import datetime
from enum import Enum, StrEnum
from typing import Any, Literal, Optional, NamedTuple
from msgspec import Struct
class Coordinate(NamedTuple):
lat: float
lon: float
class IdfmState(Enum):
unknown = "unknown"
false = "false"
partial = "partial"
true = "true"
class TrainStatus(Enum):
unknown = ""
arrived = "arrived"
onTime = "onTime"
delayed = "delayed"
noReport = "noReport"
early = "early"
cancelled = "cancelled"
undefined = "undefined"
class TransportMode(StrEnum):
bus = "bus"
tram = "tram"
metro = "metro"
rail = "rail"
funicular = "funicular"
class TransportSubMode(Enum):
unknown = "unknown"
localBus = "localBus"
regionalBus = "regionalBus"
highFrequencyBus = "highFrequencyBus"
expressBus = "expressBus"
nightBus = "nightBus"
demandAndResponseBus = "demandAndResponseBus"
airportLinkBus = "airportLinkBus"
regionalRail = "regionalRail"
railShuttle = "railShuttle"
suburbanRailway = "suburbanRailway"
local = "local"
class StopFields(Struct, kw_only=True):
arrgeopoint: Coordinate
arrtown: str
arrcreated: None | datetime = None
arryepsg2154: int
arrpostalregion: str
arrid: str
arrxepsg2154: int
arraccessibility: IdfmState
arrvisualsigns: IdfmState
arrtype: TransportMode
arrname: str
arrversion: str
arrchanged: datetime
arraudiblesignals: IdfmState
class Point(Struct):
coordinates: Coordinate
class Stop(Struct):
datasetid: str
recordid: str
fields: StopFields
record_timestamp: datetime
# geometry: Union[Point]
Stops = dict[str, Stop]
class StopAreaType(Enum):
metroStation = "metroStation"
onstreetBus = "onstreetBus"
onstreetTram = "onstreetTram"
railStation = "railStation"
class StopAreaFields(Struct, kw_only=True):
zdaname: str
zdcid: str
zdatown: str
zdaversion: str
zdaid: str
zdacreated: Optional[datetime] = None
zdatype: StopAreaType
zdayepsg2154: int
zdapostalregion: str
zdachanged: datetime
zdaxepsg2154: int
class StopArea(Struct):
datasetid: str
recordid: str
fields: StopAreaFields
record_timestamp: datetime
class StopAreaStopAssociationFields(Struct, kw_only=True):
arrid: str # TODO: use int ?
artid: Optional[str] = None
arrversion: str
zdcid: str
version: int
zdaid: str
zdaversion: str
artversion: Optional[str] = None
class StopAreaStopAssociation(Struct):
datasetid: str
recordid: str
fields: StopAreaStopAssociationFields
record_timestamp: datetime
class IdfmLineState(Enum):
active = "active"
class LinePicto(Struct, rename={"id_": "id"}):
id_: str
mimetype: str
height: int
width: int
filename: str
thumbnail: bool
format: str
# color_summary: list[str]
class LineFields(Struct, kw_only=True):
name_line: str
status: IdfmLineState
accessibility: IdfmState
shortname_groupoflines: Optional[str] = None
transportmode: TransportMode
colourweb_hexa: str
textcolourprint_hexa: str
transportsubmode: Optional[TransportSubMode] = TransportSubMode.unknown
operatorref: Optional[str] = None
visualsigns_available: IdfmState
networkname: Optional[str] = None
id_line: str
id_groupoflines: Optional[str] = None
operatorname: Optional[str] = None
audiblesigns_available: IdfmState
shortname_line: str
picto: Optional[LinePicto] = None
class Line(Struct):
datasetid: str
recordid: str
fields: LineFields
record_timestamp: datetime
@staticmethod
def optional_value(value: Any) -> Any:
if value:
return value.value if isinstance(value, Enum) else value
return "NULL"
Lines = dict[str, Line]
# TODO: Set structs frozen
class StopLineAssoFields(Struct):
pointgeo: Coordinate
stop_id: str
stop_name: str
operatorname: str
nom_commune: str
route_long_name: str
id: str
stop_lat: str
stop_lon: str
code_insee: str
class StopLineAsso(Struct):
datasetid: str
recordid: str
fields: StopLineAssoFields
# geometry: Union[Point]
class Value(Struct):
value: str
class FramedVehicleJourney(Struct):
DataFrameRef: Value
DatedVehicleJourneyRef: str
class TrainNumber(Struct):
TrainNumberRef: list[Value]
class MonitoredCall(Struct, kw_only=True):
Order: Optional[int] = None
StopPointName: list[Value]
VehicleAtStop: bool
DestinationDisplay: list[Value]
AimedArrivalTime: Optional[datetime] = None
ExpectedArrivalTime: Optional[datetime] = None
ArrivalPlatformName: Optional[Value] = None
AimedDepartureTime: Optional[datetime] = None
ExpectedDepartureTime: Optional[datetime] = None
ArrivalStatus: TrainStatus = None
DepartureStatus: TrainStatus = None
class MonitoredVehicleJourney(Struct, kw_only=True):
LineRef: Value
OperatorRef: Value
FramedVehicleJourneyRef: FramedVehicleJourney
DestinationRef: Value
DestinationName: list[Value] | None = None
JourneyNote: list[Value] | None = None
TrainNumbers: Optional[TrainNumber] = None
MonitoredCall: MonitoredCall
class StopDelivery(Struct):
RecordedAtTime: datetime
ItemIdentifier: str
MonitoringRef: Value
MonitoredVehicleJourney: MonitoredVehicleJourney
class StopMonitoringDelivery(Struct):
ResponseTimestamp: datetime
Version: str
Status: IdfmState
MonitoredStopVisit: list[StopDelivery]
class ServiceDelivery(Struct):
ResponseTimestamp: datetime
ProducerRef: str
ResponseMessageIdentifier: str
StopMonitoringDelivery: list[StopMonitoringDelivery]
class Siri(Struct):
ServiceDelivery: ServiceDelivery
class IdfmOperator(Enum):
SNCF = "SNCF"
class IdfmResponse(Struct):
Siri: Siri

View File

@@ -0,0 +1,25 @@
from datetime import datetime
from typing import Optional
from msgspec import Struct
class PictoFieldsFile(Struct, rename={"id_": "id"}):
id_: str
height: int
width: int
filename: str
thumbnail: bool
format: str
class PictoFields(Struct):
indices_commerciaux: str
noms_des_fichiers: Optional[PictoFieldsFile] = None
class Picto(Struct):
datasetid: str
recordid: str
fields: PictoFields
record_timestamp: datetime

View File

@@ -0,0 +1,3 @@
from .line import Line, LinePicto
from .stop import Stop, StopArea
from .user import UserLastStopSearchResults

View File

@@ -0,0 +1,176 @@
from asyncio import gather as asyncio_gather
from collections import defaultdict
from typing import Iterable, Self
from sqlalchemy import (
BigInteger,
Boolean,
Column,
Enum,
ForeignKey,
Integer,
select,
String,
Table,
)
from sqlalchemy.orm import Mapped, relationship, selectinload
from sqlalchemy.orm.attributes import set_committed_value
from sqlalchemy.sql.expression import tuple_
from ..db import Base, db
from ..idfm_interface.idfm_types import (
IdfmState,
IdfmLineState,
TransportMode,
TransportSubMode,
)
from .stop import _Stop
line_stop_association_table = Table(
"line_stop_association_table",
Base.metadata,
Column("line_id", ForeignKey("lines.id")),
Column("stop_id", ForeignKey("_stops.id")),
)
class LinePicto(Base):
db = db
id = Column(String, primary_key=True)
mime_type = Column(String, nullable=False)
height_px = Column(Integer, nullable=False)
width_px = Column(Integer, nullable=False)
filename = Column(String, nullable=False)
url = Column(String, nullable=False)
thumbnail = Column(Boolean, nullable=False)
format = Column(String, nullable=False)
__tablename__ = "line_pictos"
class Line(Base):
db = db
id = Column(String, primary_key=True)
short_name = Column(String)
name = Column(String, nullable=False)
status = Column(Enum(IdfmLineState), nullable=False)
transport_mode = Column(Enum(TransportMode), nullable=False)
transport_submode = Column(Enum(TransportSubMode), nullable=False)
network_name = Column(String)
group_of_lines_id = Column(String)
group_of_lines_shortname = Column(String)
colour_web_hexa = Column(String, nullable=False)
text_colour_hexa = Column(String, nullable=False)
operator_id = Column(String)
operator_name = Column(String)
accessibility = Column(Enum(IdfmState), nullable=False)
visual_signs_available = Column(Enum(IdfmState), nullable=False)
audible_signs_available = Column(Enum(IdfmState), nullable=False)
picto_id = Column(String, ForeignKey("line_pictos.id"))
picto: Mapped[LinePicto] = relationship(LinePicto, lazy="selectin")
record_id = Column(String, nullable=False)
record_ts = Column(BigInteger, nullable=False)
stops: Mapped[list["_Stop"]] = relationship(
"_Stop",
secondary=line_stop_association_table,
back_populates="lines",
lazy="selectin",
)
__tablename__ = "lines"
@classmethod
async def get_by_name(
cls, name: str, operator_name: None | str = None
) -> list[Self]:
filters = {"name": name}
if operator_name is not None:
filters["operator_name"] = operator_name
lines = None
stmt = (
select(Line)
.filter_by(**filters)
.options(selectinload(Line.stops), selectinload(Line.picto))
)
res = await cls.db.session.execute(stmt)
lines = res.scalars().all()
return lines
@classmethod
async def _add_picto_to_line(cls, line: str | Self, picto: LinePicto) -> None:
if isinstance(line, str):
if (lines := await cls.get_by_name(line)) is not None:
if len(lines) == 1:
line = lines[0]
else:
for candidate_line in lines:
if candidate_line.operator_name == "RATP":
line = candidate_line
break
if isinstance(line, Line) and line.picto is None:
line.picto = picto
line.picto_id = picto.id
@classmethod
async def add_pictos(cls, line_to_pictos: dict[str | Self, LinePicto]) -> None:
await asyncio_gather(
*[
cls._add_picto_to_line(line, picto)
for line, picto in line_to_pictos.items()
]
)
await cls.db.session.commit()
@classmethod
async def add_stops(cls, line_to_stop_ids: Iterable[tuple[str, str, str]]) -> int:
line_names_ops, stop_ids = set(), set()
for line_name, operator_name, stop_id in line_to_stop_ids:
line_names_ops.add((line_name, operator_name))
stop_ids.add(stop_id)
res = await cls.db.session.execute(
select(Line).where(
tuple_(Line.name, Line.operator_name).in_(line_names_ops)
)
)
lines = defaultdict(list)
for line in res.scalars():
lines[(line.name, line.operator_name)].append(line)
res = await cls.db.session.execute(select(_Stop).where(_Stop.id.in_(stop_ids)))
stops = {stop.id: stop for stop in res.scalars()}
found = 0
for line_name, operator_name, stop_id in line_to_stop_ids:
if (stop := stops.get(stop_id)) is not None:
if (stop_lines := lines.get((line_name, operator_name))) is not None:
if len(stop_lines) > 1:
print(stop_lines)
for stop_line in stop_lines:
stop_line.stops.append(stop)
found += 1
else:
print(f"No line found for {line_name}/{operator_name}")
else:
print(
f"No stop found for {stop_id} id (used by {line_name}/{operator_name})"
)
await cls.db.session.commit()
return found

View File

@@ -0,0 +1,144 @@
from typing import Iterable, Self
from sqlalchemy import (
BigInteger,
Column,
Enum,
Float,
ForeignKey,
select,
String,
Table,
)
from sqlalchemy.orm import Mapped, relationship, selectinload, with_polymorphic
from sqlalchemy.schema import Index
from ..db import Base, db
from ..idfm_interface.idfm_types import TransportMode, IdfmState, StopAreaType
stop_area_stop_association_table = Table(
"stop_area_stop_association_table",
Base.metadata,
Column("stop_id", ForeignKey("_stops.id")),
Column("stop_area_id", ForeignKey("stop_areas.id")),
)
class _Stop(Base):
db = db
id = Column(BigInteger, primary_key=True)
kind = Column(String)
name = Column(String, nullable=False, index=True)
town_name = Column(String, nullable=False)
postal_region = Column(String, nullable=False)
xepsg2154 = Column(BigInteger, nullable=False)
yepsg2154 = Column(BigInteger, nullable=False)
version = Column(String, nullable=False)
created_ts = Column(BigInteger)
changed_ts = Column(BigInteger, nullable=False)
lines: Mapped[list["Line"]] = relationship(
"Line",
secondary="line_stop_association_table",
back_populates="stops",
# lazy="joined",
lazy="selectin",
)
areas: Mapped[list["StopArea"]] = relationship(
"StopArea", secondary=stop_area_stop_association_table, back_populates="stops"
)
__tablename__ = "_stops"
__mapper_args__ = {"polymorphic_identity": "_stops", "polymorphic_on": kind}
__table_args__ = (
# To optimize the ilike requests
Index(
"name_idx_gin",
name,
postgresql_ops={"name": "gin_trgm_ops"},
postgresql_using="gin",
),
)
# TODO: Test https://www.cybertec-postgresql.com/en/postgresql-more-performance-for-like-and-ilike-statements/
# TODO: Should be able to remove with_polymorphic ?
@classmethod
async def get_by_name(cls, name: str) -> list[Self]:
stop_stop_area = with_polymorphic(_Stop, [Stop, StopArea])
stmt = (
select(stop_stop_area)
.where(stop_stop_area.name.ilike(f"%{name}%"))
.options(
selectinload(stop_stop_area.areas),
selectinload(stop_stop_area.lines),
)
)
res = await cls.db.session.execute(stmt)
return res.scalars()
class Stop(_Stop):
id = Column(BigInteger, ForeignKey("_stops.id"), primary_key=True)
latitude = Column(Float, nullable=False)
longitude = Column(Float, nullable=False)
transport_mode = Column(Enum(TransportMode), nullable=False)
accessibility = Column(Enum(IdfmState), nullable=False)
visual_signs_available = Column(Enum(IdfmState), nullable=False)
audible_signs_available = Column(Enum(IdfmState), nullable=False)
record_id = Column(String, nullable=False)
record_ts = Column(BigInteger, nullable=False)
__tablename__ = "stops"
__mapper_args__ = {"polymorphic_identity": "stops", "polymorphic_load": "inline"}
class StopArea(_Stop):
id = Column(BigInteger, ForeignKey("_stops.id"), primary_key=True)
type = Column(Enum(StopAreaType), nullable=False)
stops: Mapped[list[_Stop]] = relationship(
_Stop,
secondary=stop_area_stop_association_table,
back_populates="areas",
lazy="selectin",
# lazy="joined",
)
__tablename__ = "stop_areas"
__mapper_args__ = {"polymorphic_identity": "stop_areas", "polymorphic_load": "inline"}
@classmethod
async def add_stops(cls, stop_area_to_stop_ids: Iterable[tuple[str, str]]) -> int:
stop_area_ids, stop_ids = set(), set()
for stop_area_id, stop_id in stop_area_to_stop_ids:
stop_area_ids.add(stop_area_id)
stop_ids.add(stop_id)
res = await cls.db.session.execute(
select(StopArea)
.where(StopArea.id.in_(stop_area_ids))
.options(selectinload(StopArea.stops))
)
stop_areas = {stop_area.id: stop_area for stop_area in res.scalars()}
res = await cls.db.session.execute(select(_Stop).where(_Stop.id.in_(stop_ids)))
stops = {stop.id: stop for stop in res.scalars()}
found = 0
for stop_area_id, stop_id in stop_area_to_stop_ids:
if (stop_area := stop_areas.get(stop_area_id)) is not None:
if (stop := stops.get(stop_id)) is not None:
stop_area.stops.append(stop)
found += 1
else:
print(f"No stop found for {stop_id} id")
else:
print(f"No stop area found for {stop_area_id}")
await cls.db.session.commit()
return found

View File

@@ -0,0 +1,25 @@
from sqlalchemy import Column, ForeignKey, String, Table
from sqlalchemy.orm import Mapped, relationship
from ..db import Base, db
from .stop import _Stop
user_last_stop_search_stops_associations_table = Table(
"user_last_stop_search_stops_associations_table",
Base.metadata,
Column("user_mxid", ForeignKey("user_last_stop_search_results.user_mxid")),
Column("stop_id", ForeignKey("_stops.id")),
)
class UserLastStopSearchResults(Base):
db = db
__tablename__ = "user_last_stop_search_results"
user_mxid = Column(String, primary_key=True)
request_content = Column(String, nullable=False)
stops: Mapped[list[_Stop]] = relationship(
_Stop, secondary=user_last_stop_search_stops_associations_table
)

View File

@@ -0,0 +1,3 @@
from .line import Line, TransportMode
from .next_passage import NextPassage, NextPassages
from .stop import Stop, StopArea

View File

@@ -0,0 +1,60 @@
from enum import StrEnum
from typing import Self
from pydantic import BaseModel
from ..idfm_interface import (
IdfmLineState,
IdfmState,
TransportMode as IdfmTransportMode,
TransportSubMode as IdfmTransportSubMode,
)
class TransportMode(StrEnum):
"""Computed transport mode from
idfm_interface.TransportMode and idfm_interface.TransportSubMode.
"""
bus = "bus"
tram = "tram"
metro = "metro"
funicular = "funicular"
# idfm_types.TransportMode.rail + idfm_types.TransportSubMode.regionalRail
rail_ter = "ter"
# idfm_types.TransportMode.rail + idfm_types.TransportSubMode.local
rail_rer = "rer"
# idfm_types.TransportMode.rail + idfm_types.TransportSubMode.suburbanRailway
rail_transilien = "transilien"
# idfm_types.TransportMode.rail + idfm_types.TransportSubMode.railShuttle
val = "val"
@classmethod
def from_idfm_transport_mode(
cls, mode: IdfmTransportMode, sub_mode: IdfmTransportSubMode
) -> Self:
if mode == IdfmTransportMode.rail:
if sub_mode == IdfmTransportSubMode.regionalRail:
return cls.rail_ter
if sub_mode == IdfmTransportSubMode.local:
return cls.rail_rer
if sub_mode == IdfmTransportSubMode.suburbanRailway:
return cls.rail_transilien
if sub_mode == IdfmTransportSubMode.railShuttle:
return cls.val
return TransportMode(mode)
class Line(BaseModel):
id: str
shortName: str
name: str
status: IdfmLineState
transportMode: TransportMode
backColorHexa: str
foreColorHexa: str
operatorId: str
accessibility: IdfmState
visualSignsAvailable: IdfmState
audibleSignsAvailable: IdfmState
stopIds: list[str]

View File

@@ -0,0 +1,22 @@
from pydantic import BaseModel
from ..idfm_interface.idfm_types import TrainStatus
class NextPassage(BaseModel):
line: str
operator: str
destinations: list[str]
atStop: bool
aimedArrivalTs: None | int
expectedArrivalTs: None | int
arrivalPlatformName: None | str
aimedDepartTs: None | int
expectedDepartTs: None | int
arrivalStatus: TrainStatus
departStatus: TrainStatus
class NextPassages(BaseModel):
ts: int
passages: dict[str, dict[str, list[NextPassage]]]

View File

@@ -0,0 +1,25 @@
from pydantic import BaseModel
from ..idfm_interface import IdfmLineState, IdfmState, StopAreaType, TransportMode
class Stop(BaseModel):
id: int
name: str
town: str
lat: float
lon: float
# xepsg2154: int
# yepsg2154: int
lines: list[str]
class StopArea(BaseModel):
id: int
name: str
town: str
# xepsg2154: int
# yepsg2154: int
type: StopAreaType
lines: list[str] # SNCF lines are linked to stop areas and not stops.
stops: list[Stop]