23 Commits

Author SHA1 Message Date
1bb75b28eb ♻️ Use of relative imports for api modules 2023-10-22 23:34:58 +02:00
0a7337a313 ♻️ Put api_server and db_updater scripts on the backend root 2023-10-22 23:31:35 +02:00
3434802b31 🎨 Reorganize back-end code 2023-09-20 22:08:32 +02:00
bdbc72ab39 🐛 Front: Fix URL used to fetch transport mode representation 2023-09-10 12:25:38 +02:00
4cc8f60076 🐛 Front: Use the public API server to fetch data 2023-09-10 12:17:48 +02:00
cf5c4c6224 🔒️ Fix CORS allowed origins and methods 2023-09-10 12:07:20 +02:00
f69aee1c9c 🔒️ Remove driver and password from configuration file
Password will be provided by vault using an env variable.
2023-09-10 12:04:25 +02:00
8c493f8fab ♻️ Remove pg_trgm creation from the db session init
The pg_trgm extension will be created during db init, by the db-updated image.
2023-09-10 11:46:24 +02:00
4fce832db5 ♻️ Rename docker file building api image 2023-09-10 11:45:08 +02:00
bfc669cd11 ♻️ Use pydantic-settings to handle config file 2023-09-09 23:35:18 +02:00
4056b3a739 🐛 Error raised by frontend Map component if no stop found 2023-09-09 23:18:03 +02:00
f7f0fdb980 ️ Use of integer to store Line and Stop id
Update Line and Stop schemas.
2023-09-09 23:05:18 +02:00
6c149e844b 💥 Remove /widget static endpoint
This endpoint shall be served by a dedicated static HTTP server.
2023-06-13 05:45:33 +02:00
f5529bba24 Merge branch 'remove-db-filling-from-backend' into develop 2023-06-13 05:44:00 +02:00
5da918c04b 👽️ Take the last IDFM format into account 2023-06-11 22:41:44 +02:00
2eaf0f4ed5 Use of db merge when adds fails due to single key violations 2023-06-11 22:28:15 +02:00
c42b687870 🐛 Fix IdfmInterface circular import issue 2023-06-11 22:24:09 +02:00
d8adb4f52d ♻️ Remove code in charge or db filling from IdfmInterface 2023-06-11 22:22:05 +02:00
5e7f440b54 ♻️ Add the db_updater package 2023-06-11 22:18:47 +02:00
824536ddbe 💥 Rename API_KEY to IDFM_API_KEY 2023-05-28 12:45:03 +02:00
7fbdd0606c ️ Reduce the size of the backend docker image 2023-05-28 12:40:10 +02:00
581f6b7b8f 🐛 Add workaround for fastapi-cache issue #144 2023-05-28 10:45:14 +02:00
404b228cbf 🔥 Remove env variables from backend dockerfile 2023-05-26 23:55:58 +02:00
38 changed files with 998 additions and 815 deletions

View File

@@ -1,7 +1,12 @@
docker .dir-locals.el
**/__pycache__ .dockerignore
poetry.lock .gitignore
Dockerfile **/.mypy_cache
docker-compose.yml **/.ruff_cache
config
.venv .venv
**/__pycache__
config
docker
poetry.lock
tests
Dockerfile

View File

@@ -1,26 +0,0 @@
FROM python:3.11-slim as builder
WORKDIR /app
COPY . /app
RUN apt update && apt install -y proj-bin
RUN pip install --upgrade poetry && \
poetry config virtualenvs.create false && \
poetry install --only=main && \
poetry export -f requirements.txt >> requirements.txt
FROM python:3.11-slim as runtime
COPY . /app
COPY --from=builder /app/requirements.txt /app
RUN apt update && apt install -y postgresql libpq5
RUN pip install --no-cache-dir -r /app/requirements.txt
WORKDIR /app
ENV CONFIG_PATH=./config.sample.yaml
ENV API_KEY=MwP7lbljnXIYAnmmmPRzasHsIknaiKqD
CMD ["python", "./main.py"]

View File

