26 Commits

Author SHA1 Message Date
5da918c04b 👽️ Take the last IDFM format into account 2023-06-11 22:41:44 +02:00
2eaf0f4ed5 Use of db merge when adds fails due to single key violations 2023-06-11 22:28:15 +02:00
c42b687870 🐛 Fix IdfmInterface circular import issue 2023-06-11 22:24:09 +02:00
d8adb4f52d ♻️ Remove code in charge or db filling from IdfmInterface 2023-06-11 22:22:05 +02:00
5e7f440b54 ♻️ Add the db_updater package 2023-06-11 22:18:47 +02:00
824536ddbe 💥 Rename API_KEY to IDFM_API_KEY 2023-05-28 12:45:03 +02:00
7fbdd0606c ️ Reduce the size of the backend docker image 2023-05-28 12:40:10 +02:00
581f6b7b8f 🐛 Add workaround for fastapi-cache issue #144 2023-05-28 10:45:14 +02:00
404b228cbf 🔥 Remove env variables from backend dockerfile 2023-05-26 23:55:58 +02:00
e2ff90cd5f ️ Use Redis to cache REST responses 2023-05-26 18:10:47 +02:00
cd700ebd42 🐛 The backend shall serve requests once the database reachable 2023-05-26 18:09:24 +02:00
c44a52b7ae ♻️ Add backend and frontend to docker-compose 2023-05-26 18:01:04 +02:00
b3b36bc3de Replace rich with icecream for temporary tracing 2023-05-11 21:44:58 +02:00
5e0d7b174c 🏷️ Fix some type issues (mypy) 2023-05-11 21:40:38 +02:00
b437bbbf70 🎨 Split main into several APIRouters 2023-05-11 21:17:02 +02:00
85fdb28cc6 🐛 Set default value to Settings.clear_static_data 2023-05-11 20:31:24 +02:00
b894d68a7a ♻️ Use of pydantic to manage config+env variables
FastAPI release has been updated allowing to use lifespan parameter
to prepare/shutdown sub components.
2023-05-10 22:30:30 +02:00
ef26509b87 🐛 Fix invalid line id returned by /stop/{stop_id}/nextPassages endpoint 2023-05-09 23:25:30 +02:00
de82eb6c55 🔊 Reformat error logs generated by frontend BusinessData class 2023-05-08 17:25:14 +02:00
0ba4c1e6fa ️ Stop.postal_region and Line.id/operator_id can be integer values 2023-05-08 17:15:21 +02:00
0f1c16ab53 💄 Force App to use 100% of the shortest side 2023-05-08 16:17:56 +02:00
5692bc96d5 🐛 Check backend return codes before handling data 2023-05-08 15:05:43 +02:00
d15fee75ca ♻️ Fix /stop/ endpoints inconsistency 2023-05-08 13:44:44 +02:00
93047c8706 ♻️ Use declarative table configuration to define backend db tables 2023-05-08 13:17:09 +02:00
c84b78d3e2 🚑️ /stop responses didn't return StopArea.stops fields 2023-05-07 12:36:38 +02:00
c6e3881966 Add sqlalchemy-utils types dependency 2023-05-07 12:20:28 +02:00
29 changed files with 1360 additions and 867 deletions

12
backend/.dockerignore Normal file
View File

@@ -0,0 +1,12 @@
.dir-locals.el
.dockerignore
.gitignore
**/.mypy_cache
**/.ruff_cache
.venv
**/__pycache__
config
docker
poetry.lock
tests
Dockerfile

View File

