16 Commits

Author SHA1 Message Date
5da918c04b 👽️ Take the last IDFM format into account 2023-06-11 22:41:44 +02:00
2eaf0f4ed5 Use of db merge when adds fails due to single key violations 2023-06-11 22:28:15 +02:00
c42b687870 🐛 Fix IdfmInterface circular import issue 2023-06-11 22:24:09 +02:00
d8adb4f52d ♻️ Remove code in charge or db filling from IdfmInterface 2023-06-11 22:22:05 +02:00
5e7f440b54 ♻️ Add the db_updater package 2023-06-11 22:18:47 +02:00
824536ddbe 💥 Rename API_KEY to IDFM_API_KEY 2023-05-28 12:45:03 +02:00
7fbdd0606c ️ Reduce the size of the backend docker image 2023-05-28 12:40:10 +02:00
581f6b7b8f 🐛 Add workaround for fastapi-cache issue #144 2023-05-28 10:45:14 +02:00
404b228cbf 🔥 Remove env variables from backend dockerfile 2023-05-26 23:55:58 +02:00
e2ff90cd5f ️ Use Redis to cache REST responses 2023-05-26 18:10:47 +02:00
cd700ebd42 🐛 The backend shall serve requests once the database reachable 2023-05-26 18:09:24 +02:00
c44a52b7ae ♻️ Add backend and frontend to docker-compose 2023-05-26 18:01:04 +02:00
b3b36bc3de Replace rich with icecream for temporary tracing 2023-05-11 21:44:58 +02:00
5e0d7b174c 🏷️ Fix some type issues (mypy) 2023-05-11 21:40:38 +02:00
b437bbbf70 🎨 Split main into several APIRouters 2023-05-11 21:17:02 +02:00
85fdb28cc6 🐛 Set default value to Settings.clear_static_data 2023-05-11 20:31:24 +02:00
22 changed files with 1144 additions and 785 deletions

12
backend/.dockerignore Normal file
View File

@@ -0,0 +1,12 @@
.dir-locals.el
.dockerignore
.gitignore
**/.mypy_cache
**/.ruff_cache
.venv
**/__pycache__
config
docker
poetry.lock
tests
Dockerfile

View File

