🎉 First commit !!!

This commit is contained in:
2023-01-22 16:53:45 +01:00
commit dde835760a
68 changed files with 3250 additions and 0 deletions

1
backend/.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
!**/__pycache__/

0
backend/README.md Normal file
View File

View File

@@ -0,0 +1,19 @@
version: '3.7'
services:
database:
image: postgres:15.1-alpine
restart: always
environment:
- POSTGRES_USER=postgres
- POSTGRES_PASSWORD=postgres
logging:
options:
max-size: 10m
max-file: "3"
ports:
- '127.0.0.1:5438:5432'
volumes:
- ./docker/database/docker-entrypoint-initdb.d:/docker-entrypoint-initdb.d
- ./docker/database/data:/var/lib/postgresql/data

View File

@@ -0,0 +1,11 @@
#!/bin/bash
set -e
psql -v ON_ERROR_STOP=1 --username "$POSTGRES_USER" --dbname "$POSTGRES_DB" <<-EOSQL
CREATE USER idfm_matrix_bot;
CREATE DATABASE bot;
CREATE DATABASE idfm;
GRANT ALL PRIVILEGES ON DATABASE bot TO idfm_matrix_bot;
GRANT ALL PRIVILEGES ON DATABASE idfm TO idfm_matrix_bot;
EOSQL

View File

@@ -0,0 +1,4 @@
from .db import Database
from .base_class import Base
db = Database()

View File

@@ -0,0 +1,34 @@
from collections.abc import Iterable
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import declarative_base
from typing import Iterable, Self
Base = declarative_base()
Base.db = None
async def base_add(cls, stops: Self | Iterable[Self]) -> bool:
try:
method = (
cls.db.session.add_all
if isinstance(stops, Iterable)
else cls.db.session.add
)
method(stops)
await cls.db.session.commit()
except IntegrityError as err:
print(err)
Base.add = classmethod(base_add)
async def base_get_by_id(cls, id_: int | str) -> None | Base:
res = await cls.db.session.execute(select(cls).where(cls.id == id_))
element = res.scalar_one_or_none()
return element
Base.get_by_id = classmethod(base_get_by_id)

View File

@@ -0,0 +1,80 @@
from asyncio import gather as asyncio_gather
from functools import wraps
from pathlib import Path
from time import time
from typing import Callable, Iterable, Optional
from rich import print
from sqlalchemy import event, select, tuple_
from sqlalchemy.engine import Engine
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import (
selectinload,
sessionmaker,
with_polymorphic,
)
from sqlalchemy.orm.attributes import set_committed_value
from .base_class import Base
# import logging
# logging.basicConfig()
# logger = logging.getLogger("bot.sqltime")
# logger.setLevel(logging.DEBUG)
# @event.listens_for(Engine, "before_cursor_execute")
# def before_cursor_execute(conn, cursor, statement, parameters, context, executemany):
# conn.info.setdefault("query_start_time", []).append(time())
# logger.debug("Start Query: %s", statement)
# @event.listens_for(Engine, "after_cursor_execute")
# def after_cursor_execute(conn, cursor, statement, parameters, context, executemany):
# total = time() - conn.info["query_start_time"].pop(-1)
# logger.debug("Query Complete!")
# logger.debug("Total Time: %f", total)
class Database:
def __init__(self) -> None:
self._engine = None
self._session_maker = None
self._session = None
@property
def session(self) -> None:
if self._session is None:
self._session = self._session_maker()
return self._session
def use_session(func: Callable):
@wraps(func)
async def wrapper(self, *args, **kwargs):
if self._check_session() is not None:
return await func(self, *args, **kwargs)
# TODO: Raise an exception ?
return wrapper
async def connect(self, db_path: str, clear_static_data: bool = False) -> None:
# TODO: Preserve UserLastStopSearchResults table from drop.
self._engine = create_async_engine(db_path)
self._session_maker = sessionmaker(
self._engine, expire_on_commit=False, class_=AsyncSession
)
await self.session.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm;")
async with self._engine.begin() as conn:
if clear_static_data:
await conn.run_sync(Base.metadata.drop_all)
await conn.run_sync(Base.metadata.create_all)
async def disconnect(self) -> None:
if self._session is not None:
await self._session.close()
self._session = None
await self._engine.dispose()

View File

@@ -0,0 +1,2 @@
from .idfm_interface import IdfmInterface
from .idfm_types import *

View File

