🎨 Reorganize back-end code

This commit is contained in:
2023-09-20 22:08:32 +02:00
parent bdbc72ab39
commit 3434802b31
28 changed files with 29 additions and 36 deletions

0
backend/api/__init__.py Normal file
View File

View File

@@ -0,0 +1,21 @@
app_name: carrramba-encore-rate
clear_static_data: false
http:
host: 0.0.0.0
port: 8080
cert: ./config/cert.pem
db:
name: carrramba-encore-rate
host: 127.0.0.1
port: 5432
driver: postgresql+psycopg
user: cer
password: cer_password
cache:
enable: true
tracing:
enable: false

View File

@@ -0,0 +1,21 @@
app_name: carrramba-encore-rate
clear_static_data: false
http:
host: 0.0.0.0
port: 8080
# cert: ./config/cert.pem
db:
name: carrramba-encore-rate
host: postgres
port: 5432
user: cer
cache:
enable: true
host: redis
# TODO: Add user credentials
tracing:
enable: false

View File

@@ -0,0 +1,6 @@
from .db import Database
from .base_class import Base
__all__ = ["Base"]
db = Database()

View File

@@ -0,0 +1,58 @@
from __future__ import annotations
from logging import getLogger
from typing import Self, Sequence, TYPE_CHECKING
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import DeclarativeBase
if TYPE_CHECKING:
from .db import Database
logger = getLogger(__name__)
class Base(DeclarativeBase):
db: Database | None = None
@classmethod
async def add(cls, objs: Sequence[Self]) -> bool:
if cls.db is not None and (session := await cls.db.get_session()) is not None:
try:
async with session.begin():
session.add_all(objs)
except IntegrityError as err:
logger.warning(err)
return await cls.merge(objs)
except AttributeError as err:
logger.error(err)
return False
return True
@classmethod
async def merge(cls, objs: Sequence[Self]) -> bool:
if cls.db is not None and (session := await cls.db.get_session()) is not None:
async with session.begin():
for obj in objs:
await session.merge(obj)
return True
return False
@classmethod
async def get_by_id(cls, id_: int | str) -> Self | None:
if cls.db is not None and (session := await cls.db.get_session()) is not None:
async with session.begin():
stmt = select(cls).where(cls.id == id_)
res = await session.execute(stmt)
return res.scalar_one_or_none()
return None

76
backend/api/db/db.py Normal file
View File

@@ -0,0 +1,76 @@
from asyncio import sleep
from logging import getLogger
from typing import Annotated, AsyncIterator
from fastapi import Depends
from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor
from sqlalchemy import text
from sqlalchemy.exc import OperationalError, SQLAlchemyError
from sqlalchemy.ext.asyncio import (
async_sessionmaker,
AsyncEngine,
AsyncSession,
create_async_engine,
)
from .base_class import Base
from settings import DatabaseSettings
logger = getLogger(__name__)
class Database:
def __init__(self) -> None:
self._async_engine: AsyncEngine | None = None
self._async_session_local: async_sessionmaker[AsyncSession] | None = None
async def get_session(self) -> AsyncSession | None:
try:
return self._async_session_local() # type: ignore
except (SQLAlchemyError, AttributeError) as e:
logger.exception(e)
return None
# TODO: Preserve UserLastStopSearchResults table from drop.
async def connect(
self, settings: DatabaseSettings, clear_static_data: bool = False
) -> bool:
password = settings.password
path = (
f"{settings.driver}://{settings.user}:"
f"{password.get_secret_value() if password is not None else ''}"
f"@{settings.host}:{settings.port}/{settings.name}"
)
self._async_engine = create_async_engine(
path, pool_pre_ping=True, pool_size=10, max_overflow=20
)
if self._async_engine is not None:
SQLAlchemyInstrumentor().instrument(engine=self._async_engine.sync_engine)
self._async_session_local = async_sessionmaker(
bind=self._async_engine,
# autoflush=False,
expire_on_commit=False,
class_=AsyncSession,
)
ret = False
while not ret:
try:
async with self._async_engine.begin() as session:
if clear_static_data:
await session.run_sync(Base.metadata.drop_all)
await session.run_sync(Base.metadata.create_all)
ret = True
except OperationalError as err:
logger.error(err)
await sleep(1)
return True
async def disconnect(self) -> None:
if self._async_engine is not None:
await self._async_engine.dispose()

View File

@@ -0,0 +1,38 @@
from os import environ
from fastapi_cache.backends.redis import RedisBackend
from redis import asyncio as aioredis
from yaml import safe_load
from db import db
from idfm_interface.idfm_interface import IdfmInterface
from settings import CacheSettings, Settings
CONFIG_PATH = environ.get("CONFIG_PATH", "./config.sample.yaml")
def load_settings(path: str) -> Settings:
with open(path, "r") as config_file:
config = safe_load(config_file)
return Settings(**config)
settings = load_settings(CONFIG_PATH)
idfm_interface = IdfmInterface(settings.idfm_api_key.get_secret_value(), db)
def init_redis_backend(settings: CacheSettings) -> RedisBackend:
login = f"{settings.user}:{settings.password}@" if settings.user is not None else ""
url = f"redis://{login}{settings.host}:{settings.port}"
redis_connections_pool = aioredis.from_url(
url, encoding="utf8", decode_responses=True
)
return RedisBackend(redis_connections_pool)
redis_backend = init_redis_backend(settings.cache)

View File