@@ -0,0 +1,38 @@
FROM python:3.11-slim as builder
RUN pip install poetry
ENV POETRY_NO_INTERACTION=1 \
POETRY_VIRTUALENVS_IN_PROJECT=1 \
POETRY_VIRTUALENVS_CREATE=1 \
POETRY_CACHE_DIR=/tmp/poetry_cache
WORKDIR /app
COPY ./pyproject.toml /app
RUN poetry install --only=main --no-root && \
rm -rf ${POETRY_CACHE_DIR}
FROM python:3.11-slim as runtime
WORKDIR /app
RUN apt update && \
apt install -y --no-install-recommends libpq5 && \
apt clean && \
rm -rf /var/lib/apt/lists/*
env VIRTUAL_ENV=/app/.venv \
PATH="/app/.venv/bin:$PATH"
COPY --from=builder ${VIRTUAL_ENV} ${VIRTUAL_ENV}
COPY backend /app/backend
COPY dependencies.py /app
COPY config.sample.yaml /app
COPY routers/ /app/routers
COPY main.py /app
CMD ["python", "./main.py"]

View File

@@ -0,0 +1,43 @@
FROM python:3.11-slim as builder
RUN apt update && \
apt install -y --no-install-recommends proj-bin && \
apt clean && \
rm -rf /var/lib/apt/lists/*
RUN pip install poetry
ENV POETRY_NO_INTERACTION=1 \
POETRY_VIRTUALENVS_IN_PROJECT=1 \
POETRY_VIRTUALENVS_CREATE=1 \
POETRY_CACHE_DIR=/tmp/poetry_cache
WORKDIR /app
COPY ./pyproject.toml /app
RUN poetry install --only=db_updater --no-root && \
rm -rf ${POETRY_CACHE_DIR}
FROM python:3.11-slim as runtime
WORKDIR /app
RUN apt update && \
apt install -y --no-install-recommends libpq5 && \
apt clean && \
rm -rf /var/lib/apt/lists/*
env VIRTUAL_ENV=/app/.venv \
PATH="/app/.venv/bin:$PATH"
COPY --from=builder ${VIRTUAL_ENV} ${VIRTUAL_ENV}
COPY backend /app/backend
COPY dependencies.py /app
COPY config.sample.yaml /app
COPY config.local.yaml /app
COPY db_updater /app/db_updater
CMD ["python", "-m", "db_updater.fill_db"]

View File

@@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from logging import getLogger from logging import getLogger
from typing import Iterable, Self, TYPE_CHECKING from typing import Self, Sequence, TYPE_CHECKING
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
@@ -17,22 +17,35 @@ class Base(DeclarativeBase):
db: Database | None = None db: Database | None = None
@classmethod @classmethod
async def add(cls, objs: Self | Iterable[Self]) -> bool: async def add(cls, objs: Sequence[Self]) -> bool:
if cls.db is not None and (session := await cls.db.get_session()) is not None: if cls.db is not None and (session := await cls.db.get_session()) is not None:
async with session.begin():
try: try:
if isinstance(objs, Iterable): async with session.begin():
session.add_all(objs) session.add_all(objs)
else:
session.add(objs)
except (AttributeError, IntegrityError) as err: except IntegrityError as err:
logger.warning(err)
return await cls.merge(objs)
except AttributeError as err:
logger.error(err) logger.error(err)
return False return False
return True return True
@classmethod
async def merge(cls, objs: Sequence[Self]) -> bool:
if cls.db is not None and (session := await cls.db.get_session()) is not None:
async with session.begin():
for obj in objs:
await session.merge(obj)
return True
return False
@classmethod @classmethod
async def get_by_id(cls, id_: int | str) -> Self | None: async def get_by_id(cls, id_: int | str) -> Self | None:
if cls.db is not None and (session := await cls.db.get_session()) is not None: if cls.db is not None and (session := await cls.db.get_session()) is not None:

View File

@@ -1,10 +1,11 @@
from asyncio import sleep
from logging import getLogger from logging import getLogger
from typing import Annotated, AsyncIterator from typing import Annotated, AsyncIterator
from fastapi import Depends from fastapi import Depends
from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor
from sqlalchemy import text from sqlalchemy import text
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import OperationalError, SQLAlchemyError
from sqlalchemy.ext.asyncio import ( from sqlalchemy.ext.asyncio import (
async_sessionmaker, async_sessionmaker,
AsyncEngine, AsyncEngine,
@@ -13,7 +14,7 @@ from sqlalchemy.ext.asyncio import (
) )
from .base_class import Base from .base_class import Base
from ..settings import DatabaseSettings
logger = getLogger(__name__) logger = getLogger(__name__)
@@ -33,9 +34,17 @@ class Database:
return None return None
# TODO: Preserve UserLastStopSearchResults table from drop. # TODO: Preserve UserLastStopSearchResults table from drop.
async def connect(self, db_path: str, clear_static_data: bool = False) -> bool: async def connect(
self, settings: DatabaseSettings, clear_static_data: bool = False
) -> bool:
password = settings.password
path = (
f"{settings.driver}://{settings.user}:"
f"{password.get_secret_value() if password is not None else ''}"
f"@{settings.host}:{settings.port}/{settings.name}"
)
self._async_engine = create_async_engine( self._async_engine = create_async_engine(
db_path, pool_pre_ping=True, pool_size=10, max_overflow=20 path, pool_pre_ping=True, pool_size=10, max_overflow=20
) )
if self._async_engine is not None: if self._async_engine is not None:
@@ -48,11 +57,20 @@ class Database:
class_=AsyncSession, class_=AsyncSession,
) )
ret = False
while not ret:
try:
async with self._async_engine.begin() as session: async with self._async_engine.begin() as session:
await session.execute(text("CREATE EXTENSION IF NOT EXISTS pg_trgm;")) await session.execute(
text("CREATE EXTENSION IF NOT EXISTS pg_trgm;")
)
if clear_static_data: if clear_static_data:
await session.run_sync(Base.metadata.drop_all) await session.run_sync(Base.metadata.drop_all)
await session.run_sync(Base.metadata.create_all) await session.run_sync(Base.metadata.create_all)
ret = True
except OperationalError as err:
logger.error(err)
await sleep(1)
return True return True

View File

@@ -1,5 +1,3 @@
from .idfm_interface import IdfmInterface
from .idfm_types import ( from .idfm_types import (
Coordinate, Coordinate,
Destinations, Destinations,
@@ -38,7 +36,6 @@ __all__ = [
"Coordinate", "Coordinate",
"Destinations", "Destinations",
"FramedVehicleJourney", "FramedVehicleJourney",
"IdfmInterface",
"IdfmLineState", "IdfmLineState",
"IdfmOperator", "IdfmOperator",
"IdfmResponse", "IdfmResponse",

View File

@@ -1,40 +1,15 @@
from collections import defaultdict from collections import defaultdict
from re import compile as re_compile from re import compile as re_compile
from time import time from typing import ByteString
from typing import (
AsyncIterator,
ByteString,
Callable,
Iterable,
List,
Type,
)
from aiofiles import open as async_open from aiofiles import open as async_open
from aiohttp import ClientSession from aiohttp import ClientSession
from msgspec import ValidationError from msgspec import ValidationError
from msgspec.json import Decoder from msgspec.json import Decoder
from pyproj import Transformer
from shapefile import Reader as ShapeFileReader, ShapeRecord # type: ignore
from ..db import Database from ..db import Database
from ..models import ConnectionArea, Line, LinePicto, Stop, StopArea, StopShape from ..models import Line, Stop, StopArea
from .idfm_types import ( from .idfm_types import Destinations as IdfmDestinations, IdfmResponse, IdfmState
ConnectionArea as IdfmConnectionArea,
Destinations as IdfmDestinations,
IdfmLineState,
IdfmResponse,
Line as IdfmLine,
LinePicto as IdfmPicto,
IdfmState,
Stop as IdfmStop,
StopArea as IdfmStopArea,
StopAreaStopAssociation,
StopAreaType,
StopLineAsso as IdfmStopLineAsso,
TransportMode,
)
from .ratp_types import Picto as RatpPicto
class IdfmInterface: class IdfmInterface:
@@ -42,20 +17,8 @@ class IdfmInterface:
IDFM_ROOT_URL = "https://prim.iledefrance-mobilites.fr/marketplace" IDFM_ROOT_URL = "https://prim.iledefrance-mobilites.fr/marketplace"
IDFM_STOP_MON_URL = f"{IDFM_ROOT_URL}/stop-monitoring" IDFM_STOP_MON_URL = f"{IDFM_ROOT_URL}/stop-monitoring"
IDFM_ROOT_URL = "https://data.iledefrance-mobilites.fr/explore/dataset"
IDFM_STOPS_URL = (
f"{IDFM_ROOT_URL}/arrets/download/?format=json&timezone=Europe/Berlin"
)
IDFM_PICTO_URL = f"{IDFM_ROOT_URL}/referentiel-des-lignes/files"
RATP_ROOT_URL = "https://data.ratp.fr/explore/dataset"
RATP_PICTO_URL = (
f"{RATP_ROOT_URL}"
"/pictogrammes-des-lignes-de-metro-rer-tramway-bus-et-noctilien/files"
)
OPERATOR_RE = re_compile(r"[^:]+:Operator::([^:]+):") OPERATOR_RE = re_compile(r"[^:]+:Operator::([^:]+):")
LINE_RE = re_compile(r"[^:]+:Line::([^:]+):") LINE_RE = re_compile(r"[^:]+:Line::C([^:]+):")
def __init__(self, api_key: str, database: Database) -> None: def __init__(self, api_key: str, database: Database) -> None:
self._api_key = api_key self._api_key = api_key
@@ -63,438 +26,12 @@ class IdfmInterface:
self._http_headers = {"Accept": "application/json", "apikey": self._api_key} self._http_headers = {"Accept": "application/json", "apikey": self._api_key}
self._epsg2154_epsg3857_transformer = Transformer.from_crs(2154, 3857)
self._json_stops_decoder = Decoder(type=List[IdfmStop])
self._json_stop_areas_decoder = Decoder(type=List[IdfmStopArea])
self._json_connection_areas_decoder = Decoder(type=List[IdfmConnectionArea])
self._json_lines_decoder = Decoder(type=List[IdfmLine])
self._json_stops_lines_assos_decoder = Decoder(type=List[IdfmStopLineAsso])
self._json_ratp_pictos_decoder = Decoder(type=List[RatpPicto])
self._json_stop_area_stop_asso_decoder = Decoder(
type=List[StopAreaStopAssociation]
)
self._response_json_decoder = Decoder(type=IdfmResponse) self._response_json_decoder = Decoder(type=IdfmResponse)
async def startup(self) -> None: async def startup(self) -> None:
BATCH_SIZE = 10000 ...
STEPS: tuple[
tuple[
Type[ConnectionArea] | Type[Stop] | Type[StopArea] | Type[StopShape],
Callable,
Callable,
],
...,
] = (
(
StopShape,
self._request_stop_shapes,
self._format_idfm_stop_shapes,
),
(
ConnectionArea,
self._request_idfm_connection_areas,
self._format_idfm_connection_areas,
),
(
StopArea,
self._request_idfm_stop_areas,
self._format_idfm_stop_areas,
),
(Stop, self._request_idfm_stops, self._format_idfm_stops),
)
for model, get_method, format_method in STEPS:
step_begin_ts = time()
elements = []
async for element in get_method():
elements.append(element)
if len(elements) == BATCH_SIZE:
await model.add(format_method(*elements))
elements.clear()
if elements:
await model.add(format_method(*elements))
print(f"Add {model.__name__}s: {time() - step_begin_ts}s")
begin_ts = time()
await self._load_lines()
print(f"Add Lines and IDFM LinePictos: {time() - begin_ts}s")
begin_ts = time()
await self._load_ratp_pictos(30)
print(f"Add RATP LinePictos: {time() - begin_ts}s")
begin_ts = time()
await self._load_lines_stops_assos()
print(f"Link Stops to Lines: {time() - begin_ts}s")
begin_ts = time()
await self._load_stop_assos()
print(f"Link Stops to StopAreas: {time() - begin_ts}s")
async def _load_lines(self, batch_size: int = 5000) -> None:
lines, pictos = [], []
picto_ids = set()
async for line in self._request_idfm_lines():
if (picto := line.fields.picto) is not None and picto.id_ not in picto_ids:
picto_ids.add(picto.id_)
pictos.append(picto)
lines.append(line)
if len(lines) == batch_size:
await LinePicto.add(IdfmInterface._format_idfm_pictos(*pictos))
await Line.add(await self._format_idfm_lines(*lines))
lines.clear()
pictos.clear()
if pictos:
await LinePicto.add(IdfmInterface._format_idfm_pictos(*pictos))
if lines:
await Line.add(await self._format_idfm_lines(*lines))
async def _load_ratp_pictos(self, batch_size: int = 5) -> None:
pictos = []
async for picto in self._request_ratp_pictos():
pictos.append(picto)
if len(pictos) == batch_size:
formatted_pictos = IdfmInterface._format_ratp_pictos(*pictos)
await LinePicto.add(map(lambda picto: picto[1], formatted_pictos))
await Line.add_pictos(formatted_pictos)
pictos.clear()
if pictos:
formatted_pictos = IdfmInterface._format_ratp_pictos(*pictos)
await LinePicto.add(map(lambda picto: picto[1], formatted_pictos))
await Line.add_pictos(formatted_pictos)
async def _load_lines_stops_assos(self, batch_size: int = 5000) -> None:
total_assos_nb = total_found_nb = 0
assos = []
async for asso in self._request_idfm_stops_lines_associations():
fields = asso.fields
try:
stop_id = int(fields.stop_id.rsplit(":", 1)[-1])
except ValueError as err:
print(err)
print(f"{fields.stop_id = }")
continue
assos.append((fields.route_long_name, fields.operatorname, stop_id))
if len(assos) == batch_size:
total_assos_nb += batch_size
total_found_nb += await Line.add_stops(assos)
assos.clear()
if assos:
total_assos_nb += len(assos)
total_found_nb += await Line.add_stops(assos)
print(f"{total_found_nb} line <-> stop ({total_assos_nb = } found)")
async def _load_stop_assos(self, batch_size: int = 5000) -> None:
total_assos_nb = area_stop_assos_nb = conn_stop_assos_nb = 0
area_stop_assos = []
connection_stop_assos = []
async for asso in self._request_idfm_stop_area_stop_associations():
fields = asso.fields
stop_id = int(fields.arrid)
area_stop_assos.append((int(fields.zdaid), stop_id))
connection_stop_assos.append((int(fields.zdcid), stop_id))
if len(area_stop_assos) == batch_size:
total_assos_nb += batch_size
if (found_nb := await StopArea.add_stops(area_stop_assos)) is not None:
area_stop_assos_nb += found_nb
area_stop_assos.clear()
if (
found_nb := await ConnectionArea.add_stops(connection_stop_assos)
) is not None:
conn_stop_assos_nb += found_nb
connection_stop_assos.clear()
if area_stop_assos:
total_assos_nb += len(area_stop_assos)
if (found_nb := await StopArea.add_stops(area_stop_assos)) is not None:
area_stop_assos_nb += found_nb
if (
found_nb := await ConnectionArea.add_stops(connection_stop_assos)
) is not None:
conn_stop_assos_nb += found_nb
print(f"{area_stop_assos_nb} stop area <-> stop ({total_assos_nb = } found)")
print(f"{conn_stop_assos_nb} stop area <-> stop ({total_assos_nb = } found)")
# TODO: This method is synchronous due to the shapefile library.
# It's not a blocking issue but it could be nice to find an alternative.
async def _request_stop_shapes(self) -> AsyncIterator[ShapeRecord]:
# TODO: Use HTTP
with ShapeFileReader("./tests/datasets/REF_LDA.zip") as reader:
for record in reader.shapeRecords():
yield record
async def _request_idfm_stops(self) -> AsyncIterator[IdfmStop]:
# headers = {"Accept": "application/json", "apikey": self._api_key}
# async with ClientSession(headers=headers) as session:
# async with session.get(self.STOPS_URL) as response:
# # print("Status:", response.status)
# if response.status == 200:
# for point in self._json_stops_decoder.decode(await response.read()):
# yield point
# TODO: Use HTTP
async with async_open("./tests/datasets/stops_dataset.json", "rb") as raw:
for element in self._json_stops_decoder.decode(await raw.read()):
yield element
async def _request_idfm_stop_areas(self) -> AsyncIterator[IdfmStopArea]:
# TODO: Use HTTP
async with async_open("./tests/datasets/zones-d-arrets.json", "rb") as raw:
for element in self._json_stop_areas_decoder.decode(await raw.read()):
yield element
async def _request_idfm_connection_areas(self) -> AsyncIterator[IdfmConnectionArea]:
async with async_open(
"./tests/datasets/zones-de-correspondance.json", "rb"
) as raw:
for element in self._json_connection_areas_decoder.decode(await raw.read()):
yield element
async def _request_idfm_lines(self) -> AsyncIterator[IdfmLine]:
# TODO: Use HTTP
async with async_open("./tests/datasets/lines_dataset.json", "rb") as raw:
for element in self._json_lines_decoder.decode(await raw.read()):
yield element
async def _request_idfm_stops_lines_associations(
self,
) -> AsyncIterator[IdfmStopLineAsso]:
# TODO: Use HTTP
async with async_open("./tests/datasets/arrets-lignes.json", "rb") as raw:
for element in self._json_stops_lines_assos_decoder.decode(
await raw.read()
):
yield element
async def _request_idfm_stop_area_stop_associations(
self,
) -> AsyncIterator[StopAreaStopAssociation]:
# TODO: Use HTTP
async with async_open("./tests/datasets/relations.json", "rb") as raw:
for element in self._json_stop_area_stop_asso_decoder.decode(
await raw.read()
):
yield element
async def _request_ratp_pictos(self) -> AsyncIterator[RatpPicto]:
# TODO: Use HTTP
async with async_open(
"./tests/datasets/pictogrammes-des-lignes-de-metro-rer-tramway-bus-et-noctilien.json",
"rb",
) as fd:
for element in self._json_ratp_pictos_decoder.decode(await fd.read()):
yield element
@classmethod
def _format_idfm_pictos(cls, *pictos: IdfmPicto) -> Iterable[LinePicto]:
ret = []
for picto in pictos:
ret.append(
LinePicto(
id=picto.id_,
mime_type=picto.mimetype,
height_px=picto.height,
width_px=picto.width,
filename=picto.filename,
url=f"{cls.IDFM_PICTO_URL}/{picto.id_}/download",
thumbnail=picto.thumbnail,
format=picto.format,
)
)
return ret
@classmethod
def _format_ratp_pictos(cls, *pictos: RatpPicto) -> Iterable[tuple[str, LinePicto]]:
ret = []
for picto in pictos:
if (fields := picto.fields.noms_des_fichiers) is not None:
ret.append(
(
picto.fields.indices_commerciaux,
LinePicto(
id=fields.id_,
mime_type=f"image/{fields.format.lower()}",
height_px=fields.height,
width_px=fields.width,
filename=fields.filename,
url=f"{cls.RATP_PICTO_URL}/{fields.id_}/download",
thumbnail=fields.thumbnail,
format=fields.format,
),
)
)
return ret
async def _format_idfm_lines(self, *lines: IdfmLine) -> Iterable[Line]:
ret = []
optional_value = IdfmLine.optional_value
for line in lines:
fields = line.fields
picto_id = fields.picto.id_ if fields.picto is not None else None
ret.append(
Line(
id=fields.id_line,
short_name=fields.shortname_line,
name=fields.name_line,
status=IdfmLineState(fields.status.value),
transport_mode=TransportMode(fields.transportmode.value),
transport_submode=optional_value(fields.transportsubmode),
network_name=optional_value(fields.networkname),
group_of_lines_id=optional_value(fields.id_groupoflines),
group_of_lines_shortname=optional_value(
fields.shortname_groupoflines
),
colour_web_hexa=fields.colourweb_hexa,
text_colour_hexa=fields.textcolourprint_hexa,
operator_id=optional_value(fields.operatorref),
operator_name=optional_value(fields.operatorname),
accessibility=IdfmState(fields.accessibility.value),
visual_signs_available=IdfmState(
fields.visualsigns_available.value
),
audible_signs_available=IdfmState(
fields.audiblesigns_available.value
),
picto_id=fields.picto.id_ if fields.picto is not None else None,
record_id=line.recordid,
record_ts=int(line.record_timestamp.timestamp()),
)
)
return ret
def _format_idfm_stops(self, *stops: IdfmStop) -> Iterable[Stop]:
for stop in stops:
fields = stop.fields
try:
created_ts = int(fields.arrcreated.timestamp()) # type: ignore
except AttributeError:
created_ts = None
epsg3857_point = self._epsg2154_epsg3857_transformer.transform(
fields.arrxepsg2154, fields.arryepsg2154
)
yield Stop(
id=int(fields.arrid),
name=fields.arrname,
epsg3857_x=epsg3857_point[0],
epsg3857_y=epsg3857_point[1],
town_name=fields.arrtown,
postal_region=fields.arrpostalregion,
transport_mode=TransportMode(fields.arrtype.value),
version=fields.arrversion,
created_ts=created_ts,
changed_ts=int(fields.arrchanged.timestamp()),
accessibility=IdfmState(fields.arraccessibility.value),
visual_signs_available=IdfmState(fields.arrvisualsigns.value),
audible_signs_available=IdfmState(fields.arraudiblesignals.value),
record_id=stop.recordid,
record_ts=int(stop.record_timestamp.timestamp()),
)
def _format_idfm_stop_areas(self, *stop_areas: IdfmStopArea) -> Iterable[StopArea]:
for stop_area in stop_areas:
fields = stop_area.fields
try:
created_ts = int(fields.zdacreated.timestamp()) # type: ignore
except AttributeError:
created_ts = None
epsg3857_point = self._epsg2154_epsg3857_transformer.transform(
fields.zdaxepsg2154, fields.zdayepsg2154
)
yield StopArea(
id=int(fields.zdaid),
name=fields.zdaname,
town_name=fields.zdatown,
postal_region=fields.zdapostalregion,
epsg3857_x=epsg3857_point[0],
epsg3857_y=epsg3857_point[1],
type=StopAreaType(fields.zdatype.value),
version=fields.zdaversion,
created_ts=created_ts,
changed_ts=int(fields.zdachanged.timestamp()),
)
def _format_idfm_connection_areas(
self,
*connection_areas: IdfmConnectionArea,
) -> Iterable[ConnectionArea]:
for connection_area in connection_areas:
epsg3857_point = self._epsg2154_epsg3857_transformer.transform(
connection_area.zdcxepsg2154, connection_area.zdcyepsg2154
)
yield ConnectionArea(
id=int(connection_area.zdcid),
name=connection_area.zdcname,
town_name=connection_area.zdctown,
postal_region=connection_area.zdcpostalregion,
epsg3857_x=epsg3857_point[0],
epsg3857_y=epsg3857_point[1],
transport_mode=StopAreaType(connection_area.zdctype.value),
version=connection_area.zdcversion,
created_ts=int(connection_area.zdccreated.timestamp()),
changed_ts=int(connection_area.zdcchanged.timestamp()),
)
def _format_idfm_stop_shapes(
self, *shape_records: ShapeRecord
) -> Iterable[StopShape]:
for shape_record in shape_records:
epsg3857_points = [
self._epsg2154_epsg3857_transformer.transform(*point)
for point in shape_record.shape.points
]
bbox_it = iter(shape_record.shape.bbox)
epsg3857_bbox = [
self._epsg2154_epsg3857_transformer.transform(*point)
for point in zip(bbox_it, bbox_it)
]
yield StopShape(
id=shape_record.record[1],
type=shape_record.shape.shapeType,
epsg3857_bbox=epsg3857_bbox,
epsg3857_points=epsg3857_points,
)
async def render_line_picto(self, line: Line) -> tuple[None | str, None | str]: async def render_line_picto(self, line: Line) -> tuple[None | str, None | str]:
begin_ts = time()
line_picto_path = line_picto_format = None line_picto_path = line_picto_format = None
target = f"/tmp/{line.id}_repr" target = f"/tmp/{line.id}_repr"
@@ -506,12 +43,9 @@ class IdfmInterface:
line_picto_path = target line_picto_path = target
line_picto_format = picto.mime_type line_picto_format = picto.mime_type
print(f"render_line_picto: {time() - begin_ts}")
return (line_picto_path, line_picto_format) return (line_picto_path, line_picto_format)
async def _get_line_picto(self, line: Line) -> ByteString | None: async def _get_line_picto(self, line: Line) -> ByteString | None:
print("---------------------------------------------------------------------")
begin_ts = time()
data = None data = None
picto = line.picto picto = line.picto
@@ -519,25 +53,20 @@ class IdfmInterface:
headers = ( headers = (
self._http_headers if picto.url.startswith(self.IDFM_ROOT_URL) else None self._http_headers if picto.url.startswith(self.IDFM_ROOT_URL) else None
) )
session_begin_ts = time()
async with ClientSession(headers=headers) as session:
session_creation_ts = time()
print(f"Session creation {session_creation_ts - session_begin_ts}")
async with session.get(picto.url) as response:
get_end_ts = time()
print(f"GET {get_end_ts - session_creation_ts}")
data = await response.read()
print(f"read {time() - get_end_ts}")
print(f"render_line_picto: {time() - begin_ts}") async with ClientSession(headers=headers) as session:
print("---------------------------------------------------------------------") async with session.get(picto.url) as response:
data = await response.read()
return data return data
async def get_next_passages(self, stop_point_id: int) -> IdfmResponse | None: async def get_next_passages(self, stop_point_id: int) -> IdfmResponse | None:
ret = None ret = None
params = {"MonitoringRef": f"STIF:StopPoint:Q:{stop_point_id}:"} params = {"MonitoringRef": f"STIF:StopPoint:Q:{stop_point_id}:"}
async with ClientSession(headers=self._http_headers) as session: async with ClientSession(headers=self._http_headers) as session:
async with session.get(self.IDFM_STOP_MON_URL, params=params) as response: async with session.get(self.IDFM_STOP_MON_URL, params=params) as response:
if response.status == 200: if response.status == 200:
data = await response.read() data = await response.read()
try: try:
@@ -548,8 +77,6 @@ class IdfmInterface:
return ret return ret
async def get_destinations(self, stop_id: int) -> IdfmDestinations | None: async def get_destinations(self, stop_id: int) -> IdfmDestinations | None:
begin_ts = time()
destinations: IdfmDestinations = defaultdict(set) destinations: IdfmDestinations = defaultdict(set)
if (stop := await Stop.get_by_id(stop_id)) is not None: if (stop := await Stop.get_by_id(stop_id)) is not None:
@@ -557,7 +84,6 @@ class IdfmInterface:
elif (stop_area := await StopArea.get_by_id(stop_id)) is not None: elif (stop_area := await StopArea.get_by_id(stop_id)) is not None:
expected_stop_ids = {stop.id for stop in stop_area.stops} expected_stop_ids = {stop.id for stop in stop_area.stops}
else: else:
return None return None
@@ -568,6 +94,7 @@ class IdfmInterface:
for stop_visit in delivery.MonitoredStopVisit: for stop_visit in delivery.MonitoredStopVisit:
monitoring_ref = stop_visit.MonitoringRef.value monitoring_ref = stop_visit.MonitoringRef.value
try: try:
monitored_stop_id = int(monitoring_ref.split(":")[-2]) monitored_stop_id = int(monitoring_ref.split(":")[-2])
except (IndexError, ValueError): except (IndexError, ValueError):
@@ -578,9 +105,7 @@ class IdfmInterface:
if ( if (
dst_names := journey.DestinationName dst_names := journey.DestinationName
) and monitored_stop_id in expected_stop_ids: ) and monitored_stop_id in expected_stop_ids:
line_id = journey.LineRef.value.split(":")[-2] line_id = journey.LineRef.value.split(":")[-2]
destinations[line_id].add(dst_names[0].value) destinations[line_id].add(dst_names[0].value)
print(f"get_next_passages: {time() - begin_ts}")
return destinations return destinations

View File

@@ -116,19 +116,26 @@ class StopArea(Struct):
record_timestamp: datetime record_timestamp: datetime
class ConnectionArea(Struct): class ConnectionAreaFields(Struct, kw_only=True):
zdcid: str zdcid: str
zdcversion: str zdcversion: str
zdccreated: datetime zdccreated: datetime
zdcchanged: datetime zdcchanged: datetime
zdcname: str zdcname: str
zdcxepsg2154: int zdcxepsg2154: int | None = None
zdcyepsg2154: int zdcyepsg2154: int | None = None
zdctown: str zdctown: str
zdcpostalregion: str zdcpostalregion: str
zdctype: StopAreaType zdctype: StopAreaType
class ConnectionArea(Struct):
datasetid: str
recordid: str
fields: ConnectionAreaFields
record_timestamp: datetime
class StopAreaStopAssociationFields(Struct, kw_only=True): class StopAreaStopAssociationFields(Struct, kw_only=True):
arrid: str # TODO: use int ? arrid: str # TODO: use int ?
artid: str | None = None artid: str | None = None
@@ -149,6 +156,7 @@ class StopAreaStopAssociation(Struct):
class IdfmLineState(Enum): class IdfmLineState(Enum):
active = "active" active = "active"
available_soon = "prochainement active"
class LinePicto(Struct, rename={"id_": "id"}): class LinePicto(Struct, rename={"id_": "id"}):

View File

@@ -1,6 +1,3 @@
from datetime import datetime
from typing import Optional
from msgspec import Struct from msgspec import Struct
@@ -13,13 +10,6 @@ class PictoFieldsFile(Struct, rename={"id_": "id"}):
format: str format: str
class PictoFields(Struct):
indices_commerciaux: str
noms_des_fichiers: Optional[PictoFieldsFile] = None
class Picto(Struct): class Picto(Struct):
datasetid: str indices_commerciaux: str
recordid: str noms_des_fichiers: PictoFieldsFile | None = None
fields: PictoFields
record_timestamp: datetime

View File

@@ -5,13 +5,11 @@ from typing import Iterable, Self, Sequence
from sqlalchemy import ( from sqlalchemy import (
BigInteger, BigInteger,
Boolean, Boolean,
Column,
Enum, Enum,
ForeignKey, ForeignKey,
Integer, Integer,
select, select,
String, String,
Table,
) )
from sqlalchemy.orm import Mapped, mapped_column, relationship, selectinload from sqlalchemy.orm import Mapped, mapped_column, relationship, selectinload
from sqlalchemy.sql.expression import tuple_ from sqlalchemy.sql.expression import tuple_
@@ -25,12 +23,14 @@ from ..idfm_interface.idfm_types import (
) )
from .stop import _Stop from .stop import _Stop
line_stop_association_table = Table(
"line_stop_association_table", class LineStopAssociations(Base):
Base.metadata,
Column("line_id", ForeignKey("lines.id")), id = mapped_column(BigInteger, primary_key=True)
Column("stop_id", ForeignKey("_stops.id")), line_id = mapped_column(BigInteger, ForeignKey("lines.id"))
) stop_id = mapped_column(BigInteger, ForeignKey("_stops.id"))
__tablename__ = "line_stop_associations"
class LinePicto(Base): class LinePicto(Base):
@@ -53,7 +53,7 @@ class Line(Base):
db = db db = db
id = mapped_column(String, primary_key=True) id = mapped_column(BigInteger, primary_key=True)
short_name = mapped_column(String) short_name = mapped_column(String)
name = mapped_column(String, nullable=False) name = mapped_column(String, nullable=False)
@@ -68,7 +68,7 @@ class Line(Base):
colour_web_hexa = mapped_column(String, nullable=False) colour_web_hexa = mapped_column(String, nullable=False)
text_colour_hexa = mapped_column(String, nullable=False) text_colour_hexa = mapped_column(String, nullable=False)
operator_id = mapped_column(String) operator_id = mapped_column(Integer)
operator_name = mapped_column(String) operator_name = mapped_column(String)
accessibility = mapped_column(Enum(IdfmState), nullable=False) accessibility = mapped_column(Enum(IdfmState), nullable=False)
@@ -83,7 +83,7 @@ class Line(Base):
stops: Mapped[list[_Stop]] = relationship( stops: Mapped[list[_Stop]] = relationship(
"_Stop", "_Stop",
secondary=line_stop_association_table, secondary="line_stop_associations",
back_populates="lines", back_populates="lines",
lazy="selectin", lazy="selectin",
) )

View File

@@ -1,11 +1,10 @@
from __future__ import annotations from __future__ import annotations
from logging import getLogger from logging import getLogger
from typing import Annotated, Iterable, Sequence, TYPE_CHECKING from typing import Iterable, Sequence, TYPE_CHECKING
from sqlalchemy import ( from sqlalchemy import (
BigInteger, BigInteger,
Column,
Computed, Computed,
desc, desc,
Enum, Enum,
@@ -16,13 +15,13 @@ from sqlalchemy import (
JSON, JSON,
select, select,
String, String,
Table,
) )
from sqlalchemy.orm import ( from sqlalchemy.orm import (
mapped_column, mapped_column,
Mapped, Mapped,
relationship, relationship,
selectinload, selectinload,
with_polymorphic,
) )
from sqlalchemy.schema import Index from sqlalchemy.schema import Index
from sqlalchemy_utils.types.ts_vector import TSVectorType from sqlalchemy_utils.types.ts_vector import TSVectorType
@@ -36,12 +35,13 @@ if TYPE_CHECKING:
logger = getLogger(__name__) logger = getLogger(__name__)
stop_area_stop_association_table = Table( class StopAreaStopAssociations(Base):
"stop_area_stop_association_table",
Base.metadata, id = mapped_column(BigInteger, primary_key=True)
Column("stop_id", ForeignKey("_stops.id")), stop_id = mapped_column(BigInteger, ForeignKey("_stops.id"))
Column("stop_area_id", ForeignKey("stop_areas.id")), stop_area_id = mapped_column(BigInteger, ForeignKey("stop_areas.id"))
)
__tablename__ = "stop_area_stop_associations"
class _Stop(Base): class _Stop(Base):
@@ -53,7 +53,7 @@ class _Stop(Base):
name = mapped_column(String, nullable=False, index=True) name = mapped_column(String, nullable=False, index=True)
town_name = mapped_column(String, nullable=False) town_name = mapped_column(String, nullable=False)
postal_region = mapped_column(String, nullable=False) postal_region = mapped_column(Integer, nullable=False)
epsg3857_x = mapped_column(Float, nullable=False) epsg3857_x = mapped_column(Float, nullable=False)
epsg3857_y = mapped_column(Float, nullable=False) epsg3857_y = mapped_column(Float, nullable=False)
@@ -63,12 +63,14 @@ class _Stop(Base):
lines: Mapped[list[Line]] = relationship( lines: Mapped[list[Line]] = relationship(
"Line", "Line",
secondary="line_stop_association_table", secondary="line_stop_associations",
back_populates="stops", back_populates="stops",
lazy="selectin", lazy="selectin",
) )
areas: Mapped[list["StopArea"]] = relationship( areas: Mapped[list["StopArea"]] = relationship(
"StopArea", secondary=stop_area_stop_association_table, back_populates="stops" "StopArea",
secondary="stop_area_stop_associations",
back_populates="stops",
) )
connection_area_id: Mapped[int] = mapped_column( connection_area_id: Mapped[int] = mapped_column(
ForeignKey("connection_areas.id"), nullable=True ForeignKey("connection_areas.id"), nullable=True
@@ -98,11 +100,17 @@ class _Stop(Base):
if (session := await cls.db.get_session()) is not None: if (session := await cls.db.get_session()) is not None:
async with session.begin(): async with session.begin():
match_stmt = cls.names_tsv.match(name, postgresql_regconfig="french") descendants = with_polymorphic(_Stop, "*")
ranking_stmt = func.ts_rank_cd(
cls.names_tsv, func.plainto_tsquery("french", name) match_stmt = descendants.names_tsv.match(
name, postgresql_regconfig="french"
)
ranking_stmt = func.ts_rank_cd(
descendants.names_tsv, func.plainto_tsquery("french", name)
)
stmt = (
select(descendants).filter(match_stmt).order_by(desc(ranking_stmt))
) )
stmt = select(cls).filter(match_stmt).order_by(desc(ranking_stmt))
res = await session.execute(stmt) res = await session.execute(stmt)
stops = res.scalars().all() stops = res.scalars().all()
@@ -136,7 +144,7 @@ class StopArea(_Stop):
stops: Mapped[list["Stop"]] = relationship( stops: Mapped[list["Stop"]] = relationship(
"Stop", "Stop",
secondary=stop_area_stop_association_table, secondary="stop_area_stop_associations",
back_populates="areas", back_populates="areas",
lazy="selectin", lazy="selectin",
) )

View File

@@ -1,25 +1,27 @@
from sqlalchemy import Column, ForeignKey, String, Table from sqlalchemy import BigInteger, ForeignKey, String
from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.orm import Mapped, mapped_column, relationship
from ..db import Base, db from ..db import Base, db
from .stop import _Stop from .stop import _Stop
user_last_stop_search_stops_associations_table = Table(
"user_last_stop_search_stops_associations_table", class UserLastStopSearchStopAssociations(Base):
Base.metadata, id = mapped_column(BigInteger, primary_key=True)
Column("user_mxid", ForeignKey("user_last_stop_search_results.user_mxid")), user_mxid = mapped_column(
Column("stop_id", ForeignKey("_stops.id")), String, ForeignKey("user_last_stop_search_results.user_mxid")
) )
stop_id = mapped_column(BigInteger, ForeignKey("_stops.id"))
__tablename__ = "user_last_stop_search_stop_associations"
class UserLastStopSearchResults(Base): class UserLastStopSearchResults(Base):
db = db db = db
__tablename__ = "user_last_stop_search_results"
user_mxid = mapped_column(String, primary_key=True) user_mxid = mapped_column(String, primary_key=True)
request_content = mapped_column(String, nullable=False) request_content = mapped_column(String, nullable=False)
stops: Mapped[_Stop] = relationship( stops: Mapped[_Stop] = relationship(
_Stop, secondary=user_last_stop_search_stops_associations_table _Stop, secondary="user_last_stop_search_stop_associations"
) )
__tablename__ = "user_last_stop_search_results"

View File

@@ -46,7 +46,7 @@ class TransportMode(StrEnum):
class Line(BaseModel): class Line(BaseModel):
id: str id: int
shortName: str shortName: str
name: str name: str
status: IdfmLineState status: IdfmLineState

View File

@@ -4,7 +4,7 @@ from ..idfm_interface.idfm_types import TrainStatus
class NextPassage(BaseModel): class NextPassage(BaseModel):
line: str line: int
operator: str operator: str
destinations: list[str] destinations: list[str]
atStop: bool atStop: bool
@@ -19,4 +19,4 @@ class NextPassage(BaseModel):
class NextPassages(BaseModel): class NextPassages(BaseModel):
ts: int ts: int
passages: dict[str, dict[str, list[NextPassage]]] passages: dict[int, dict[str, list[NextPassage]]]

View File

@@ -0,0 +1,59 @@
from typing import Any
from pydantic import BaseModel, BaseSettings, Field, root_validator, SecretStr
class HttpSettings(BaseModel):
host: str = "127.0.0.1"
port: int = 8080
cert: str | None = None
def check_user_password(cls, values: dict[str, Any]) -> dict[str, Any]:
user = values.get("user")
password = values.get("password")
if user is not None and password is None:
raise ValueError("user is set, password shall be set too.")
if password is not None and user is None:
raise ValueError("password is set, user shall be set too.")
return values
class DatabaseSettings(BaseModel):
name: str = "carrramba-encore-rate"
host: str = "127.0.0.1"
port: int = 5432
driver: str = "postgresql+psycopg"
user: str | None = None
password: SecretStr | None = None
_user_password_validation = root_validator(allow_reuse=True)(check_user_password)
class CacheSettings(BaseModel):
enable: bool = False
host: str = "127.0.0.1"
port: int = 6379
user: str | None = None
password: SecretStr | None = None
_user_password_validation = root_validator(allow_reuse=True)(check_user_password)
class TracingSettings(BaseModel):
enable: bool = False
class Settings(BaseSettings):
app_name: str
idfm_api_key: SecretStr = Field(..., env="IDFM_API_KEY")
clear_static_data: bool = Field(False, env="CLEAR_STATIC_DATA")
http: HttpSettings = HttpSettings()
db: DatabaseSettings = DatabaseSettings()
cache: CacheSettings = CacheSettings()
tracing: TracingSettings = TracingSettings()

21
backend/config.local.yaml Normal file
View File

@@ -0,0 +1,21 @@
app_name: carrramba-encore-rate
clear_static_data: false
http:
host: 0.0.0.0
port: 8080
cert: ./config/cert.pem
db:
name: carrramba-encore-rate
host: 127.0.0.1
port: 5432
driver: postgresql+psycopg
user: cer
password: cer_password
cache:
enable: true
tracing:
enable: false

View File

@@ -0,0 +1,23 @@
app_name: carrramba-encore-rate
clear_static_data: false
http:
host: 0.0.0.0
port: 8080
# cert: ./config/cert.pem
db:
name: carrramba-encore-rate
host: postgres
port: 5432
driver: postgresql+psycopg
user: cer
password: cer_password
cache:
enable: true
host: redis
# TODO: Add user credentials
tracing:
enable: false

575
backend/db_updater/fill_db.py Executable file
View File

@@ -0,0 +1,575 @@
#!/usr/bin/env python3
from asyncio import run, gather
from logging import getLogger, INFO, Handler as LoggingHandler, NOTSET
from itertools import islice
from time import time
from os import environ
from typing import Callable, Iterable, List, Type
from aiofiles.tempfile import NamedTemporaryFile
from aiohttp import ClientSession
from msgspec import ValidationError
from msgspec.json import Decoder
from pyproj import Transformer
from shapefile import Reader as ShapeFileReader, ShapeRecord # type: ignore
from tqdm import tqdm
from yaml import safe_load
from backend.db import Base, db, Database
from backend.models import ConnectionArea, Line, LinePicto, Stop, StopArea, StopShape
from backend.idfm_interface.idfm_types import (
ConnectionArea as IdfmConnectionArea,
IdfmLineState,
Line as IdfmLine,
LinePicto as IdfmPicto,
IdfmState,
Stop as IdfmStop,
StopArea as IdfmStopArea,
StopAreaStopAssociation,
StopAreaType,
StopLineAsso as IdfmStopLineAsso,
TransportMode,
)
from backend.idfm_interface.ratp_types import Picto as RatpPicto
from backend.settings import Settings
CONFIG_PATH = environ.get("CONFIG_PATH", "./config.sample.yaml")
BATCH_SIZE = 1000
IDFM_ROOT_URL = "https://data.iledefrance-mobilites.fr/explore/dataset"
IDFM_CONNECTION_AREAS_URL = (
f"{IDFM_ROOT_URL}/zones-de-correspondance/download/?format=json"
)
IDFM_LINES_URL = f"{IDFM_ROOT_URL}/referentiel-des-lignes/download/?format=json"
IDFM_PICTO_URL = f"{IDFM_ROOT_URL}/referentiel-des-lignes/files"
IDFM_STOP_AREAS_URL = f"{IDFM_ROOT_URL}/zones-d-arrets/download/?format=json"
IDFM_STOP_SHAPES_URL = "https://eu.ftp.opendatasoft.com/stif/Reflex/REF_ArR.zip"
IDFM_STOP_AREA_SHAPES_URL = "https://eu.ftp.opendatasoft.com/stif/Reflex/REF_ZdA.zip"
IDFM_STOP_STOP_AREAS_ASSOS_URL = f"{IDFM_ROOT_URL}/relations/download/?format=json"
IDFM_STOPS_LINES_ASSOS_URL = f"{IDFM_ROOT_URL}/arrets-lignes/download/?format=json"
IDFM_STOPS_URL = f"{IDFM_ROOT_URL}/arrets/download/?format=json"
RATP_ROOT_URL = "https://data.ratp.fr/api/explore/v2.1/catalog/datasets"
RATP_PICTOS_URL = (
f"{RATP_ROOT_URL}"
"/pictogrammes-des-lignes-de-metro-rer-tramway-bus-et-noctilien/exports/json?lang=fr"
)
# From https://stackoverflow.com/a/38739634
class TqdmLoggingHandler(LoggingHandler):
def __init__(self, level=NOTSET):
super().__init__(level)
def emit(self, record):
try:
msg = self.format(record)
tqdm.write(msg)
self.flush()
except Exception:
self.handleError(record)
logger = getLogger(__name__)
logger.setLevel(INFO)
logger.addHandler(TqdmLoggingHandler())
epsg2154_epsg3857_transformer = Transformer.from_crs(2154, 3857)
json_stops_decoder = Decoder(type=List[IdfmStop])
json_stop_areas_decoder = Decoder(type=List[IdfmStopArea])
json_connection_areas_decoder = Decoder(type=List[IdfmConnectionArea])
json_lines_decoder = Decoder(type=List[IdfmLine])
json_stops_lines_assos_decoder = Decoder(type=List[IdfmStopLineAsso])
json_ratp_pictos_decoder = Decoder(type=List[RatpPicto])
json_stop_area_stop_asso_decoder = Decoder(type=List[StopAreaStopAssociation])
def format_idfm_pictos(*pictos: IdfmPicto) -> Iterable[LinePicto]:
ret = []
for picto in pictos:
ret.append(
LinePicto(
id=picto.id_,
mime_type=picto.mimetype,
height_px=picto.height,
width_px=picto.width,
filename=picto.filename,
url=f"{IDFM_PICTO_URL}/{picto.id_}/download",
thumbnail=picto.thumbnail,
format=picto.format,
)
)
return ret
def format_ratp_pictos(*pictos: RatpPicto) -> Iterable[tuple[str, LinePicto]]:
ret = []
for picto in pictos:
if (fields := picto.noms_des_fichiers) is not None:
ret.append(
(
picto.indices_commerciaux,
LinePicto(
id=fields.id_,
mime_type=f"image/{fields.format.lower()}",
height_px=fields.height,
width_px=fields.width,
filename=fields.filename,
url=f"{RATP_PICTOS_URL}/{fields.id_}/download",
thumbnail=fields.thumbnail,
format=fields.format,
),
)
)
return ret
def format_idfm_lines(*lines: IdfmLine) -> Iterable[Line]:
ret = []
optional_value = IdfmLine.optional_value
for line in lines:
fields = line.fields
line_id = fields.id_line
try:
formatted_line_id = int(line_id[1:] if line_id[0] == "C" else line_id)
except ValueError:
logger.warning("Unable to format %s line id.", line_id)
continue
try:
operator_id = int(fields.operatorref) # type: ignore
except (ValueError, TypeError):
logger.warning("Unable to format %s operator id.", fields.operatorref)
operator_id = 0
ret.append(
Line(
id=formatted_line_id,
short_name=fields.shortname_line,
name=fields.name_line,
status=IdfmLineState(fields.status.value),
transport_mode=TransportMode(fields.transportmode.value),
transport_submode=optional_value(fields.transportsubmode),
network_name=optional_value(fields.networkname),
group_of_lines_id=optional_value(fields.id_groupoflines),
group_of_lines_shortname=optional_value(fields.shortname_groupoflines),
colour_web_hexa=fields.colourweb_hexa,
text_colour_hexa=fields.textcolourprint_hexa,
operator_id=operator_id,
operator_name=optional_value(fields.operatorname),
accessibility=IdfmState(fields.accessibility.value),
visual_signs_available=IdfmState(fields.visualsigns_available.value),
audible_signs_available=IdfmState(fields.audiblesigns_available.value),
picto_id=fields.picto.id_ if fields.picto is not None else None,
record_id=line.recordid,
record_ts=int(line.record_timestamp.timestamp()),
)
)
return ret
def format_idfm_stops(*stops: IdfmStop) -> Iterable[Stop]:
for stop in stops:
fields = stop.fields
try:
created_ts = int(fields.arrcreated.timestamp()) # type: ignore
except AttributeError:
created_ts = None
epsg3857_point = epsg2154_epsg3857_transformer.transform(
fields.arrxepsg2154, fields.arryepsg2154
)
try:
postal_region = int(fields.arrpostalregion)
except ValueError:
logger.warning("Unable to format %s postal region.", fields.arrpostalregion)
continue
yield Stop(
id=int(fields.arrid),
name=fields.arrname,
epsg3857_x=epsg3857_point[0],
epsg3857_y=epsg3857_point[1],
town_name=fields.arrtown,
postal_region=postal_region,
transport_mode=TransportMode(fields.arrtype.value),
version=fields.arrversion,
created_ts=created_ts,
changed_ts=int(fields.arrchanged.timestamp()),
accessibility=IdfmState(fields.arraccessibility.value),
visual_signs_available=IdfmState(fields.arrvisualsigns.value),
audible_signs_available=IdfmState(fields.arraudiblesignals.value),
record_id=stop.recordid,
record_ts=int(stop.record_timestamp.timestamp()),
)
def format_idfm_stop_areas(*stop_areas: IdfmStopArea) -> Iterable[StopArea]:
for stop_area in stop_areas:
fields = stop_area.fields
try:
created_ts = int(fields.zdacreated.timestamp()) # type: ignore
except AttributeError:
created_ts = None
epsg3857_point = epsg2154_epsg3857_transformer.transform(
fields.zdaxepsg2154, fields.zdayepsg2154
)
yield StopArea(
id=int(fields.zdaid),
name=fields.zdaname,
town_name=fields.zdatown,
postal_region=fields.zdapostalregion,
epsg3857_x=epsg3857_point[0],
epsg3857_y=epsg3857_point[1],
type=StopAreaType(fields.zdatype.value),
version=fields.zdaversion,
created_ts=created_ts,
changed_ts=int(fields.zdachanged.timestamp()),
)
def format_idfm_connection_areas(
*connection_areas: IdfmConnectionArea,
) -> Iterable[ConnectionArea]:
for connection_area in connection_areas:
fields = connection_area.fields
epsg3857_point = epsg2154_epsg3857_transformer.transform(
fields.zdcxepsg2154, fields.zdcyepsg2154
)
yield ConnectionArea(
id=int(fields.zdcid),
name=fields.zdcname,
town_name=fields.zdctown,
postal_region=fields.zdcpostalregion,
epsg3857_x=epsg3857_point[0],
epsg3857_y=epsg3857_point[1],
transport_mode=StopAreaType(fields.zdctype.value),
version=fields.zdcversion,
created_ts=int(fields.zdccreated.timestamp()),
changed_ts=int(fields.zdcchanged.timestamp()),
)
def format_idfm_stop_shapes(*shape_records: ShapeRecord) -> Iterable[StopShape]:
for shape_record in shape_records:
epsg3857_points = [
epsg2154_epsg3857_transformer.transform(*point)
for point in shape_record.shape.points
]
try:
bbox_it = iter(shape_record.shape.bbox)
epsg3857_bbox = [
epsg2154_epsg3857_transformer.transform(*point)
for point in zip(bbox_it, bbox_it)
]
except AttributeError:
# Handle stop shapes for which no bbox is provided
epsg3857_bbox = []
yield StopShape(
id=shape_record.record[1],
type=shape_record.shape.shapeType,
epsg3857_bbox=epsg3857_bbox,
epsg3857_points=epsg3857_points,
)
async def http_get(url: str) -> str | None:
chunks = []
headers = {"Accept": "application/json"}
async with ClientSession(headers=headers) as session:
async with session.get(url) as response:
size = int(response.headers.get("content-length", 0)) or None
progress_bar = tqdm(desc=f"Downloading {url}", total=size)
if response.status == 200:
async for chunk in response.content.iter_chunked(1024 * 1024):
chunks.append(chunk.decode())
progress_bar.update(len(chunk))
else:
return None
return "".join(chunks)
async def http_request(
url: str, decode: Callable, format_method: Callable, model: Type[Base]
) -> bool:
elements = []
data = await http_get(url)
if data is None:
return False
try:
for element in decode(data):
elements.append(element)
if len(elements) == BATCH_SIZE:
await model.add(format_method(*elements))
elements.clear()
if elements:
await model.add(format_method(*elements))
except ValidationError as err:
logger.warning(err)
return False
return True
async def load_idfm_stops() -> bool:
return await http_request(
IDFM_STOPS_URL, json_stops_decoder.decode, format_idfm_stops, Stop
)
async def load_idfm_stop_areas() -> bool:
return await http_request(
IDFM_STOP_AREAS_URL,
json_stop_areas_decoder.decode,
format_idfm_stop_areas,
StopArea,
)
async def load_idfm_connection_areas() -> bool:
return await http_request(
IDFM_CONNECTION_AREAS_URL,
json_connection_areas_decoder.decode,
format_idfm_connection_areas,
ConnectionArea,
)
async def load_idfm_stop_shapes(url: str) -> None:
async with ClientSession(headers={"Accept": "application/zip"}) as session:
async with session.get(url) as response:
size = int(response.headers.get("content-length", 0)) or None
dl_progress_bar = tqdm(desc=f"Downloading {url}", total=size)
if response.status == 200:
async with NamedTemporaryFile(suffix=".zip") as tmp_file:
async for chunk in response.content.iter_chunked(1024 * 1024):
await tmp_file.write(chunk)
dl_progress_bar.update(len(chunk))
with ShapeFileReader(tmp_file.name) as reader:
step_begin_ts = time()
shapes = reader.shapeRecords()
shapes_len = len(shapes)
db_progress_bar = tqdm(
desc=f"Filling db with {shapes_len} StopShapes",
total=shapes_len,
)
begin, end, finished = 0, BATCH_SIZE, False
while not finished:
elements = islice(shapes, begin, end)
formatteds = list(format_idfm_stop_shapes(*elements))
await StopShape.add(formatteds)
begin = end
end = begin + BATCH_SIZE
finished = begin > len(shapes)
db_progress_bar.update(BATCH_SIZE)
logger.info(
f"Add {StopShape.__name__}s: {time() - step_begin_ts}s"
)
async def load_idfm_lines() -> None:
data = await http_get(IDFM_LINES_URL)
if data is None:
return None
lines, pictos = [], []
picto_ids = set()
for line in json_lines_decoder.decode(data):
if (picto := line.fields.picto) is not None and picto.id_ not in picto_ids:
picto_ids.add(picto.id_)
pictos.append(picto)
lines.append(line)
if len(lines) == BATCH_SIZE:
await LinePicto.add(list(format_idfm_pictos(*pictos)))
await Line.add(list(format_idfm_lines(*lines)))
lines.clear()
pictos.clear()
if pictos:
await LinePicto.add(list(format_idfm_pictos(*pictos)))
if lines:
await Line.add(list(format_idfm_lines(*lines)))
async def load_ratp_pictos(batch_size: int = 5) -> None:
data = await http_get(RATP_PICTOS_URL)
if data is None:
return None
pictos = []
for picto in json_ratp_pictos_decoder.decode(data):
pictos.append(picto)
if len(pictos) == batch_size:
formatteds = format_ratp_pictos(*pictos)
await LinePicto.add([picto[1] for picto in formatteds])
await Line.add_pictos(formatteds)
pictos.clear()
if pictos:
formatteds = format_ratp_pictos(*pictos)
await LinePicto.add([picto[1] for picto in formatteds])
await Line.add_pictos(formatteds)
async def load_lines_stops_assos(batch_size: int = 5000) -> None:
data = await http_get(IDFM_STOPS_LINES_ASSOS_URL)
if data is None:
return None
total_assos_nb = total_found_nb = 0
assos = []
for asso in json_stops_lines_assos_decoder.decode(data):
fields = asso.fields
try:
stop_id = int(fields.stop_id.rsplit(":", 1)[-1])
except ValueError as err:
logger.error(err)
logger.error(f"{fields.stop_id = }")
continue
assos.append((fields.route_long_name, fields.operatorname, stop_id))
if len(assos) == batch_size:
total_assos_nb += batch_size
total_found_nb += await Line.add_stops(assos)
assos.clear()
if assos:
total_assos_nb += len(assos)
total_found_nb += await Line.add_stops(assos)
logger.info(f"{total_found_nb} line <-> stop ({total_assos_nb = } found)")
async def load_stop_assos(batch_size: int = 5000) -> None:
data = await http_get(IDFM_STOP_STOP_AREAS_ASSOS_URL)
if data is None:
return None
total_assos_nb = area_stop_assos_nb = conn_stop_assos_nb = 0
area_stop_assos = []
connection_stop_assos = []
for asso in json_stop_area_stop_asso_decoder.decode(data):
fields = asso.fields
stop_id = int(fields.arrid)
area_stop_assos.append((int(fields.zdaid), stop_id))
connection_stop_assos.append((int(fields.zdcid), stop_id))
if len(area_stop_assos) == batch_size:
total_assos_nb += batch_size
if (found_nb := await StopArea.add_stops(area_stop_assos)) is not None:
area_stop_assos_nb += found_nb
area_stop_assos.clear()
if (
found_nb := await ConnectionArea.add_stops(connection_stop_assos)
) is not None:
conn_stop_assos_nb += found_nb
connection_stop_assos.clear()
if area_stop_assos:
total_assos_nb += len(area_stop_assos)
if (found_nb := await StopArea.add_stops(area_stop_assos)) is not None:
area_stop_assos_nb += found_nb
if (
found_nb := await ConnectionArea.add_stops(connection_stop_assos)
) is not None:
conn_stop_assos_nb += found_nb
logger.info(f"{area_stop_assos_nb} stop area <-> stop ({total_assos_nb = } found)")
logger.info(f"{conn_stop_assos_nb} stop area <-> stop ({total_assos_nb = } found)")
async def prepare(db: Database) -> None:
await load_idfm_lines()
await gather(
*(
load_idfm_stops(),
load_idfm_stop_areas(),
load_idfm_connection_areas(),
load_ratp_pictos(),
)
)
await gather(
*(
load_idfm_stop_shapes(IDFM_STOP_SHAPES_URL),
load_idfm_stop_shapes(IDFM_STOP_AREA_SHAPES_URL),
load_lines_stops_assos(),
load_stop_assos(),
)
)
def load_settings(path: str) -> Settings:
with open(path, "r") as config_file:
config = safe_load(config_file)
return Settings(**config)
async def main() -> None:
settings = load_settings(CONFIG_PATH)
await db.connect(settings.db, True)
begin_ts = time()
await prepare(db)
logger.info(f"Elapsed time: {time() - begin_ts}s")
await db.disconnect()
if __name__ == "__main__":
run(main())

38
backend/dependencies.py Normal file
View File

@@ -0,0 +1,38 @@
from os import environ
from fastapi_cache.backends.redis import RedisBackend
from redis import asyncio as aioredis
from yaml import safe_load
from backend.db import db
from backend.idfm_interface.idfm_interface import IdfmInterface
from backend.settings import CacheSettings, Settings
CONFIG_PATH = environ.get("CONFIG_PATH", "./config.sample.yaml")
def load_settings(path: str) -> Settings:
with open(path, "r") as config_file:
config = safe_load(config_file)
return Settings(**config)
settings = load_settings(CONFIG_PATH)
idfm_interface = IdfmInterface(settings.idfm_api_key.get_secret_value(), db)
def init_redis_backend(settings: CacheSettings) -> RedisBackend:
login = f"{settings.user}:{settings.password}@" if settings.user is not None else ""
url = f"redis://{login}{settings.host}:{settings.port}"
redis_connections_pool = aioredis.from_url(
url, encoding="utf8", decode_responses=True
)
return RedisBackend(redis_connections_pool)
redis_backend = init_redis_backend(settings.cache)

View File

@@ -1,290 +1,93 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import logging
from collections import defaultdict
from datetime import datetime
from os import environ, EX_USAGE
from typing import Sequence
import uvicorn import uvicorn
from fastapi import FastAPI, HTTPException from contextlib import asynccontextmanager
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from fastapi_cache import FastAPICache
from opentelemetry import trace from opentelemetry import trace
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
from opentelemetry.sdk.resources import Resource, SERVICE_NAME
from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
from opentelemetry.instrumentation.logging import LoggingInstrumentor
from opentelemetry.sdk.resources import Resource as OtResource
from opentelemetry.sdk.trace import TracerProvider as OtTracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from rich import print
from starlette.types import ASGIApp
from backend.db import db from backend.db import db
from backend.idfm_interface import Destinations as IdfmDestinations, IdfmInterface from dependencies import idfm_interface, redis_backend, settings
from backend.models import Line, Stop, StopArea, StopShape from routers import line, stop
from backend.schemas import (
Line as LineSchema,
TransportMode,
NextPassage as NextPassageSchema,
NextPassages as NextPassagesSchema,
Stop as StopSchema,
StopArea as StopAreaSchema,
StopShape as StopShapeSchema,
)
API_KEY = environ.get("API_KEY")
if API_KEY is None:
print('No "API_KEY" environment variable set... abort.')
exit(EX_USAGE)
APP_NAME = environ.get("APP_NAME", "app")
MODE = environ.get("MODE", "grpc")
COLLECTOR_ENDPOINT_GRPC_ENDPOINT = environ.get(
"COLLECTOR_ENDPOINT_GRPC_ENDPOINT", "127.0.0.1:14250" # "jaeger-collector:14250"
)
# CREATE DATABASE "carrramba-encore-rate";
# CREATE USER cer WITH ENCRYPTED PASSWORD 'cer_password';
# GRANT ALL PRIVILEGES ON DATABASE "carrramba-encore-rate" TO cer;
# \c "carrramba-encore-rate";
# GRANT ALL ON schema public TO cer;
# CREATE EXTENSION IF NOT EXISTS pg_trgm;
# TODO: Remove postgresql+psycopg from environ variable
DB_PATH = "postgresql+psycopg://cer:cer_password@127.0.0.1:5432/carrramba-encore-rate"
app = FastAPI() @asynccontextmanager
async def lifespan(app: FastAPI):
FastAPICache.init(redis_backend, prefix="api", enable=settings.cache.enable)
await db.connect(settings.db, settings.clear_static_data)
if settings.clear_static_data:
await idfm_interface.startup()
yield
await db.disconnect()
app = FastAPI(lifespan=lifespan)
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=[ allow_origins=["https://localhost:4443", "https://localhost:3000"],
"https://localhost:4443",
"https://localhost:3000",
],
allow_credentials=True, allow_credentials=True,
allow_methods=["*"], allow_methods=["*"],
allow_headers=["*"], allow_headers=["*"],
) )
trace.set_tracer_provider(TracerProvider()) app.mount("/widget", StaticFiles(directory="../frontend/", html=True), name="widget")
trace.get_tracer_provider().add_span_processor(BatchSpanProcessor(OTLPSpanExporter()))
tracer = trace.get_tracer(APP_NAME)
with tracer.start_as_current_span("foo"): # The cache-control header entry is not managed properly by fastapi-cache:
print("Hello world!") # For now, a request with a cache-control set to no-cache
# is interpreted as disabling the use of the server cache.
# Cf. Improve Cache-Control header parsing and handling
# https://github.com/long2ice/fastapi-cache/issues/144 workaround
@app.middleware("http")
async def fastapi_cache_issue_144_workaround(request: Request, call_next):
entries = request.headers.__dict__["_list"]
new_entries = [
entry for entry in entries if entry[0].decode().lower() != "cache-control"
]
request.headers.__dict__["_list"] = new_entries
return await call_next(request)
idfm_interface = IdfmInterface(API_KEY, db)
app.include_router(line.router)
app.include_router(stop.router)
@app.on_event("startup")
async def startup():
await db.connect(DB_PATH, clear_static_data=True)
await idfm_interface.startup()
# await db.connect(DB_PATH, clear_static_data=False)
print("Connected")
@app.on_event("shutdown")
async def shutdown():
await db.disconnect()
STATIC_ROOT = "../frontend/"
app.mount("/widget", StaticFiles(directory=STATIC_ROOT, html=True), name="widget")
def optional_datetime_to_ts(dt: datetime | None) -> int | None:
return int(dt.timestamp()) if dt else None
@app.get("/line/{line_id}", response_model=LineSchema)
async def get_line(line_id: str) -> LineSchema:
line: Line | None = await Line.get_by_id(line_id)
if line is None:
raise HTTPException(status_code=404, detail=f'Line "{line_id}" not found')
return LineSchema(
id=line.id,
shortName=line.short_name,
name=line.name,
status=line.status,
transportMode=TransportMode.from_idfm_transport_mode(
line.transport_mode, line.transport_submode
),
backColorHexa=line.colour_web_hexa,
foreColorHexa=line.text_colour_hexa,
operatorId=line.operator_id,
accessibility=line.accessibility,
visualSignsAvailable=line.visual_signs_available,
audibleSignsAvailable=line.audible_signs_available,
stopIds=[stop.id for stop in line.stops],
)
def _format_stop(stop: Stop) -> StopSchema:
return StopSchema(
id=stop.id,
name=stop.name,
town=stop.town_name,
epsg3857_x=stop.epsg3857_x,
epsg3857_y=stop.epsg3857_y,
lines=[line.id for line in stop.lines],
)
@app.get("/stop/")
async def get_stop(
name: str = "", limit: int = 10
) -> Sequence[StopAreaSchema | StopSchema]:
# TODO: Add limit support
formatted: list[StopAreaSchema | StopSchema] = []
matching_stops = await Stop.get_by_name(name)
# print(matching_stops, flush=True)
stop_areas: dict[int, StopArea] = {}
stops: dict[int, Stop] = {}
for stop in matching_stops:
# print(f"{stop.__dict__ = }", flush=True)
dst = stop_areas if isinstance(stop, StopArea) else stops
dst[stop.id] = stop
for stop_area in stop_areas.values():
formatted_stops = []
for stop in stop_area.stops:
formatted_stops.append(_format_stop(stop))
try:
del stops[stop.id]
except KeyError as err:
print(err)
formatted.append(
StopAreaSchema(
id=stop_area.id,
name=stop_area.name,
town=stop_area.town_name,
type=stop_area.type,
lines=[line.id for line in stop_area.lines],
stops=formatted_stops,
)
)
# print(f"{stops = }", flush=True)
formatted.extend(_format_stop(stop) for stop in stops.values())
return formatted
# TODO: Cache response for 30 secs ?
@app.get("/stop/nextPassages/{stop_id}")
async def get_next_passages(stop_id: str) -> NextPassagesSchema | None:
res = await idfm_interface.get_next_passages(stop_id)
if res is None:
return None
service_delivery = res.Siri.ServiceDelivery
stop_monitoring_deliveries = service_delivery.StopMonitoringDelivery
by_line_by_dst_passages: dict[
str, dict[str, list[NextPassageSchema]]
] = defaultdict(lambda: defaultdict(list))
for delivery in stop_monitoring_deliveries:
for stop_visit in delivery.MonitoredStopVisit:
journey = stop_visit.MonitoredVehicleJourney
# re.match will return None if the given journey.LineRef.value is not valid.
try:
line_id = IdfmInterface.LINE_RE.match(journey.LineRef.value).group(1)
except AttributeError as exc:
raise HTTPException(
status_code=404, detail=f'Line "{journey.LineRef.value}" not found'
) from exc
call = journey.MonitoredCall
dst_names = call.DestinationDisplay
dsts = [dst.value for dst in dst_names] if dst_names else []
arrivalPlatformName = (
call.ArrivalPlatformName.value if call.ArrivalPlatformName else None
)
next_passage = NextPassageSchema(
line=line_id,
operator=journey.OperatorRef.value,
destinations=dsts,
atStop=call.VehicleAtStop,
aimedArrivalTs=optional_datetime_to_ts(call.AimedArrivalTime),
expectedArrivalTs=optional_datetime_to_ts(call.ExpectedArrivalTime),
arrivalPlatformName=arrivalPlatformName,
aimedDepartTs=optional_datetime_to_ts(call.AimedDepartureTime),
expectedDepartTs=optional_datetime_to_ts(call.ExpectedDepartureTime),
arrivalStatus=call.ArrivalStatus.value,
departStatus=call.DepartureStatus.value,
)
by_line_passages = by_line_by_dst_passages[line_id]
# TODO: by_line_passages[dst].extend(dsts) instead ?
for dst in dsts:
by_line_passages[dst].append(next_passage)
return NextPassagesSchema(
ts=service_delivery.ResponseTimestamp.timestamp(),
passages=by_line_by_dst_passages,
)
@app.get("/stop/{stop_id}/destinations")
async def get_stop_destinations(
stop_id: int,
) -> IdfmDestinations | None:
destinations = await idfm_interface.get_destinations(stop_id)
return destinations
# TODO: Rename endpoint -> /stop/{stop_id}/shape
@app.get("/stop_shape/{stop_id}")
async def get_stop_shape(stop_id: int) -> StopShapeSchema | None:
connection_area = None
if (stop := await Stop.get_by_id(stop_id)) is not None:
connection_area = stop.connection_area
elif (stop_area := await StopArea.get_by_id(stop_id)) is not None:
connection_areas = {stop.connection_area for stop in stop_area.stops}
connection_areas_len = len(connection_areas)
if connection_areas_len == 1:
connection_area = connection_areas.pop()
else:
prefix = "More than one" if connection_areas_len else "No"
msg = f"{prefix} connection area has been found for stop area #{stop_id}"
raise HTTPException(status_code=500, detail=msg)
if (
connection_area is not None
and (shape := await StopShape.get_by_id(connection_area.id)) is not None
):
return StopShapeSchema(
id=shape.id,
type=shape.type,
epsg3857_bbox=shape.epsg3857_bbox,
epsg3857_points=shape.epsg3857_points,
)
msg = f"No shape found for stop {stop_id}"
raise HTTPException(status_code=404, detail=msg)
if settings.tracing.enable:
FastAPIInstrumentor.instrument_app(app) FastAPIInstrumentor.instrument_app(app)
trace.set_tracer_provider(
TracerProvider(resource=Resource.create({SERVICE_NAME: settings.app_name}))
)
trace.get_tracer_provider().add_span_processor(
BatchSpanProcessor(OTLPSpanExporter())
)
tracer = trace.get_tracer(settings.app_name)
if __name__ == "__main__": if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=4443, ssl_certfile="./config/cert.pem") http_settings = settings.http
config = uvicorn.Config(
app=app,
host=http_settings.host,
port=http_settings.port,
ssl_certfile=http_settings.cert,
proxy_headers=True,
)
server = uvicorn.Server(config)
server.run()

View File

@@ -4,18 +4,14 @@ version = "0.1.0"
description = "" description = ""
authors = ["Adrien SUEUR <me@adrien.run>"] authors = ["Adrien SUEUR <me@adrien.run>"]
readme = "README.md" readme = "README.md"
packages = [{include = "backend"}]
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = "^3.11" python = "^3.11"
aiohttp = "^3.8.3" aiohttp = "^3.8.3"
rich = "^12.6.0"
aiofiles = "^22.1.0" aiofiles = "^22.1.0"
fastapi = "^0.88.0" fastapi = "^0.95.0"
uvicorn = "^0.20.0" uvicorn = "^0.20.0"
msgspec = "^0.12.0" msgspec = "^0.12.0"
pyshp = "^2.3.1"
pyproj = "^3.5.0"
opentelemetry-instrumentation-fastapi = "^0.38b0" opentelemetry-instrumentation-fastapi = "^0.38b0"
sqlalchemy-utils = "^0.41.1" sqlalchemy-utils = "^0.41.1"
opentelemetry-instrumentation-logging = "^0.38b0" opentelemetry-instrumentation-logging = "^0.38b0"
@@ -25,6 +21,26 @@ opentelemetry-exporter-otlp-proto-http = "^1.17.0"
opentelemetry-instrumentation-sqlalchemy = "^0.38b0" opentelemetry-instrumentation-sqlalchemy = "^0.38b0"
sqlalchemy = "^2.0.12" sqlalchemy = "^2.0.12"
psycopg = "^3.1.9" psycopg = "^3.1.9"
pyyaml = "^6.0"
fastapi-cache2 = {extras = ["redis"], version = "^0.2.1"}
[tool.poetry.group.db_updater.dependencies]
aiofiles = "^22.1.0"
aiohttp = "^3.8.3"
fastapi = "^0.95.0"
msgspec = "^0.12.0"
opentelemetry-instrumentation-fastapi = "^0.38b0"
opentelemetry-instrumentation-sqlalchemy = "^0.38b0"
opentelemetry-sdk = "^1.17.0"
opentelemetry-api = "^1.17.0"
psycopg = "^3.1.9"
pyproj = "^3.5.0"
pyshp = "^2.3.1"
python = "^3.11"
pyyaml = "^6.0"
sqlalchemy = "^2.0.12"
sqlalchemy-utils = "^0.41.1"
tqdm = "^4.65.0"
[build-system] [build-system]
requires = ["poetry-core"] requires = ["poetry-core"]
@@ -47,6 +63,10 @@ pyflakes = "^3.0.1"
yapf = "^0.32.0" yapf = "^0.32.0"
whatthepatch = "^1.0.4" whatthepatch = "^1.0.4"
mypy = "^1.0.0" mypy = "^1.0.0"
icecream = "^2.1.3"
types-sqlalchemy-utils = "^1.0.1"
types-pyyaml = "^6.0.12.9"
types-tqdm = "^4.65.0.1"
[tool.mypy] [tool.mypy]
plugins = "sqlalchemy.ext.mypy.plugin" plugins = "sqlalchemy.ext.mypy.plugin"

View File

34
backend/routers/line.py Normal file
View File

@@ -0,0 +1,34 @@
from fastapi import APIRouter, HTTPException
from fastapi_cache.decorator import cache
from backend.models import Line
from backend.schemas import Line as LineSchema, TransportMode
router = APIRouter(prefix="/line", tags=["line"])
@router.get("/{line_id}", response_model=LineSchema)
@cache(namespace="line")
async def get_line(line_id: int) -> LineSchema:
line: Line | None = await Line.get_by_id(line_id)
if line is None:
raise HTTPException(status_code=404, detail=f'Line "{line_id}" not found')
return LineSchema(
id=line.id,
shortName=line.short_name,
name=line.name,
status=line.status,
transportMode=TransportMode.from_idfm_transport_mode(
line.transport_mode, line.transport_submode
),
backColorHexa=line.colour_web_hexa,
foreColorHexa=line.text_colour_hexa,
operatorId=line.operator_id,
accessibility=line.accessibility,
visualSignsAvailable=line.visual_signs_available,
audibleSignsAvailable=line.audible_signs_available,
stopIds=[stop.id for stop in line.stops],
)

176
backend/routers/stop.py Normal file
View File

@@ -0,0 +1,176 @@
from collections import defaultdict
from datetime import datetime
from typing import Sequence
from fastapi import APIRouter, HTTPException
from fastapi_cache.decorator import cache
from backend.idfm_interface import Destinations as IdfmDestinations, TrainStatus
from backend.models import Stop, StopArea, StopShape
from backend.schemas import (
NextPassage as NextPassageSchema,
NextPassages as NextPassagesSchema,
Stop as StopSchema,
StopArea as StopAreaSchema,
StopShape as StopShapeSchema,
)
from dependencies import idfm_interface
router = APIRouter(prefix="/stop", tags=["stop"])
def _format_stop(stop: Stop) -> StopSchema:
return StopSchema(
id=stop.id,
name=stop.name,
town=stop.town_name,
epsg3857_x=stop.epsg3857_x,
epsg3857_y=stop.epsg3857_y,
lines=[line.id for line in stop.lines],
)
def optional_datetime_to_ts(dt: datetime | None) -> int | None:
return int(dt.timestamp()) if dt else None
# TODO: Add limit support
@router.get("/")
@cache(namespace="stop")
async def get_stop(
name: str = "", limit: int = 10
) -> Sequence[StopAreaSchema | StopSchema] | None:
matching_stops = await Stop.get_by_name(name)
if matching_stops is None:
return None
formatted: list[StopAreaSchema | StopSchema] = []
stop_areas: dict[int, StopArea] = {}
stops: dict[int, Stop] = {}
for stop in matching_stops:
if isinstance(stop, StopArea):
stop_areas[stop.id] = stop
elif isinstance(stop, Stop):
stops[stop.id] = stop
for stop_area in stop_areas.values():
formatted_stops = []
for stop in stop_area.stops:
formatted_stops.append(_format_stop(stop))
try:
del stops[stop.id]
except KeyError as err:
print(err)
formatted.append(
StopAreaSchema(
id=stop_area.id,
name=stop_area.name,
town=stop_area.town_name,
type=stop_area.type,
lines=[line.id for line in stop_area.lines],
stops=formatted_stops,
)
)
formatted.extend(_format_stop(stop) for stop in stops.values())
return formatted
@router.get("/{stop_id}/nextPassages")
@cache(namespace="stop-nextPassages", expire=30)
async def get_next_passages(stop_id: int) -> NextPassagesSchema | None:
res = await idfm_interface.get_next_passages(stop_id)
if res is None:
return None
service_delivery = res.Siri.ServiceDelivery
stop_monitoring_deliveries = service_delivery.StopMonitoringDelivery
by_line_by_dst_passages: dict[
int, dict[str, list[NextPassageSchema]]
] = defaultdict(lambda: defaultdict(list))
for delivery in stop_monitoring_deliveries:
for stop_visit in delivery.MonitoredStopVisit:
journey = stop_visit.MonitoredVehicleJourney
# re.match will return None if the given journey.LineRef.value is not valid.
try:
line_id_match = idfm_interface.LINE_RE.match(journey.LineRef.value)
line_id = int(line_id_match.group(1)) # type: ignore
except (AttributeError, TypeError, ValueError) as err:
raise HTTPException(
status_code=404, detail=f'Line "{journey.LineRef.value}" not found'
) from err
call = journey.MonitoredCall
dst_names = call.DestinationDisplay
dsts = [dst.value for dst in dst_names] if dst_names else []
arrivalPlatformName = (
call.ArrivalPlatformName.value if call.ArrivalPlatformName else None
)
next_passage = NextPassageSchema(
line=line_id,
operator=journey.OperatorRef.value,
destinations=dsts,
atStop=call.VehicleAtStop,
aimedArrivalTs=optional_datetime_to_ts(call.AimedArrivalTime),
expectedArrivalTs=optional_datetime_to_ts(call.ExpectedArrivalTime),
arrivalPlatformName=arrivalPlatformName,
aimedDepartTs=optional_datetime_to_ts(call.AimedDepartureTime),
expectedDepartTs=optional_datetime_to_ts(call.ExpectedDepartureTime),
arrivalStatus=call.ArrivalStatus
if call.ArrivalStatus is not None
else TrainStatus.unknown,
departStatus=call.DepartureStatus
if call.DepartureStatus is not None
else TrainStatus.unknown,
)
by_line_passages = by_line_by_dst_passages[line_id]
# TODO: by_line_passages[dst].extend(dsts) instead ?
for dst in dsts:
by_line_passages[dst].append(next_passage)
return NextPassagesSchema(
ts=int(service_delivery.ResponseTimestamp.timestamp()),
passages=by_line_by_dst_passages,
)
@router.get("/{stop_id}/destinations")
@cache(namespace="stop-destinations", expire=30)
async def get_stop_destinations(
stop_id: int,
) -> IdfmDestinations | None:
destinations = await idfm_interface.get_destinations(stop_id)
return destinations
@router.get("/{stop_id}/shape")
@cache(namespace="stop-shape")
async def get_stop_shape(stop_id: int) -> StopShapeSchema | None:
if (await Stop.get_by_id(stop_id)) is not None or (
await StopArea.get_by_id(stop_id)
) is not None:
shape_id = stop_id
if (shape := await StopShape.get_by_id(shape_id)) is not None:
return StopShapeSchema(
id=shape.id,
type=shape.type,
epsg3857_bbox=shape.epsg3857_bbox,
epsg3857_points=shape.epsg3857_points,
)
msg = f"No shape found for stop {stop_id}"
raise HTTPException(status_code=404, detail=msg)

View File

@@ -15,8 +15,15 @@ services:
ports: ports:
- "127.0.0.1:5432:5432" - "127.0.0.1:5432:5432"
volumes: volumes:
- ./docker/database/docker-entrypoint-initdb.d:/docker-entrypoint-initdb.d - ./backend/docker/database/docker-entrypoint-initdb.d:/docker-entrypoint-initdb.d
- ./docker/database/data:/var/lib/postgresql/data - ./backend/docker/database/data:/var/lib/postgresql/data
redis:
image: redis:latest
restart: always
command: redis-server --loglevel warning
ports:
- "127.0.0.1:6379:6379"
jaeger-agent: jaeger-agent:
image: jaegertracing/jaeger-agent:latest image: jaegertracing/jaeger-agent:latest
@@ -45,10 +52,6 @@ services:
ports: ports:
- "127.0.0.1:4317:4317" - "127.0.0.1:4317:4317"
- "127.0.0.1:4318:4318" - "127.0.0.1:4318:4318"
# - "127.0.0.1:9411:9411"
# - "127.0.0.1:14250:14250"
# - "127.0.0.1:14268:14268"
# - "127.0.0.1:14269:14269"
restart: on-failure restart: on-failure
depends_on: depends_on:
- cassandra-schema - cassandra-schema
@@ -68,7 +71,29 @@ services:
- "--cassandra.servers=cassandra" - "--cassandra.servers=cassandra"
ports: ports:
- "127.0.0.1:16686:16686" - "127.0.0.1:16686:16686"
# - "127.0.0.1:16687:16687"
restart: on-failure restart: on-failure
depends_on: depends_on:
- cassandra-schema - cassandra-schema
carrramba-encore-rate-api:
build:
context: ./backend/
dockerfile: Dockerfile.backend
environment:
- CONFIG_PATH=./config.local.yaml
- IDFM_API_KEY=set_your_idfm_key_here
ports:
- "127.0.0.1:8080:8080"
carrramba-encore-rate-frontend:
build:
context: ./frontend/
ports:
- "127.0.0.1:80:8081"
carrramba-encore-rate-db-updater:
build:
context: ./backend/
dockerfile: Dockerfile.db_updater
environment:
- CONFIG_PATH=./config.local.yaml

4
frontend/Dockerfile Normal file
View File

@@ -0,0 +1,4 @@
# pull the latest official nginx image
FROM nginx:mainline-alpine-slim
COPY dist /usr/share/nginx/html

View File

@@ -1,4 +1,4 @@
import { Component } from 'solid-js'; import { Component, createSignal } from 'solid-js';
import { IVisibilityActionRequest, MatrixCapabilities, WidgetApi, WidgetApiToWidgetAction } from 'matrix-widget-api'; import { IVisibilityActionRequest, MatrixCapabilities, WidgetApi, WidgetApiToWidgetAction } from 'matrix-widget-api';
import { HopeProvider } from "@hope-ui/solid"; import { HopeProvider } from "@hope-ui/solid";
@@ -6,9 +6,10 @@ import { BusinessDataProvider } from './businessData';
import { AppContextProvider } from './appContext'; import { AppContextProvider } from './appContext';
import { PassagesDisplay } from './passagesDisplay'; import { PassagesDisplay } from './passagesDisplay';
import { StopsSearchMenu } from './stopsSearchMenu'; import { StopsSearchMenu } from './stopsSearchMenu/stopsSearchMenu';
import "./App.scss"; import "./App.scss";
import { onCleanup, onMount } from 'solid-js';
function parseFragment() { function parseFragment() {
@@ -43,6 +44,33 @@ const App: Component = () => {
api.transport.reply(ev.detail, {}); api.transport.reply(ev.detail, {});
}); });
createSignal({
height: window.innerHeight,
width: window.innerWidth
});
const onResize = () => {
const body = document.body;
if (window.innerWidth * 9 / 16 < window.innerHeight) {
body.style['height'] = 'auto';
body.style['width'] = '100vw';
}
else {
body.style['height'] = '100vh';
body.style['width'] = 'auto';
}
};
onMount(() => {
window.addEventListener('resize', onResize);
onResize();
});
onCleanup(() => {
window.removeEventListener('resize', onResize);
})
return ( return (
<BusinessDataProvider> <BusinessDataProvider>
<AppContextProvider> <AppContextProvider>

View File

@@ -21,7 +21,7 @@ export interface BusinessDataStore {
getStop: (stopId: number) => Stop | undefined; getStop: (stopId: number) => Stop | undefined;
searchStopByName: (name: string) => Promise<Stops>; searchStopByName: (name: string) => Promise<Stops>;
getStopDestinations: (stopId: number) => Promise<StopDestinations>; getStopDestinations: (stopId: number) => Promise<StopDestinations | undefined>;
getStopShape: (stopId: number) => Promise<StopShape | undefined>; getStopShape: (stopId: number) => Promise<StopShape | undefined>;
}; };
@@ -44,11 +44,20 @@ export function BusinessDataProvider(props: { children: JSX.Element }) {
let line = store.lines[lineId]; let line = store.lines[lineId];
if (line === undefined) { if (line === undefined) {
console.log(`${lineId} not found... fetch it from backend.`); console.log(`${lineId} not found... fetch it from backend.`);
const data = await fetch(`${serverUrl()}/line/${lineId}`, {
const response = await fetch(`${serverUrl()}/line/${lineId}`, {
headers: { 'Content-Type': 'application/json' } headers: { 'Content-Type': 'application/json' }
}); });
line = await data.json();
setStore('lines', lineId, line); const json = await response.json();
if (response.ok) {
setStore('lines', lineId, json);
line = json;
}
else {
console.warn(`No line found for ${lineId} line id:`, json);
}
} }
return line; return line;
} }
@@ -91,12 +100,19 @@ export function BusinessDataProvider(props: { children: JSX.Element }) {
} }
const refreshPassages = async (stopId: number): Promise<void> => { const refreshPassages = async (stopId: number): Promise<void> => {
const httpOptions = { headers: { "Content-Type": "application/json" } };
console.log(`Fetching data for ${stopId}`); console.log(`Fetching data for ${stopId}`);
const data = await fetch(`${serverUrl()}/stop/nextPassages/${stopId}`, httpOptions); const httpOptions = { headers: { "Content-Type": "application/json" } };
const response = await data.json(); const response = await fetch(`${serverUrl()}/stop/${stopId}/nextPassages`, httpOptions);
_cleanupPassages(response.passages);
addPassages(response.passages); const json = await response.json();
if (response.ok) {
_cleanupPassages(json.passages);
addPassages(json.passages);
}
else {
console.warn(`No passage found for ${stopId} stop:`, json);
}
} }
const addPassages = (passages: Passages): void => { const addPassages = (passages: Passages): void => {
@@ -155,39 +171,58 @@ ${linePassagesDestination.length} here... refresh all them.`);
} }
const searchStopByName = async (name: string): Promise<Stops> => { const searchStopByName = async (name: string): Promise<Stops> => {
const data = await fetch(`${serverUrl()}/stop/?name=${name}`, { const byIdStops: Stops = {};
const response = await fetch(`${serverUrl()}/stop/?name=${name}`, {
headers: { 'Content-Type': 'application/json' } headers: { 'Content-Type': 'application/json' }
}); });
const stops = await data.json();
const byIdStops: Stops = {}; const json = await response.json();
for (const stop of stops) {
if (response.ok) {
for (const stop of json) {
byIdStops[stop.id] = stop; byIdStops[stop.id] = stop;
setStore('stops', stop.id, stop); setStore('stops', stop.id, stop);
if (stop.stops !== undefined) {
for (const innerStop of stop.stops) { for (const innerStop of stop.stops) {
setStore('stops', innerStop.id, innerStop); setStore('stops', innerStop.id, innerStop);
} }
} }
}
}
else {
console.warn(`No stop found for '${name}' query:`, json);
}
return byIdStops; return byIdStops;
} }
const getStopDestinations = async (stopId: number): Promise<StopDestinations> => { const getStopDestinations = async (stopId: number): Promise<StopDestinations | undefined> => {
const data = await fetch(`${serverUrl()}/stop/${stopId}/destinations`, { const response = await fetch(`${serverUrl()}/stop/${stopId}/destinations`, {
headers: { 'Content-Type': 'application/json' } headers: { 'Content-Type': 'application/json' }
}); });
const response = await data.json(); const destinations = response.ok ? await response.json() : undefined;
return response; return destinations;
} }
const getStopShape = async (stopId: number): Promise<StopShape | undefined> => { const getStopShape = async (stopId: number): Promise<StopShape | undefined> => {
let shape = store.stopShapes[stopId]; let shape = store.stopShapes[stopId];
if (shape === undefined) { if (shape === undefined) {
console.log(`No shape found for ${stopId} stop... fetch it from backend.`); console.log(`No shape found for ${stopId} stop... fetch it from backend.`);
const data = await fetch(`${serverUrl()}/stop_shape/${stopId}`, { const response = await fetch(`${serverUrl()}/stop/${stopId}/shape`, {
headers: { 'Content-Type': 'application/json' } headers: { 'Content-Type': 'application/json' }
}); });
shape = await data.json();
setStore('stopShapes', stopId, shape); const json = await response.json();
if (response.ok) {
setStore('stopShapes', stopId, json);
shape = json;
}
else {
console.warn(`No shape found for ${stopId} stop:`, json);
}
} }
return shape; return shape;
} }

View File

@@ -13,10 +13,8 @@
src: url(/public/fonts/IDFVoyageur-Medium.otf); src: url(/public/fonts/IDFVoyageur-Medium.otf);
} }
body { html, body {
aspect-ratio: 16/9; aspect-ratio: 16/9;
width: 100vw;
height: none;
margin: 0; margin: 0;