@@ -0,0 +1,447 @@
from pathlib import Path
from re import compile as re_compile
from time import time
from typing import ByteString, Iterable, List, Optional
from aiofiles import open as async_open
from aiohttp import ClientSession
from msgspec import ValidationError
from msgspec.json import Decoder
from rich import print
from ..db import Database
from ..models import Line, LinePicto, Stop, StopArea
from .idfm_types import (
IdfmLineState,
IdfmResponse,
Line as IdfmLine,
MonitoredVehicleJourney,
LinePicto as IdfmPicto,
IdfmState,
Stop as IdfmStop,
StopArea as IdfmStopArea,
StopAreaStopAssociation,
StopLineAsso as IdfmStopLineAsso,
Stops,
)
from .ratp_types import Picto as RatpPicto
class IdfmInterface:
IDFM_ROOT_URL = "https://prim.iledefrance-mobilites.fr/marketplace"
IDFM_STOP_MON_URL = f"{IDFM_ROOT_URL}/stop-monitoring"
IDFM_ROOT_URL = "https://data.iledefrance-mobilites.fr/explore/dataset"
IDFM_STOPS_URL = (
f"{IDFM_ROOT_URL}/arrets/download/?format=json&timezone=Europe/Berlin"
)
IDFM_PICTO_URL = f"{IDFM_ROOT_URL}/referentiel-des-lignes/files"
RATP_ROOT_URL = "https://data.ratp.fr/explore/dataset"
RATP_PICTO_URL = f"{RATP_ROOT_URL}/pictogrammes-des-lignes-de-metro-rer-tramway-bus-et-noctilien/files"
OPERATOR_RE = re_compile(r"[^:]+:Operator::([^:]+):")
LINE_RE = re_compile(r"[^:]+:Line::([^:]+):")
def __init__(self, api_key: str, database: Database) -> None:
self._api_key = api_key
self._database = database
self._http_headers = {"Accept": "application/json", "apikey": self._api_key}
self._json_stops_decoder = Decoder(type=List[IdfmStop])
self._json_stop_areas_decoder = Decoder(type=List[IdfmStopArea])
self._json_lines_decoder = Decoder(type=List[IdfmLine])
self._json_stops_lines_assos_decoder = Decoder(type=List[IdfmStopLineAsso])
self._json_ratp_pictos_decoder = Decoder(type=List[RatpPicto])
self._json_stop_area_stop_asso_decoder = Decoder(
type=List[StopAreaStopAssociation]
)
self._response_json_decoder = Decoder(type=IdfmResponse)
async def startup(self) -> None:
BATCH_SIZE = 10000
STEPS = (
(
StopArea,
self._request_idfm_stop_areas,
IdfmInterface._format_idfm_stop_areas,
),
(Stop, self._request_idfm_stops, IdfmInterface._format_idfm_stops),
)
for model, get_method, format_method in STEPS:
step_begin_ts = time()
elements = []
async for element in get_method():
elements.append(element)
if len(elements) == BATCH_SIZE:
await model.add(format_method(*elements))
elements.clear()
if elements:
await model.add(format_method(*elements))
print(f"Add {model.__name__}s: {time() - step_begin_ts}s")
begin_ts = time()
await self._load_lines()
print(f"Add Lines and IDFM LinePictos: {time() - begin_ts}s")
begin_ts = time()
await self._load_ratp_pictos(30)
print(f"Add RATP LinePictos: {time() - begin_ts}s")
begin_ts = time()
await self._load_lines_stops_assos()
print(f"Link Stops to Lines: {time() - begin_ts}s")
begin_ts = time()
await self._load_stop_areas_stops_assos()
print(f"Link Stops to StopAreas: {time() - begin_ts}s")
async def _load_lines(self, batch_size: int = 5000) -> None:
lines, pictos = [], []
picto_ids = set()
async for line in self._request_idfm_lines():
if (picto := line.fields.picto) is not None and picto.id_ not in picto_ids:
picto_ids.add(picto.id_)
pictos.append(picto)
lines.append(line)
if len(lines) == batch_size:
await LinePicto.add(IdfmInterface._format_idfm_pictos(*pictos))
await Line.add(await self._format_idfm_lines(*lines))
lines.clear()
pictos.clear()
if pictos:
await LinePicto.add(IdfmInterface._format_idfm_pictos(*pictos))
if lines:
await Line.add(await self._format_idfm_lines(*lines))
async def _load_ratp_pictos(self, batch_size: int = 5) -> None:
pictos = []
async for picto in self._request_ratp_pictos():
pictos.append(picto)
if len(pictos) == batch_size:
formatted_pictos = IdfmInterface._format_ratp_pictos(*pictos)
await LinePicto.add(formatted_pictos.values())
await Line.add_pictos(formatted_pictos)
pictos.clear()
if pictos:
formatted_pictos = IdfmInterface._format_ratp_pictos(*pictos)
await LinePicto.add(formatted_pictos.values())
await Line.add_pictos(formatted_pictos)
async def _load_lines_stops_assos(self, batch_size: int = 5000) -> None:
total_assos_nb = total_found_nb = 0
assos = []
async for asso in self._request_idfm_stops_lines_associations():
fields = asso.fields
try:
stop_id = int(fields.stop_id.rsplit(":", 1)[-1])
except ValueError as err:
print(err)
print(f"{fields.stop_id = }")
continue
assos.append((fields.route_long_name, fields.operatorname, stop_id))
if len(assos) == batch_size:
total_assos_nb += batch_size
total_found_nb += await Line.add_stops(assos)
assos.clear()
if assos:
total_assos_nb += len(assos)
total_found_nb += await Line.add_stops(assos)
print(f"{total_found_nb} line <-> stop ({total_assos_nb = } found)")
async def _load_stop_areas_stops_assos(self, batch_size: int = 5000) -> None:
total_assos_nb = total_found_nb = 0
assos = []
async for asso in self._request_idfm_stop_area_stop_associations():
fields = asso.fields
assos.append((int(fields.zdaid), int(fields.arrid)))
if len(assos) == batch_size:
total_assos_nb += batch_size
total_found_nb += await StopArea.add_stops(assos)
assos.clear()
if assos:
total_assos_nb += len(assos)
total_found_nb += await StopArea.add_stops(assos)
print(f"{total_found_nb} stop area <-> stop ({total_assos_nb = } found)")
async def _request_idfm_stops(self):
# headers = {"Accept": "application/json", "apikey": self._api_key}
# async with ClientSession(headers=headers) as session:
# async with session.get(self.STOPS_URL) as response:
# # print("Status:", response.status)
# if response.status == 200:
# for point in self._json_stops_decoder.decode(await response.read()):
# yield point
# TODO: Use HTTP
async with async_open("./tests/datasets/stops_dataset.json", "rb") as raw:
for element in self._json_stops_decoder.decode(await raw.read()):
yield element
async def _request_idfm_stop_areas(self):
# TODO: Use HTTP
async with async_open("./tests/datasets/zones-d-arrets.json", "rb") as raw:
for element in self._json_stop_areas_decoder.decode(await raw.read()):
yield element
async def _request_idfm_lines(self):
# TODO: Use HTTP
async with async_open("./tests/datasets/lines_dataset.json", "rb") as raw:
for element in self._json_lines_decoder.decode(await raw.read()):
yield element
async def _request_idfm_stops_lines_associations(self):
# TODO: Use HTTP
async with async_open("./tests/datasets/arrets-lignes.json", "rb") as raw:
for element in self._json_stops_lines_assos_decoder.decode(
await raw.read()
):
yield element
async def _request_idfm_stop_area_stop_associations(self):
# TODO: Use HTTP
async with async_open("./tests/datasets/relations.json", "rb") as raw:
for element in self._json_stop_area_stop_asso_decoder.decode(
await raw.read()
):
yield element
async def _request_ratp_pictos(self):
# TODO: Use HTTP
async with async_open(
"./tests/datasets/pictogrammes-des-lignes-de-metro-rer-tramway-bus-et-noctilien.json",
"rb",
) as fd:
for element in self._json_ratp_pictos_decoder.decode(await fd.read()):
yield element
@classmethod
def _format_idfm_pictos(cls, *pictos: IdfmPicto) -> Iterable[LinePicto]:
ret = []
for picto in pictos:
ret.append(
LinePicto(
id=picto.id_,
mime_type=picto.mimetype,
height_px=picto.height,
width_px=picto.width,
filename=picto.filename,
url=f"{cls.IDFM_PICTO_URL}/{picto.id_}/download",
thumbnail=picto.thumbnail,
format=picto.format,
)
)
return ret
@classmethod
def _format_ratp_pictos(cls, *pictos: RatpPicto) -> dict[str, None | LinePicto]:
ret = {}
for picto in pictos:
if (fields := picto.fields.noms_des_fichiers) is not None:
ret[picto.fields.indices_commerciaux] = LinePicto(
id=fields.id_,
mime_type=f"image/{fields.format.lower()}",
height_px=fields.height,
width_px=fields.width,
filename=fields.filename,
url=f"{cls.RATP_PICTO_URL}/{fields.id_}/download",
thumbnail=fields.thumbnail,
format=fields.format,
)
return ret
async def _format_idfm_lines(self, *lines: IdfmLine) -> Iterable[Line]:
ret = []
optional_value = IdfmLine.optional_value
for line in lines:
fields = line.fields
picto_id = fields.picto.id_ if fields.picto is not None else None
picto = await LinePicto.get_by_id(picto_id) if picto_id else None
ret.append(
Line(
id=fields.id_line,
short_name=fields.shortname_line,
name=fields.name_line,
status=IdfmLineState(fields.status.value),
transport_mode=fields.transportmode.value,
transport_submode=optional_value(fields.transportsubmode),
network_name=optional_value(fields.networkname),
group_of_lines_id=optional_value(fields.id_groupoflines),
group_of_lines_shortname=optional_value(
fields.shortname_groupoflines
),
colour_web_hexa=fields.colourweb_hexa,
text_colour_hexa=fields.textcolourprint_hexa,
operator_id=optional_value(fields.operatorref),
operator_name=optional_value(fields.operatorname),
accessibility=fields.accessibility.value,
visual_signs_available=fields.visualsigns_available.value,
audible_signs_available=fields.audiblesigns_available.value,
picto_id=fields.picto.id_ if fields.picto is not None else None,
picto=picto,
record_id=line.recordid,
record_ts=int(line.record_timestamp.timestamp()),
)
)
return ret
@staticmethod
def _format_idfm_stops(*stops: IdfmStop) -> Iterable[Stop]:
for stop in stops:
fields = stop.fields
try:
created_ts = int(fields.arrcreated.timestamp())
except AttributeError:
created_ts = None
yield Stop(
id=int(fields.arrid),
name=fields.arrname,
latitude=fields.arrgeopoint.lat,
longitude=fields.arrgeopoint.lon,
town_name=fields.arrtown,
postal_region=fields.arrpostalregion,
xepsg2154=fields.arrxepsg2154,
yepsg2154=fields.arryepsg2154,
transport_mode=fields.arrtype.value,
version=fields.arrversion,
created_ts=created_ts,
changed_ts=int(fields.arrchanged.timestamp()),
accessibility=fields.arraccessibility.value,
visual_signs_available=fields.arrvisualsigns.value,
audible_signs_available=fields.arraudiblesignals.value,
record_id=stop.recordid,
record_ts=int(stop.record_timestamp.timestamp()),
)
@staticmethod
def _format_idfm_stop_areas(*stop_areas: IdfmStopArea) -> Iterable[StopArea]:
for stop_area in stop_areas:
fields = stop_area.fields
try:
created_ts = int(fields.arrcreated.timestamp())
except AttributeError:
created_ts = None
yield StopArea(
id=int(fields.zdaid),
name=fields.zdaname,
town_name=fields.zdatown,
postal_region=fields.zdapostalregion,
xepsg2154=fields.zdaxepsg2154,
yepsg2154=fields.zdayepsg2154,
type=fields.zdatype.value,
version=fields.zdaversion,
created_ts=created_ts,
changed_ts=int(fields.zdachanged.timestamp()),
)
async def render_line_picto(self, line: Line) -> tuple[None | str, None | str]:
begin_ts = time()
line_picto_path = line_picto_format = None
target = f"/tmp/{line.id}_repr"
picto = line.picto
if picto is not None:
picto_data = await self._get_line_picto(line)
async with async_open(target, "wb") as fd:
await fd.write(picto_data)
line_picto_path = target
line_picto_format = picto.mime_type
print(f"render_line_picto: {time() - begin_ts}")
return (line_picto_path, line_picto_format)
async def _get_line_picto(self, line: Line) -> Optional[ByteString]:
print("---------------------------------------------------------------------")
begin_ts = time()
data = None
picto = line.picto
if picto is not None:
headers = (
self._http_headers if picto.url.startswith(self.IDFM_ROOT_URL) else None
)
session_begin_ts = time()
async with ClientSession(headers=headers) as session:
session_creation_ts = time()
print(f"Session creation {session_creation_ts - session_begin_ts}")
async with session.get(picto.url) as response:
get_end_ts = time()
print(f"GET {get_end_ts - session_creation_ts}")
data = await response.read()
print(f"read {time() - get_end_ts}")
print(f"render_line_picto: {time() - begin_ts}")
print("---------------------------------------------------------------------")
return data
async def get_next_passages(self, stop_point_id: str) -> Optional[IdfmResponse]:
# print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
begin_ts = time()
ret = None
params = {"MonitoringRef": f"STIF:StopPoint:Q:{stop_point_id}:"}
session_begin_ts = time()
async with ClientSession(headers=self._http_headers) as session:
session_creation_ts = time()
# print(f"Session creation {session_creation_ts - session_begin_ts}")
async with session.get(self.IDFM_STOP_MON_URL, params=params) as response:
get_end_ts = time()
# print(f"GET {get_end_ts - session_creation_ts}")
if response.status == 200:
get_end_ts = time()
# print(f"GET {get_end_ts - session_creation_ts}")
data = await response.read()
# print(data)
try:
ret = self._response_json_decoder.decode(data)
except ValidationError as err:
print(err)
# print(f"read {time() - get_end_ts}")
# print(f"get_next_passages: {time() - begin_ts}")
# print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
return ret
async def get_destinations(self, stop_point_id: str) -> Iterable[str]:
# TODO: Store in database the destination for the given stop and line id.
begin_ts = time()
destinations: dict[str, str] = {}
if (res := await self.get_next_passages(stop_point_id)) is not None:
for delivery in res.Siri.ServiceDelivery.StopMonitoringDelivery:
if delivery.Status == IdfmState.true:
for stop_visit in delivery.MonitoredStopVisit:
journey = stop_visit.MonitoredVehicleJourney
if (destination_name := journey.DestinationName) and (
line_ref := journey.LineRef
):
line_id = line_ref.value.replace("STIF:Line::", "")[:-1]
print(f"{line_id = }")
destinations[line_id] = destination_name[0].value
print(f"get_next_passages: {time() - begin_ts}")
return destinations