@@ -0,0 +1,67 @@
from .idfm_types import (
Coordinate,
Destinations,
FramedVehicleJourney,
IdfmLineState,
IdfmOperator,
IdfmResponse,
IdfmState,
LinePicto,
LineFields,
Line,
MonitoredCall,
MonitoredVehicleJourney,
Point,
Siri,
ServiceDelivery,
Stop,
StopArea,
StopAreaFields,
StopAreaStopAssociation,
StopAreaStopAssociationFields,
StopAreaType,
StopDelivery,
StopFields,
StopLineAsso,
StopLineAssoFields,
StopMonitoringDelivery,
TrainNumber,
TrainStatus,
TransportMode,
TransportSubMode,
Value,
)
__all__ = [
"Coordinate",
"Destinations",
"FramedVehicleJourney",
"IdfmLineState",
"IdfmOperator",
"IdfmResponse",
"IdfmState",
"LinePicto",
"LineFields",
"Line",
"MonitoredCall",
"MonitoredVehicleJourney",
"Point",
"Siri",
"ServiceDelivery",
"Stop",
"StopArea",
"StopAreaFields",
"StopAreaStopAssociation",
"StopAreaStopAssociationFields",
"StopAreaType",
"StopDelivery",
"StopFields",
"StopLineAsso",
"StopLineAssoFields",
"StopMonitoringDelivery",
"TrainNumber",
"TrainStatus",
"TransportMode",
"TransportSubMode",
"Value",
]

View File

@@ -0,0 +1,115 @@
from collections import defaultdict
from re import compile as re_compile
from typing import ByteString
from aiofiles import open as async_open
from aiohttp import ClientSession
from msgspec import ValidationError
from msgspec.json import Decoder
from .idfm_types import Destinations as IdfmDestinations, IdfmResponse, IdfmState
from db import Database
from models import Line, Stop, StopArea
class IdfmInterface:
IDFM_ROOT_URL = "https://prim.iledefrance-mobilites.fr/marketplace"
IDFM_STOP_MON_URL = f"{IDFM_ROOT_URL}/stop-monitoring"
OPERATOR_RE = re_compile(r"[^:]+:Operator::([^:]+):")
LINE_RE = re_compile(r"[^:]+:Line::C([^:]+):")
def __init__(self, api_key: str, database: Database) -> None:
self._api_key = api_key
self._database = database
self._http_headers = {"Accept": "application/json", "apikey": self._api_key}
self._response_json_decoder = Decoder(type=IdfmResponse)
async def startup(self) -> None:
...
@staticmethod
def _format_line_id(line_id: str) -> int:
return int(line_id[1:] if line_id[0] == "C" else line_id)
async def render_line_picto(self, line: Line) -> tuple[None | str, None | str]:
line_picto_path = line_picto_format = None
target = f"/tmp/{line.id}_repr"
picto = line.picto
if picto is not None:
if (picto_data := await self._get_line_picto(line)) is not None:
async with async_open(target, "wb") as fd:
await fd.write(bytes(picto_data))
line_picto_path = target
line_picto_format = picto.mime_type
return (line_picto_path, line_picto_format)
async def _get_line_picto(self, line: Line) -> ByteString | None:
data = None
picto = line.picto
if picto is not None and picto.url is not None:
headers = (
self._http_headers if picto.url.startswith(self.IDFM_ROOT_URL) else None
)
async with ClientSession(headers=headers) as session:
async with session.get(picto.url) as response:
data = await response.read()
return data
async def get_next_passages(self, stop_point_id: int) -> IdfmResponse | None:
ret = None
params = {"MonitoringRef": f"STIF:StopPoint:Q:{stop_point_id}:"}
async with ClientSession(headers=self._http_headers) as session:
async with session.get(self.IDFM_STOP_MON_URL, params=params) as response:
if response.status == 200:
data = await response.read()
try:
ret = self._response_json_decoder.decode(data)
except ValidationError as err:
print(err)
return ret
async def get_destinations(self, stop_id: int) -> IdfmDestinations | None:
destinations: IdfmDestinations = defaultdict(set)
if (stop := await Stop.get_by_id(stop_id)) is not None:
expected_stop_ids = {stop.id}
elif (stop_area := await StopArea.get_by_id(stop_id)) is not None:
expected_stop_ids = {stop.id for stop in stop_area.stops}
else:
return None
if (res := await self.get_next_passages(stop_id)) is not None:
for delivery in res.Siri.ServiceDelivery.StopMonitoringDelivery:
if delivery.Status == IdfmState.true:
for stop_visit in delivery.MonitoredStopVisit:
monitoring_ref = stop_visit.MonitoringRef.value
try:
monitored_stop_id = int(monitoring_ref.split(":")[-2])
except (IndexError, ValueError):
print(f"Unable to get stop id from {monitoring_ref}")
continue
journey = stop_visit.MonitoredVehicleJourney
if (
dst_names := journey.DestinationName
) and monitored_stop_id in expected_stop_ids:
raw_line_id = journey.LineRef.value.split(":")[-2]
line_id = IdfmInterface._format_line_id(raw_line_id)
destinations[line_id].add(dst_names[0].value)
return destinations

View File