@@ -0,0 +1,38 @@
FROM python:3.11-slim as builder
RUN pip install poetry
ENV POETRY_NO_INTERACTION=1 \
POETRY_VIRTUALENVS_IN_PROJECT=1 \
POETRY_VIRTUALENVS_CREATE=1 \
POETRY_CACHE_DIR=/tmp/poetry_cache
WORKDIR /app
COPY ./pyproject.toml /app
RUN poetry install --only=main --no-root && \
rm -rf ${POETRY_CACHE_DIR}
FROM python:3.11-slim as runtime
WORKDIR /app
RUN apt update && \
apt install -y --no-install-recommends libpq5 && \
apt clean && \
rm -rf /var/lib/apt/lists/*
env VIRTUAL_ENV=/app/.venv \
PATH="/app/.venv/bin:$PATH"
COPY --from=builder ${VIRTUAL_ENV} ${VIRTUAL_ENV}
COPY backend /app/backend
COPY dependencies.py /app
COPY config.sample.yaml /app
COPY routers/ /app/routers
COPY main.py /app
CMD ["python", "./main.py"]

View File

@@ -0,0 +1,43 @@
FROM python:3.11-slim as builder
RUN apt update && \
apt install -y --no-install-recommends proj-bin && \
apt clean && \
rm -rf /var/lib/apt/lists/*
RUN pip install poetry
ENV POETRY_NO_INTERACTION=1 \
POETRY_VIRTUALENVS_IN_PROJECT=1 \
POETRY_VIRTUALENVS_CREATE=1 \
POETRY_CACHE_DIR=/tmp/poetry_cache
WORKDIR /app
COPY ./pyproject.toml /app
RUN poetry install --only=db_updater --no-root && \
rm -rf ${POETRY_CACHE_DIR}
FROM python:3.11-slim as runtime
WORKDIR /app
RUN apt update && \
apt install -y --no-install-recommends libpq5 && \
apt clean && \
rm -rf /var/lib/apt/lists/*
env VIRTUAL_ENV=/app/.venv \
PATH="/app/.venv/bin:$PATH"
COPY --from=builder ${VIRTUAL_ENV} ${VIRTUAL_ENV}
COPY backend /app/backend
COPY dependencies.py /app
COPY config.sample.yaml /app
COPY config.local.yaml /app
COPY db_updater /app/db_updater
CMD ["python", "-m", "db_updater.fill_db"]

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
from logging import getLogger
from typing import Iterable, Self, TYPE_CHECKING
from typing import Self, Sequence, TYPE_CHECKING
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
@@ -17,21 +17,34 @@ class Base(DeclarativeBase):
db: Database | None = None
@classmethod
async def add(cls, objs: Self | Iterable[Self]) -> bool:
async def add(cls, objs: Sequence[Self]) -> bool:
if cls.db is not None and (session := await cls.db.get_session()) is not None:
try:
async with session.begin():
session.add_all(objs)
except IntegrityError as err:
logger.warning(err)
return await cls.merge(objs)
except AttributeError as err:
logger.error(err)
return False
return True
@classmethod
async def merge(cls, objs: Sequence[Self]) -> bool:
if cls.db is not None and (session := await cls.db.get_session()) is not None:
async with session.begin():
try:
if isinstance(objs, Iterable):
session.add_all(objs)
else:
session.add(objs)
for obj in objs:
await session.merge(obj)
except (AttributeError, IntegrityError) as err:
logger.error(err)
return False
return True
return True
return False
@classmethod
async def get_by_id(cls, id_: int | str) -> Self | None:

View File

@@ -1,10 +1,11 @@
from asyncio import sleep
from logging import getLogger
from typing import Annotated, AsyncIterator
from fastapi import Depends
from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor
from sqlalchemy import text
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.exc import OperationalError, SQLAlchemyError
from sqlalchemy.ext.asyncio import (
async_sessionmaker,
AsyncEngine,
@@ -56,11 +57,20 @@ class Database:
class_=AsyncSession,
)
async with self._async_engine.begin() as session:
await session.execute(text("CREATE EXTENSION IF NOT EXISTS pg_trgm;"))
if clear_static_data:
await session.run_sync(Base.metadata.drop_all)
await session.run_sync(Base.metadata.create_all)
ret = False
while not ret:
try:
async with self._async_engine.begin() as session:
await session.execute(
text("CREATE EXTENSION IF NOT EXISTS pg_trgm;")
)
if clear_static_data:
await session.run_sync(Base.metadata.drop_all)
await session.run_sync(Base.metadata.create_all)
ret = True
except OperationalError as err:
logger.error(err)
await sleep(1)
return True

View File

@@ -1,5 +1,3 @@
from .idfm_interface import IdfmInterface
from .idfm_types import (
Coordinate,
Destinations,
@@ -38,7 +36,6 @@ __all__ = [
"Coordinate",
"Destinations",
"FramedVehicleJourney",
"IdfmInterface",
"IdfmLineState",
"IdfmOperator",
"IdfmResponse",

View File

@@ -1,44 +1,15 @@
from collections import defaultdict
from logging import getLogger
from re import compile as re_compile
from time import time
from typing import (
AsyncIterator,
ByteString,
Callable,
Iterable,
List,
Type,
)
from typing import ByteString
from aiofiles import open as async_open
from aiohttp import ClientSession
from msgspec import ValidationError
from msgspec.json import Decoder
from pyproj import Transformer
from shapefile import Reader as ShapeFileReader, ShapeRecord # type: ignore
from ..db import Database
from ..models import ConnectionArea, Line, LinePicto, Stop, StopArea, StopShape
from .idfm_types import (
ConnectionArea as IdfmConnectionArea,
Destinations as IdfmDestinations,
IdfmLineState,
IdfmResponse,
Line as IdfmLine,
LinePicto as IdfmPicto,
IdfmState,
Stop as IdfmStop,
StopArea as IdfmStopArea,
StopAreaStopAssociation,
StopAreaType,
StopLineAsso as IdfmStopLineAsso,
TransportMode,
)
from .ratp_types import Picto as RatpPicto
logger = getLogger(__name__)
from ..models import Line, Stop, StopArea
from .idfm_types import Destinations as IdfmDestinations, IdfmResponse, IdfmState
class IdfmInterface:
@@ -46,18 +17,6 @@ class IdfmInterface:
IDFM_ROOT_URL = "https://prim.iledefrance-mobilites.fr/marketplace"
IDFM_STOP_MON_URL = f"{IDFM_ROOT_URL}/stop-monitoring"
IDFM_ROOT_URL = "https://data.iledefrance-mobilites.fr/explore/dataset"
IDFM_STOPS_URL = (
f"{IDFM_ROOT_URL}/arrets/download/?format=json&timezone=Europe/Berlin"
)
IDFM_PICTO_URL = f"{IDFM_ROOT_URL}/referentiel-des-lignes/files"
RATP_ROOT_URL = "https://data.ratp.fr/explore/dataset"
RATP_PICTO_URL = (
f"{RATP_ROOT_URL}"
"/pictogrammes-des-lignes-de-metro-rer-tramway-bus-et-noctilien/files"
)
OPERATOR_RE = re_compile(r"[^:]+:Operator::([^:]+):")
LINE_RE = re_compile(r"[^:]+:Line::C([^:]+):")
@@ -67,459 +26,12 @@ class IdfmInterface:
self._http_headers = {"Accept": "application/json", "apikey": self._api_key}
self._epsg2154_epsg3857_transformer = Transformer.from_crs(2154, 3857)
self._json_stops_decoder = Decoder(type=List[IdfmStop])
self._json_stop_areas_decoder = Decoder(type=List[IdfmStopArea])
self._json_connection_areas_decoder = Decoder(type=List[IdfmConnectionArea])
self._json_lines_decoder = Decoder(type=List[IdfmLine])
self._json_stops_lines_assos_decoder = Decoder(type=List[IdfmStopLineAsso])
self._json_ratp_pictos_decoder = Decoder(type=List[RatpPicto])
self._json_stop_area_stop_asso_decoder = Decoder(
type=List[StopAreaStopAssociation]
)
self._response_json_decoder = Decoder(type=IdfmResponse)
async def startup(self) -> None:
BATCH_SIZE = 10000
STEPS: tuple[
tuple[
Type[ConnectionArea] | Type[Stop] | Type[StopArea] | Type[StopShape],
Callable,
Callable,
],
...,
] = (
(
StopShape,
self._request_stop_shapes,
self._format_idfm_stop_shapes,
),
(
ConnectionArea,
self._request_idfm_connection_areas,
self._format_idfm_connection_areas,
),
(
StopArea,
self._request_idfm_stop_areas,
self._format_idfm_stop_areas,
),
(Stop, self._request_idfm_stops, self._format_idfm_stops),
)
for model, get_method, format_method in STEPS:
step_begin_ts = time()
elements = []
async for element in get_method():
elements.append(element)
if len(elements) == BATCH_SIZE:
await model.add(format_method(*elements))
elements.clear()
if elements:
await model.add(format_method(*elements))
print(f"Add {model.__name__}s: {time() - step_begin_ts}s")
begin_ts = time()
await self._load_lines()
print(f"Add Lines and IDFM LinePictos: {time() - begin_ts}s")
begin_ts = time()
await self._load_ratp_pictos(30)
print(f"Add RATP LinePictos: {time() - begin_ts}s")
begin_ts = time()
await self._load_lines_stops_assos()
print(f"Link Stops to Lines: {time() - begin_ts}s")
begin_ts = time()
await self._load_stop_assos()
print(f"Link Stops to StopAreas: {time() - begin_ts}s")
async def _load_lines(self, batch_size: int = 5000) -> None:
lines, pictos = [], []
picto_ids = set()
async for line in self._request_idfm_lines():
if (picto := line.fields.picto) is not None and picto.id_ not in picto_ids:
picto_ids.add(picto.id_)
pictos.append(picto)
lines.append(line)
if len(lines) == batch_size:
await LinePicto.add(IdfmInterface._format_idfm_pictos(*pictos))
await Line.add(await self._format_idfm_lines(*lines))
lines.clear()
pictos.clear()
if pictos:
await LinePicto.add(IdfmInterface._format_idfm_pictos(*pictos))
if lines:
await Line.add(await self._format_idfm_lines(*lines))
async def _load_ratp_pictos(self, batch_size: int = 5) -> None:
pictos = []
async for picto in self._request_ratp_pictos():
pictos.append(picto)
if len(pictos) == batch_size:
formatted_pictos = IdfmInterface._format_ratp_pictos(*pictos)
await LinePicto.add(map(lambda picto: picto[1], formatted_pictos))
await Line.add_pictos(formatted_pictos)
pictos.clear()
if pictos:
formatted_pictos = IdfmInterface._format_ratp_pictos(*pictos)
await LinePicto.add(map(lambda picto: picto[1], formatted_pictos))
await Line.add_pictos(formatted_pictos)
async def _load_lines_stops_assos(self, batch_size: int = 5000) -> None:
total_assos_nb = total_found_nb = 0
assos = []
async for asso in self._request_idfm_stops_lines_associations():
fields = asso.fields
try:
stop_id = int(fields.stop_id.rsplit(":", 1)[-1])
except ValueError as err:
print(err)
print(f"{fields.stop_id = }")
continue
assos.append((fields.route_long_name, fields.operatorname, stop_id))
if len(assos) == batch_size:
total_assos_nb += batch_size
total_found_nb += await Line.add_stops(assos)
assos.clear()
if assos:
total_assos_nb += len(assos)
total_found_nb += await Line.add_stops(assos)
print(f"{total_found_nb} line <-> stop ({total_assos_nb = } found)")
async def _load_stop_assos(self, batch_size: int = 5000) -> None:
total_assos_nb = area_stop_assos_nb = conn_stop_assos_nb = 0
area_stop_assos = []
connection_stop_assos = []
async for asso in self._request_idfm_stop_area_stop_associations():
fields = asso.fields
stop_id = int(fields.arrid)
area_stop_assos.append((int(fields.zdaid), stop_id))
connection_stop_assos.append((int(fields.zdcid), stop_id))
if len(area_stop_assos) == batch_size:
total_assos_nb += batch_size
if (found_nb := await StopArea.add_stops(area_stop_assos)) is not None:
area_stop_assos_nb += found_nb
area_stop_assos.clear()
if (
found_nb := await ConnectionArea.add_stops(connection_stop_assos)
) is not None:
conn_stop_assos_nb += found_nb
connection_stop_assos.clear()
if area_stop_assos:
total_assos_nb += len(area_stop_assos)
if (found_nb := await StopArea.add_stops(area_stop_assos)) is not None:
area_stop_assos_nb += found_nb
if (
found_nb := await ConnectionArea.add_stops(connection_stop_assos)
) is not None:
conn_stop_assos_nb += found_nb
print(f"{area_stop_assos_nb} stop area <-> stop ({total_assos_nb = } found)")
print(f"{conn_stop_assos_nb} stop area <-> stop ({total_assos_nb = } found)")
# TODO: This method is synchronous due to the shapefile library.
# It's not a blocking issue but it could be nice to find an alternative.
async def _request_stop_shapes(self) -> AsyncIterator[ShapeRecord]:
# TODO: Use HTTP
with ShapeFileReader("./tests/datasets/REF_LDA.zip") as reader:
for record in reader.shapeRecords():
yield record
async def _request_idfm_stops(self) -> AsyncIterator[IdfmStop]:
# headers = {"Accept": "application/json", "apikey": self._api_key}
# async with ClientSession(headers=headers) as session:
# async with session.get(self.STOPS_URL) as response:
# # print("Status:", response.status)
# if response.status == 200:
# for point in self._json_stops_decoder.decode(await response.read()):
# yield point
# TODO: Use HTTP
async with async_open("./tests/datasets/stops_dataset.json", "rb") as raw:
for element in self._json_stops_decoder.decode(await raw.read()):
yield element
async def _request_idfm_stop_areas(self) -> AsyncIterator[IdfmStopArea]:
# TODO: Use HTTP
async with async_open("./tests/datasets/zones-d-arrets.json", "rb") as raw:
for element in self._json_stop_areas_decoder.decode(await raw.read()):
yield element
async def _request_idfm_connection_areas(self) -> AsyncIterator[IdfmConnectionArea]:
async with async_open(
"./tests/datasets/zones-de-correspondance.json", "rb"
) as raw:
for element in self._json_connection_areas_decoder.decode(await raw.read()):
yield element
async def _request_idfm_lines(self) -> AsyncIterator[IdfmLine]:
# TODO: Use HTTP
async with async_open("./tests/datasets/lines_dataset.json", "rb") as raw:
for element in self._json_lines_decoder.decode(await raw.read()):
yield element
async def _request_idfm_stops_lines_associations(
self,
) -> AsyncIterator[IdfmStopLineAsso]:
# TODO: Use HTTP
async with async_open("./tests/datasets/arrets-lignes.json", "rb") as raw:
for element in self._json_stops_lines_assos_decoder.decode(
await raw.read()
):
yield element
async def _request_idfm_stop_area_stop_associations(
self,
) -> AsyncIterator[StopAreaStopAssociation]:
# TODO: Use HTTP
async with async_open("./tests/datasets/relations.json", "rb") as raw:
for element in self._json_stop_area_stop_asso_decoder.decode(
await raw.read()
):
yield element
async def _request_ratp_pictos(self) -> AsyncIterator[RatpPicto]:
# TODO: Use HTTP
async with async_open(
"./tests/datasets/pictogrammes-des-lignes-de-metro-rer-tramway-bus-et-noctilien.json",
"rb",
) as fd:
for element in self._json_ratp_pictos_decoder.decode(await fd.read()):
yield element
@classmethod
def _format_idfm_pictos(cls, *pictos: IdfmPicto) -> Iterable[LinePicto]:
ret = []
for picto in pictos:
ret.append(
LinePicto(
id=picto.id_,
mime_type=picto.mimetype,
height_px=picto.height,
width_px=picto.width,
filename=picto.filename,
url=f"{cls.IDFM_PICTO_URL}/{picto.id_}/download",
thumbnail=picto.thumbnail,
format=picto.format,
)
)
return ret
@classmethod
def _format_ratp_pictos(cls, *pictos: RatpPicto) -> Iterable[tuple[str, LinePicto]]:
ret = []
for picto in pictos:
if (fields := picto.fields.noms_des_fichiers) is not None:
ret.append(
(
picto.fields.indices_commerciaux,
LinePicto(
id=fields.id_,
mime_type=f"image/{fields.format.lower()}",
height_px=fields.height,
width_px=fields.width,
filename=fields.filename,
url=f"{cls.RATP_PICTO_URL}/{fields.id_}/download",
thumbnail=fields.thumbnail,
format=fields.format,
),
)
)
return ret
async def _format_idfm_lines(self, *lines: IdfmLine) -> Iterable[Line]:
ret = []
optional_value = IdfmLine.optional_value
for line in lines:
fields = line.fields
picto_id = fields.picto.id_ if fields.picto is not None else None
line_id = fields.id_line
try:
formatted_line_id = int(line_id[1:] if line_id[0] == "C" else line_id)
except ValueError:
logger.warning("Unable to format %s line id.", line_id)
continue
try:
operator_id = int(fields.operatorref) # type: ignore
except (ValueError, TypeError):
logger.warning("Unable to format %s operator id.", fields.operatorref)
operator_id = 0
ret.append(
Line(
id=formatted_line_id,
short_name=fields.shortname_line,
name=fields.name_line,
status=IdfmLineState(fields.status.value),
transport_mode=TransportMode(fields.transportmode.value),
transport_submode=optional_value(fields.transportsubmode),
network_name=optional_value(fields.networkname),
group_of_lines_id=optional_value(fields.id_groupoflines),
group_of_lines_shortname=optional_value(
fields.shortname_groupoflines
),
colour_web_hexa=fields.colourweb_hexa,
text_colour_hexa=fields.textcolourprint_hexa,
operator_id=operator_id,
operator_name=optional_value(fields.operatorname),
accessibility=IdfmState(fields.accessibility.value),
visual_signs_available=IdfmState(
fields.visualsigns_available.value
),
audible_signs_available=IdfmState(
fields.audiblesigns_available.value
),
picto_id=fields.picto.id_ if fields.picto is not None else None,
record_id=line.recordid,
record_ts=int(line.record_timestamp.timestamp()),
)
)
return ret
def _format_idfm_stops(self, *stops: IdfmStop) -> Iterable[Stop]:
for stop in stops:
fields = stop.fields
try:
created_ts = int(fields.arrcreated.timestamp()) # type: ignore
except AttributeError:
created_ts = None
epsg3857_point = self._epsg2154_epsg3857_transformer.transform(
fields.arrxepsg2154, fields.arryepsg2154
)
try:
postal_region = int(fields.arrpostalregion)
except ValueError:
logger.warning(
"Unable to format %s postal region.", fields.arrpostalregion
)
continue
yield Stop(
id=int(fields.arrid),
name=fields.arrname,
epsg3857_x=epsg3857_point[0],
epsg3857_y=epsg3857_point[1],
town_name=fields.arrtown,
postal_region=postal_region,
transport_mode=TransportMode(fields.arrtype.value),
version=fields.arrversion,
created_ts=created_ts,
changed_ts=int(fields.arrchanged.timestamp()),
accessibility=IdfmState(fields.arraccessibility.value),
visual_signs_available=IdfmState(fields.arrvisualsigns.value),
audible_signs_available=IdfmState(fields.arraudiblesignals.value),
record_id=stop.recordid,
record_ts=int(stop.record_timestamp.timestamp()),
)
def _format_idfm_stop_areas(self, *stop_areas: IdfmStopArea) -> Iterable[StopArea]:
for stop_area in stop_areas:
fields = stop_area.fields
try:
created_ts = int(fields.zdacreated.timestamp()) # type: ignore
except AttributeError:
created_ts = None
epsg3857_point = self._epsg2154_epsg3857_transformer.transform(
fields.zdaxepsg2154, fields.zdayepsg2154
)
yield StopArea(
id=int(fields.zdaid),
name=fields.zdaname,
town_name=fields.zdatown,
postal_region=fields.zdapostalregion,
epsg3857_x=epsg3857_point[0],
epsg3857_y=epsg3857_point[1],
type=StopAreaType(fields.zdatype.value),
version=fields.zdaversion,
created_ts=created_ts,
changed_ts=int(fields.zdachanged.timestamp()),
)
def _format_idfm_connection_areas(
self,
*connection_areas: IdfmConnectionArea,
) -> Iterable[ConnectionArea]:
for connection_area in connection_areas:
epsg3857_point = self._epsg2154_epsg3857_transformer.transform(
connection_area.zdcxepsg2154, connection_area.zdcyepsg2154
)
yield ConnectionArea(
id=int(connection_area.zdcid),
name=connection_area.zdcname,
town_name=connection_area.zdctown,
postal_region=connection_area.zdcpostalregion,
epsg3857_x=epsg3857_point[0],
epsg3857_y=epsg3857_point[1],
transport_mode=StopAreaType(connection_area.zdctype.value),
version=connection_area.zdcversion,
created_ts=int(connection_area.zdccreated.timestamp()),
changed_ts=int(connection_area.zdcchanged.timestamp()),
)
def _format_idfm_stop_shapes(
self, *shape_records: ShapeRecord
) -> Iterable[StopShape]:
for shape_record in shape_records:
epsg3857_points = [
self._epsg2154_epsg3857_transformer.transform(*point)
for point in shape_record.shape.points
]
bbox_it = iter(shape_record.shape.bbox)
epsg3857_bbox = [
self._epsg2154_epsg3857_transformer.transform(*point)
for point in zip(bbox_it, bbox_it)
]
yield StopShape(
id=shape_record.record[1],
type=shape_record.shape.shapeType,
epsg3857_bbox=epsg3857_bbox,
epsg3857_points=epsg3857_points,
)
...
async def render_line_picto(self, line: Line) -> tuple[None | str, None | str]:
begin_ts = time()
line_picto_path = line_picto_format = None
target = f"/tmp/{line.id}_repr"
@@ -531,12 +43,9 @@ class IdfmInterface:
line_picto_path = target
line_picto_format = picto.mime_type
print(f"render_line_picto: {time() - begin_ts}")
return (line_picto_path, line_picto_format)
async def _get_line_picto(self, line: Line) -> ByteString | None:
print("---------------------------------------------------------------------")
begin_ts = time()
data = None
picto = line.picto
@@ -544,25 +53,20 @@ class IdfmInterface:
headers = (
self._http_headers if picto.url.startswith(self.IDFM_ROOT_URL) else None
)
session_begin_ts = time()
async with ClientSession(headers=headers) as session:
session_creation_ts = time()
print(f"Session creation {session_creation_ts - session_begin_ts}")
async with session.get(picto.url) as response:
get_end_ts = time()
print(f"GET {get_end_ts - session_creation_ts}")
data = await response.read()
print(f"read {time() - get_end_ts}")
print(f"render_line_picto: {time() - begin_ts}")
print("---------------------------------------------------------------------")
async with ClientSession(headers=headers) as session:
async with session.get(picto.url) as response:
data = await response.read()
return data
async def get_next_passages(self, stop_point_id: int) -> IdfmResponse | None:
ret = None
params = {"MonitoringRef": f"STIF:StopPoint:Q:{stop_point_id}:"}
async with ClientSession(headers=self._http_headers) as session:
async with session.get(self.IDFM_STOP_MON_URL, params=params) as response:
if response.status == 200:
data = await response.read()
try:
@@ -573,8 +77,6 @@ class IdfmInterface:
return ret
async def get_destinations(self, stop_id: int) -> IdfmDestinations | None:
begin_ts = time()
destinations: IdfmDestinations = defaultdict(set)
if (stop := await Stop.get_by_id(stop_id)) is not None:
@@ -582,7 +84,6 @@ class IdfmInterface:
elif (stop_area := await StopArea.get_by_id(stop_id)) is not None:
expected_stop_ids = {stop.id for stop in stop_area.stops}
else:
return None
@@ -593,6 +94,7 @@ class IdfmInterface:
for stop_visit in delivery.MonitoredStopVisit:
monitoring_ref = stop_visit.MonitoringRef.value
try:
monitored_stop_id = int(monitoring_ref.split(":")[-2])
except (IndexError, ValueError):
@@ -603,9 +105,7 @@ class IdfmInterface:
if (
dst_names := journey.DestinationName
) and monitored_stop_id in expected_stop_ids:
line_id = journey.LineRef.value.split(":")[-2]
destinations[line_id].add(dst_names[0].value)
print(f"get_next_passages: {time() - begin_ts}")
return destinations

View File

@@ -116,19 +116,26 @@ class StopArea(Struct):
record_timestamp: datetime
class ConnectionArea(Struct):
class ConnectionAreaFields(Struct, kw_only=True):
zdcid: str
zdcversion: str
zdccreated: datetime
zdcchanged: datetime
zdcname: str
zdcxepsg2154: int
zdcyepsg2154: int
zdcxepsg2154: int | None = None
zdcyepsg2154: int | None = None
zdctown: str
zdcpostalregion: str
zdctype: StopAreaType
class ConnectionArea(Struct):
datasetid: str
recordid: str
fields: ConnectionAreaFields
record_timestamp: datetime
class StopAreaStopAssociationFields(Struct, kw_only=True):
arrid: str # TODO: use int ?
artid: str | None = None
@@ -149,6 +156,7 @@ class StopAreaStopAssociation(Struct):
class IdfmLineState(Enum):
active = "active"
available_soon = "prochainement active"
class LinePicto(Struct, rename={"id_": "id"}):

View File

@@ -1,6 +1,3 @@
from datetime import datetime
from typing import Optional
from msgspec import Struct
@@ -13,13 +10,6 @@ class PictoFieldsFile(Struct, rename={"id_": "id"}):
format: str
class PictoFields(Struct):
indices_commerciaux: str
noms_des_fichiers: Optional[PictoFieldsFile] = None
class Picto(Struct):
datasetid: str
recordid: str
fields: PictoFields
record_timestamp: datetime
indices_commerciaux: str
noms_des_fichiers: PictoFieldsFile | None = None

View File

@@ -4,7 +4,7 @@ from ..idfm_interface.idfm_types import TrainStatus
class NextPassage(BaseModel):
line: str
line: int
operator: str
destinations: list[str]
atStop: bool
@@ -19,4 +19,4 @@ class NextPassage(BaseModel):
class NextPassages(BaseModel):
ts: int
passages: dict[str, dict[str, list[NextPassage]]]
passages: dict[int, dict[str, list[NextPassage]]]

View File

@@ -1,4 +1,6 @@
from pydantic import BaseModel, BaseSettings, Field, SecretStr
from typing import Any
from pydantic import BaseModel, BaseSettings, Field, root_validator, SecretStr
class HttpSettings(BaseModel):
@@ -7,20 +9,51 @@ class HttpSettings(BaseModel):
cert: str | None = None
def check_user_password(cls, values: dict[str, Any]) -> dict[str, Any]:
user = values.get("user")
password = values.get("password")
if user is not None and password is None:
raise ValueError("user is set, password shall be set too.")
if password is not None and user is None:
raise ValueError("password is set, user shall be set too.")
return values
class DatabaseSettings(BaseModel):
name: str = "carrramba-encore-rate"
host: str = "127.0.0.1"
port: int = 5432
driver: str = "postgresql+psycopg"
user: str = "cer"
user: str | None = None
password: SecretStr | None = None
_user_password_validation = root_validator(allow_reuse=True)(check_user_password)
class CacheSettings(BaseModel):
enable: bool = False
host: str = "127.0.0.1"
port: int = 6379
user: str | None = None
password: SecretStr | None = None
_user_password_validation = root_validator(allow_reuse=True)(check_user_password)
class TracingSettings(BaseModel):
enable: bool = False
class Settings(BaseSettings):
app_name: str
idfm_api_key: SecretStr = Field(..., env="API_KEY")
clear_static_data: bool = Field(..., env="CLEAR_STATIC_DATA")
idfm_api_key: SecretStr = Field(..., env="IDFM_API_KEY")
clear_static_data: bool = Field(False, env="CLEAR_STATIC_DATA")
http: HttpSettings = HttpSettings()
db: DatabaseSettings = DatabaseSettings()
cache: CacheSettings = CacheSettings()
tracing: TracingSettings = TracingSettings()

View File

@@ -1,8 +1,9 @@
app_name: carrramba-encore-rate
clear_static_data: false
http:
host: 0.0.0.0
port: 4443
port: 8080
cert: ./config/cert.pem
db:
@@ -12,3 +13,9 @@ db:
driver: postgresql+psycopg
user: cer
password: cer_password
cache:
enable: true
tracing:
enable: false

View File

@@ -0,0 +1,23 @@
app_name: carrramba-encore-rate
clear_static_data: false
http:
host: 0.0.0.0
port: 8080
# cert: ./config/cert.pem
db:
name: carrramba-encore-rate
host: postgres
port: 5432
driver: postgresql+psycopg
user: cer
password: cer_password
cache:
enable: true
host: redis
# TODO: Add user credentials
tracing:
enable: false

575
backend/db_updater/fill_db.py Executable file
View File

@@ -0,0 +1,575 @@
#!/usr/bin/env python3
from asyncio import run, gather
from logging import getLogger, INFO, Handler as LoggingHandler, NOTSET
from itertools import islice
from time import time
from os import environ
from typing import Callable, Iterable, List, Type
from aiofiles.tempfile import NamedTemporaryFile
from aiohttp import ClientSession
from msgspec import ValidationError
from msgspec.json import Decoder
from pyproj import Transformer
from shapefile import Reader as ShapeFileReader, ShapeRecord # type: ignore
from tqdm import tqdm
from yaml import safe_load
from backend.db import Base, db, Database
from backend.models import ConnectionArea, Line, LinePicto, Stop, StopArea, StopShape
from backend.idfm_interface.idfm_types import (
ConnectionArea as IdfmConnectionArea,
IdfmLineState,
Line as IdfmLine,
LinePicto as IdfmPicto,
IdfmState,
Stop as IdfmStop,
StopArea as IdfmStopArea,
StopAreaStopAssociation,
StopAreaType,
StopLineAsso as IdfmStopLineAsso,
TransportMode,
)
from backend.idfm_interface.ratp_types import Picto as RatpPicto
from backend.settings import Settings
CONFIG_PATH = environ.get("CONFIG_PATH", "./config.sample.yaml")
BATCH_SIZE = 1000
IDFM_ROOT_URL = "https://data.iledefrance-mobilites.fr/explore/dataset"
IDFM_CONNECTION_AREAS_URL = (
f"{IDFM_ROOT_URL}/zones-de-correspondance/download/?format=json"
)
IDFM_LINES_URL = f"{IDFM_ROOT_URL}/referentiel-des-lignes/download/?format=json"
IDFM_PICTO_URL = f"{IDFM_ROOT_URL}/referentiel-des-lignes/files"
IDFM_STOP_AREAS_URL = f"{IDFM_ROOT_URL}/zones-d-arrets/download/?format=json"
IDFM_STOP_SHAPES_URL = "https://eu.ftp.opendatasoft.com/stif/Reflex/REF_ArR.zip"
IDFM_STOP_AREA_SHAPES_URL = "https://eu.ftp.opendatasoft.com/stif/Reflex/REF_ZdA.zip"
IDFM_STOP_STOP_AREAS_ASSOS_URL = f"{IDFM_ROOT_URL}/relations/download/?format=json"
IDFM_STOPS_LINES_ASSOS_URL = f"{IDFM_ROOT_URL}/arrets-lignes/download/?format=json"
IDFM_STOPS_URL = f"{IDFM_ROOT_URL}/arrets/download/?format=json"
RATP_ROOT_URL = "https://data.ratp.fr/api/explore/v2.1/catalog/datasets"
RATP_PICTOS_URL = (
f"{RATP_ROOT_URL}"
"/pictogrammes-des-lignes-de-metro-rer-tramway-bus-et-noctilien/exports/json?lang=fr"
)
# From https://stackoverflow.com/a/38739634
class TqdmLoggingHandler(LoggingHandler):
def __init__(self, level=NOTSET):
super().__init__(level)
def emit(self, record):
try:
msg = self.format(record)
tqdm.write(msg)
self.flush()
except Exception:
self.handleError(record)
logger = getLogger(__name__)
logger.setLevel(INFO)
logger.addHandler(TqdmLoggingHandler())
epsg2154_epsg3857_transformer = Transformer.from_crs(2154, 3857)
json_stops_decoder = Decoder(type=List[IdfmStop])
json_stop_areas_decoder = Decoder(type=List[IdfmStopArea])
json_connection_areas_decoder = Decoder(type=List[IdfmConnectionArea])
json_lines_decoder = Decoder(type=List[IdfmLine])
json_stops_lines_assos_decoder = Decoder(type=List[IdfmStopLineAsso])
json_ratp_pictos_decoder = Decoder(type=List[RatpPicto])
json_stop_area_stop_asso_decoder = Decoder(type=List[StopAreaStopAssociation])
def format_idfm_pictos(*pictos: IdfmPicto) -> Iterable[LinePicto]:
ret = []
for picto in pictos:
ret.append(
LinePicto(
id=picto.id_,
mime_type=picto.mimetype,
height_px=picto.height,
width_px=picto.width,
filename=picto.filename,
url=f"{IDFM_PICTO_URL}/{picto.id_}/download",
thumbnail=picto.thumbnail,
format=picto.format,
)
)
return ret
def format_ratp_pictos(*pictos: RatpPicto) -> Iterable[tuple[str, LinePicto]]:
ret = []
for picto in pictos:
if (fields := picto.noms_des_fichiers) is not None:
ret.append(
(
picto.indices_commerciaux,
LinePicto(
id=fields.id_,
mime_type=f"image/{fields.format.lower()}",
height_px=fields.height,
width_px=fields.width,
filename=fields.filename,
url=f"{RATP_PICTOS_URL}/{fields.id_}/download",
thumbnail=fields.thumbnail,
format=fields.format,
),
)
)
return ret
def format_idfm_lines(*lines: IdfmLine) -> Iterable[Line]:
ret = []
optional_value = IdfmLine.optional_value
for line in lines:
fields = line.fields
line_id = fields.id_line
try:
formatted_line_id = int(line_id[1:] if line_id[0] == "C" else line_id)
except ValueError:
logger.warning("Unable to format %s line id.", line_id)
continue
try:
operator_id = int(fields.operatorref) # type: ignore
except (ValueError, TypeError):
logger.warning("Unable to format %s operator id.", fields.operatorref)
operator_id = 0
ret.append(
Line(
id=formatted_line_id,
short_name=fields.shortname_line,
name=fields.name_line,
status=IdfmLineState(fields.status.value),
transport_mode=TransportMode(fields.transportmode.value),
transport_submode=optional_value(fields.transportsubmode),
network_name=optional_value(fields.networkname),
group_of_lines_id=optional_value(fields.id_groupoflines),
group_of_lines_shortname=optional_value(fields.shortname_groupoflines),
colour_web_hexa=fields.colourweb_hexa,
text_colour_hexa=fields.textcolourprint_hexa,
operator_id=operator_id,
operator_name=optional_value(fields.operatorname),
accessibility=IdfmState(fields.accessibility.value),
visual_signs_available=IdfmState(fields.visualsigns_available.value),
audible_signs_available=IdfmState(fields.audiblesigns_available.value),
picto_id=fields.picto.id_ if fields.picto is not None else None,
record_id=line.recordid,
record_ts=int(line.record_timestamp.timestamp()),
)
)
return ret
def format_idfm_stops(*stops: IdfmStop) -> Iterable[Stop]:
for stop in stops:
fields = stop.fields
try:
created_ts = int(fields.arrcreated.timestamp()) # type: ignore
except AttributeError:
created_ts = None
epsg3857_point = epsg2154_epsg3857_transformer.transform(
fields.arrxepsg2154, fields.arryepsg2154
)
try:
postal_region = int(fields.arrpostalregion)
except ValueError:
logger.warning("Unable to format %s postal region.", fields.arrpostalregion)
continue
yield Stop(
id=int(fields.arrid),
name=fields.arrname,
epsg3857_x=epsg3857_point[0],
epsg3857_y=epsg3857_point[1],
town_name=fields.arrtown,
postal_region=postal_region,
transport_mode=TransportMode(fields.arrtype.value),
version=fields.arrversion,
created_ts=created_ts,
changed_ts=int(fields.arrchanged.timestamp()),
accessibility=IdfmState(fields.arraccessibility.value),
visual_signs_available=IdfmState(fields.arrvisualsigns.value),
audible_signs_available=IdfmState(fields.arraudiblesignals.value),
record_id=stop.recordid,
record_ts=int(stop.record_timestamp.timestamp()),
)
def format_idfm_stop_areas(*stop_areas: IdfmStopArea) -> Iterable[StopArea]:
for stop_area in stop_areas:
fields = stop_area.fields
try:
created_ts = int(fields.zdacreated.timestamp()) # type: ignore
except AttributeError:
created_ts = None
epsg3857_point = epsg2154_epsg3857_transformer.transform(
fields.zdaxepsg2154, fields.zdayepsg2154
)
yield StopArea(
id=int(fields.zdaid),
name=fields.zdaname,
town_name=fields.zdatown,
postal_region=fields.zdapostalregion,
epsg3857_x=epsg3857_point[0],
epsg3857_y=epsg3857_point[1],
type=StopAreaType(fields.zdatype.value),
version=fields.zdaversion,
created_ts=created_ts,
changed_ts=int(fields.zdachanged.timestamp()),
)
def format_idfm_connection_areas(
*connection_areas: IdfmConnectionArea,
) -> Iterable[ConnectionArea]:
for connection_area in connection_areas:
fields = connection_area.fields
epsg3857_point = epsg2154_epsg3857_transformer.transform(
fields.zdcxepsg2154, fields.zdcyepsg2154
)
yield ConnectionArea(
id=int(fields.zdcid),
name=fields.zdcname,
town_name=fields.zdctown,
postal_region=fields.zdcpostalregion,
epsg3857_x=epsg3857_point[0],
epsg3857_y=epsg3857_point[1],
transport_mode=StopAreaType(fields.zdctype.value),
version=fields.zdcversion,
created_ts=int(fields.zdccreated.timestamp()),
changed_ts=int(fields.zdcchanged.timestamp()),
)
def format_idfm_stop_shapes(*shape_records: ShapeRecord) -> Iterable[StopShape]:
for shape_record in shape_records:
epsg3857_points = [
epsg2154_epsg3857_transformer.transform(*point)
for point in shape_record.shape.points
]
try:
bbox_it = iter(shape_record.shape.bbox)
epsg3857_bbox = [
epsg2154_epsg3857_transformer.transform(*point)
for point in zip(bbox_it, bbox_it)
]
except AttributeError:
# Handle stop shapes for which no bbox is provided
epsg3857_bbox = []
yield StopShape(
id=shape_record.record[1],
type=shape_record.shape.shapeType,
epsg3857_bbox=epsg3857_bbox,
epsg3857_points=epsg3857_points,
)
async def http_get(url: str) -> str | None:
chunks = []
headers = {"Accept": "application/json"}
async with ClientSession(headers=headers) as session:
async with session.get(url) as response:
size = int(response.headers.get("content-length", 0)) or None
progress_bar = tqdm(desc=f"Downloading {url}", total=size)
if response.status == 200:
async for chunk in response.content.iter_chunked(1024 * 1024):
chunks.append(chunk.decode())
progress_bar.update(len(chunk))
else:
return None
return "".join(chunks)
async def http_request(
url: str, decode: Callable, format_method: Callable, model: Type[Base]
) -> bool:
elements = []
data = await http_get(url)
if data is None:
return False
try:
for element in decode(data):
elements.append(element)
if len(elements) == BATCH_SIZE:
await model.add(format_method(*elements))
elements.clear()
if elements:
await model.add(format_method(*elements))
except ValidationError as err:
logger.warning(err)
return False
return True
async def load_idfm_stops() -> bool:
return await http_request(
IDFM_STOPS_URL, json_stops_decoder.decode, format_idfm_stops, Stop
)
async def load_idfm_stop_areas() -> bool:
return await http_request(
IDFM_STOP_AREAS_URL,
json_stop_areas_decoder.decode,
format_idfm_stop_areas,
StopArea,
)
async def load_idfm_connection_areas() -> bool:
return await http_request(
IDFM_CONNECTION_AREAS_URL,
json_connection_areas_decoder.decode,
format_idfm_connection_areas,
ConnectionArea,
)
async def load_idfm_stop_shapes(url: str) -> None:
async with ClientSession(headers={"Accept": "application/zip"}) as session:
async with session.get(url) as response:
size = int(response.headers.get("content-length", 0)) or None
dl_progress_bar = tqdm(desc=f"Downloading {url}", total=size)
if response.status == 200:
async with NamedTemporaryFile(suffix=".zip") as tmp_file:
async for chunk in response.content.iter_chunked(1024 * 1024):
await tmp_file.write(chunk)
dl_progress_bar.update(len(chunk))
with ShapeFileReader(tmp_file.name) as reader:
step_begin_ts = time()
shapes = reader.shapeRecords()
shapes_len = len(shapes)
db_progress_bar = tqdm(
desc=f"Filling db with {shapes_len} StopShapes",
total=shapes_len,
)
begin, end, finished = 0, BATCH_SIZE, False
while not finished:
elements = islice(shapes, begin, end)
formatteds = list(format_idfm_stop_shapes(*elements))
await StopShape.add(formatteds)
begin = end
end = begin + BATCH_SIZE
finished = begin > len(shapes)
db_progress_bar.update(BATCH_SIZE)
logger.info(
f"Add {StopShape.__name__}s: {time() - step_begin_ts}s"
)
async def load_idfm_lines() -> None:
data = await http_get(IDFM_LINES_URL)
if data is None:
return None
lines, pictos = [], []
picto_ids = set()
for line in json_lines_decoder.decode(data):
if (picto := line.fields.picto) is not None and picto.id_ not in picto_ids:
picto_ids.add(picto.id_)
pictos.append(picto)
lines.append(line)
if len(lines) == BATCH_SIZE:
await LinePicto.add(list(format_idfm_pictos(*pictos)))
await Line.add(list(format_idfm_lines(*lines)))
lines.clear()
pictos.clear()
if pictos:
await LinePicto.add(list(format_idfm_pictos(*pictos)))
if lines:
await Line.add(list(format_idfm_lines(*lines)))
async def load_ratp_pictos(batch_size: int = 5) -> None:
data = await http_get(RATP_PICTOS_URL)
if data is None:
return None
pictos = []
for picto in json_ratp_pictos_decoder.decode(data):
pictos.append(picto)
if len(pictos) == batch_size:
formatteds = format_ratp_pictos(*pictos)
await LinePicto.add([picto[1] for picto in formatteds])
await Line.add_pictos(formatteds)
pictos.clear()
if pictos:
formatteds = format_ratp_pictos(*pictos)
await LinePicto.add([picto[1] for picto in formatteds])
await Line.add_pictos(formatteds)
async def load_lines_stops_assos(batch_size: int = 5000) -> None:
data = await http_get(IDFM_STOPS_LINES_ASSOS_URL)
if data is None:
return None
total_assos_nb = total_found_nb = 0
assos = []
for asso in json_stops_lines_assos_decoder.decode(data):
fields = asso.fields
try:
stop_id = int(fields.stop_id.rsplit(":", 1)[-1])
except ValueError as err:
logger.error(err)
logger.error(f"{fields.stop_id = }")
continue
assos.append((fields.route_long_name, fields.operatorname, stop_id))
if len(assos) == batch_size:
total_assos_nb += batch_size
total_found_nb += await Line.add_stops(assos)
assos.clear()
if assos:
total_assos_nb += len(assos)
total_found_nb += await Line.add_stops(assos)
logger.info(f"{total_found_nb} line <-> stop ({total_assos_nb = } found)")
async def load_stop_assos(batch_size: int = 5000) -> None:
data = await http_get(IDFM_STOP_STOP_AREAS_ASSOS_URL)
if data is None:
return None
total_assos_nb = area_stop_assos_nb = conn_stop_assos_nb = 0
area_stop_assos = []
connection_stop_assos = []
for asso in json_stop_area_stop_asso_decoder.decode(data):
fields = asso.fields
stop_id = int(fields.arrid)
area_stop_assos.append((int(fields.zdaid), stop_id))
connection_stop_assos.append((int(fields.zdcid), stop_id))
if len(area_stop_assos) == batch_size:
total_assos_nb += batch_size
if (found_nb := await StopArea.add_stops(area_stop_assos)) is not None:
area_stop_assos_nb += found_nb
area_stop_assos.clear()
if (
found_nb := await ConnectionArea.add_stops(connection_stop_assos)
) is not None:
conn_stop_assos_nb += found_nb
connection_stop_assos.clear()
if area_stop_assos:
total_assos_nb += len(area_stop_assos)
if (found_nb := await StopArea.add_stops(area_stop_assos)) is not None:
area_stop_assos_nb += found_nb
if (
found_nb := await ConnectionArea.add_stops(connection_stop_assos)
) is not None:
conn_stop_assos_nb += found_nb
logger.info(f"{area_stop_assos_nb} stop area <-> stop ({total_assos_nb = } found)")
logger.info(f"{conn_stop_assos_nb} stop area <-> stop ({total_assos_nb = } found)")
async def prepare(db: Database) -> None:
await load_idfm_lines()
await gather(
*(
load_idfm_stops(),
load_idfm_stop_areas(),
load_idfm_connection_areas(),
load_ratp_pictos(),
)
)
await gather(
*(
load_idfm_stop_shapes(IDFM_STOP_SHAPES_URL),
load_idfm_stop_shapes(IDFM_STOP_AREA_SHAPES_URL),
load_lines_stops_assos(),
load_stop_assos(),
)
)
def load_settings(path: str) -> Settings:
with open(path, "r") as config_file:
config = safe_load(config_file)
return Settings(**config)
async def main() -> None:
settings = load_settings(CONFIG_PATH)
await db.connect(settings.db, True)
begin_ts = time()
await prepare(db)
logger.info(f"Elapsed time: {time() - begin_ts}s")
await db.disconnect()
if __name__ == "__main__":
run(main())

38
backend/dependencies.py Normal file
View File

@@ -0,0 +1,38 @@
from os import environ
from fastapi_cache.backends.redis import RedisBackend
from redis import asyncio as aioredis
from yaml import safe_load
from backend.db import db
from backend.idfm_interface.idfm_interface import IdfmInterface
from backend.settings import CacheSettings, Settings
CONFIG_PATH = environ.get("CONFIG_PATH", "./config.sample.yaml")
def load_settings(path: str) -> Settings:
with open(path, "r") as config_file:
config = safe_load(config_file)
return Settings(**config)
settings = load_settings(CONFIG_PATH)
idfm_interface = IdfmInterface(settings.idfm_api_key.get_secret_value(), db)
def init_redis_backend(settings: CacheSettings) -> RedisBackend:
login = f"{settings.user}:{settings.password}@" if settings.user is not None else ""
url = f"redis://{login}{settings.host}:{settings.port}"
redis_connections_pool = aioredis.from_url(
url, encoding="utf8", decode_responses=True
)
return RedisBackend(redis_connections_pool)
redis_backend = init_redis_backend(settings.cache)

View File

@@ -1,54 +1,27 @@
#!/usr/bin/env python3
import logging
from collections import defaultdict
from datetime import datetime
from os import environ, EX_USAGE
from typing import Sequence
import uvicorn
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi_cache import FastAPICache
from opentelemetry import trace
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
from opentelemetry.sdk.resources import Resource, SERVICE_NAME
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from rich import print
from yaml import safe_load
from backend.db import db
from backend.idfm_interface import Destinations as IdfmDestinations, IdfmInterface
from backend.models import Line, Stop, StopArea, StopShape
from backend.schemas import (
Line as LineSchema,
TransportMode,
NextPassage as NextPassageSchema,
NextPassages as NextPassagesSchema,
Stop as StopSchema,
StopArea as StopAreaSchema,
StopShape as StopShapeSchema,
)
from backend.settings import Settings
CONFIG_PATH = environ.get("CONFIG_PATH", "./config.sample.yaml")
def load_settings(path: str) -> Settings:
with open(path, "r") as config_file:
config = safe_load(config_file)
return Settings(**config)
settings = load_settings(CONFIG_PATH)
from dependencies import idfm_interface, redis_backend, settings
from routers import line, stop
@asynccontextmanager
async def lifespan(app: FastAPI):
FastAPICache.init(redis_backend, prefix="api", enable=settings.cache.enable)
await db.connect(settings.db, settings.clear_static_data)
if settings.clear_static_data:
await idfm_interface.startup()
@@ -70,208 +43,51 @@ app.add_middleware(
app.mount("/widget", StaticFiles(directory="../frontend/", html=True), name="widget")
FastAPIInstrumentor.instrument_app(app)
# The cache-control header entry is not managed properly by fastapi-cache:
# For now, a request with a cache-control set to no-cache
# is interpreted as disabling the use of the server cache.
# Cf. Improve Cache-Control header parsing and handling
# https://github.com/long2ice/fastapi-cache/issues/144 workaround
@app.middleware("http")
async def fastapi_cache_issue_144_workaround(request: Request, call_next):
entries = request.headers.__dict__["_list"]
new_entries = [
entry for entry in entries if entry[0].decode().lower() != "cache-control"
]
trace.set_tracer_provider(
TracerProvider(resource=Resource.create({SERVICE_NAME: settings.app_name}))
)
trace.get_tracer_provider().add_span_processor(BatchSpanProcessor(OTLPSpanExporter()))
tracer = trace.get_tracer(settings.app_name)
request.headers.__dict__["_list"] = new_entries
idfm_interface = IdfmInterface(settings.idfm_api_key.get_secret_value(), db)
return await call_next(request)
def optional_datetime_to_ts(dt: datetime | None) -> int | None:
return int(dt.timestamp()) if dt else None
app.include_router(line.router)
app.include_router(stop.router)
@app.get("/line/{line_id}", response_model=LineSchema)
async def get_line(line_id: int) -> LineSchema:
line: Line | None = await Line.get_by_id(line_id)
if settings.tracing.enable:
FastAPIInstrumentor.instrument_app(app)
if line is None:
raise HTTPException(status_code=404, detail=f'Line "{line_id}" not found')
return LineSchema(
id=line.id,
shortName=line.short_name,
name=line.name,
status=line.status,
transportMode=TransportMode.from_idfm_transport_mode(
line.transport_mode, line.transport_submode
),
backColorHexa=line.colour_web_hexa,
foreColorHexa=line.text_colour_hexa,
operatorId=line.operator_id,
accessibility=line.accessibility,
visualSignsAvailable=line.visual_signs_available,
audibleSignsAvailable=line.audible_signs_available,
stopIds=[stop.id for stop in line.stops],
trace.set_tracer_provider(
TracerProvider(resource=Resource.create({SERVICE_NAME: settings.app_name}))
)
def _format_stop(stop: Stop) -> StopSchema:
return StopSchema(
id=stop.id,
name=stop.name,
town=stop.town_name,
epsg3857_x=stop.epsg3857_x,
epsg3857_y=stop.epsg3857_y,
lines=[line.id for line in stop.lines],
trace.get_tracer_provider().add_span_processor(
BatchSpanProcessor(OTLPSpanExporter())
)
@app.get("/stop/")
async def get_stop(
name: str = "", limit: int = 10
) -> Sequence[StopAreaSchema | StopSchema]:
# TODO: Add limit support
formatted: list[StopAreaSchema | StopSchema] = []
matching_stops = await Stop.get_by_name(name)
# print(matching_stops, flush=True)
stop_areas: dict[int, StopArea] = {}
stops: dict[int, Stop] = {}
for stop in matching_stops:
# print(f"{stop.__dict__ = }", flush=True)
dst = stop_areas if isinstance(stop, StopArea) else stops
dst[stop.id] = stop
for stop_area in stop_areas.values():
formatted_stops = []
for stop in stop_area.stops:
formatted_stops.append(_format_stop(stop))
try:
del stops[stop.id]
except KeyError as err:
print(err)
formatted.append(
StopAreaSchema(
id=stop_area.id,
name=stop_area.name,
town=stop_area.town_name,
type=stop_area.type,
lines=[line.id for line in stop_area.lines],
stops=formatted_stops,
)
)
formatted.extend(_format_stop(stop) for stop in stops.values())
return formatted
# TODO: Cache response for 30 secs ?
@app.get("/stop/{stop_id}/nextPassages")
async def get_next_passages(stop_id: int) -> NextPassagesSchema | None:
res = await idfm_interface.get_next_passages(stop_id)
if res is None:
return None
service_delivery = res.Siri.ServiceDelivery
stop_monitoring_deliveries = service_delivery.StopMonitoringDelivery
by_line_by_dst_passages: dict[
str, dict[str, list[NextPassageSchema]]
] = defaultdict(lambda: defaultdict(list))
for delivery in stop_monitoring_deliveries:
for stop_visit in delivery.MonitoredStopVisit:
journey = stop_visit.MonitoredVehicleJourney
# re.match will return None if the given journey.LineRef.value is not valid.
try:
line_id_match = IdfmInterface.LINE_RE.match(journey.LineRef.value)
line_id = int(line_id_match.group(1)) # type: ignore
except (AttributeError, TypeError, ValueError) as err:
raise HTTPException(
status_code=404, detail=f'Line "{journey.LineRef.value}" not found'
) from err
call = journey.MonitoredCall
dst_names = call.DestinationDisplay
dsts = [dst.value for dst in dst_names] if dst_names else []
arrivalPlatformName = (
call.ArrivalPlatformName.value if call.ArrivalPlatformName else None
)
next_passage = NextPassageSchema(
line=line_id,
operator=journey.OperatorRef.value,
destinations=dsts,
atStop=call.VehicleAtStop,
aimedArrivalTs=optional_datetime_to_ts(call.AimedArrivalTime),
expectedArrivalTs=optional_datetime_to_ts(call.ExpectedArrivalTime),
arrivalPlatformName=arrivalPlatformName,
aimedDepartTs=optional_datetime_to_ts(call.AimedDepartureTime),
expectedDepartTs=optional_datetime_to_ts(call.ExpectedDepartureTime),
arrivalStatus=call.ArrivalStatus.value,
departStatus=call.DepartureStatus.value,
)
by_line_passages = by_line_by_dst_passages[line_id]
# TODO: by_line_passages[dst].extend(dsts) instead ?
for dst in dsts:
by_line_passages[dst].append(next_passage)
return NextPassagesSchema(
ts=service_delivery.ResponseTimestamp.timestamp(),
passages=by_line_by_dst_passages,
)
@app.get("/stop/{stop_id}/destinations")
async def get_stop_destinations(
stop_id: int,
) -> IdfmDestinations | None:
destinations = await idfm_interface.get_destinations(stop_id)
return destinations
@app.get("/stop/{stop_id}/shape")
async def get_stop_shape(stop_id: int) -> StopShapeSchema | None:
connection_area = None
if (stop := await Stop.get_by_id(stop_id)) is not None:
connection_area = stop.connection_area
elif (stop_area := await StopArea.get_by_id(stop_id)) is not None:
connection_areas = {stop.connection_area for stop in stop_area.stops}
connection_areas_len = len(connection_areas)
if connection_areas_len == 1:
connection_area = connection_areas.pop()
else:
prefix = "More than one" if connection_areas_len else "No"
msg = f"{prefix} connection area has been found for stop area #{stop_id}"
raise HTTPException(status_code=500, detail=msg)
if (
connection_area is not None
and (shape := await StopShape.get_by_id(connection_area.id)) is not None
):
return StopShapeSchema(
id=shape.id,
type=shape.type,
epsg3857_bbox=shape.epsg3857_bbox,
epsg3857_points=shape.epsg3857_points,
)
msg = f"No shape found for stop {stop_id}"
raise HTTPException(status_code=404, detail=msg)
tracer = trace.get_tracer(settings.app_name)
if __name__ == "__main__":
http_settings = settings.http
uvicorn.run(
app,
config = uvicorn.Config(
app=app,
host=http_settings.host,
port=http_settings.port,
ssl_certfile=http_settings.cert,
proxy_headers=True,
)
server = uvicorn.Server(config)
server.run()

View File

@@ -4,18 +4,14 @@ version = "0.1.0"
description = ""
authors = ["Adrien SUEUR <me@adrien.run>"]
readme = "README.md"
packages = [{include = "backend"}]
[tool.poetry.dependencies]
python = "^3.11"
aiohttp = "^3.8.3"
rich = "^12.6.0"
aiofiles = "^22.1.0"
fastapi = "^0.95.0"
uvicorn = "^0.20.0"
msgspec = "^0.12.0"
pyshp = "^2.3.1"
pyproj = "^3.5.0"
opentelemetry-instrumentation-fastapi = "^0.38b0"
sqlalchemy-utils = "^0.41.1"
opentelemetry-instrumentation-logging = "^0.38b0"
@@ -26,6 +22,25 @@ opentelemetry-instrumentation-sqlalchemy = "^0.38b0"
sqlalchemy = "^2.0.12"
psycopg = "^3.1.9"
pyyaml = "^6.0"
fastapi-cache2 = {extras = ["redis"], version = "^0.2.1"}
[tool.poetry.group.db_updater.dependencies]
aiofiles = "^22.1.0"
aiohttp = "^3.8.3"
fastapi = "^0.95.0"
msgspec = "^0.12.0"
opentelemetry-instrumentation-fastapi = "^0.38b0"
opentelemetry-instrumentation-sqlalchemy = "^0.38b0"
opentelemetry-sdk = "^1.17.0"
opentelemetry-api = "^1.17.0"
psycopg = "^3.1.9"
pyproj = "^3.5.0"
pyshp = "^2.3.1"
python = "^3.11"
pyyaml = "^6.0"
sqlalchemy = "^2.0.12"
sqlalchemy-utils = "^0.41.1"
tqdm = "^4.65.0"
[build-system]
requires = ["poetry-core"]
@@ -48,8 +63,10 @@ pyflakes = "^3.0.1"
yapf = "^0.32.0"
whatthepatch = "^1.0.4"
mypy = "^1.0.0"
icecream = "^2.1.3"
types-sqlalchemy-utils = "^1.0.1"
types-pyyaml = "^6.0.12.9"
types-tqdm = "^4.65.0.1"
[tool.mypy]
plugins = "sqlalchemy.ext.mypy.plugin"

View File

34
backend/routers/line.py Normal file
View File

@@ -0,0 +1,34 @@
from fastapi import APIRouter, HTTPException
from fastapi_cache.decorator import cache
from backend.models import Line
from backend.schemas import Line as LineSchema, TransportMode
router = APIRouter(prefix="/line", tags=["line"])
@router.get("/{line_id}", response_model=LineSchema)
@cache(namespace="line")
async def get_line(line_id: int) -> LineSchema:
line: Line | None = await Line.get_by_id(line_id)
if line is None:
raise HTTPException(status_code=404, detail=f'Line "{line_id}" not found')
return LineSchema(
id=line.id,
shortName=line.short_name,
name=line.name,
status=line.status,
transportMode=TransportMode.from_idfm_transport_mode(
line.transport_mode, line.transport_submode
),
backColorHexa=line.colour_web_hexa,
foreColorHexa=line.text_colour_hexa,
operatorId=line.operator_id,
accessibility=line.accessibility,
visualSignsAvailable=line.visual_signs_available,
audibleSignsAvailable=line.audible_signs_available,
stopIds=[stop.id for stop in line.stops],
)

176
backend/routers/stop.py Normal file
View File

@@ -0,0 +1,176 @@
from collections import defaultdict
from datetime import datetime
from typing import Sequence
from fastapi import APIRouter, HTTPException
from fastapi_cache.decorator import cache
from backend.idfm_interface import Destinations as IdfmDestinations, TrainStatus
from backend.models import Stop, StopArea, StopShape
from backend.schemas import (
NextPassage as NextPassageSchema,
NextPassages as NextPassagesSchema,
Stop as StopSchema,
StopArea as StopAreaSchema,
StopShape as StopShapeSchema,
)
from dependencies import idfm_interface
router = APIRouter(prefix="/stop", tags=["stop"])
def _format_stop(stop: Stop) -> StopSchema:
return StopSchema(
id=stop.id,
name=stop.name,
town=stop.town_name,
epsg3857_x=stop.epsg3857_x,
epsg3857_y=stop.epsg3857_y,
lines=[line.id for line in stop.lines],
)
def optional_datetime_to_ts(dt: datetime | None) -> int | None:
return int(dt.timestamp()) if dt else None
# TODO: Add limit support
@router.get("/")
@cache(namespace="stop")
async def get_stop(
name: str = "", limit: int = 10
) -> Sequence[StopAreaSchema | StopSchema] | None:
matching_stops = await Stop.get_by_name(name)
if matching_stops is None:
return None
formatted: list[StopAreaSchema | StopSchema] = []
stop_areas: dict[int, StopArea] = {}
stops: dict[int, Stop] = {}
for stop in matching_stops:
if isinstance(stop, StopArea):
stop_areas[stop.id] = stop
elif isinstance(stop, Stop):
stops[stop.id] = stop
for stop_area in stop_areas.values():
formatted_stops = []
for stop in stop_area.stops:
formatted_stops.append(_format_stop(stop))
try:
del stops[stop.id]
except KeyError as err:
print(err)
formatted.append(
StopAreaSchema(
id=stop_area.id,
name=stop_area.name,
town=stop_area.town_name,
type=stop_area.type,
lines=[line.id for line in stop_area.lines],
stops=formatted_stops,
)
)
formatted.extend(_format_stop(stop) for stop in stops.values())
return formatted
@router.get("/{stop_id}/nextPassages")
@cache(namespace="stop-nextPassages", expire=30)
async def get_next_passages(stop_id: int) -> NextPassagesSchema | None:
res = await idfm_interface.get_next_passages(stop_id)
if res is None:
return None
service_delivery = res.Siri.ServiceDelivery
stop_monitoring_deliveries = service_delivery.StopMonitoringDelivery
by_line_by_dst_passages: dict[
int, dict[str, list[NextPassageSchema]]
] = defaultdict(lambda: defaultdict(list))
for delivery in stop_monitoring_deliveries:
for stop_visit in delivery.MonitoredStopVisit:
journey = stop_visit.MonitoredVehicleJourney
# re.match will return None if the given journey.LineRef.value is not valid.
try:
line_id_match = idfm_interface.LINE_RE.match(journey.LineRef.value)
line_id = int(line_id_match.group(1)) # type: ignore
except (AttributeError, TypeError, ValueError) as err:
raise HTTPException(
status_code=404, detail=f'Line "{journey.LineRef.value}" not found'
) from err
call = journey.MonitoredCall
dst_names = call.DestinationDisplay
dsts = [dst.value for dst in dst_names] if dst_names else []
arrivalPlatformName = (
call.ArrivalPlatformName.value if call.ArrivalPlatformName else None
)
next_passage = NextPassageSchema(
line=line_id,
operator=journey.OperatorRef.value,
destinations=dsts,
atStop=call.VehicleAtStop,
aimedArrivalTs=optional_datetime_to_ts(call.AimedArrivalTime),
expectedArrivalTs=optional_datetime_to_ts(call.ExpectedArrivalTime),
arrivalPlatformName=arrivalPlatformName,
aimedDepartTs=optional_datetime_to_ts(call.AimedDepartureTime),
expectedDepartTs=optional_datetime_to_ts(call.ExpectedDepartureTime),
arrivalStatus=call.ArrivalStatus
if call.ArrivalStatus is not None
else TrainStatus.unknown,
departStatus=call.DepartureStatus
if call.DepartureStatus is not None
else TrainStatus.unknown,
)
by_line_passages = by_line_by_dst_passages[line_id]
# TODO: by_line_passages[dst].extend(dsts) instead ?
for dst in dsts:
by_line_passages[dst].append(next_passage)
return NextPassagesSchema(
ts=int(service_delivery.ResponseTimestamp.timestamp()),
passages=by_line_by_dst_passages,
)
@router.get("/{stop_id}/destinations")
@cache(namespace="stop-destinations", expire=30)
async def get_stop_destinations(
stop_id: int,
) -> IdfmDestinations | None:
destinations = await idfm_interface.get_destinations(stop_id)
return destinations
@router.get("/{stop_id}/shape")
@cache(namespace="stop-shape")
async def get_stop_shape(stop_id: int) -> StopShapeSchema | None:
if (await Stop.get_by_id(stop_id)) is not None or (
await StopArea.get_by_id(stop_id)
) is not None:
shape_id = stop_id
if (shape := await StopShape.get_by_id(shape_id)) is not None:
return StopShapeSchema(
id=shape.id,
type=shape.type,
epsg3857_bbox=shape.epsg3857_bbox,
epsg3857_points=shape.epsg3857_points,
)
msg = f"No shape found for stop {stop_id}"
raise HTTPException(status_code=404, detail=msg)

View File

@@ -15,8 +15,15 @@ services:
ports:
- "127.0.0.1:5432:5432"
volumes:
- ./docker/database/docker-entrypoint-initdb.d:/docker-entrypoint-initdb.d
- ./docker/database/data:/var/lib/postgresql/data
- ./backend/docker/database/docker-entrypoint-initdb.d:/docker-entrypoint-initdb.d
- ./backend/docker/database/data:/var/lib/postgresql/data
redis:
image: redis:latest
restart: always
command: redis-server --loglevel warning
ports:
- "127.0.0.1:6379:6379"
jaeger-agent:
image: jaegertracing/jaeger-agent:latest
@@ -45,10 +52,6 @@ services:
ports:
- "127.0.0.1:4317:4317"
- "127.0.0.1:4318:4318"
# - "127.0.0.1:9411:9411"
# - "127.0.0.1:14250:14250"
# - "127.0.0.1:14268:14268"
# - "127.0.0.1:14269:14269"
restart: on-failure
depends_on:
- cassandra-schema
@@ -68,7 +71,29 @@ services:
- "--cassandra.servers=cassandra"
ports:
- "127.0.0.1:16686:16686"
# - "127.0.0.1:16687:16687"
restart: on-failure
depends_on:
- cassandra-schema
carrramba-encore-rate-api:
build:
context: ./backend/
dockerfile: Dockerfile.backend
environment:
- CONFIG_PATH=./config.local.yaml
- IDFM_API_KEY=set_your_idfm_key_here
ports:
- "127.0.0.1:8080:8080"
carrramba-encore-rate-frontend:
build:
context: ./frontend/
ports:
- "127.0.0.1:80:8081"
carrramba-encore-rate-db-updater:
build:
context: ./backend/
dockerfile: Dockerfile.db_updater
environment:
- CONFIG_PATH=./config.local.yaml

4
frontend/Dockerfile Normal file
View File

@@ -0,0 +1,4 @@
# pull the latest official nginx image
FROM nginx:mainline-alpine-slim
COPY dist /usr/share/nginx/html