View File

@@ -0,0 +1,277 @@
from __future__ import annotations
from datetime import datetime
from enum import Enum, StrEnum
from typing import Any, Literal, Optional, NamedTuple
from msgspec import Struct
class Coordinate(NamedTuple):
lat: float
lon: float
class IdfmState(Enum):
unknown = "unknown"
false = "false"
partial = "partial"
true = "true"
class TrainStatus(Enum):
unknown = ""
arrived = "arrived"
onTime = "onTime"
delayed = "delayed"
noReport = "noReport"
early = "early"
cancelled = "cancelled"
undefined = "undefined"
class TransportMode(StrEnum):
bus = "bus"
tram = "tram"
metro = "metro"
rail = "rail"
funicular = "funicular"
class TransportSubMode(Enum):
unknown = "unknown"
localBus = "localBus"
regionalBus = "regionalBus"
highFrequencyBus = "highFrequencyBus"
expressBus = "expressBus"
nightBus = "nightBus"
demandAndResponseBus = "demandAndResponseBus"
airportLinkBus = "airportLinkBus"
regionalRail = "regionalRail"
railShuttle = "railShuttle"
suburbanRailway = "suburbanRailway"
local = "local"
class StopFields(Struct, kw_only=True):
arrgeopoint: Coordinate
arrtown: str
arrcreated: None | datetime = None
arryepsg2154: int
arrpostalregion: str
arrid: str
arrxepsg2154: int
arraccessibility: IdfmState
arrvisualsigns: IdfmState
arrtype: TransportMode
arrname: str
arrversion: str
arrchanged: datetime
arraudiblesignals: IdfmState
class Point(Struct):
coordinates: Coordinate
class Stop(Struct):
datasetid: str
recordid: str
fields: StopFields
record_timestamp: datetime
# geometry: Union[Point]
Stops = dict[str, Stop]
class StopAreaType(Enum):
metroStation = "metroStation"
onstreetBus = "onstreetBus"
onstreetTram = "onstreetTram"
railStation = "railStation"
class StopAreaFields(Struct, kw_only=True):
zdaname: str
zdcid: str
zdatown: str
zdaversion: str
zdaid: str
zdacreated: Optional[datetime] = None
zdatype: StopAreaType
zdayepsg2154: int
zdapostalregion: str
zdachanged: datetime
zdaxepsg2154: int
class StopArea(Struct):
datasetid: str
recordid: str
fields: StopAreaFields
record_timestamp: datetime
class StopAreaStopAssociationFields(Struct, kw_only=True):
arrid: str # TODO: use int ?
artid: Optional[str] = None
arrversion: str
zdcid: str
version: int
zdaid: str
zdaversion: str
artversion: Optional[str] = None
class StopAreaStopAssociation(Struct):
datasetid: str
recordid: str
fields: StopAreaStopAssociationFields
record_timestamp: datetime
class IdfmLineState(Enum):
active = "active"
class LinePicto(Struct, rename={"id_": "id"}):
id_: str
mimetype: str
height: int
width: int
filename: str
thumbnail: bool
format: str
# color_summary: list[str]
class LineFields(Struct, kw_only=True):
name_line: str
status: IdfmLineState
accessibility: IdfmState
shortname_groupoflines: Optional[str] = None
transportmode: TransportMode
colourweb_hexa: str
textcolourprint_hexa: str
transportsubmode: Optional[TransportSubMode] = TransportSubMode.unknown
operatorref: Optional[str] = None
visualsigns_available: IdfmState
networkname: Optional[str] = None
id_line: str
id_groupoflines: Optional[str] = None
operatorname: Optional[str] = None
audiblesigns_available: IdfmState
shortname_line: str
picto: Optional[LinePicto] = None
class Line(Struct):
datasetid: str
recordid: str
fields: LineFields
record_timestamp: datetime
@staticmethod
def optional_value(value: Any) -> Any:
if value:
return value.value if isinstance(value, Enum) else value
return "NULL"
Lines = dict[str, Line]
# TODO: Set structs frozen
class StopLineAssoFields(Struct):
pointgeo: Coordinate
stop_id: str
stop_name: str
operatorname: str
nom_commune: str
route_long_name: str
id: str
stop_lat: str
stop_lon: str
code_insee: str
class StopLineAsso(Struct):
datasetid: str
recordid: str
fields: StopLineAssoFields
# geometry: Union[Point]
class Value(Struct):
value: str
class FramedVehicleJourney(Struct):
DataFrameRef: Value
DatedVehicleJourneyRef: str
class TrainNumber(Struct):
TrainNumberRef: list[Value]
class MonitoredCall(Struct, kw_only=True):
Order: Optional[int] = None
StopPointName: list[Value]
VehicleAtStop: bool
DestinationDisplay: list[Value]
AimedArrivalTime: Optional[datetime] = None
ExpectedArrivalTime: Optional[datetime] = None
ArrivalPlatformName: Optional[Value] = None
AimedDepartureTime: Optional[datetime] = None
ExpectedDepartureTime: Optional[datetime] = None
ArrivalStatus: TrainStatus = None
DepartureStatus: TrainStatus = None
class MonitoredVehicleJourney(Struct, kw_only=True):
LineRef: Value
OperatorRef: Value
FramedVehicleJourneyRef: FramedVehicleJourney
DestinationRef: Value
DestinationName: list[Value] | None = None
JourneyNote: list[Value] | None = None
TrainNumbers: Optional[TrainNumber] = None
MonitoredCall: MonitoredCall
class StopDelivery(Struct):
RecordedAtTime: datetime
ItemIdentifier: str
MonitoringRef: Value
MonitoredVehicleJourney: MonitoredVehicleJourney
class StopMonitoringDelivery(Struct):
ResponseTimestamp: datetime
Version: str
Status: IdfmState
MonitoredStopVisit: list[StopDelivery]
class ServiceDelivery(Struct):
ResponseTimestamp: datetime
ProducerRef: str
ResponseMessageIdentifier: str
StopMonitoringDelivery: list[StopMonitoringDelivery]
class Siri(Struct):
ServiceDelivery: ServiceDelivery
class IdfmOperator(Enum):
SNCF = "SNCF"
class IdfmResponse(Struct):
Siri: Siri