@@ -0,0 +1,300 @@
from __future__ import annotations
from datetime import datetime
from enum import Enum, StrEnum
from typing import Any, NamedTuple
from msgspec import Struct
class Coordinate(NamedTuple):
lat: float
lon: float
class IdfmState(Enum):
unknown = "unknown"
false = "false"
partial = "partial"
true = "true"
class TrainStatus(Enum):
unknown = ""
arrived = "arrived"
onTime = "onTime"
delayed = "delayed"
noReport = "noReport"
early = "early"
cancelled = "cancelled"
undefined = "undefined"
class TransportMode(StrEnum):
bus = "bus"
tram = "tram"
metro = "metro"
rail = "rail"
funicular = "funicular"
class TransportSubMode(Enum):
unknown = "unknown"
localBus = "localBus"
regionalBus = "regionalBus"
highFrequencyBus = "highFrequencyBus"
expressBus = "expressBus"
nightBus = "nightBus"
demandAndResponseBus = "demandAndResponseBus"
airportLinkBus = "airportLinkBus"
regionalRail = "regionalRail"
railShuttle = "railShuttle"
suburbanRailway = "suburbanRailway"
local = "local"
class StopFields(Struct, kw_only=True):
arrgeopoint: Coordinate
arrtown: str
arrcreated: None | datetime = None
arryepsg2154: int
arrpostalregion: str
arrid: str
arrxepsg2154: int
arraccessibility: IdfmState
arrvisualsigns: IdfmState
arrtype: TransportMode
arrname: str
arrversion: str
arrchanged: datetime
arraudiblesignals: IdfmState
class Point(Struct):
coordinates: Coordinate
class Stop(Struct):
datasetid: str
recordid: str
fields: StopFields
record_timestamp: datetime
# geometry: Union[Point]
Stops = dict[str, Stop]
class StopAreaType(StrEnum):
metroStation = "metroStation"
onstreetBus = "onstreetBus"
onstreetTram = "onstreetTram"
railStation = "railStation"
class StopAreaFields(Struct, kw_only=True):
zdaname: str
zdcid: str
zdatown: str
zdaversion: str
zdaid: str
zdacreated: datetime | None = None
zdatype: StopAreaType
zdayepsg2154: int
zdapostalregion: str
zdachanged: datetime
zdaxepsg2154: int
class StopArea(Struct):
datasetid: str
recordid: str
fields: StopAreaFields
record_timestamp: datetime
class ConnectionAreaFields(Struct, kw_only=True):
zdcid: str
zdcversion: str
zdccreated: datetime
zdcchanged: datetime
zdcname: str
zdcxepsg2154: int | None = None
zdcyepsg2154: int | None = None
zdctown: str
zdcpostalregion: str
zdctype: StopAreaType
class ConnectionArea(Struct):
datasetid: str
recordid: str
fields: ConnectionAreaFields
record_timestamp: datetime
class StopAreaStopAssociationFields(Struct, kw_only=True):
arrid: str # TODO: use int ?
artid: str | None = None
arrversion: str
zdcid: str
version: int
zdaid: str
zdaversion: str
artversion: str | None = None
class StopAreaStopAssociation(Struct):
datasetid: str
recordid: str
fields: StopAreaStopAssociationFields
record_timestamp: datetime
class IdfmLineState(Enum):
active = "active"
available_soon = "prochainement active"
class LinePicto(Struct, rename={"id_": "id"}):
id_: str
mimetype: str
height: int
width: int
filename: str
thumbnail: bool
format: str
# color_summary: list[str]
class LineFields(Struct, kw_only=True):
name_line: str
status: IdfmLineState
accessibility: IdfmState
shortname_groupoflines: str | None = None
transportmode: TransportMode
colourweb_hexa: str
textcolourprint_hexa: str
transportsubmode: TransportSubMode | None = TransportSubMode.unknown
operatorref: str | None = None
visualsigns_available: IdfmState
networkname: str | None = None
id_line: str
id_groupoflines: str | None = None
operatorname: str | None = None
audiblesigns_available: IdfmState
shortname_line: str
picto: LinePicto | None = None
class Line(Struct):
datasetid: str
recordid: str
fields: LineFields
record_timestamp: datetime
@staticmethod
def optional_value(value: Any) -> Any:
if value:
return value.value if isinstance(value, Enum) else value
return "NULL"
Lines = dict[str, Line]
Destinations = dict[str, set[str]]
# TODO: Set structs frozen
class StopLineAssoFields(Struct):
pointgeo: Coordinate
stop_id: str
stop_name: str
operatorname: str
nom_commune: str
route_long_name: str
id: str
stop_lat: str
stop_lon: str
code_insee: str
class StopLineAsso(Struct):
datasetid: str
recordid: str
fields: StopLineAssoFields
# geometry: Union[Point]
class Value(Struct):
value: str
class FramedVehicleJourney(Struct):
DataFrameRef: Value
DatedVehicleJourneyRef: str
class TrainNumber(Struct):
TrainNumberRef: list[Value]
class MonitoredCall(Struct, kw_only=True):
Order: int | None = None
StopPointName: list[Value]
VehicleAtStop: bool
DestinationDisplay: list[Value]
AimedArrivalTime: datetime | None = None
ExpectedArrivalTime: datetime | None = None
ArrivalPlatformName: Value | None = None
AimedDepartureTime: datetime | None = None
ExpectedDepartureTime: datetime | None = None
ArrivalStatus: TrainStatus | None = None
DepartureStatus: TrainStatus | None = None
class MonitoredVehicleJourney(Struct, kw_only=True):
LineRef: Value
OperatorRef: Value
FramedVehicleJourneyRef: FramedVehicleJourney
DestinationRef: Value
DestinationName: list[Value] | None = None
JourneyNote: list[Value] | None = None
TrainNumbers: TrainNumber | None = None
MonitoredCall: MonitoredCall
class StopDelivery(Struct):
RecordedAtTime: datetime
ItemIdentifier: str
MonitoringRef: Value
MonitoredVehicleJourney: MonitoredVehicleJourney
class StopMonitoringDelivery(Struct):
ResponseTimestamp: datetime
Version: str
Status: IdfmState
MonitoredStopVisit: list[StopDelivery]
class ServiceDelivery(Struct):
ResponseTimestamp: datetime
ProducerRef: str
ResponseMessageIdentifier: str
StopMonitoringDelivery: list[StopMonitoringDelivery]
class Siri(Struct):
ServiceDelivery: ServiceDelivery
class IdfmOperator(Enum):
SNCF = "SNCF"
class IdfmResponse(Struct):
Siri: Siri

View File

@@ -0,0 +1,15 @@
from msgspec import Struct
class PictoFieldsFile(Struct, rename={"id_": "id"}):
id_: str
height: int
width: int
filename: str
thumbnail: bool
format: str
class Picto(Struct):
indices_commerciaux: str
noms_des_fichiers: PictoFieldsFile | None = None