@@ -0,0 +1,36 @@
FROM python:3.11-slim as builder
RUN pip install poetry
ENV POETRY_NO_INTERACTION=1 \
POETRY_VIRTUALENVS_IN_PROJECT=1 \
POETRY_VIRTUALENVS_CREATE=1 \
POETRY_CACHE_DIR=/tmp/poetry_cache
WORKDIR /app
COPY pyproject.toml /app
RUN poetry install --only=main --no-root && \
rm -rf ${POETRY_CACHE_DIR}
FROM python:3.11-slim as runtime
WORKDIR /app
RUN apt update && \
apt install -y --no-install-recommends libpq5 && \
apt clean && \
rm -rf /var/lib/apt/lists/*
env VIRTUAL_ENV=/app/.venv \
PATH="/app/.venv/bin:$PATH"
COPY --from=builder ${VIRTUAL_ENV} ${VIRTUAL_ENV}
COPY api /app/api
COPY config.sample.yaml .
COPY api_server.py .
CMD ["./api_server.py"]

View File

@@ -0,0 +1,41 @@
FROM python:3.11-slim as builder
RUN apt update && \
apt install -y --no-install-recommends proj-bin && \
apt clean && \
rm -rf /var/lib/apt/lists/*
RUN pip install poetry
ENV POETRY_NO_INTERACTION=1 \
POETRY_VIRTUALENVS_IN_PROJECT=1 \
POETRY_VIRTUALENVS_CREATE=1 \
POETRY_CACHE_DIR=/tmp/poetry_cache
WORKDIR /app
COPY ./pyproject.toml /app
RUN poetry install --only=db_updater --no-root && \
rm -rf ${POETRY_CACHE_DIR}
FROM python:3.11-slim as runtime
WORKDIR /app
RUN apt update && \
apt install -y --no-install-recommends libpq5 && \
apt clean && \
rm -rf /var/lib/apt/lists/*
env VIRTUAL_ENV=/app/.venv \
PATH="/app/.venv/bin:$PATH"
COPY --from=builder ${VIRTUAL_ENV} ${VIRTUAL_ENV}
COPY api /app/api
COPY config.sample.yaml .
COPY db_updater.py .
CMD ["./db_updater.py"]

View File

@@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from logging import getLogger from logging import getLogger
from typing import Iterable, Self, TYPE_CHECKING from typing import Self, Sequence, TYPE_CHECKING
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
@@ -17,21 +17,34 @@ class Base(DeclarativeBase):
db: Database | None = None db: Database | None = None
@classmethod @classmethod
async def add(cls, objs: Self | Iterable[Self]) -> bool: async def add(cls, objs: Sequence[Self]) -> bool:
if cls.db is not None and (session := await cls.db.get_session()) is not None:
try:
async with session.begin():
session.add_all(objs)
except IntegrityError as err:
logger.warning(err)
return await cls.merge(objs)
except AttributeError as err:
logger.error(err)
return False
return True
@classmethod
async def merge(cls, objs: Sequence[Self]) -> bool:
if cls.db is not None and (session := await cls.db.get_session()) is not None: if cls.db is not None and (session := await cls.db.get_session()) is not None:
async with session.begin(): async with session.begin():
try: for obj in objs:
if isinstance(objs, Iterable): await session.merge(obj)
session.add_all(objs)
else:
session.add(objs)
except (AttributeError, IntegrityError) as err: return True
logger.error(err)
return False
return True return False
@classmethod @classmethod
async def get_by_id(cls, id_: int | str) -> Self | None: async def get_by_id(cls, id_: int | str) -> Self | None:

View File

@@ -61,9 +61,6 @@ class Database:
while not ret: while not ret:
try: try:
async with self._async_engine.begin() as session: async with self._async_engine.begin() as session:
await session.execute(
text("CREATE EXTENSION IF NOT EXISTS pg_trgm;")
)
if clear_static_data: if clear_static_data:
await session.run_sync(Base.metadata.drop_all) await session.run_sync(Base.metadata.drop_all)
await session.run_sync(Base.metadata.create_all) await session.run_sync(Base.metadata.create_all)

View File

@@ -4,9 +4,9 @@ from fastapi_cache.backends.redis import RedisBackend
from redis import asyncio as aioredis from redis import asyncio as aioredis
from yaml import safe_load from yaml import safe_load
from backend.db import db from .db import db
from backend.idfm_interface import IdfmInterface from .idfm_interface.idfm_interface import IdfmInterface
from backend.settings import CacheSettings, Settings from .settings import CacheSettings, Settings
CONFIG_PATH = environ.get("CONFIG_PATH", "./config.sample.yaml") CONFIG_PATH = environ.get("CONFIG_PATH", "./config.sample.yaml")

View File

@@ -1,5 +1,3 @@
from .idfm_interface import IdfmInterface
from .idfm_types import ( from .idfm_types import (
Coordinate, Coordinate,
Destinations, Destinations,
@@ -38,7 +36,6 @@ __all__ = [
"Coordinate", "Coordinate",
"Destinations", "Destinations",
"FramedVehicleJourney", "FramedVehicleJourney",
"IdfmInterface",
"IdfmLineState", "IdfmLineState",
"IdfmOperator", "IdfmOperator",
"IdfmResponse", "IdfmResponse",

View File

@@ -0,0 +1,115 @@
from collections import defaultdict
from re import compile as re_compile
from typing import ByteString
from aiofiles import open as async_open
from aiohttp import ClientSession
from msgspec import ValidationError
from msgspec.json import Decoder
from .idfm_types import Destinations as IdfmDestinations, IdfmResponse, IdfmState
from ..db import Database
from ..models import Line, Stop, StopArea
class IdfmInterface:
IDFM_ROOT_URL = "https://prim.iledefrance-mobilites.fr/marketplace"
IDFM_STOP_MON_URL = f"{IDFM_ROOT_URL}/stop-monitoring"
OPERATOR_RE = re_compile(r"[^:]+:Operator::([^:]+):")
LINE_RE = re_compile(r"[^:]+:Line::C([^:]+):")
def __init__(self, api_key: str, database: Database) -> None:
self._api_key = api_key
self._database = database
self._http_headers = {"Accept": "application/json", "apikey": self._api_key}
self._response_json_decoder = Decoder(type=IdfmResponse)
async def startup(self) -> None:
...
@staticmethod
def _format_line_id(line_id: str) -> int:
return int(line_id[1:] if line_id[0] == "C" else line_id)
async def render_line_picto(self, line: Line) -> tuple[None | str, None | str]:
line_picto_path = line_picto_format = None
target = f"/tmp/{line.id}_repr"
picto = line.picto
if picto is not None:
if (picto_data := await self._get_line_picto(line)) is not None:
async with async_open(target, "wb") as fd:
await fd.write(bytes(picto_data))
line_picto_path = target
line_picto_format = picto.mime_type
return (line_picto_path, line_picto_format)
async def _get_line_picto(self, line: Line) -> ByteString | None:
data = None
picto = line.picto
if picto is not None and picto.url is not None:
headers = (
self._http_headers if picto.url.startswith(self.IDFM_ROOT_URL) else None
)
async with ClientSession(headers=headers) as session:
async with session.get(picto.url) as response:
data = await response.read()
return data
async def get_next_passages(self, stop_point_id: int) -> IdfmResponse | None:
ret = None
params = {"MonitoringRef": f"STIF:StopPoint:Q:{stop_point_id}:"}
async with ClientSession(headers=self._http_headers) as session:
async with session.get(self.IDFM_STOP_MON_URL, params=params) as response:
if response.status == 200:
data = await response.read()
try:
ret = self._response_json_decoder.decode(data)
except ValidationError as err:
print(err)
return ret
async def get_destinations(self, stop_id: int) -> IdfmDestinations | None:
destinations: IdfmDestinations = defaultdict(set)
if (stop := await Stop.get_by_id(stop_id)) is not None:
expected_stop_ids = {stop.id}
elif (stop_area := await StopArea.get_by_id(stop_id)) is not None:
expected_stop_ids = {stop.id for stop in stop_area.stops}
else:
return None
if (res := await self.get_next_passages(stop_id)) is not None:
for delivery in res.Siri.ServiceDelivery.StopMonitoringDelivery:
if delivery.Status == IdfmState.true:
for stop_visit in delivery.MonitoredStopVisit:
monitoring_ref = stop_visit.MonitoringRef.value
try:
monitored_stop_id = int(monitoring_ref.split(":")[-2])
except (IndexError, ValueError):
print(f"Unable to get stop id from {monitoring_ref}")
continue
journey = stop_visit.MonitoredVehicleJourney
if (
dst_names := journey.DestinationName
) and monitored_stop_id in expected_stop_ids:
raw_line_id = journey.LineRef.value.split(":")[-2]
line_id = IdfmInterface._format_line_id(raw_line_id)
destinations[line_id].add(dst_names[0].value)
return destinations

View File

@@ -116,19 +116,26 @@ class StopArea(Struct):
record_timestamp: datetime record_timestamp: datetime
class ConnectionArea(Struct): class ConnectionAreaFields(Struct, kw_only=True):
zdcid: str zdcid: str
zdcversion: str zdcversion: str
zdccreated: datetime zdccreated: datetime
zdcchanged: datetime zdcchanged: datetime
zdcname: str zdcname: str
zdcxepsg2154: int zdcxepsg2154: int | None = None
zdcyepsg2154: int zdcyepsg2154: int | None = None
zdctown: str zdctown: str
zdcpostalregion: str zdcpostalregion: str
zdctype: StopAreaType zdctype: StopAreaType
class ConnectionArea(Struct):
datasetid: str
recordid: str
fields: ConnectionAreaFields
record_timestamp: datetime
class StopAreaStopAssociationFields(Struct, kw_only=True): class StopAreaStopAssociationFields(Struct, kw_only=True):
arrid: str # TODO: use int ? arrid: str # TODO: use int ?
artid: str | None = None artid: str | None = None
@@ -149,6 +156,7 @@ class StopAreaStopAssociation(Struct):
class IdfmLineState(Enum): class IdfmLineState(Enum):
active = "active" active = "active"
available_soon = "prochainement active"
class LinePicto(Struct, rename={"id_": "id"}): class LinePicto(Struct, rename={"id_": "id"}):

View File

@@ -0,0 +1,15 @@
from msgspec import Struct
class PictoFieldsFile(Struct, rename={"id_": "id"}):
id_: str
height: int
width: int
filename: str
thumbnail: bool
format: str
class Picto(Struct):
indices_commerciaux: str
noms_des_fichiers: PictoFieldsFile | None = None

View File

@@ -1,8 +1,8 @@
from fastapi import APIRouter, HTTPException from fastapi import APIRouter, HTTPException
from fastapi_cache.decorator import cache from fastapi_cache.decorator import cache
from backend.models import Line from ..models import Line
from backend.schemas import Line as LineSchema, TransportMode from ..schemas import Line as LineSchema, TransportMode
router = APIRouter(prefix="/line", tags=["line"]) router = APIRouter(prefix="/line", tags=["line"])

View File

@@ -5,20 +5,16 @@ from typing import Sequence
from fastapi import APIRouter, HTTPException from fastapi import APIRouter, HTTPException
from fastapi_cache.decorator import cache from fastapi_cache.decorator import cache
from backend.idfm_interface import ( from ..idfm_interface import Destinations as IdfmDestinations, TrainStatus
Destinations as IdfmDestinations, from ..models import Stop, StopArea, StopShape
IdfmInterface, from ..schemas import (
TrainStatus,
)
from backend.models import Stop, StopArea, StopShape
from backend.schemas import (
NextPassage as NextPassageSchema, NextPassage as NextPassageSchema,
NextPassages as NextPassagesSchema, NextPassages as NextPassagesSchema,
Stop as StopSchema, Stop as StopSchema,
StopArea as StopAreaSchema, StopArea as StopAreaSchema,
StopShape as StopShapeSchema, StopShape as StopShapeSchema,
) )
from dependencies import idfm_interface from ..dependencies import idfm_interface
router = APIRouter(prefix="/stop", tags=["stop"]) router = APIRouter(prefix="/stop", tags=["stop"])
@@ -106,7 +102,7 @@ async def get_next_passages(stop_id: int) -> NextPassagesSchema | None:
# re.match will return None if the given journey.LineRef.value is not valid. # re.match will return None if the given journey.LineRef.value is not valid.
try: try:
line_id_match = IdfmInterface.LINE_RE.match(journey.LineRef.value) line_id_match = idfm_interface.LINE_RE.match(journey.LineRef.value)
line_id = int(line_id_match.group(1)) # type: ignore line_id = int(line_id_match.group(1)) # type: ignore
except (AttributeError, TypeError, ValueError) as err: except (AttributeError, TypeError, ValueError) as err:
raise HTTPException( raise HTTPException(
@@ -163,32 +159,18 @@ async def get_stop_destinations(
@router.get("/{stop_id}/shape") @router.get("/{stop_id}/shape")
@cache(namespace="stop-shape") @cache(namespace="stop-shape")
async def get_stop_shape(stop_id: int) -> StopShapeSchema | None: async def get_stop_shape(stop_id: int) -> StopShapeSchema | None:
connection_area = None if (await Stop.get_by_id(stop_id)) is not None or (
await StopArea.get_by_id(stop_id)
) is not None:
shape_id = stop_id
if (stop := await Stop.get_by_id(stop_id)) is not None: if (shape := await StopShape.get_by_id(shape_id)) is not None:
connection_area = stop.connection_area return StopShapeSchema(
id=shape.id,
elif (stop_area := await StopArea.get_by_id(stop_id)) is not None: type=shape.type,
connection_areas = {stop.connection_area for stop in stop_area.stops} epsg3857_bbox=shape.epsg3857_bbox,
connection_areas_len = len(connection_areas) epsg3857_points=shape.epsg3857_points,
if connection_areas_len == 1: )
connection_area = connection_areas.pop()
else:
prefix = "More than one" if connection_areas_len else "No"
msg = f"{prefix} connection area has been found for stop area #{stop_id}"
raise HTTPException(status_code=500, detail=msg)
if (
connection_area is not None
and (shape := await StopShape.get_by_id(connection_area.id)) is not None
):
return StopShapeSchema(
id=shape.id,
type=shape.type,
epsg3857_bbox=shape.epsg3857_bbox,
epsg3857_points=shape.epsg3857_points,
)
msg = f"No shape found for stop {stop_id}" msg = f"No shape found for stop {stop_id}"
raise HTTPException(status_code=404, detail=msg) raise HTTPException(status_code=404, detail=msg)

View File

@@ -53,8 +53,8 @@ class Line(BaseModel):
transportMode: TransportMode transportMode: TransportMode
backColorHexa: str backColorHexa: str
foreColorHexa: str foreColorHexa: str
operatorId: str operatorId: int
accessibility: IdfmState accessibility: IdfmState
visualSignsAvailable: IdfmState visualSignsAvailable: IdfmState
audibleSignsAvailable: IdfmState audibleSignsAvailable: IdfmState
stopIds: list[str] stopIds: list[int]

View File

@@ -9,7 +9,7 @@ class Stop(BaseModel):
town: str town: str
epsg3857_x: float epsg3857_x: float
epsg3857_y: float epsg3857_y: float
lines: list[str] lines: list[int]
class StopArea(BaseModel): class StopArea(BaseModel):
@@ -17,7 +17,7 @@ class StopArea(BaseModel):
name: str name: str
town: str town: str
type: StopAreaType type: StopAreaType
lines: list[str] # SNCF lines are linked to stop areas and not stops. lines: list[int] # SNCF lines are linked to stop areas and not stops.
stops: list[Stop] stops: list[Stop]

74
backend/api/settings.py Normal file
View File

@@ -0,0 +1,74 @@
from __future__ import annotations
from typing import Annotated
from pydantic import BaseModel, SecretStr
from pydantic.functional_validators import model_validator
from pydantic_settings import (
BaseSettings,
PydanticBaseSettingsSource,
SettingsConfigDict,
)
class HttpSettings(BaseModel):
host: str = "127.0.0.1"
port: int = 8080
cert: str | None = None
class DatabaseSettings(BaseModel):
name: str
host: str
port: int
driver: str = "postgresql+psycopg"
user: str
password: Annotated[SecretStr, check_user_password]
class CacheSettings(BaseModel):
enable: bool = False
host: str = "127.0.0.1"
port: int = 6379
user: str | None = None
password: Annotated[SecretStr | None, check_user_password] = None
@model_validator(mode="after")
def check_user_password(self) -> DatabaseSettings | CacheSettings:
if self.user is not None and self.password is None:
raise ValueError("user is set, password shall be set too.")
if self.password is not None and self.user is None:
raise ValueError("password is set, user shall be set too.")
return self
class TracingSettings(BaseModel):
enable: bool = False
class Settings(BaseSettings):
app_name: str
idfm_api_key: SecretStr
clear_static_data: bool
http: HttpSettings
db: DatabaseSettings
cache: CacheSettings
tracing: TracingSettings
model_config = SettingsConfigDict(env_prefix="CER__", env_nested_delimiter="__")
@classmethod
def settings_customise_sources(
cls,
settings_cls: type[BaseSettings],
init_settings: PydanticBaseSettingsSource,
env_settings: PydanticBaseSettingsSource,
dotenv_settings: PydanticBaseSettingsSource,
file_secret_settings: PydanticBaseSettingsSource,
) -> tuple[PydanticBaseSettingsSource, ...]:
return env_settings, init_settings, file_secret_settings

View File

@@ -2,9 +2,8 @@
import uvicorn import uvicorn
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from fastapi import FastAPI from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi_cache import FastAPICache from fastapi_cache import FastAPICache
from opentelemetry import trace from opentelemetry import trace
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
@@ -13,9 +12,9 @@ from opentelemetry.sdk.resources import Resource, SERVICE_NAME
from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor from opentelemetry.sdk.trace.export import BatchSpanProcessor
from backend.db import db from api.db import db
from dependencies import idfm_interface, redis_backend, settings from api.dependencies import idfm_interface, redis_backend, settings
from routers import line, stop from api.routers import line, stop
@asynccontextmanager @asynccontextmanager
@@ -35,13 +34,28 @@ app = FastAPI(lifespan=lifespan)
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=["https://localhost:4443", "https://localhost:3000"], allow_origins=["http://carrramba.adrien.run", "https://carrramba.adrien.run"],
allow_credentials=True, allow_credentials=True,
allow_methods=["*"], allow_methods=["OPTIONS", "GET"],
allow_headers=["*"], allow_headers=["*"],
) )
app.mount("/widget", StaticFiles(directory="../frontend/", html=True), name="widget") # The cache-control header entry is not managed properly by fastapi-cache:
# For now, a request with a cache-control set to no-cache
# is interpreted as disabling the use of the server cache.
# Cf. Improve Cache-Control header parsing and handling
# https://github.com/long2ice/fastapi-cache/issues/144 workaround
@app.middleware("http")
async def fastapi_cache_issue_144_workaround(request: Request, call_next):
entries = request.headers.__dict__["_list"]
new_entries = [
entry for entry in entries if entry[0].decode().lower() != "cache-control"
]
request.headers.__dict__["_list"] = new_entries
return await call_next(request)
app.include_router(line.router) app.include_router(line.router)
app.include_router(stop.router) app.include_router(stop.router)

View File

@@ -1,611 +0,0 @@
from collections import defaultdict
from logging import getLogger
from re import compile as re_compile
from time import time
from typing import (
AsyncIterator,
ByteString,
Callable,
Iterable,
List,
Type,
)
from aiofiles import open as async_open
from aiohttp import ClientSession
from msgspec import ValidationError
from msgspec.json import Decoder
from pyproj import Transformer
from shapefile import Reader as ShapeFileReader, ShapeRecord # type: ignore
from ..db import Database
from ..models import ConnectionArea, Line, LinePicto, Stop, StopArea, StopShape
from .idfm_types import (
ConnectionArea as IdfmConnectionArea,
Destinations as IdfmDestinations,
IdfmLineState,
IdfmResponse,
Line as IdfmLine,
LinePicto as IdfmPicto,
IdfmState,
Stop as IdfmStop,
StopArea as IdfmStopArea,
StopAreaStopAssociation,
StopAreaType,
StopLineAsso as IdfmStopLineAsso,
TransportMode,
)
from .ratp_types import Picto as RatpPicto
logger = getLogger(__name__)
class IdfmInterface:
IDFM_ROOT_URL = "https://prim.iledefrance-mobilites.fr/marketplace"
IDFM_STOP_MON_URL = f"{IDFM_ROOT_URL}/stop-monitoring"
IDFM_ROOT_URL = "https://data.iledefrance-mobilites.fr/explore/dataset"
IDFM_STOPS_URL = (
f"{IDFM_ROOT_URL}/arrets/download/?format=json&timezone=Europe/Berlin"
)
IDFM_PICTO_URL = f"{IDFM_ROOT_URL}/referentiel-des-lignes/files"
RATP_ROOT_URL = "https://data.ratp.fr/explore/dataset"
RATP_PICTO_URL = (
f"{RATP_ROOT_URL}"
"/pictogrammes-des-lignes-de-metro-rer-tramway-bus-et-noctilien/files"
)
OPERATOR_RE = re_compile(r"[^:]+:Operator::([^:]+):")
LINE_RE = re_compile(r"[^:]+:Line::C([^:]+):")
def __init__(self, api_key: str, database: Database) -> None:
self._api_key = api_key
self._database = database
self._http_headers = {"Accept": "application/json", "apikey": self._api_key}
self._epsg2154_epsg3857_transformer = Transformer.from_crs(2154, 3857)
self._json_stops_decoder = Decoder(type=List[IdfmStop])
self._json_stop_areas_decoder = Decoder(type=List[IdfmStopArea])
self._json_connection_areas_decoder = Decoder(type=List[IdfmConnectionArea])
self._json_lines_decoder = Decoder(type=List[IdfmLine])
self._json_stops_lines_assos_decoder = Decoder(type=List[IdfmStopLineAsso])
self._json_ratp_pictos_decoder = Decoder(type=List[RatpPicto])
self._json_stop_area_stop_asso_decoder = Decoder(
type=List[StopAreaStopAssociation]
)
self._response_json_decoder = Decoder(type=IdfmResponse)
async def startup(self) -> None:
BATCH_SIZE = 10000
STEPS: tuple[
tuple[
Type[ConnectionArea] | Type[Stop] | Type[StopArea] | Type[StopShape],
Callable,
Callable,
],
...,
] = (
(
StopShape,
self._request_stop_shapes,
self._format_idfm_stop_shapes,
),
(
ConnectionArea,
self._request_idfm_connection_areas,
self._format_idfm_connection_areas,
),
(
StopArea,
self._request_idfm_stop_areas,
self._format_idfm_stop_areas,
),
(Stop, self._request_idfm_stops, self._format_idfm_stops),
)
for model, get_method, format_method in STEPS:
step_begin_ts = time()
elements = []
async for element in get_method():
elements.append(element)
if len(elements) == BATCH_SIZE:
await model.add(format_method(*elements))
elements.clear()
if elements:
await model.add(format_method(*elements))
print(f"Add {model.__name__}s: {time() - step_begin_ts}s")
begin_ts = time()
await self._load_lines()
print(f"Add Lines and IDFM LinePictos: {time() - begin_ts}s")
begin_ts = time()
await self._load_ratp_pictos(30)
print(f"Add RATP LinePictos: {time() - begin_ts}s")
begin_ts = time()
await self._load_lines_stops_assos()
print(f"Link Stops to Lines: {time() - begin_ts}s")
begin_ts = time()
await self._load_stop_assos()
print(f"Link Stops to StopAreas: {time() - begin_ts}s")
async def _load_lines(self, batch_size: int = 5000) -> None:
lines, pictos = [], []
picto_ids = set()
async for line in self._request_idfm_lines():
if (picto := line.fields.picto) is not None and picto.id_ not in picto_ids:
picto_ids.add(picto.id_)
pictos.append(picto)
lines.append(line)
if len(lines) == batch_size:
await LinePicto.add(IdfmInterface._format_idfm_pictos(*pictos))
await Line.add(await self._format_idfm_lines(*lines))
lines.clear()
pictos.clear()
if pictos:
await LinePicto.add(IdfmInterface._format_idfm_pictos(*pictos))
if lines:
await Line.add(await self._format_idfm_lines(*lines))
async def _load_ratp_pictos(self, batch_size: int = 5) -> None:
pictos = []
async for picto in self._request_ratp_pictos():
pictos.append(picto)
if len(pictos) == batch_size:
formatted_pictos = IdfmInterface._format_ratp_pictos(*pictos)
await LinePicto.add(map(lambda picto: picto[1], formatted_pictos))
await Line.add_pictos(formatted_pictos)
pictos.clear()
if pictos:
formatted_pictos = IdfmInterface._format_ratp_pictos(*pictos)
await LinePicto.add(map(lambda picto: picto[1], formatted_pictos))
await Line.add_pictos(formatted_pictos)
async def _load_lines_stops_assos(self, batch_size: int = 5000) -> None:
total_assos_nb = total_found_nb = 0
assos = []
async for asso in self._request_idfm_stops_lines_associations():
fields = asso.fields
try:
stop_id = int(fields.stop_id.rsplit(":", 1)[-1])
except ValueError as err:
print(err)
print(f"{fields.stop_id = }")
continue
assos.append((fields.route_long_name, fields.operatorname, stop_id))
if len(assos) == batch_size:
total_assos_nb += batch_size
total_found_nb += await Line.add_stops(assos)
assos.clear()
if assos:
total_assos_nb += len(assos)
total_found_nb += await Line.add_stops(assos)
print(f"{total_found_nb} line <-> stop ({total_assos_nb = } found)")
async def _load_stop_assos(self, batch_size: int = 5000) -> None:
total_assos_nb = area_stop_assos_nb = conn_stop_assos_nb = 0
area_stop_assos = []
connection_stop_assos = []
async for asso in self._request_idfm_stop_area_stop_associations():
fields = asso.fields
stop_id = int(fields.arrid)
area_stop_assos.append((int(fields.zdaid), stop_id))
connection_stop_assos.append((int(fields.zdcid), stop_id))
if len(area_stop_assos) == batch_size:
total_assos_nb += batch_size
if (found_nb := await StopArea.add_stops(area_stop_assos)) is not None:
area_stop_assos_nb += found_nb
area_stop_assos.clear()
if (
found_nb := await ConnectionArea.add_stops(connection_stop_assos)
) is not None:
conn_stop_assos_nb += found_nb
connection_stop_assos.clear()
if area_stop_assos:
total_assos_nb += len(area_stop_assos)
if (found_nb := await StopArea.add_stops(area_stop_assos)) is not None:
area_stop_assos_nb += found_nb
if (
found_nb := await ConnectionArea.add_stops(connection_stop_assos)
) is not None:
conn_stop_assos_nb += found_nb
print(f"{area_stop_assos_nb} stop area <-> stop ({total_assos_nb = } found)")
print(f"{conn_stop_assos_nb} stop area <-> stop ({total_assos_nb = } found)")
# TODO: This method is synchronous due to the shapefile library.
# It's not a blocking issue but it could be nice to find an alternative.
async def _request_stop_shapes(self) -> AsyncIterator[ShapeRecord]:
# TODO: Use HTTP
with ShapeFileReader("./tests/datasets/REF_LDA.zip") as reader:
for record in reader.shapeRecords():
yield record
async def _request_idfm_stops(self) -> AsyncIterator[IdfmStop]:
# headers = {"Accept": "application/json", "apikey": self._api_key}
# async with ClientSession(headers=headers) as session:
# async with session.get(self.STOPS_URL) as response:
# # print("Status:", response.status)
# if response.status == 200:
# for point in self._json_stops_decoder.decode(await response.read()):
# yield point
# TODO: Use HTTP
async with async_open("./tests/datasets/stops_dataset.json", "rb") as raw:
for element in self._json_stops_decoder.decode(await raw.read()):
yield element
async def _request_idfm_stop_areas(self) -> AsyncIterator[IdfmStopArea]:
# TODO: Use HTTP
async with async_open("./tests/datasets/zones-d-arrets.json", "rb") as raw:
for element in self._json_stop_areas_decoder.decode(await raw.read()):
yield element
async def _request_idfm_connection_areas(self) -> AsyncIterator[IdfmConnectionArea]:
async with async_open(
"./tests/datasets/zones-de-correspondance.json", "rb"
) as raw:
for element in self._json_connection_areas_decoder.decode(await raw.read()):
yield element
async def _request_idfm_lines(self) -> AsyncIterator[IdfmLine]:
# TODO: Use HTTP
async with async_open("./tests/datasets/lines_dataset.json", "rb") as raw:
for element in self._json_lines_decoder.decode(await raw.read()):
yield element
async def _request_idfm_stops_lines_associations(
self,
) -> AsyncIterator[IdfmStopLineAsso]:
# TODO: Use HTTP
async with async_open("./tests/datasets/arrets-lignes.json", "rb") as raw:
for element in self._json_stops_lines_assos_decoder.decode(
await raw.read()
):
yield element
async def _request_idfm_stop_area_stop_associations(
self,
) -> AsyncIterator[StopAreaStopAssociation]:
# TODO: Use HTTP
async with async_open("./tests/datasets/relations.json", "rb") as raw:
for element in self._json_stop_area_stop_asso_decoder.decode(
await raw.read()
):
yield element
async def _request_ratp_pictos(self) -> AsyncIterator[RatpPicto]:
# TODO: Use HTTP
async with async_open(
"./tests/datasets/pictogrammes-des-lignes-de-metro-rer-tramway-bus-et-noctilien.json",
"rb",
) as fd:
for element in self._json_ratp_pictos_decoder.decode(await fd.read()):
yield element
@classmethod
def _format_idfm_pictos(cls, *pictos: IdfmPicto) -> Iterable[LinePicto]:
ret = []
for picto in pictos:
ret.append(
LinePicto(
id=picto.id_,
mime_type=picto.mimetype,
height_px=picto.height,
width_px=picto.width,
filename=picto.filename,
url=f"{cls.IDFM_PICTO_URL}/{picto.id_}/download",
thumbnail=picto.thumbnail,
format=picto.format,
)
)
return ret
@classmethod
def _format_ratp_pictos(cls, *pictos: RatpPicto) -> Iterable[tuple[str, LinePicto]]:
ret = []
for picto in pictos:
if (fields := picto.fields.noms_des_fichiers) is not None:
ret.append(
(
picto.fields.indices_commerciaux,
LinePicto(
id=fields.id_,
mime_type=f"image/{fields.format.lower()}",
height_px=fields.height,
width_px=fields.width,
filename=fields.filename,
url=f"{cls.RATP_PICTO_URL}/{fields.id_}/download",
thumbnail=fields.thumbnail,
format=fields.format,
),
)
)
return ret
async def _format_idfm_lines(self, *lines: IdfmLine) -> Iterable[Line]:
ret = []
optional_value = IdfmLine.optional_value
for line in lines:
fields = line.fields
picto_id = fields.picto.id_ if fields.picto is not None else None
line_id = fields.id_line
try:
formatted_line_id = int(line_id[1:] if line_id[0] == "C" else line_id)
except ValueError:
logger.warning("Unable to format %s line id.", line_id)
continue
try:
operator_id = int(fields.operatorref) # type: ignore
except (ValueError, TypeError):
logger.warning("Unable to format %s operator id.", fields.operatorref)
operator_id = 0
ret.append(
Line(
id=formatted_line_id,
short_name=fields.shortname_line,
name=fields.name_line,
status=IdfmLineState(fields.status.value),
transport_mode=TransportMode(fields.transportmode.value),
transport_submode=optional_value(fields.transportsubmode),
network_name=optional_value(fields.networkname),
group_of_lines_id=optional_value(fields.id_groupoflines),
group_of_lines_shortname=optional_value(
fields.shortname_groupoflines
),
colour_web_hexa=fields.colourweb_hexa,
text_colour_hexa=fields.textcolourprint_hexa,
operator_id=operator_id,
operator_name=optional_value(fields.operatorname),
accessibility=IdfmState(fields.accessibility.value),
visual_signs_available=IdfmState(
fields.visualsigns_available.value
),
audible_signs_available=IdfmState(
fields.audiblesigns_available.value
),
picto_id=fields.picto.id_ if fields.picto is not None else None,
record_id=line.recordid,
record_ts=int(line.record_timestamp.timestamp()),
)
)
return ret
def _format_idfm_stops(self, *stops: IdfmStop) -> Iterable[Stop]:
for stop in stops:
fields = stop.fields
try:
created_ts = int(fields.arrcreated.timestamp()) # type: ignore
except AttributeError:
created_ts = None
epsg3857_point = self._epsg2154_epsg3857_transformer.transform(
fields.arrxepsg2154, fields.arryepsg2154
)
try:
postal_region = int(fields.arrpostalregion)
except ValueError:
logger.warning(
"Unable to format %s postal region.", fields.arrpostalregion
)
continue
yield Stop(
id=int(fields.arrid),
name=fields.arrname,
epsg3857_x=epsg3857_point[0],
epsg3857_y=epsg3857_point[1],
town_name=fields.arrtown,
postal_region=postal_region,
transport_mode=TransportMode(fields.arrtype.value),
version=fields.arrversion,
created_ts=created_ts,
changed_ts=int(fields.arrchanged.timestamp()),
accessibility=IdfmState(fields.arraccessibility.value),
visual_signs_available=IdfmState(fields.arrvisualsigns.value),
audible_signs_available=IdfmState(fields.arraudiblesignals.value),
record_id=stop.recordid,
record_ts=int(stop.record_timestamp.timestamp()),
)
def _format_idfm_stop_areas(self, *stop_areas: IdfmStopArea) -> Iterable[StopArea]:
for stop_area in stop_areas:
fields = stop_area.fields
try:
created_ts = int(fields.zdacreated.timestamp()) # type: ignore
except AttributeError:
created_ts = None
epsg3857_point = self._epsg2154_epsg3857_transformer.transform(
fields.zdaxepsg2154, fields.zdayepsg2154
)
yield StopArea(
id=int(fields.zdaid),
name=fields.zdaname,
town_name=fields.zdatown,
postal_region=fields.zdapostalregion,
epsg3857_x=epsg3857_point[0],
epsg3857_y=epsg3857_point[1],
type=StopAreaType(fields.zdatype.value),
version=fields.zdaversion,
created_ts=created_ts,
changed_ts=int(fields.zdachanged.timestamp()),
)
def _format_idfm_connection_areas(
self,
*connection_areas: IdfmConnectionArea,
) -> Iterable[ConnectionArea]:
for connection_area in connection_areas:
epsg3857_point = self._epsg2154_epsg3857_transformer.transform(
connection_area.zdcxepsg2154, connection_area.zdcyepsg2154
)
yield ConnectionArea(
id=int(connection_area.zdcid),
name=connection_area.zdcname,
town_name=connection_area.zdctown,
postal_region=connection_area.zdcpostalregion,
epsg3857_x=epsg3857_point[0],
epsg3857_y=epsg3857_point[1],
transport_mode=StopAreaType(connection_area.zdctype.value),
version=connection_area.zdcversion,
created_ts=int(connection_area.zdccreated.timestamp()),
changed_ts=int(connection_area.zdcchanged.timestamp()),
)
def _format_idfm_stop_shapes(
self, *shape_records: ShapeRecord
) -> Iterable[StopShape]:
for shape_record in shape_records:
epsg3857_points = [
self._epsg2154_epsg3857_transformer.transform(*point)
for point in shape_record.shape.points
]
bbox_it = iter(shape_record.shape.bbox)
epsg3857_bbox = [
self._epsg2154_epsg3857_transformer.transform(*point)
for point in zip(bbox_it, bbox_it)
]
yield StopShape(
id=shape_record.record[1],
type=shape_record.shape.shapeType,
epsg3857_bbox=epsg3857_bbox,
epsg3857_points=epsg3857_points,
)
async def render_line_picto(self, line: Line) -> tuple[None | str, None | str]:
begin_ts = time()
line_picto_path = line_picto_format = None
target = f"/tmp/{line.id}_repr"
picto = line.picto
if picto is not None:
if (picto_data := await self._get_line_picto(line)) is not None:
async with async_open(target, "wb") as fd:
await fd.write(bytes(picto_data))
line_picto_path = target
line_picto_format = picto.mime_type
print(f"render_line_picto: {time() - begin_ts}")
return (line_picto_path, line_picto_format)
async def _get_line_picto(self, line: Line) -> ByteString | None:
print("---------------------------------------------------------------------")
begin_ts = time()
data = None
picto = line.picto
if picto is not None and picto.url is not None:
headers = (
self._http_headers if picto.url.startswith(self.IDFM_ROOT_URL) else None
)
session_begin_ts = time()
async with ClientSession(headers=headers) as session:
session_creation_ts = time()
print(f"Session creation {session_creation_ts - session_begin_ts}")
async with session.get(picto.url) as response:
get_end_ts = time()
print(f"GET {get_end_ts - session_creation_ts}")
data = await response.read()
print(f"read {time() - get_end_ts}")
print(f"render_line_picto: {time() - begin_ts}")
print("---------------------------------------------------------------------")
return data
async def get_next_passages(self, stop_point_id: int) -> IdfmResponse | None:
ret = None
params = {"MonitoringRef": f"STIF:StopPoint:Q:{stop_point_id}:"}
async with ClientSession(headers=self._http_headers) as session:
async with session.get(self.IDFM_STOP_MON_URL, params=params) as response:
if response.status == 200:
data = await response.read()
try:
ret = self._response_json_decoder.decode(data)
except ValidationError as err:
print(err)
return ret
async def get_destinations(self, stop_id: int) -> IdfmDestinations | None:
begin_ts = time()
destinations: IdfmDestinations = defaultdict(set)
if (stop := await Stop.get_by_id(stop_id)) is not None:
expected_stop_ids = {stop.id}
elif (stop_area := await StopArea.get_by_id(stop_id)) is not None:
expected_stop_ids = {stop.id for stop in stop_area.stops}
else:
return None
if (res := await self.get_next_passages(stop_id)) is not None:
for delivery in res.Siri.ServiceDelivery.StopMonitoringDelivery:
if delivery.Status == IdfmState.true:
for stop_visit in delivery.MonitoredStopVisit:
monitoring_ref = stop_visit.MonitoringRef.value
try:
monitored_stop_id = int(monitoring_ref.split(":")[-2])
except (IndexError, ValueError):
print(f"Unable to get stop id from {monitoring_ref}")
continue
journey = stop_visit.MonitoredVehicleJourney
if (
dst_names := journey.DestinationName
) and monitored_stop_id in expected_stop_ids:
line_id = journey.LineRef.value.split(":")[-2]
destinations[line_id].add(dst_names[0].value)
print(f"get_next_passages: {time() - begin_ts}")
return destinations

View File

@@ -1,25 +0,0 @@
from datetime import datetime
from typing import Optional
from msgspec import Struct
class PictoFieldsFile(Struct, rename={"id_": "id"}):
id_: str
height: int
width: int
filename: str
thumbnail: bool
format: str
class PictoFields(Struct):
indices_commerciaux: str
noms_des_fichiers: Optional[PictoFieldsFile] = None
class Picto(Struct):
datasetid: str
recordid: str
fields: PictoFields
record_timestamp: datetime

View File

@@ -1,59 +0,0 @@
from typing import Any
from pydantic import BaseModel, BaseSettings, Field, root_validator, SecretStr
class HttpSettings(BaseModel):
host: str = "127.0.0.1"
port: int = 8080
cert: str | None = None
def check_user_password(cls, values: dict[str, Any]) -> dict[str, Any]:
user = values.get("user")
password = values.get("password")
if user is not None and password is None:
raise ValueError("user is set, password shall be set too.")
if password is not None and user is None:
raise ValueError("password is set, user shall be set too.")
return values
class DatabaseSettings(BaseModel):
name: str = "carrramba-encore-rate"
host: str = "127.0.0.1"
port: int = 5432
driver: str = "postgresql+psycopg"
user: str | None = None
password: SecretStr | None = None
_user_password_validation = root_validator(allow_reuse=True)(check_user_password)
class CacheSettings(BaseModel):
enable: bool = False
host: str = "127.0.0.1"
port: int = 6379
user: str | None = None
password: SecretStr | None = None
_user_password_validation = root_validator(allow_reuse=True)(check_user_password)
class TracingSettings(BaseModel):
enable: bool = False
class Settings(BaseSettings):
app_name: str
idfm_api_key: SecretStr = Field(..., env="API_KEY")
clear_static_data: bool = Field(False, env="CLEAR_STATIC_DATA")
http: HttpSettings = HttpSettings()
db: DatabaseSettings = DatabaseSettings()
cache: CacheSettings = CacheSettings()
tracing: TracingSettings = TracingSettings()

View File

@@ -10,9 +10,7 @@ db:
name: carrramba-encore-rate name: carrramba-encore-rate
host: postgres host: postgres
port: 5432 port: 5432
driver: postgresql+psycopg
user: cer user: cer
password: cer_password
cache: cache:
enable: true enable: true

575
backend/db_updater.py Executable file
View File

@@ -0,0 +1,575 @@
#!/usr/bin/env python3
from asyncio import run, gather
from logging import getLogger, INFO, Handler as LoggingHandler, NOTSET
from itertools import islice
from time import time
from os import environ
from typing import Callable, Iterable, List, Type
from aiofiles.tempfile import NamedTemporaryFile
from aiohttp import ClientSession
from msgspec import ValidationError
from msgspec.json import Decoder
from pyproj import Transformer
from shapefile import Reader as ShapeFileReader, ShapeRecord # type: ignore
from tqdm import tqdm
from yaml import safe_load
from api.db import Base, db, Database
from api.models import ConnectionArea, Line, LinePicto, Stop, StopArea, StopShape
from api.idfm_interface.idfm_types import (
ConnectionArea as IdfmConnectionArea,
IdfmLineState,
Line as IdfmLine,
LinePicto as IdfmPicto,
IdfmState,
Stop as IdfmStop,
StopArea as IdfmStopArea,
StopAreaStopAssociation,
StopAreaType,
StopLineAsso as IdfmStopLineAsso,
TransportMode,
)
from api.idfm_interface.ratp_types import Picto as RatpPicto
from api.settings import Settings
CONFIG_PATH = environ.get("CONFIG_PATH", "./config.sample.yaml")
BATCH_SIZE = 1000
IDFM_ROOT_URL = "https://data.iledefrance-mobilites.fr/explore/dataset"
IDFM_CONNECTION_AREAS_URL = (
f"{IDFM_ROOT_URL}/zones-de-correspondance/download/?format=json"
)
IDFM_LINES_URL = f"{IDFM_ROOT_URL}/referentiel-des-lignes/download/?format=json"
IDFM_PICTO_URL = f"{IDFM_ROOT_URL}/referentiel-des-lignes/files"
IDFM_STOP_AREAS_URL = f"{IDFM_ROOT_URL}/zones-d-arrets/download/?format=json"
IDFM_STOP_SHAPES_URL = "https://eu.ftp.opendatasoft.com/stif/Reflex/REF_ArR.zip"
IDFM_STOP_AREA_SHAPES_URL = "https://eu.ftp.opendatasoft.com/stif/Reflex/REF_ZdA.zip"
IDFM_STOP_STOP_AREAS_ASSOS_URL = f"{IDFM_ROOT_URL}/relations/download/?format=json"
IDFM_STOPS_LINES_ASSOS_URL = f"{IDFM_ROOT_URL}/arrets-lignes/download/?format=json"
IDFM_STOPS_URL = f"{IDFM_ROOT_URL}/arrets/download/?format=json"
RATP_ROOT_URL = "https://data.ratp.fr/api/explore/v2.1/catalog/datasets"
RATP_PICTOS_URL = (
f"{RATP_ROOT_URL}"
"/pictogrammes-des-lignes-de-metro-rer-tramway-bus-et-noctilien/exports/json?lang=fr"
)
# From https://stackoverflow.com/a/38739634
class TqdmLoggingHandler(LoggingHandler):
def __init__(self, level=NOTSET):
super().__init__(level)
def emit(self, record):
try:
msg = self.format(record)
tqdm.write(msg)
self.flush()
except Exception:
self.handleError(record)
logger = getLogger(__name__)
logger.setLevel(INFO)
logger.addHandler(TqdmLoggingHandler())
epsg2154_epsg3857_transformer = Transformer.from_crs(2154, 3857)
json_stops_decoder = Decoder(type=List[IdfmStop])
json_stop_areas_decoder = Decoder(type=List[IdfmStopArea])
json_connection_areas_decoder = Decoder(type=List[IdfmConnectionArea])
json_lines_decoder = Decoder(type=List[IdfmLine])
json_stops_lines_assos_decoder = Decoder(type=List[IdfmStopLineAsso])
json_ratp_pictos_decoder = Decoder(type=List[RatpPicto])
json_stop_area_stop_asso_decoder = Decoder(type=List[StopAreaStopAssociation])
def format_idfm_pictos(*pictos: IdfmPicto) -> Iterable[LinePicto]:
ret = []
for picto in pictos:
ret.append(
LinePicto(
id=picto.id_,
mime_type=picto.mimetype,
height_px=picto.height,
width_px=picto.width,
filename=picto.filename,
url=f"{IDFM_PICTO_URL}/{picto.id_}/download",
thumbnail=picto.thumbnail,
format=picto.format,
)
)
return ret
def format_ratp_pictos(*pictos: RatpPicto) -> Iterable[tuple[str, LinePicto]]:
ret = []
for picto in pictos:
if (fields := picto.noms_des_fichiers) is not None:
ret.append(
(
picto.indices_commerciaux,
LinePicto(
id=fields.id_,
mime_type=f"image/{fields.format.lower()}",
height_px=fields.height,
width_px=fields.width,
filename=fields.filename,
url=f"{RATP_PICTOS_URL}/{fields.id_}/download",
thumbnail=fields.thumbnail,
format=fields.format,
),
)
)
return ret
def format_idfm_lines(*lines: IdfmLine) -> Iterable[Line]:
ret = []
optional_value = IdfmLine.optional_value
for line in lines:
fields = line.fields
line_id = fields.id_line
try:
formatted_line_id = int(line_id[1:] if line_id[0] == "C" else line_id)
except ValueError:
logger.warning("Unable to format %s line id.", line_id)
continue
try:
operator_id = int(fields.operatorref) # type: ignore
except (ValueError, TypeError):
logger.warning("Unable to format %s operator id.", fields.operatorref)
operator_id = 0
ret.append(
Line(
id=formatted_line_id,
short_name=fields.shortname_line,
name=fields.name_line,
status=IdfmLineState(fields.status.value),
transport_mode=TransportMode(fields.transportmode.value),
transport_submode=optional_value(fields.transportsubmode),
network_name=optional_value(fields.networkname),
group_of_lines_id=optional_value(fields.id_groupoflines),
group_of_lines_shortname=optional_value(fields.shortname_groupoflines),
colour_web_hexa=fields.colourweb_hexa,
text_colour_hexa=fields.textcolourprint_hexa,
operator_id=operator_id,
operator_name=optional_value(fields.operatorname),
accessibility=IdfmState(fields.accessibility.value),
visual_signs_available=IdfmState(fields.visualsigns_available.value),
audible_signs_available=IdfmState(fields.audiblesigns_available.value),
picto_id=fields.picto.id_ if fields.picto is not None else None,
record_id=line.recordid,
record_ts=int(line.record_timestamp.timestamp()),
)
)
return ret
def format_idfm_stops(*stops: IdfmStop) -> Iterable[Stop]:
for stop in stops:
fields = stop.fields
try:
created_ts = int(fields.arrcreated.timestamp()) # type: ignore
except AttributeError:
created_ts = None
epsg3857_point = epsg2154_epsg3857_transformer.transform(
fields.arrxepsg2154, fields.arryepsg2154
)
try:
postal_region = int(fields.arrpostalregion)
except ValueError:
logger.warning("Unable to format %s postal region.", fields.arrpostalregion)
continue
yield Stop(
id=int(fields.arrid),
name=fields.arrname,
epsg3857_x=epsg3857_point[0],
epsg3857_y=epsg3857_point[1],
town_name=fields.arrtown,
postal_region=postal_region,
transport_mode=TransportMode(fields.arrtype.value),
version=fields.arrversion,
created_ts=created_ts,
changed_ts=int(fields.arrchanged.timestamp()),
accessibility=IdfmState(fields.arraccessibility.value),
visual_signs_available=IdfmState(fields.arrvisualsigns.value),
audible_signs_available=IdfmState(fields.arraudiblesignals.value),
record_id=stop.recordid,
record_ts=int(stop.record_timestamp.timestamp()),
)
def format_idfm_stop_areas(*stop_areas: IdfmStopArea) -> Iterable[StopArea]:
for stop_area in stop_areas:
fields = stop_area.fields
try:
created_ts = int(fields.zdacreated.timestamp()) # type: ignore
except AttributeError:
created_ts = None
epsg3857_point = epsg2154_epsg3857_transformer.transform(
fields.zdaxepsg2154, fields.zdayepsg2154
)
yield StopArea(
id=int(fields.zdaid),
name=fields.zdaname,
town_name=fields.zdatown,
postal_region=fields.zdapostalregion,
epsg3857_x=epsg3857_point[0],
epsg3857_y=epsg3857_point[1],
type=StopAreaType(fields.zdatype.value),
version=fields.zdaversion,
created_ts=created_ts,
changed_ts=int(fields.zdachanged.timestamp()),
)
def format_idfm_connection_areas(
*connection_areas: IdfmConnectionArea,
) -> Iterable[ConnectionArea]:
for connection_area in connection_areas:
fields = connection_area.fields
epsg3857_point = epsg2154_epsg3857_transformer.transform(
fields.zdcxepsg2154, fields.zdcyepsg2154
)
yield ConnectionArea(
id=int(fields.zdcid),
name=fields.zdcname,
town_name=fields.zdctown,
postal_region=fields.zdcpostalregion,
epsg3857_x=epsg3857_point[0],
epsg3857_y=epsg3857_point[1],
transport_mode=StopAreaType(fields.zdctype.value),
version=fields.zdcversion,
created_ts=int(fields.zdccreated.timestamp()),
changed_ts=int(fields.zdcchanged.timestamp()),
)
def format_idfm_stop_shapes(*shape_records: ShapeRecord) -> Iterable[StopShape]:
for shape_record in shape_records:
epsg3857_points = [
epsg2154_epsg3857_transformer.transform(*point)
for point in shape_record.shape.points
]
try:
bbox_it = iter(shape_record.shape.bbox)
epsg3857_bbox = [
epsg2154_epsg3857_transformer.transform(*point)
for point in zip(bbox_it, bbox_it)
]
except AttributeError:
# Handle stop shapes for which no bbox is provided
epsg3857_bbox = []
yield StopShape(
id=shape_record.record[1],
type=shape_record.shape.shapeType,
epsg3857_bbox=epsg3857_bbox,
epsg3857_points=epsg3857_points,
)
async def http_get(url: str) -> str | None:
chunks = []
headers = {"Accept": "application/json"}
async with ClientSession(headers=headers) as session:
async with session.get(url) as response:
size = int(response.headers.get("content-length", 0)) or None
progress_bar = tqdm(desc=f"Downloading {url}", total=size)
if response.status == 200:
async for chunk in response.content.iter_chunked(1024 * 1024):
chunks.append(chunk.decode())
progress_bar.update(len(chunk))
else:
return None
return "".join(chunks)
async def http_request(
url: str, decode: Callable, format_method: Callable, model: Type[Base]
) -> bool:
elements = []
data = await http_get(url)
if data is None:
return False
try:
for element in decode(data):
elements.append(element)
if len(elements) == BATCH_SIZE:
await model.add(format_method(*elements))
elements.clear()
if elements:
await model.add(format_method(*elements))
except ValidationError as err:
logger.warning(err)
return False
return True
async def load_idfm_stops() -> bool:
return await http_request(
IDFM_STOPS_URL, json_stops_decoder.decode, format_idfm_stops, Stop
)
async def load_idfm_stop_areas() -> bool:
return await http_request(
IDFM_STOP_AREAS_URL,
json_stop_areas_decoder.decode,
format_idfm_stop_areas,
StopArea,
)
async def load_idfm_connection_areas() -> bool:
return await http_request(
IDFM_CONNECTION_AREAS_URL,
json_connection_areas_decoder.decode,
format_idfm_connection_areas,
ConnectionArea,
)
async def load_idfm_stop_shapes(url: str) -> None:
async with ClientSession(headers={"Accept": "application/zip"}) as session:
async with session.get(url) as response:
size = int(response.headers.get("content-length", 0)) or None
dl_progress_bar = tqdm(desc=f"Downloading {url}", total=size)
if response.status == 200:
async with NamedTemporaryFile(suffix=".zip") as tmp_file:
async for chunk in response.content.iter_chunked(1024 * 1024):
await tmp_file.write(chunk)
dl_progress_bar.update(len(chunk))
with ShapeFileReader(tmp_file.name) as reader:
step_begin_ts = time()
shapes = reader.shapeRecords()
shapes_len = len(shapes)
db_progress_bar = tqdm(
desc=f"Filling db with {shapes_len} StopShapes",
total=shapes_len,
)
begin, end, finished = 0, BATCH_SIZE, False
while not finished:
elements = islice(shapes, begin, end)
formatteds = list(format_idfm_stop_shapes(*elements))
await StopShape.add(formatteds)
begin = end
end = begin + BATCH_SIZE
finished = begin > len(shapes)
db_progress_bar.update(BATCH_SIZE)
logger.info(
f"Add {StopShape.__name__}s: {time() - step_begin_ts}s"
)
async def load_idfm_lines() -> None:
data = await http_get(IDFM_LINES_URL)
if data is None:
return None
lines, pictos = [], []
picto_ids = set()
for line in json_lines_decoder.decode(data):
if (picto := line.fields.picto) is not None and picto.id_ not in picto_ids:
picto_ids.add(picto.id_)
pictos.append(picto)
lines.append(line)
if len(lines) == BATCH_SIZE:
await LinePicto.add(list(format_idfm_pictos(*pictos)))
await Line.add(list(format_idfm_lines(*lines)))
lines.clear()
pictos.clear()
if pictos:
await LinePicto.add(list(format_idfm_pictos(*pictos)))
if lines:
await Line.add(list(format_idfm_lines(*lines)))
async def load_ratp_pictos(batch_size: int = 5) -> None:
data = await http_get(RATP_PICTOS_URL)
if data is None:
return None
pictos = []
for picto in json_ratp_pictos_decoder.decode(data):
pictos.append(picto)
if len(pictos) == batch_size:
formatteds = format_ratp_pictos(*pictos)
await LinePicto.add([picto[1] for picto in formatteds])
await Line.add_pictos(formatteds)
pictos.clear()
if pictos:
formatteds = format_ratp_pictos(*pictos)
await LinePicto.add([picto[1] for picto in formatteds])
await Line.add_pictos(formatteds)
async def load_lines_stops_assos(batch_size: int = 5000) -> None:
data = await http_get(IDFM_STOPS_LINES_ASSOS_URL)
if data is None:
return None
total_assos_nb = total_found_nb = 0
assos = []
for asso in json_stops_lines_assos_decoder.decode(data):
fields = asso.fields
try:
stop_id = int(fields.stop_id.rsplit(":", 1)[-1])
except ValueError as err:
logger.error(err)
logger.error(f"{fields.stop_id = }")
continue
assos.append((fields.route_long_name, fields.operatorname, stop_id))
if len(assos) == batch_size:
total_assos_nb += batch_size
total_found_nb += await Line.add_stops(assos)
assos.clear()
if assos:
total_assos_nb += len(assos)
total_found_nb += await Line.add_stops(assos)
logger.info(f"{total_found_nb} line <-> stop ({total_assos_nb = } found)")
async def load_stop_assos(batch_size: int = 5000) -> None:
data = await http_get(IDFM_STOP_STOP_AREAS_ASSOS_URL)
if data is None:
return None
total_assos_nb = area_stop_assos_nb = conn_stop_assos_nb = 0
area_stop_assos = []
connection_stop_assos = []
for asso in json_stop_area_stop_asso_decoder.decode(data):
fields = asso.fields
stop_id = int(fields.arrid)
area_stop_assos.append((int(fields.zdaid), stop_id))
connection_stop_assos.append((int(fields.zdcid), stop_id))
if len(area_stop_assos) == batch_size:
total_assos_nb += batch_size
if (found_nb := await StopArea.add_stops(area_stop_assos)) is not None:
area_stop_assos_nb += found_nb
area_stop_assos.clear()
if (
found_nb := await ConnectionArea.add_stops(connection_stop_assos)
) is not None:
conn_stop_assos_nb += found_nb
connection_stop_assos.clear()
if area_stop_assos:
total_assos_nb += len(area_stop_assos)
if (found_nb := await StopArea.add_stops(area_stop_assos)) is not None:
area_stop_assos_nb += found_nb
if (
found_nb := await ConnectionArea.add_stops(connection_stop_assos)
) is not None:
conn_stop_assos_nb += found_nb
logger.info(f"{area_stop_assos_nb} stop area <-> stop ({total_assos_nb = } found)")
logger.info(f"{conn_stop_assos_nb} stop area <-> stop ({total_assos_nb = } found)")
async def prepare(db: Database) -> None:
await load_idfm_lines()
await gather(
*(
load_idfm_stops(),
load_idfm_stop_areas(),
load_idfm_connection_areas(),
load_ratp_pictos(),
)
)
await gather(
*(
load_idfm_stop_shapes(IDFM_STOP_SHAPES_URL),
load_idfm_stop_shapes(IDFM_STOP_AREA_SHAPES_URL),
load_lines_stops_assos(),
load_stop_assos(),
)
)
def load_settings(path: str) -> Settings:
with open(path, "r") as config_file:
config = safe_load(config_file)
return Settings(**config)
async def main() -> None:
settings = load_settings(CONFIG_PATH)
await db.connect(settings.db, True)
begin_ts = time()
await prepare(db)
logger.info(f"Elapsed time: {time() - begin_ts}s")
await db.disconnect()
if __name__ == "__main__":
run(main())

View File

@@ -4,17 +4,14 @@ version = "0.1.0"
description = "" description = ""
authors = ["Adrien SUEUR <me@adrien.run>"] authors = ["Adrien SUEUR <me@adrien.run>"]
readme = "README.md" readme = "README.md"
packages = [{include = "backend"}]
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = "^3.11" python = "^3.11"
aiohttp = "^3.8.3" aiohttp = "^3.8.3"
aiofiles = "^22.1.0" aiofiles = "^22.1.0"
fastapi = "^0.95.0" fastapi = "^0.103.0"
uvicorn = "^0.20.0" uvicorn = "^0.20.0"
msgspec = "^0.12.0" msgspec = "^0.12.0"
pyshp = "^2.3.1"
pyproj = "^3.5.0"
opentelemetry-instrumentation-fastapi = "^0.38b0" opentelemetry-instrumentation-fastapi = "^0.38b0"
sqlalchemy-utils = "^0.41.1" sqlalchemy-utils = "^0.41.1"
opentelemetry-instrumentation-logging = "^0.38b0" opentelemetry-instrumentation-logging = "^0.38b0"
@@ -26,11 +23,33 @@ sqlalchemy = "^2.0.12"
psycopg = "^3.1.9" psycopg = "^3.1.9"
pyyaml = "^6.0" pyyaml = "^6.0"
fastapi-cache2 = {extras = ["redis"], version = "^0.2.1"} fastapi-cache2 = {extras = ["redis"], version = "^0.2.1"}
pydantic-settings = "^2.0.3"
[tool.poetry.group.db_updater.dependencies]
aiofiles = "^22.1.0"
aiohttp = "^3.8.3"
fastapi = "^0.103.0"
msgspec = "^0.12.0"
opentelemetry-instrumentation-fastapi = "^0.38b0"
opentelemetry-instrumentation-sqlalchemy = "^0.38b0"
opentelemetry-sdk = "^1.17.0"
opentelemetry-api = "^1.17.0"
psycopg = "^3.1.9"
pyproj = "^3.5.0"
pyshp = "^2.3.1"
python = "^3.11"
pyyaml = "^6.0"
sqlalchemy = "^2.0.12"
sqlalchemy-utils = "^0.41.1"
tqdm = "^4.65.0"
pydantic-settings = "^2.0.3"
[build-system] [build-system]
requires = ["poetry-core"] requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.core.masonry.api"
[tool.poetry.group.dev.dependencies] [tool.poetry.group.dev.dependencies]
pylsp-mypy = "^0.6.2" pylsp-mypy = "^0.6.2"
mccabe = "^0.7.0" mccabe = "^0.7.0"
@@ -51,6 +70,7 @@ mypy = "^1.0.0"
icecream = "^2.1.3" icecream = "^2.1.3"
types-sqlalchemy-utils = "^1.0.1" types-sqlalchemy-utils = "^1.0.1"
types-pyyaml = "^6.0.12.9" types-pyyaml = "^6.0.12.9"
types-tqdm = "^4.65.0.1"
[tool.mypy] [tool.mypy]
plugins = "sqlalchemy.ext.mypy.plugin" plugins = "sqlalchemy.ext.mypy.plugin"

View File

@@ -76,11 +76,24 @@ services:
- cassandra-schema - cassandra-schema
carrramba-encore-rate-api: carrramba-encore-rate-api:
build: ./backend/ build:
context: ./backend/
dockerfile: Dockerfile.backend
environment:
- CONFIG_PATH=./config.local.yaml
- IDFM_API_KEY=set_your_idfm_key_here
ports: ports:
- "127.0.0.1:8080:8080" - "127.0.0.1:8080:8080"
carrramba-encore-rate-frontend: carrramba-encore-rate-frontend:
build: ./frontend/ build:
context: ./frontend/
ports: ports:
- "127.0.0.1:80:8081" - "127.0.0.1:80:8081"
carrramba-encore-rate-db-updater:
build:
context: ./backend/
dockerfile: Dockerfile.db_updater
environment:
- CONFIG_PATH=./config.local.yaml

View File

@@ -28,8 +28,7 @@ export interface BusinessDataStore {
export const BusinessDataContext = createContext<BusinessDataStore>(); export const BusinessDataContext = createContext<BusinessDataStore>();
export function BusinessDataProvider(props: { children: JSX.Element }) { export function BusinessDataProvider(props: { children: JSX.Element }) {
const [serverUrl] = createSignal<string>("https://carrramba.adrien.run/api");
const [serverUrl] = createSignal<string>("https://localhost:4443");
type Store = { type Store = {
lines: Lines; lines: Lines;

View File

@@ -116,7 +116,9 @@ export const Map: ParentComponent<{}> = () => {
const foundStopIds = new Set(); const foundStopIds = new Set();
for (const foundStop of stops) { for (const foundStop of stops) {
foundStopIds.add(foundStop.id); foundStopIds.add(foundStop.id);
foundStop.stops.forEach(s => foundStopIds.add(s.id)); if (foundStop.stops !== undefined) {
foundStop.stops.forEach(s => foundStopIds.add(s.id));
}
} }
for (const [stopIdStr, feature] of Object.entries(displayedFeatures)) { for (const [stopIdStr, feature] of Object.entries(displayedFeatures)) {

View File

@@ -8,7 +8,7 @@ export enum TrafficStatus {
export class Passage { export class Passage {
line: number; line: number;
operator: string; operator: number;
destinations: string[]; destinations: string[];
atStop: boolean; atStop: boolean;
aimedArrivalTs: number; aimedArrivalTs: number;
@@ -19,7 +19,7 @@ export class Passage {
arrivalStatus: string; arrivalStatus: string;
departStatus: string; departStatus: string;
constructor(line: number, operator: string, destinations: string[], atStop: boolean, aimedArrivalTs: number, constructor(line: number, operator: number, destinations: string[], atStop: boolean, aimedArrivalTs: number,
expectedArrivalTs: number, arrivalPlatformName: string, aimedDepartTs: number, expectedDepartTs: number, expectedArrivalTs: number, arrivalPlatformName: string, aimedDepartTs: number, expectedDepartTs: number,
arrivalStatus: string, departStatus: string) { arrivalStatus: string, departStatus: string) {
this.line = line; this.line = line;
@@ -45,9 +45,9 @@ export class Stop {
epsg3857_x: number; epsg3857_x: number;
epsg3857_y: number; epsg3857_y: number;
stops: Stop[]; stops: Stop[];
lines: string[]; lines: number[];
constructor(id: number, name: string, town: string, epsg3857_x: number, epsg3857_y: number, stops: Stop[], lines: string[]) { constructor(id: number, name: string, town: string, epsg3857_x: number, epsg3857_y: number, stops: Stop[], lines: number[]) {
this.id = id; this.id = id;
this.name = name; this.name = name;
this.town = town; this.town = town;
@@ -82,7 +82,7 @@ export class StopShape {
export type StopShapes = Record<number, StopShape>; export type StopShapes = Record<number, StopShape>;
export class Line { export class Line {
id: string; id: number;
shortName: string; shortName: string;
name: string; name: string;
status: string; // TODO: Use an enum status: string; // TODO: Use an enum
@@ -95,7 +95,7 @@ export class Line {
audibleSignsAvailable: string; // TODO: Use an enum audibleSignsAvailable: string; // TODO: Use an enum
stopIds: number[]; stopIds: number[];
constructor(id: string, shortName: string, name: string, status: string, transportMode: string, backColorHexa: string, constructor(id: number, shortName: string, name: string, status: string, transportMode: string, backColorHexa: string,
foreColorHexa: string, operatorId: number, accessibility: boolean, visualSignsAvailable: string, foreColorHexa: string, operatorId: number, accessibility: boolean, visualSignsAvailable: string,
audibleSignsAvailable: string, stopIds: number[]) { audibleSignsAvailable: string, stopIds: number[]) {
this.id = id; this.id = id;

View File

@@ -26,7 +26,7 @@ export const TransportModeWeights: Record<string, number> = {
export function getTransportModeSrc(mode: string, color: boolean = true): string | undefined { export function getTransportModeSrc(mode: string, color: boolean = true): string | undefined {
let ret = undefined; let ret = undefined;
if (validTransportModes.includes(mode)) { if (validTransportModes.includes(mode)) {
return `/carrramba-encore-rate/public/symbole_${mode}_${color ? "" : "support_fonce_"}RVB.svg`; return `/symbole_${mode}_${color ? "" : "support_fonce_"}RVB.svg`;
} }
return ret; return ret;
} }