View File

@@ -0,0 +1,25 @@
from datetime import datetime
from typing import Optional
from msgspec import Struct
class PictoFieldsFile(Struct, rename={"id_": "id"}):
id_: str
height: int
width: int
filename: str
thumbnail: bool
format: str
class PictoFields(Struct):
indices_commerciaux: str
noms_des_fichiers: Optional[PictoFieldsFile] = None
class Picto(Struct):
datasetid: str
recordid: str
fields: PictoFields
record_timestamp: datetime

View File

@@ -0,0 +1,3 @@
from .line import Line, LinePicto
from .stop import Stop, StopArea
from .user import UserLastStopSearchResults

View File

@@ -0,0 +1,176 @@
from asyncio import gather as asyncio_gather
from collections import defaultdict
from typing import Iterable, Self
from sqlalchemy import (
BigInteger,
Boolean,
Column,
Enum,
ForeignKey,
Integer,
select,
String,
Table,
)
from sqlalchemy.orm import Mapped, relationship, selectinload
from sqlalchemy.orm.attributes import set_committed_value
from sqlalchemy.sql.expression import tuple_
from ..db import Base, db
from ..idfm_interface.idfm_types import (
IdfmState,
IdfmLineState,
TransportMode,
TransportSubMode,
)
from .stop import _Stop
line_stop_association_table = Table(
"line_stop_association_table",
Base.metadata,
Column("line_id", ForeignKey("lines.id")),
Column("stop_id", ForeignKey("_stops.id")),
)
class LinePicto(Base):
db = db
id = Column(String, primary_key=True)
mime_type = Column(String, nullable=False)
height_px = Column(Integer, nullable=False)
width_px = Column(Integer, nullable=False)
filename = Column(String, nullable=False)
url = Column(String, nullable=False)
thumbnail = Column(Boolean, nullable=False)
format = Column(String, nullable=False)
__tablename__ = "line_pictos"
class Line(Base):
db = db
id = Column(String, primary_key=True)
short_name = Column(String)
name = Column(String, nullable=False)
status = Column(Enum(IdfmLineState), nullable=False)
transport_mode = Column(Enum(TransportMode), nullable=False)
transport_submode = Column(Enum(TransportSubMode), nullable=False)
network_name = Column(String)
group_of_lines_id = Column(String)
group_of_lines_shortname = Column(String)
colour_web_hexa = Column(String, nullable=False)
text_colour_hexa = Column(String, nullable=False)
operator_id = Column(String)
operator_name = Column(String)
accessibility = Column(Enum(IdfmState), nullable=False)
visual_signs_available = Column(Enum(IdfmState), nullable=False)
audible_signs_available = Column(Enum(IdfmState), nullable=False)
picto_id = Column(String, ForeignKey("line_pictos.id"))
picto: Mapped[LinePicto] = relationship(LinePicto, lazy="selectin")
record_id = Column(String, nullable=False)
record_ts = Column(BigInteger, nullable=False)
stops: Mapped[list["_Stop"]] = relationship(
"_Stop",
secondary=line_stop_association_table,
back_populates="lines",
lazy="selectin",
)
__tablename__ = "lines"
@classmethod
async def get_by_name(
cls, name: str, operator_name: None | str = None
) -> list[Self]:
filters = {"name": name}
if operator_name is not None:
filters["operator_name"] = operator_name
lines = None
stmt = (
select(Line)
.filter_by(**filters)
.options(selectinload(Line.stops), selectinload(Line.picto))
)
res = await cls.db.session.execute(stmt)
lines = res.scalars().all()
return lines
@classmethod
async def _add_picto_to_line(cls, line: str | Self, picto: LinePicto) -> None:
if isinstance(line, str):
if (lines := await cls.get_by_name(line)) is not None:
if len(lines) == 1:
line = lines[0]
else:
for candidate_line in lines:
if candidate_line.operator_name == "RATP":
line = candidate_line
break
if isinstance(line, Line) and line.picto is None:
line.picto = picto
line.picto_id = picto.id
@classmethod
async def add_pictos(cls, line_to_pictos: dict[str | Self, LinePicto]) -> None:
await asyncio_gather(
*[
cls._add_picto_to_line(line, picto)
for line, picto in line_to_pictos.items()
]
)
await cls.db.session.commit()
@classmethod
async def add_stops(cls, line_to_stop_ids: Iterable[tuple[str, str, str]]) -> int:
line_names_ops, stop_ids = set(), set()
for line_name, operator_name, stop_id in line_to_stop_ids:
line_names_ops.add((line_name, operator_name))
stop_ids.add(stop_id)
res = await cls.db.session.execute(
select(Line).where(
tuple_(Line.name, Line.operator_name).in_(line_names_ops)
)
)
lines = defaultdict(list)
for line in res.scalars():
lines[(line.name, line.operator_name)].append(line)
res = await cls.db.session.execute(select(_Stop).where(_Stop.id.in_(stop_ids)))
stops = {stop.id: stop for stop in res.scalars()}
found = 0
for line_name, operator_name, stop_id in line_to_stop_ids:
if (stop := stops.get(stop_id)) is not None:
if (stop_lines := lines.get((line_name, operator_name))) is not None:
if len(stop_lines) > 1:
print(stop_lines)
for stop_line in stop_lines:
stop_line.stops.append(stop)
found += 1
else:
print(f"No line found for {line_name}/{operator_name}")
else:
print(
f"No stop found for {stop_id} id (used by {line_name}/{operator_name})"
)
await cls.db.session.commit()
return found