89
backend/api/main.py Executable file
View File

@@ -0,0 +1,89 @@
#!/usr/bin/env python3
import uvicorn
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi_cache import FastAPICache
from opentelemetry import trace
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
from opentelemetry.sdk.resources import Resource, SERVICE_NAME
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from db import db
from dependencies import idfm_interface, redis_backend, settings
from routers import line, stop
@asynccontextmanager
async def lifespan(app: FastAPI):
FastAPICache.init(redis_backend, prefix="api", enable=settings.cache.enable)
await db.connect(settings.db, settings.clear_static_data)
if settings.clear_static_data:
await idfm_interface.startup()
yield
await db.disconnect()
app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["http://carrramba.adrien.run", "https://carrramba.adrien.run"],
allow_credentials=True,
allow_methods=["OPTIONS", "GET"],
allow_headers=["*"],
)
# The cache-control header entry is not managed properly by fastapi-cache:
# For now, a request with a cache-control set to no-cache
# is interpreted as disabling the use of the server cache.
# Cf. Improve Cache-Control header parsing and handling
# https://github.com/long2ice/fastapi-cache/issues/144 workaround
@app.middleware("http")
async def fastapi_cache_issue_144_workaround(request: Request, call_next):
entries = request.headers.__dict__["_list"]
new_entries = [
entry for entry in entries if entry[0].decode().lower() != "cache-control"
]
request.headers.__dict__["_list"] = new_entries
return await call_next(request)
app.include_router(line.router)
app.include_router(stop.router)
if settings.tracing.enable:
FastAPIInstrumentor.instrument_app(app)
trace.set_tracer_provider(
TracerProvider(resource=Resource.create({SERVICE_NAME: settings.app_name}))
)
trace.get_tracer_provider().add_span_processor(
BatchSpanProcessor(OTLPSpanExporter())
)
tracer = trace.get_tracer(settings.app_name)
if __name__ == "__main__":
http_settings = settings.http
config = uvicorn.Config(
app=app,
host=http_settings.host,
port=http_settings.port,
ssl_certfile=http_settings.cert,
proxy_headers=True,
)
server = uvicorn.Server(config)
server.run()

View File

@@ -0,0 +1,14 @@
from .line import Line, LinePicto
from .stop import ConnectionArea, Stop, StopArea, StopShape
from .user import UserLastStopSearchResults
__all__ = [
"ConnectionArea",
"Line",
"LinePicto",
"Stop",
"StopArea",
"StopShape",
"UserLastStopSearchResults",
]

196
backend/api/models/line.py Normal file
View File

@@ -0,0 +1,196 @@
from asyncio import gather as asyncio_gather
from collections import defaultdict
from typing import Iterable, Self, Sequence
from sqlalchemy import (
BigInteger,
Boolean,
Enum,
ForeignKey,
Integer,
select,
String,
)
from sqlalchemy.orm import Mapped, mapped_column, relationship, selectinload
from sqlalchemy.sql.expression import tuple_
from db import Base, db
from idfm_interface.idfm_types import (
IdfmState,
IdfmLineState,
TransportMode,
TransportSubMode,
)
from .stop import _Stop
class LineStopAssociations(Base):
id = mapped_column(BigInteger, primary_key=True)
line_id = mapped_column(BigInteger, ForeignKey("lines.id"))
stop_id = mapped_column(BigInteger, ForeignKey("_stops.id"))
__tablename__ = "line_stop_associations"
class LinePicto(Base):
db = db
id = mapped_column(String, primary_key=True)
mime_type = mapped_column(String, nullable=False)
height_px = mapped_column(Integer, nullable=False)
width_px = mapped_column(Integer, nullable=False)
filename = mapped_column(String, nullable=False)
url = mapped_column(String, nullable=False)
thumbnail = mapped_column(Boolean, nullable=False)
format = mapped_column(String, nullable=False)
__tablename__ = "line_pictos"
class Line(Base):
db = db
id = mapped_column(BigInteger, primary_key=True)
short_name = mapped_column(String)
name = mapped_column(String, nullable=False)
status = mapped_column(Enum(IdfmLineState), nullable=False)
transport_mode = mapped_column(Enum(TransportMode), nullable=False)
transport_submode = mapped_column(Enum(TransportSubMode), nullable=False)
network_name = mapped_column(String)
group_of_lines_id = mapped_column(String)
group_of_lines_shortname = mapped_column(String)
colour_web_hexa = mapped_column(String, nullable=False)
text_colour_hexa = mapped_column(String, nullable=False)
operator_id = mapped_column(Integer)
operator_name = mapped_column(String)
accessibility = mapped_column(Enum(IdfmState), nullable=False)
visual_signs_available = mapped_column(Enum(IdfmState), nullable=False)
audible_signs_available = mapped_column(Enum(IdfmState), nullable=False)
picto_id = mapped_column(String, ForeignKey("line_pictos.id"))
picto: Mapped[LinePicto] = relationship(LinePicto, lazy="selectin")
record_id = mapped_column(String, nullable=False)
record_ts = mapped_column(BigInteger, nullable=False)
stops: Mapped[list[_Stop]] = relationship(
"_Stop",
secondary="line_stop_associations",
back_populates="lines",
lazy="selectin",
)
__tablename__ = "lines"
@classmethod
async def get_by_name(
cls, name: str, operator_name: None | str = None
) -> Sequence[Self] | None:
if (session := await cls.db.get_session()) is not None:
async with session.begin():
filters = {"name": name}
if operator_name is not None:
filters["operator_name"] = operator_name
stmt = (
select(cls)
.filter_by(**filters)
.options(selectinload(cls.stops), selectinload(cls.picto))
)
res = await session.execute(stmt)
lines = res.scalars().all()
return lines
return None
@classmethod
async def _add_picto_to_line(cls, line: str | Self, picto: LinePicto) -> None:
formatted_line: Self | None = None
if isinstance(line, str):
if (lines := await cls.get_by_name(line)) is not None:
if len(lines) == 1:
formatted_line = lines[0]
else:
for candidate_line in lines:
if candidate_line.operator_name == "RATP":
formatted_line = candidate_line
break
else:
formatted_line = line
if isinstance(formatted_line, Line) and formatted_line.picto is None:
formatted_line.picto = picto
formatted_line.picto_id = picto.id
@classmethod
async def add_pictos(cls, line_to_pictos: Iterable[tuple[str, LinePicto]]) -> bool:
if (session := await cls.db.get_session()) is not None:
async with session.begin():
await asyncio_gather(
*[
cls._add_picto_to_line(line, picto)
for line, picto in line_to_pictos
]
)
return True
return False
@classmethod
async def add_stops(cls, line_to_stop_ids: Iterable[tuple[str, str, int]]) -> int:
if (session := await cls.db.get_session()) is not None:
async with session.begin():
line_names_ops, stop_ids = set(), set()
for line_name, operator_name, stop_id in line_to_stop_ids:
line_names_ops.add((line_name, operator_name))
stop_ids.add(stop_id)
lines_res = await session.execute(
select(Line).where(
tuple_(Line.name, Line.operator_name).in_(line_names_ops)
)
)
lines = defaultdict(list)
for line in lines_res.scalars():
lines[(line.name, line.operator_name)].append(line)
stops_res = await session.execute(
select(_Stop).where(_Stop.id.in_(stop_ids))
)
stops = {stop.id: stop for stop in stops_res.scalars()}
found = 0
for line_name, operator_name, stop_id in line_to_stop_ids:
if (stop := stops.get(stop_id)) is not None:
if (
stop_lines := lines.get((line_name, operator_name))
) is not None:
for stop_line in stop_lines:
stop_line.stops.append(stop)
found += 1
else:
print(f"No line found for {line_name}/{operator_name}")
else:
print(
f"No stop found for {stop_id} id"
f"(used by {line_name}/{operator_name})"
)
return found
return 0

275
backend/api/models/stop.py Normal file
View File

@@ -0,0 +1,275 @@
from __future__ import annotations
from logging import getLogger
from typing import Iterable, Sequence, TYPE_CHECKING
from sqlalchemy import (
BigInteger,
Computed,
desc,
Enum,
Float,
ForeignKey,
func,
Integer,
JSON,
select,
String,
)
from sqlalchemy.orm import (
mapped_column,
Mapped,
relationship,
selectinload,
with_polymorphic,
)
from sqlalchemy.schema import Index
from sqlalchemy_utils.types.ts_vector import TSVectorType
from db import Base, db
from idfm_interface.idfm_types import TransportMode, IdfmState, StopAreaType
if TYPE_CHECKING:
from .line import Line
logger = getLogger(__name__)
class StopAreaStopAssociations(Base):
id = mapped_column(BigInteger, primary_key=True)
stop_id = mapped_column(BigInteger, ForeignKey("_stops.id"))
stop_area_id = mapped_column(BigInteger, ForeignKey("stop_areas.id"))
__tablename__ = "stop_area_stop_associations"
class _Stop(Base):
db = db
id = mapped_column(BigInteger, primary_key=True)
kind = mapped_column(String)
name = mapped_column(String, nullable=False, index=True)
town_name = mapped_column(String, nullable=False)
postal_region = mapped_column(Integer, nullable=False)
epsg3857_x = mapped_column(Float, nullable=False)
epsg3857_y = mapped_column(Float, nullable=False)
version = mapped_column(String, nullable=False)
created_ts = mapped_column(BigInteger)
changed_ts = mapped_column(BigInteger, nullable=False)
lines: Mapped[list[Line]] = relationship(
"Line",
secondary="line_stop_associations",
back_populates="stops",
lazy="selectin",
)
areas: Mapped[list["StopArea"]] = relationship(
"StopArea",
secondary="stop_area_stop_associations",
back_populates="stops",
)
connection_area_id: Mapped[int] = mapped_column(
ForeignKey("connection_areas.id"), nullable=True
)
connection_area: Mapped["ConnectionArea"] = relationship(
back_populates="stops", lazy="selectin"
)
names_tsv = mapped_column(
TSVectorType("name", "town_name", regconfig="french"),
Computed("to_tsvector('french', name || ' ' || town_name)", persisted=True),
)
__tablename__ = "_stops"
__mapper_args__ = {"polymorphic_identity": "_stops", "polymorphic_on": kind}
__table_args__ = (
Index(
"names_tsv_idx",
names_tsv,
postgresql_ops={"name": "gin_trgm_ops"},
postgresql_using="gin",
),
)
@classmethod
async def get_by_name(cls, name: str) -> Sequence[_Stop] | None:
if (session := await cls.db.get_session()) is not None:
async with session.begin():
descendants = with_polymorphic(_Stop, "*")
match_stmt = descendants.names_tsv.match(
name, postgresql_regconfig="french"
)
ranking_stmt = func.ts_rank_cd(
descendants.names_tsv, func.plainto_tsquery("french", name)
)
stmt = (
select(descendants).filter(match_stmt).order_by(desc(ranking_stmt))
)
res = await session.execute(stmt)
stops = res.scalars().all()
return stops
return None
class Stop(_Stop):
id = mapped_column(BigInteger, ForeignKey("_stops.id"), primary_key=True)
transport_mode = mapped_column(Enum(TransportMode), nullable=False)
accessibility = mapped_column(Enum(IdfmState), nullable=False)
visual_signs_available = mapped_column(Enum(IdfmState), nullable=False)
audible_signs_available = mapped_column(Enum(IdfmState), nullable=False)
record_id = mapped_column(String, nullable=False)
record_ts = mapped_column(BigInteger, nullable=False)
__tablename__ = "stops"
__mapper_args__ = {"polymorphic_identity": "stops", "polymorphic_load": "inline"}
class StopArea(_Stop):
id = mapped_column(BigInteger, ForeignKey("_stops.id"), primary_key=True)
type = mapped_column(Enum(StopAreaType), nullable=False)
stops: Mapped[list["Stop"]] = relationship(
"Stop",
secondary="stop_area_stop_associations",
back_populates="areas",
lazy="selectin",
)
__tablename__ = "stop_areas"
__mapper_args__ = {
"polymorphic_identity": "stop_areas",
"polymorphic_load": "inline",
}
@classmethod
async def add_stops(
cls, stop_area_to_stop_ids: Iterable[tuple[int, int]]
) -> int | None:
if (session := await cls.db.get_session()) is not None:
async with session.begin():
stop_area_ids, stop_ids = set(), set()
for stop_area_id, stop_id in stop_area_to_stop_ids:
stop_area_ids.add(stop_area_id)
stop_ids.add(stop_id)
stop_areas_res = await session.scalars(
select(StopArea)
.where(StopArea.id.in_(stop_area_ids))
.options(selectinload(StopArea.stops))
)
stop_areas: dict[int, StopArea] = {
stop_area.id: stop_area for stop_area in stop_areas_res.all()
}
stop_res = await session.execute(
select(Stop).where(Stop.id.in_(stop_ids))
)
stops: dict[int, Stop] = {stop.id: stop for stop in stop_res.scalars()}
found = 0
for stop_area_id, stop_id in stop_area_to_stop_ids:
if (stop_area := stop_areas.get(stop_area_id)) is not None:
if (stop := stops.get(stop_id)) is not None:
stop_area.stops.append(stop)
found += 1
else:
print(f"No stop found for {stop_id} id")
else:
print(f"No stop area found for {stop_area_id}")
return found
return None
class StopShape(Base):
db = db
id = mapped_column(BigInteger, primary_key=True) # Same id than ConnectionArea
type = mapped_column(Integer, nullable=False)
epsg3857_bbox = mapped_column(JSON)
epsg3857_points = mapped_column(JSON)
__tablename__ = "stop_shapes"
class ConnectionArea(Base):
db = db
id = mapped_column(BigInteger, primary_key=True)
name = mapped_column(String, nullable=False)
town_name = mapped_column(String, nullable=False)
postal_region = mapped_column(String, nullable=False)
epsg3857_x = mapped_column(Float, nullable=False)
epsg3857_y = mapped_column(Float, nullable=False)
transport_mode = mapped_column(Enum(StopAreaType), nullable=False)
version = mapped_column(String, nullable=False)
created_ts = mapped_column(BigInteger)
changed_ts = mapped_column(BigInteger, nullable=False)
stops: Mapped[list["_Stop"]] = relationship(back_populates="connection_area")
__tablename__ = "connection_areas"
# TODO: Merge with StopArea.add_stops
@classmethod
async def add_stops(
cls, conn_area_to_stop_ids: Iterable[tuple[int, int]]
) -> int | None:
if (session := await cls.db.get_session()) is not None:
async with session.begin():
conn_area_ids, stop_ids = set(), set()
for conn_area_id, stop_id in conn_area_to_stop_ids:
conn_area_ids.add(conn_area_id)
stop_ids.add(stop_id)
conn_area_res = await session.execute(
select(ConnectionArea)
.where(ConnectionArea.id.in_(conn_area_ids))
.options(selectinload(ConnectionArea.stops))
)
conn_areas: dict[int, ConnectionArea] = {
conn.id: conn for conn in conn_area_res.scalars()
}
stop_res = await session.execute(
select(Stop).where(Stop.id.in_(stop_ids))
)
stops: dict[int, Stop] = {stop.id: stop for stop in stop_res.scalars()}
found = 0
for conn_area_id, stop_id in conn_area_to_stop_ids:
if (conn_area := conn_areas.get(conn_area_id)) is not None:
if (stop := stops.get(stop_id)) is not None:
conn_area.stops.append(stop)
found += 1
else:
print(f"No stop found for {stop_id} id")
else:
print(f"No connection area found for {conn_area_id}")
return found
return None