View File

@@ -0,0 +1,144 @@
from typing import Iterable, Self
from sqlalchemy import (
BigInteger,
Column,
Enum,
Float,
ForeignKey,
select,
String,
Table,
)
from sqlalchemy.orm import Mapped, relationship, selectinload, with_polymorphic
from sqlalchemy.schema import Index
from ..db import Base, db
from ..idfm_interface.idfm_types import TransportMode, IdfmState, StopAreaType
stop_area_stop_association_table = Table(
"stop_area_stop_association_table",
Base.metadata,
Column("stop_id", ForeignKey("_stops.id")),
Column("stop_area_id", ForeignKey("stop_areas.id")),
)
class _Stop(Base):
db = db
id = Column(BigInteger, primary_key=True)
kind = Column(String)
name = Column(String, nullable=False, index=True)
town_name = Column(String, nullable=False)
postal_region = Column(String, nullable=False)
xepsg2154 = Column(BigInteger, nullable=False)
yepsg2154 = Column(BigInteger, nullable=False)
version = Column(String, nullable=False)
created_ts = Column(BigInteger)
changed_ts = Column(BigInteger, nullable=False)
lines: Mapped[list["Line"]] = relationship(
"Line",
secondary="line_stop_association_table",
back_populates="stops",
# lazy="joined",
lazy="selectin",
)
areas: Mapped[list["StopArea"]] = relationship(
"StopArea", secondary=stop_area_stop_association_table, back_populates="stops"
)
__tablename__ = "_stops"
__mapper_args__ = {"polymorphic_identity": "_stops", "polymorphic_on": kind}
__table_args__ = (
# To optimize the ilike requests
Index(
"name_idx_gin",
name,
postgresql_ops={"name": "gin_trgm_ops"},
postgresql_using="gin",
),
)
# TODO: Test https://www.cybertec-postgresql.com/en/postgresql-more-performance-for-like-and-ilike-statements/
# TODO: Should be able to remove with_polymorphic ?
@classmethod
async def get_by_name(cls, name: str) -> list[Self]:
stop_stop_area = with_polymorphic(_Stop, [Stop, StopArea])
stmt = (
select(stop_stop_area)
.where(stop_stop_area.name.ilike(f"%{name}%"))
.options(
selectinload(stop_stop_area.areas),
selectinload(stop_stop_area.lines),
)
)
res = await cls.db.session.execute(stmt)
return res.scalars()
class Stop(_Stop):
id = Column(BigInteger, ForeignKey("_stops.id"), primary_key=True)
latitude = Column(Float, nullable=False)
longitude = Column(Float, nullable=False)
transport_mode = Column(Enum(TransportMode), nullable=False)
accessibility = Column(Enum(IdfmState), nullable=False)
visual_signs_available = Column(Enum(IdfmState), nullable=False)
audible_signs_available = Column(Enum(IdfmState), nullable=False)
record_id = Column(String, nullable=False)
record_ts = Column(BigInteger, nullable=False)
__tablename__ = "stops"
__mapper_args__ = {"polymorphic_identity": "stops", "polymorphic_load": "inline"}
class StopArea(_Stop):
id = Column(BigInteger, ForeignKey("_stops.id"), primary_key=True)
type = Column(Enum(StopAreaType), nullable=False)
stops: Mapped[list[_Stop]] = relationship(
_Stop,
secondary=stop_area_stop_association_table,
back_populates="areas",
lazy="selectin",
# lazy="joined",
)
__tablename__ = "stop_areas"
__mapper_args__ = {"polymorphic_identity": "stop_areas", "polymorphic_load": "inline"}
@classmethod
async def add_stops(cls, stop_area_to_stop_ids: Iterable[tuple[str, str]]) -> int:
stop_area_ids, stop_ids = set(), set()
for stop_area_id, stop_id in stop_area_to_stop_ids:
stop_area_ids.add(stop_area_id)
stop_ids.add(stop_id)
res = await cls.db.session.execute(
select(StopArea)
.where(StopArea.id.in_(stop_area_ids))
.options(selectinload(StopArea.stops))
)
stop_areas = {stop_area.id: stop_area for stop_area in res.scalars()}
res = await cls.db.session.execute(select(_Stop).where(_Stop.id.in_(stop_ids)))
stops = {stop.id: stop for stop in res.scalars()}
found = 0
for stop_area_id, stop_id in stop_area_to_stop_ids:
if (stop_area := stop_areas.get(stop_area_id)) is not None:
if (stop := stops.get(stop_id)) is not None:
stop_area.stops.append(stop)
found += 1
else:
print(f"No stop found for {stop_id} id")
else:
print(f"No stop area found for {stop_area_id}")
await cls.db.session.commit()
return found