View File

@@ -0,0 +1,27 @@
from sqlalchemy import BigInteger, ForeignKey, String
from sqlalchemy.orm import Mapped, mapped_column, relationship
from db import Base, db
from .stop import _Stop
class UserLastStopSearchStopAssociations(Base):
id = mapped_column(BigInteger, primary_key=True)
user_mxid = mapped_column(
String, ForeignKey("user_last_stop_search_results.user_mxid")
)
stop_id = mapped_column(BigInteger, ForeignKey("_stops.id"))
__tablename__ = "user_last_stop_search_stop_associations"
class UserLastStopSearchResults(Base):
db = db
user_mxid = mapped_column(String, primary_key=True)
request_content = mapped_column(String, nullable=False)
stops: Mapped[_Stop] = relationship(
_Stop, secondary="user_last_stop_search_stop_associations"
)
__tablename__ = "user_last_stop_search_results"

0
backend/api/py.typed Normal file
View File

View File

View File

@@ -0,0 +1,34 @@
from fastapi import APIRouter, HTTPException
from fastapi_cache.decorator import cache
from models import Line
from schemas import Line as LineSchema, TransportMode
router = APIRouter(prefix="/line", tags=["line"])
@router.get("/{line_id}", response_model=LineSchema)
@cache(namespace="line")
async def get_line(line_id: int) -> LineSchema:
line: Line | None = await Line.get_by_id(line_id)
if line is None:
raise HTTPException(status_code=404, detail=f'Line "{line_id}" not found')
return LineSchema(
id=line.id,
shortName=line.short_name,
name=line.name,
status=line.status,
transportMode=TransportMode.from_idfm_transport_mode(
line.transport_mode, line.transport_submode
),
backColorHexa=line.colour_web_hexa,
foreColorHexa=line.text_colour_hexa,
operatorId=line.operator_id,
accessibility=line.accessibility,
visualSignsAvailable=line.visual_signs_available,
audibleSignsAvailable=line.audible_signs_available,
stopIds=[stop.id for stop in line.stops],
)

176
backend/api/routers/stop.py Normal file
View File

@@ -0,0 +1,176 @@
from collections import defaultdict
from datetime import datetime
from typing import Sequence
from fastapi import APIRouter, HTTPException
from fastapi_cache.decorator import cache
from idfm_interface import Destinations as IdfmDestinations, TrainStatus
from models import Stop, StopArea, StopShape
from schemas import (
NextPassage as NextPassageSchema,
NextPassages as NextPassagesSchema,
Stop as StopSchema,
StopArea as StopAreaSchema,
StopShape as StopShapeSchema,
)
from dependencies import idfm_interface
router = APIRouter(prefix="/stop", tags=["stop"])
def _format_stop(stop: Stop) -> StopSchema:
return StopSchema(
id=stop.id,
name=stop.name,
town=stop.town_name,
epsg3857_x=stop.epsg3857_x,
epsg3857_y=stop.epsg3857_y,
lines=[line.id for line in stop.lines],
)
def optional_datetime_to_ts(dt: datetime | None) -> int | None:
return int(dt.timestamp()) if dt else None
# TODO: Add limit support
@router.get("/")
@cache(namespace="stop")
async def get_stop(
name: str = "", limit: int = 10
) -> Sequence[StopAreaSchema | StopSchema] | None:
matching_stops = await Stop.get_by_name(name)
if matching_stops is None:
return None
formatted: list[StopAreaSchema | StopSchema] = []
stop_areas: dict[int, StopArea] = {}
stops: dict[int, Stop] = {}
for stop in matching_stops:
if isinstance(stop, StopArea):
stop_areas[stop.id] = stop
elif isinstance(stop, Stop):
stops[stop.id] = stop
for stop_area in stop_areas.values():
formatted_stops = []
for stop in stop_area.stops:
formatted_stops.append(_format_stop(stop))
try:
del stops[stop.id]
except KeyError as err:
print(err)
formatted.append(
StopAreaSchema(
id=stop_area.id,
name=stop_area.name,
town=stop_area.town_name,
type=stop_area.type,
lines=[line.id for line in stop_area.lines],
stops=formatted_stops,
)
)
formatted.extend(_format_stop(stop) for stop in stops.values())
return formatted
@router.get("/{stop_id}/nextPassages")
@cache(namespace="stop-nextPassages", expire=30)
async def get_next_passages(stop_id: int) -> NextPassagesSchema | None:
res = await idfm_interface.get_next_passages(stop_id)
if res is None:
return None
service_delivery = res.Siri.ServiceDelivery
stop_monitoring_deliveries = service_delivery.StopMonitoringDelivery
by_line_by_dst_passages: dict[
int, dict[str, list[NextPassageSchema]]
] = defaultdict(lambda: defaultdict(list))
for delivery in stop_monitoring_deliveries:
for stop_visit in delivery.MonitoredStopVisit:
journey = stop_visit.MonitoredVehicleJourney
# re.match will return None if the given journey.LineRef.value is not valid.
try:
line_id_match = idfm_interface.LINE_RE.match(journey.LineRef.value)
line_id = int(line_id_match.group(1)) # type: ignore
except (AttributeError, TypeError, ValueError) as err:
raise HTTPException(
status_code=404, detail=f'Line "{journey.LineRef.value}" not found'
) from err
call = journey.MonitoredCall
dst_names = call.DestinationDisplay
dsts = [dst.value for dst in dst_names] if dst_names else []
arrivalPlatformName = (
call.ArrivalPlatformName.value if call.ArrivalPlatformName else None
)
next_passage = NextPassageSchema(
line=line_id,
operator=journey.OperatorRef.value,
destinations=dsts,
atStop=call.VehicleAtStop,
aimedArrivalTs=optional_datetime_to_ts(call.AimedArrivalTime),
expectedArrivalTs=optional_datetime_to_ts(call.ExpectedArrivalTime),
arrivalPlatformName=arrivalPlatformName,
aimedDepartTs=optional_datetime_to_ts(call.AimedDepartureTime),
expectedDepartTs=optional_datetime_to_ts(call.ExpectedDepartureTime),
arrivalStatus=call.ArrivalStatus
if call.ArrivalStatus is not None
else TrainStatus.unknown,
departStatus=call.DepartureStatus
if call.DepartureStatus is not None
else TrainStatus.unknown,
)
by_line_passages = by_line_by_dst_passages[line_id]
# TODO: by_line_passages[dst].extend(dsts) instead ?
for dst in dsts:
by_line_passages[dst].append(next_passage)
return NextPassagesSchema(
ts=int(service_delivery.ResponseTimestamp.timestamp()),
passages=by_line_by_dst_passages,
)
@router.get("/{stop_id}/destinations")
@cache(namespace="stop-destinations", expire=30)
async def get_stop_destinations(
stop_id: int,
) -> IdfmDestinations | None:
destinations = await idfm_interface.get_destinations(stop_id)
return destinations
@router.get("/{stop_id}/shape")
@cache(namespace="stop-shape")
async def get_stop_shape(stop_id: int) -> StopShapeSchema | None:
if (await Stop.get_by_id(stop_id)) is not None or (
await StopArea.get_by_id(stop_id)
) is not None:
shape_id = stop_id
if (shape := await StopShape.get_by_id(shape_id)) is not None:
return StopShapeSchema(
id=shape.id,
type=shape.type,
epsg3857_bbox=shape.epsg3857_bbox,
epsg3857_points=shape.epsg3857_points,
)
msg = f"No shape found for stop {stop_id}"
raise HTTPException(status_code=404, detail=msg)