View File

@@ -0,0 +1,25 @@
from sqlalchemy import Column, ForeignKey, String, Table
from sqlalchemy.orm import Mapped, relationship
from ..db import Base, db
from .stop import _Stop
user_last_stop_search_stops_associations_table = Table(
"user_last_stop_search_stops_associations_table",
Base.metadata,
Column("user_mxid", ForeignKey("user_last_stop_search_results.user_mxid")),
Column("stop_id", ForeignKey("_stops.id")),
)
class UserLastStopSearchResults(Base):
db = db
__tablename__ = "user_last_stop_search_results"
user_mxid = Column(String, primary_key=True)
request_content = Column(String, nullable=False)
stops: Mapped[list[_Stop]] = relationship(
_Stop, secondary=user_last_stop_search_stops_associations_table
)

View File

@@ -0,0 +1,3 @@
from .line import Line, TransportMode
from .next_passage import NextPassage, NextPassages
from .stop import Stop, StopArea

View File

@@ -0,0 +1,60 @@
from enum import StrEnum
from typing import Self
from pydantic import BaseModel
from ..idfm_interface import (
IdfmLineState,
IdfmState,
TransportMode as IdfmTransportMode,
TransportSubMode as IdfmTransportSubMode,
)
class TransportMode(StrEnum):
"""Computed transport mode from
idfm_interface.TransportMode and idfm_interface.TransportSubMode.
"""
bus = "bus"
tram = "tram"
metro = "metro"
funicular = "funicular"
# idfm_types.TransportMode.rail + idfm_types.TransportSubMode.regionalRail
rail_ter = "ter"
# idfm_types.TransportMode.rail + idfm_types.TransportSubMode.local
rail_rer = "rer"
# idfm_types.TransportMode.rail + idfm_types.TransportSubMode.suburbanRailway
rail_transilien = "transilien"
# idfm_types.TransportMode.rail + idfm_types.TransportSubMode.railShuttle
val = "val"
@classmethod
def from_idfm_transport_mode(
cls, mode: IdfmTransportMode, sub_mode: IdfmTransportSubMode
) -> Self:
if mode == IdfmTransportMode.rail:
if sub_mode == IdfmTransportSubMode.regionalRail:
return cls.rail_ter
if sub_mode == IdfmTransportSubMode.local:
return cls.rail_rer
if sub_mode == IdfmTransportSubMode.suburbanRailway:
return cls.rail_transilien
if sub_mode == IdfmTransportSubMode.railShuttle:
return cls.val
return TransportMode(mode)
class Line(BaseModel):
id: str
shortName: str
name: str
status: IdfmLineState
transportMode: TransportMode
backColorHexa: str
foreColorHexa: str
operatorId: str
accessibility: IdfmState
visualSignsAvailable: IdfmState
audibleSignsAvailable: IdfmState
stopIds: list[str]

View File

@@ -0,0 +1,22 @@
from pydantic import BaseModel
from ..idfm_interface.idfm_types import TrainStatus
class NextPassage(BaseModel):
line: str
operator: str
destinations: list[str]
atStop: bool
aimedArrivalTs: None | int
expectedArrivalTs: None | int
arrivalPlatformName: None | str
aimedDepartTs: None | int
expectedDepartTs: None | int
arrivalStatus: TrainStatus
departStatus: TrainStatus
class NextPassages(BaseModel):
ts: int
passages: dict[str, dict[str, list[NextPassage]]]

View File

@@ -0,0 +1,25 @@
from pydantic import BaseModel
from ..idfm_interface import IdfmLineState, IdfmState, StopAreaType, TransportMode
class Stop(BaseModel):
id: int
name: str
town: str
lat: float
lon: float
# xepsg2154: int
# yepsg2154: int
lines: list[str]
class StopArea(BaseModel):
id: int
name: str
town: str
# xepsg2154: int
# yepsg2154: int
type: StopAreaType
lines: list[str] # SNCF lines are linked to stop areas and not stops.
stops: list[Stop]

208
backend/main.py Normal file
View File

@@ -0,0 +1,208 @@
from collections import defaultdict
from datetime import datetime
from os import environ
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from fastapi.staticfiles import StaticFiles
from rich import print
from idfm_matrix_backend.db import db
from idfm_matrix_backend.idfm_interface import IdfmInterface
from idfm_matrix_backend.models import Line, Stop, StopArea
from idfm_matrix_backend.schemas import (
Line as LineSchema,
TransportMode,
NextPassage as NextPassageSchema,
NextPassages as NextPassagesSchema,
Stop as StopSchema,
StopArea as StopAreaSchema,
)
API_KEY = environ.get("API_KEY")
# TODO: Add error message if no key is given.
# TODO: Remove postgresql+asyncpg from environ variable
DB_PATH = "postgresql+asyncpg://postgres:postgres@127.0.0.1:5438/idfm"
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=[
"https://localhost:4443",
"https://localhost:3000",
],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
idfm_interface = IdfmInterface(API_KEY, db)
@app.on_event("startup")
async def startup():
# await db.connect(DB_PATH, clear_static_data=True)
# await idfm_interface.startup()
await db.connect(DB_PATH, clear_static_data=False)
print("Connected")
@app.on_event("shutdown")
async def shutdown():
await db.disconnect()
# /addwidget https://localhost:4443/static/#?widgetId=$matrix_widget_id&userId=$matrix_user_id
# /addwidget https://localhost:3000/widget?widgetId=$matrix_widget_id&userId=$matrix_user_id
STATIC_ROOT = "../frontend/"
app.mount("/widget", StaticFiles(directory=STATIC_ROOT, html=True), name="widget")
def optional_datetime_to_ts(dt: datetime) -> int | None:
return dt.timestamp() if dt else None
@app.get("/line/{line_id}", response_model=LineSchema)
async def get_line(line_id: str) -> JSONResponse:
line: Line | None = await Line.get_by_id(line_id)
if line is None:
raise HTTPException(status_code=404, detail=f'Line "{line_id}" not found')
return LineSchema(
id=line.id,
shortName=line.short_name,
name=line.name,
status=line.status,
transportMode=TransportMode.from_idfm_transport_mode(
line.transport_mode, line.transport_submode
),
backColorHexa=line.colour_web_hexa,
foreColorHexa=line.text_colour_hexa,
operatorId=line.operator_id,
accessibility=line.accessibility,
visualSignsAvailable=line.visual_signs_available,
audibleSignsAvailable=line.audible_signs_available,
stopIds=[stop.id for stop in line.stops],
)
def _format_stop(stop: Stop) -> StopSchema:
print(stop.__dict__)
return StopSchema(
id=stop.id,
name=stop.name,
town=stop.town_name,
# xepsg2154=stop.xepsg2154,
# yepsg2154=stop.yepsg2154,
lat=stop.latitude,
lon=stop.longitude,
lines=[line.id for line in stop.lines],
)
# châtelet
@app.get("/stop/")
async def get_stop(
name: str = "", limit: int = 10
) -> list[StopAreaSchema | StopSchema]:
# TODO: Add limit support
formatted = []
matching_stops = await Stop.get_by_name(name)
# print(matching_stops, flush=True)
stop_areas: dict[int, StopArea] = {}
stops: dict[int, Stop] = {}
for stop in matching_stops:
# print(f"{stop.__dict__ = }", flush=True)
dst = stop_areas if isinstance(stop, StopArea) else stops
dst[stop.id] = stop
for stop_area in stop_areas.values():
formatted_stops = []
for stop in stop_area.stops:
formatted_stops.append(_format_stop(stop))
try:
del stops[stop.id]
except KeyError as err:
print(err)
formatted.append(
StopAreaSchema(
id=stop_area.id,
name=stop_area.name,
town=stop_area.town_name,
# xepsg2154=stop_area.xepsg2154,
# yepsg2154=stop_area.yepsg2154,
type=stop_area.type,
lines=[line.id for line in stop_area.lines],
stops=formatted_stops,
)
)
# print(f"{stops = }", flush=True)
formatted.extend(_format_stop(stop) for stop in stops.values())
return formatted
# TODO: Cache response for 30 secs ?
@app.get("/stop/nextPassages/{stop_id}")
async def get_next_passages(stop_id: str) -> JSONResponse:
res = await idfm_interface.get_next_passages(stop_id)
# print(res)
service_delivery = res.Siri.ServiceDelivery
stop_monitoring_deliveries = service_delivery.StopMonitoringDelivery
by_line_by_dst_passages = defaultdict(lambda: defaultdict(list))
for delivery in stop_monitoring_deliveries:
for stop_visit in delivery.MonitoredStopVisit:
journey = stop_visit.MonitoredVehicleJourney
# re.match will return None if the given journey.LineRef.value is not valid.
try:
line_id = IdfmInterface.LINE_RE.match(journey.LineRef.value).group(1)
except AttributeError as exc:
raise HTTPException(
status_code=404, detail=f'Line "{journey.LineRef.value}" not found'
) from exc
call = journey.MonitoredCall
dst_names = call.DestinationDisplay
dsts = [dst.value for dst in dst_names] if dst_names else []
print(f"{call.ArrivalPlatformName = }")
next_passage = NextPassageSchema(
line=line_id,
operator=journey.OperatorRef.value,
destinations=dsts,
atStop=call.VehicleAtStop,
aimedArrivalTs=optional_datetime_to_ts(call.AimedArrivalTime),
expectedArrivalTs=optional_datetime_to_ts(call.ExpectedArrivalTime),
arrivalPlatformName=call.ArrivalPlatformName.value if call.ArrivalPlatformName else None,
aimedDepartTs=optional_datetime_to_ts(call.AimedDepartureTime),
expectedDepartTs=optional_datetime_to_ts(call.ExpectedDepartureTime),
arrivalStatus=call.ArrivalStatus.value,
departStatus=call.DepartureStatus.value,
)
by_line_passages = by_line_by_dst_passages[line_id]
# TODO: by_line_passages[dst].extend(dsts) instead ?
for dst in dsts:
by_line_passages[dst].append(next_passage)
return NextPassagesSchema(
ts=service_delivery.ResponseTimestamp.timestamp(),
passages=by_line_by_dst_passages,
)