View File

@@ -0,0 +1,14 @@
from .line import Line, TransportMode
from .next_passage import NextPassage, NextPassages
from .stop import Stop, StopArea, StopShape
__all__ = [
"Line",
"NextPassage",
"NextPassages",
"Stop",
"StopArea",
"StopShape",
"TransportMode",
]

View File

@@ -0,0 +1,60 @@
from enum import StrEnum
from pydantic import BaseModel
from idfm_interface import (
IdfmLineState,
IdfmState,
TransportMode as IdfmTransportMode,
TransportSubMode as IdfmTransportSubMode,
)
class TransportMode(StrEnum):
"""Computed transport mode from
idfm_interface.TransportMode and idfm_interface.TransportSubMode.
"""
bus = "bus"
tram = "tram"
metro = "metro"
funicular = "funicular"
# idfm_types.TransportMode.rail + idfm_types.TransportSubMode.regionalRail
rail_ter = "ter"
# idfm_types.TransportMode.rail + idfm_types.TransportSubMode.local
rail_rer = "rer"
# idfm_types.TransportMode.rail + idfm_types.TransportSubMode.suburbanRailway
rail_transilien = "transilien"
# idfm_types.TransportMode.rail + idfm_types.TransportSubMode.railShuttle
val = "val"
# Self return type replaced by "TransportMode" to fix following mypy error:
# Incompatible return value type (got "TransportMode", expected "Self")
# TODO: Is it the good fix ?
@classmethod
def from_idfm_transport_mode(cls, mode: str, sub_mode: str) -> "TransportMode":
if mode == IdfmTransportMode.rail:
if sub_mode == IdfmTransportSubMode.regionalRail:
return cls.rail_ter
if sub_mode == IdfmTransportSubMode.local:
return cls.rail_rer
if sub_mode == IdfmTransportSubMode.suburbanRailway:
return cls.rail_transilien
if sub_mode == IdfmTransportSubMode.railShuttle:
return cls.val
return cls(mode)
class Line(BaseModel):
id: int
shortName: str
name: str
status: IdfmLineState
transportMode: TransportMode
backColorHexa: str
foreColorHexa: str
operatorId: int
accessibility: IdfmState
visualSignsAvailable: IdfmState
audibleSignsAvailable: IdfmState
stopIds: list[int]

View File

@@ -0,0 +1,22 @@
from pydantic import BaseModel
from idfm_interface.idfm_types import TrainStatus
class NextPassage(BaseModel):
line: int
operator: str
destinations: list[str]
atStop: bool
aimedArrivalTs: int | None
expectedArrivalTs: int | None
arrivalPlatformName: str | None
aimedDepartTs: int | None
expectedDepartTs: int | None
arrivalStatus: TrainStatus
departStatus: TrainStatus
class NextPassages(BaseModel):
ts: int
passages: dict[int, dict[str, list[NextPassage]]]

View File

@@ -0,0 +1,31 @@
from pydantic import BaseModel
from idfm_interface import StopAreaType
class Stop(BaseModel):
id: int
name: str
town: str
epsg3857_x: float
epsg3857_y: float
lines: list[int]
class StopArea(BaseModel):
id: int
name: str
town: str
type: StopAreaType
lines: list[int] # SNCF lines are linked to stop areas and not stops.
stops: list[Stop]
Point = tuple[float, float]
class StopShape(BaseModel):
id: int
type: int
epsg3857_bbox: list[Point]
epsg3857_points: list[Point]

74
backend/api/settings.py Normal file
View File

@@ -0,0 +1,74 @@
from __future__ import annotations
from typing import Annotated
from pydantic import BaseModel, SecretStr
from pydantic.functional_validators import model_validator
from pydantic_settings import (
BaseSettings,
PydanticBaseSettingsSource,
SettingsConfigDict,
)
class HttpSettings(BaseModel):
host: str = "127.0.0.1"
port: int = 8080
cert: str | None = None
class DatabaseSettings(BaseModel):
name: str
host: str
port: int
driver: str = "postgresql+psycopg"
user: str
password: Annotated[SecretStr, check_user_password]
class CacheSettings(BaseModel):
enable: bool = False
host: str = "127.0.0.1"
port: int = 6379
user: str | None = None
password: Annotated[SecretStr | None, check_user_password] = None
@model_validator(mode="after")
def check_user_password(self) -> DatabaseSettings | CacheSettings:
if self.user is not None and self.password is None:
raise ValueError("user is set, password shall be set too.")
if self.password is not None and self.user is None:
raise ValueError("password is set, user shall be set too.")
return self
class TracingSettings(BaseModel):
enable: bool = False
class Settings(BaseSettings):
app_name: str
idfm_api_key: SecretStr
clear_static_data: bool
http: HttpSettings
db: DatabaseSettings
cache: CacheSettings
tracing: TracingSettings
model_config = SettingsConfigDict(env_prefix="CER__", env_nested_delimiter="__")
@classmethod
def settings_customise_sources(
cls,
settings_cls: type[BaseSettings],
init_settings: PydanticBaseSettingsSource,
env_settings: PydanticBaseSettingsSource,
dotenv_settings: PydanticBaseSettingsSource,
file_secret_settings: PydanticBaseSettingsSource,
) -> tuple[PydanticBaseSettingsSource, ...]:
return env_settings, init_settings, file_secret_settings