58
backend/pyproject.toml Normal file
View File

@@ -0,0 +1,58 @@
[tool.poetry]
name = "idfm-matrix-widget"
version = "0.1.0"
description = ""
authors = ["Adrien SUEUR <me@adrien.run>"]
readme = "README.md"
packages = [{include = "idfm_matrix_backend"}]
[tool.poetry.dependencies]
python = "^3.11"
aiohttp = "^3.8.3"
rich = "^12.6.0"
aiofiles = "^22.1.0"
sqlalchemy = {extras = ["asyncio"], version = "^1.4.46"}
fastapi = "^0.88.0"
uvicorn = "^0.20.0"
asyncpg = "^0.27.0"
msgspec = "^0.12.0"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
[tool.poetry.dev-dependencies]
mypy = "^0.971"
pylsp-mypy = "^0.6.2"
autopep8 = "^1.7.0"
mccabe = "^0.7.0"
pycodestyle = "^2.9.1"
pydocstyle = "^6.1.1"
pyflakes = "^2.5.0"
pylint = "^2.14.5"
rope = "^1.3.0"
python-lsp-server = {extras = ["yapf"], version = "^1.5.0"}
python-lsp-black = "^1.2.1"
black = "^22.10.0"
whatthepatch = "^1.0.2"
[tool.poetry.group.dev.dependencies]
types-aiofiles = "^22.1.0.2"
sqlalchemy-stubs = "^0.4"
wrapt = "^1.14.1"
pydocstyle = "^6.2.2"
pylint = "^2.15.9"
dill = "^0.3.6"
[tool.pylsp-mypy]
enabled = true
[mypy]
plugins = "sqlmypy"
[pycodestyle]
max_line_length = 100
[pylint]
max-line-length = 100