🎨 Reorganize back-end code
This commit is contained in:
0
backend/api/__init__.py
Normal file
0
backend/api/__init__.py
Normal file
21
backend/api/config.local.yaml
Normal file
21
backend/api/config.local.yaml
Normal file
@@ -0,0 +1,21 @@
|
||||
app_name: carrramba-encore-rate
|
||||
clear_static_data: false
|
||||
|
||||
http:
|
||||
host: 0.0.0.0
|
||||
port: 8080
|
||||
cert: ./config/cert.pem
|
||||
|
||||
db:
|
||||
name: carrramba-encore-rate
|
||||
host: 127.0.0.1
|
||||
port: 5432
|
||||
driver: postgresql+psycopg
|
||||
user: cer
|
||||
password: cer_password
|
||||
|
||||
cache:
|
||||
enable: true
|
||||
|
||||
tracing:
|
||||
enable: false
|
21
backend/api/config.sample.yaml
Normal file
21
backend/api/config.sample.yaml
Normal file
@@ -0,0 +1,21 @@
|
||||
app_name: carrramba-encore-rate
|
||||
clear_static_data: false
|
||||
|
||||
http:
|
||||
host: 0.0.0.0
|
||||
port: 8080
|
||||
# cert: ./config/cert.pem
|
||||
|
||||
db:
|
||||
name: carrramba-encore-rate
|
||||
host: postgres
|
||||
port: 5432
|
||||
user: cer
|
||||
|
||||
cache:
|
||||
enable: true
|
||||
host: redis
|
||||
# TODO: Add user credentials
|
||||
|
||||
tracing:
|
||||
enable: false
|
6
backend/api/db/__init__.py
Normal file
6
backend/api/db/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from .db import Database
|
||||
from .base_class import Base
|
||||
|
||||
__all__ = ["Base"]
|
||||
|
||||
db = Database()
|
58
backend/api/db/base_class.py
Normal file
58
backend/api/db/base_class.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from logging import getLogger
|
||||
from typing import Self, Sequence, TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .db import Database
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
db: Database | None = None
|
||||
|
||||
@classmethod
|
||||
async def add(cls, objs: Sequence[Self]) -> bool:
|
||||
if cls.db is not None and (session := await cls.db.get_session()) is not None:
|
||||
|
||||
try:
|
||||
async with session.begin():
|
||||
session.add_all(objs)
|
||||
|
||||
except IntegrityError as err:
|
||||
logger.warning(err)
|
||||
return await cls.merge(objs)
|
||||
|
||||
except AttributeError as err:
|
||||
logger.error(err)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
async def merge(cls, objs: Sequence[Self]) -> bool:
|
||||
if cls.db is not None and (session := await cls.db.get_session()) is not None:
|
||||
|
||||
async with session.begin():
|
||||
for obj in objs:
|
||||
await session.merge(obj)
|
||||
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
async def get_by_id(cls, id_: int | str) -> Self | None:
|
||||
if cls.db is not None and (session := await cls.db.get_session()) is not None:
|
||||
|
||||
async with session.begin():
|
||||
stmt = select(cls).where(cls.id == id_)
|
||||
res = await session.execute(stmt)
|
||||
return res.scalar_one_or_none()
|
||||
|
||||
return None
|
76
backend/api/db/db.py
Normal file
76
backend/api/db/db.py
Normal file
@@ -0,0 +1,76 @@
|
||||
from asyncio import sleep
|
||||
from logging import getLogger
|
||||
from typing import Annotated, AsyncIterator
|
||||
|
||||
from fastapi import Depends
|
||||
from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.exc import OperationalError, SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
async_sessionmaker,
|
||||
AsyncEngine,
|
||||
AsyncSession,
|
||||
create_async_engine,
|
||||
)
|
||||
|
||||
from .base_class import Base
|
||||
from settings import DatabaseSettings
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
class Database:
|
||||
def __init__(self) -> None:
|
||||
self._async_engine: AsyncEngine | None = None
|
||||
self._async_session_local: async_sessionmaker[AsyncSession] | None = None
|
||||
|
||||
async def get_session(self) -> AsyncSession | None:
|
||||
try:
|
||||
return self._async_session_local() # type: ignore
|
||||
|
||||
except (SQLAlchemyError, AttributeError) as e:
|
||||
logger.exception(e)
|
||||
|
||||
return None
|
||||
|
||||
# TODO: Preserve UserLastStopSearchResults table from drop.
|
||||
async def connect(
|
||||
self, settings: DatabaseSettings, clear_static_data: bool = False
|
||||
) -> bool:
|
||||
password = settings.password
|
||||
path = (
|
||||
f"{settings.driver}://{settings.user}:"
|
||||
f"{password.get_secret_value() if password is not None else ''}"
|
||||
f"@{settings.host}:{settings.port}/{settings.name}"
|
||||
)
|
||||
self._async_engine = create_async_engine(
|
||||
path, pool_pre_ping=True, pool_size=10, max_overflow=20
|
||||
)
|
||||
|
||||
if self._async_engine is not None:
|
||||
SQLAlchemyInstrumentor().instrument(engine=self._async_engine.sync_engine)
|
||||
|
||||
self._async_session_local = async_sessionmaker(
|
||||
bind=self._async_engine,
|
||||
# autoflush=False,
|
||||
expire_on_commit=False,
|
||||
class_=AsyncSession,
|
||||
)
|
||||
|
||||
ret = False
|
||||
while not ret:
|
||||
try:
|
||||
async with self._async_engine.begin() as session:
|
||||
if clear_static_data:
|
||||
await session.run_sync(Base.metadata.drop_all)
|
||||
await session.run_sync(Base.metadata.create_all)
|
||||
ret = True
|
||||
except OperationalError as err:
|
||||
logger.error(err)
|
||||
await sleep(1)
|
||||
|
||||
return True
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
if self._async_engine is not None:
|
||||
await self._async_engine.dispose()
|
38
backend/api/dependencies.py
Normal file
38
backend/api/dependencies.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from os import environ
|
||||
|
||||
from fastapi_cache.backends.redis import RedisBackend
|
||||
from redis import asyncio as aioredis
|
||||
from yaml import safe_load
|
||||
|
||||
from db import db
|
||||
from idfm_interface.idfm_interface import IdfmInterface
|
||||
from settings import CacheSettings, Settings
|
||||
|
||||
|
||||
CONFIG_PATH = environ.get("CONFIG_PATH", "./config.sample.yaml")
|
||||
|
||||
|
||||
def load_settings(path: str) -> Settings:
|
||||
with open(path, "r") as config_file:
|
||||
config = safe_load(config_file)
|
||||
|
||||
return Settings(**config)
|
||||
|
||||
|
||||
settings = load_settings(CONFIG_PATH)
|
||||
|
||||
idfm_interface = IdfmInterface(settings.idfm_api_key.get_secret_value(), db)
|
||||
|
||||
|
||||
def init_redis_backend(settings: CacheSettings) -> RedisBackend:
|
||||
login = f"{settings.user}:{settings.password}@" if settings.user is not None else ""
|
||||
|
||||
url = f"redis://{login}{settings.host}:{settings.port}"
|
||||
|
||||
redis_connections_pool = aioredis.from_url(
|
||||
url, encoding="utf8", decode_responses=True
|
||||
)
|
||||
return RedisBackend(redis_connections_pool)
|
||||
|
||||
|
||||
redis_backend = init_redis_backend(settings.cache)
|
67
backend/api/idfm_interface/__init__.py
Normal file
67
backend/api/idfm_interface/__init__.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from .idfm_types import (
|
||||
Coordinate,
|
||||
Destinations,
|
||||
FramedVehicleJourney,
|
||||
IdfmLineState,
|
||||
IdfmOperator,
|
||||
IdfmResponse,
|
||||
IdfmState,
|
||||
LinePicto,
|
||||
LineFields,
|
||||
Line,
|
||||
MonitoredCall,
|
||||
MonitoredVehicleJourney,
|
||||
Point,
|
||||
Siri,
|
||||
ServiceDelivery,
|
||||
Stop,
|
||||
StopArea,
|
||||
StopAreaFields,
|
||||
StopAreaStopAssociation,
|
||||
StopAreaStopAssociationFields,
|
||||
StopAreaType,
|
||||
StopDelivery,
|
||||
StopFields,
|
||||
StopLineAsso,
|
||||
StopLineAssoFields,
|
||||
StopMonitoringDelivery,
|
||||
TrainNumber,
|
||||
TrainStatus,
|
||||
TransportMode,
|
||||
TransportSubMode,
|
||||
Value,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Coordinate",
|
||||
"Destinations",
|
||||
"FramedVehicleJourney",
|
||||
"IdfmLineState",
|
||||
"IdfmOperator",
|
||||
"IdfmResponse",
|
||||
"IdfmState",
|
||||
"LinePicto",
|
||||
"LineFields",
|
||||
"Line",
|
||||
"MonitoredCall",
|
||||
"MonitoredVehicleJourney",
|
||||
"Point",
|
||||
"Siri",
|
||||
"ServiceDelivery",
|
||||
"Stop",
|
||||
"StopArea",
|
||||
"StopAreaFields",
|
||||
"StopAreaStopAssociation",
|
||||
"StopAreaStopAssociationFields",
|
||||
"StopAreaType",
|
||||
"StopDelivery",
|
||||
"StopFields",
|
||||
"StopLineAsso",
|
||||
"StopLineAssoFields",
|
||||
"StopMonitoringDelivery",
|
||||
"TrainNumber",
|
||||
"TrainStatus",
|
||||
"TransportMode",
|
||||
"TransportSubMode",
|
||||
"Value",
|
||||
]
|
115
backend/api/idfm_interface/idfm_interface.py
Normal file
115
backend/api/idfm_interface/idfm_interface.py
Normal file
@@ -0,0 +1,115 @@
|
||||
from collections import defaultdict
|
||||
from re import compile as re_compile
|
||||
from typing import ByteString
|
||||
|
||||
from aiofiles import open as async_open
|
||||
from aiohttp import ClientSession
|
||||
from msgspec import ValidationError
|
||||
from msgspec.json import Decoder
|
||||
|
||||
from .idfm_types import Destinations as IdfmDestinations, IdfmResponse, IdfmState
|
||||
from db import Database
|
||||
from models import Line, Stop, StopArea
|
||||
|
||||
|
||||
class IdfmInterface:
|
||||
|
||||
IDFM_ROOT_URL = "https://prim.iledefrance-mobilites.fr/marketplace"
|
||||
IDFM_STOP_MON_URL = f"{IDFM_ROOT_URL}/stop-monitoring"
|
||||
|
||||
OPERATOR_RE = re_compile(r"[^:]+:Operator::([^:]+):")
|
||||
LINE_RE = re_compile(r"[^:]+:Line::C([^:]+):")
|
||||
|
||||
def __init__(self, api_key: str, database: Database) -> None:
|
||||
self._api_key = api_key
|
||||
self._database = database
|
||||
|
||||
self._http_headers = {"Accept": "application/json", "apikey": self._api_key}
|
||||
|
||||
self._response_json_decoder = Decoder(type=IdfmResponse)
|
||||
|
||||
async def startup(self) -> None:
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
def _format_line_id(line_id: str) -> int:
|
||||
return int(line_id[1:] if line_id[0] == "C" else line_id)
|
||||
|
||||
async def render_line_picto(self, line: Line) -> tuple[None | str, None | str]:
|
||||
line_picto_path = line_picto_format = None
|
||||
target = f"/tmp/{line.id}_repr"
|
||||
|
||||
picto = line.picto
|
||||
if picto is not None:
|
||||
if (picto_data := await self._get_line_picto(line)) is not None:
|
||||
async with async_open(target, "wb") as fd:
|
||||
await fd.write(bytes(picto_data))
|
||||
line_picto_path = target
|
||||
line_picto_format = picto.mime_type
|
||||
|
||||
return (line_picto_path, line_picto_format)
|
||||
|
||||
async def _get_line_picto(self, line: Line) -> ByteString | None:
|
||||
data = None
|
||||
|
||||
picto = line.picto
|
||||
if picto is not None and picto.url is not None:
|
||||
headers = (
|
||||
self._http_headers if picto.url.startswith(self.IDFM_ROOT_URL) else None
|
||||
)
|
||||
|
||||
async with ClientSession(headers=headers) as session:
|
||||
async with session.get(picto.url) as response:
|
||||
data = await response.read()
|
||||
|
||||
return data
|
||||
|
||||
async def get_next_passages(self, stop_point_id: int) -> IdfmResponse | None:
|
||||
ret = None
|
||||
params = {"MonitoringRef": f"STIF:StopPoint:Q:{stop_point_id}:"}
|
||||
|
||||
async with ClientSession(headers=self._http_headers) as session:
|
||||
async with session.get(self.IDFM_STOP_MON_URL, params=params) as response:
|
||||
|
||||
if response.status == 200:
|
||||
data = await response.read()
|
||||
try:
|
||||
ret = self._response_json_decoder.decode(data)
|
||||
except ValidationError as err:
|
||||
print(err)
|
||||
|
||||
return ret
|
||||
|
||||
async def get_destinations(self, stop_id: int) -> IdfmDestinations | None:
|
||||
destinations: IdfmDestinations = defaultdict(set)
|
||||
|
||||
if (stop := await Stop.get_by_id(stop_id)) is not None:
|
||||
expected_stop_ids = {stop.id}
|
||||
elif (stop_area := await StopArea.get_by_id(stop_id)) is not None:
|
||||
expected_stop_ids = {stop.id for stop in stop_area.stops}
|
||||
else:
|
||||
return None
|
||||
|
||||
if (res := await self.get_next_passages(stop_id)) is not None:
|
||||
|
||||
for delivery in res.Siri.ServiceDelivery.StopMonitoringDelivery:
|
||||
if delivery.Status == IdfmState.true:
|
||||
for stop_visit in delivery.MonitoredStopVisit:
|
||||
|
||||
monitoring_ref = stop_visit.MonitoringRef.value
|
||||
|
||||
try:
|
||||
monitored_stop_id = int(monitoring_ref.split(":")[-2])
|
||||
except (IndexError, ValueError):
|
||||
print(f"Unable to get stop id from {monitoring_ref}")
|
||||
continue
|
||||
|
||||
journey = stop_visit.MonitoredVehicleJourney
|
||||
if (
|
||||
dst_names := journey.DestinationName
|
||||
) and monitored_stop_id in expected_stop_ids:
|
||||
raw_line_id = journey.LineRef.value.split(":")[-2]
|
||||
line_id = IdfmInterface._format_line_id(raw_line_id)
|
||||
destinations[line_id].add(dst_names[0].value)
|
||||
|
||||
return destinations
|
300
backend/api/idfm_interface/idfm_types.py
Normal file
300
backend/api/idfm_interface/idfm_types.py
Normal file
@@ -0,0 +1,300 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum, StrEnum
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
from msgspec import Struct
|
||||
|
||||
|
||||
class Coordinate(NamedTuple):
|
||||
lat: float
|
||||
lon: float
|
||||
|
||||
|
||||
class IdfmState(Enum):
|
||||
unknown = "unknown"
|
||||
false = "false"
|
||||
partial = "partial"
|
||||
true = "true"
|
||||
|
||||
|
||||
class TrainStatus(Enum):
|
||||
unknown = ""
|
||||
arrived = "arrived"
|
||||
onTime = "onTime"
|
||||
delayed = "delayed"
|
||||
noReport = "noReport"
|
||||
early = "early"
|
||||
cancelled = "cancelled"
|
||||
undefined = "undefined"
|
||||
|
||||
|
||||
class TransportMode(StrEnum):
|
||||
bus = "bus"
|
||||
tram = "tram"
|
||||
metro = "metro"
|
||||
rail = "rail"
|
||||
funicular = "funicular"
|
||||
|
||||
|
||||
class TransportSubMode(Enum):
|
||||
unknown = "unknown"
|
||||
|
||||
localBus = "localBus"
|
||||
regionalBus = "regionalBus"
|
||||
highFrequencyBus = "highFrequencyBus"
|
||||
expressBus = "expressBus"
|
||||
nightBus = "nightBus"
|
||||
demandAndResponseBus = "demandAndResponseBus"
|
||||
airportLinkBus = "airportLinkBus"
|
||||
|
||||
regionalRail = "regionalRail"
|
||||
railShuttle = "railShuttle"
|
||||
suburbanRailway = "suburbanRailway"
|
||||
|
||||
local = "local"
|
||||
|
||||
|
||||
class StopFields(Struct, kw_only=True):
|
||||
arrgeopoint: Coordinate
|
||||
arrtown: str
|
||||
arrcreated: None | datetime = None
|
||||
arryepsg2154: int
|
||||
arrpostalregion: str
|
||||
arrid: str
|
||||
arrxepsg2154: int
|
||||
arraccessibility: IdfmState
|
||||
arrvisualsigns: IdfmState
|
||||
arrtype: TransportMode
|
||||
arrname: str
|
||||
arrversion: str
|
||||
arrchanged: datetime
|
||||
arraudiblesignals: IdfmState
|
||||
|
||||
|
||||
class Point(Struct):
|
||||
coordinates: Coordinate
|
||||
|
||||
|
||||
class Stop(Struct):
|
||||
datasetid: str
|
||||
recordid: str
|
||||
fields: StopFields
|
||||
record_timestamp: datetime
|
||||
# geometry: Union[Point]
|
||||
|
||||
|
||||
Stops = dict[str, Stop]
|
||||
|
||||
|
||||
class StopAreaType(StrEnum):
|
||||
metroStation = "metroStation"
|
||||
onstreetBus = "onstreetBus"
|
||||
onstreetTram = "onstreetTram"
|
||||
railStation = "railStation"
|
||||
|
||||
|
||||
class StopAreaFields(Struct, kw_only=True):
|
||||
zdaname: str
|
||||
zdcid: str
|
||||
zdatown: str
|
||||
zdaversion: str
|
||||
zdaid: str
|
||||
zdacreated: datetime | None = None
|
||||
zdatype: StopAreaType
|
||||
zdayepsg2154: int
|
||||
zdapostalregion: str
|
||||
zdachanged: datetime
|
||||
zdaxepsg2154: int
|
||||
|
||||
|
||||
class StopArea(Struct):
|
||||
datasetid: str
|
||||
recordid: str
|
||||
fields: StopAreaFields
|
||||
record_timestamp: datetime
|
||||
|
||||
|
||||
class ConnectionAreaFields(Struct, kw_only=True):
|
||||
zdcid: str
|
||||
zdcversion: str
|
||||
zdccreated: datetime
|
||||
zdcchanged: datetime
|
||||
zdcname: str
|
||||
zdcxepsg2154: int | None = None
|
||||
zdcyepsg2154: int | None = None
|
||||
zdctown: str
|
||||
zdcpostalregion: str
|
||||
zdctype: StopAreaType
|
||||
|
||||
|
||||
class ConnectionArea(Struct):
|
||||
datasetid: str
|
||||
recordid: str
|
||||
fields: ConnectionAreaFields
|
||||
record_timestamp: datetime
|
||||
|
||||
|
||||
class StopAreaStopAssociationFields(Struct, kw_only=True):
|
||||
arrid: str # TODO: use int ?
|
||||
artid: str | None = None
|
||||
arrversion: str
|
||||
zdcid: str
|
||||
version: int
|
||||
zdaid: str
|
||||
zdaversion: str
|
||||
artversion: str | None = None
|
||||
|
||||
|
||||
class StopAreaStopAssociation(Struct):
|
||||
datasetid: str
|
||||
recordid: str
|
||||
fields: StopAreaStopAssociationFields
|
||||
record_timestamp: datetime
|
||||
|
||||
|
||||
class IdfmLineState(Enum):
|
||||
active = "active"
|
||||
available_soon = "prochainement active"
|
||||
|
||||
|
||||
class LinePicto(Struct, rename={"id_": "id"}):
|
||||
id_: str
|
||||
mimetype: str
|
||||
height: int
|
||||
width: int
|
||||
filename: str
|
||||
thumbnail: bool
|
||||
format: str
|
||||
# color_summary: list[str]
|
||||
|
||||
|
||||
class LineFields(Struct, kw_only=True):
|
||||
name_line: str
|
||||
status: IdfmLineState
|
||||
accessibility: IdfmState
|
||||
shortname_groupoflines: str | None = None
|
||||
transportmode: TransportMode
|
||||
colourweb_hexa: str
|
||||
textcolourprint_hexa: str
|
||||
transportsubmode: TransportSubMode | None = TransportSubMode.unknown
|
||||
operatorref: str | None = None
|
||||
visualsigns_available: IdfmState
|
||||
networkname: str | None = None
|
||||
id_line: str
|
||||
id_groupoflines: str | None = None
|
||||
operatorname: str | None = None
|
||||
audiblesigns_available: IdfmState
|
||||
shortname_line: str
|
||||
picto: LinePicto | None = None
|
||||
|
||||
|
||||
class Line(Struct):
|
||||
datasetid: str
|
||||
recordid: str
|
||||
fields: LineFields
|
||||
record_timestamp: datetime
|
||||
|
||||
@staticmethod
|
||||
def optional_value(value: Any) -> Any:
|
||||
if value:
|
||||
return value.value if isinstance(value, Enum) else value
|
||||
return "NULL"
|
||||
|
||||
|
||||
Lines = dict[str, Line]
|
||||
|
||||
Destinations = dict[str, set[str]]
|
||||
|
||||
|
||||
# TODO: Set structs frozen
|
||||
class StopLineAssoFields(Struct):
|
||||
pointgeo: Coordinate
|
||||
stop_id: str
|
||||
stop_name: str
|
||||
operatorname: str
|
||||
nom_commune: str
|
||||
route_long_name: str
|
||||
id: str
|
||||
stop_lat: str
|
||||
stop_lon: str
|
||||
code_insee: str
|
||||
|
||||
|
||||
class StopLineAsso(Struct):
|
||||
datasetid: str
|
||||
recordid: str
|
||||
fields: StopLineAssoFields
|
||||
# geometry: Union[Point]
|
||||
|
||||
|
||||
class Value(Struct):
|
||||
value: str
|
||||
|
||||
|
||||
class FramedVehicleJourney(Struct):
|
||||
DataFrameRef: Value
|
||||
DatedVehicleJourneyRef: str
|
||||
|
||||
|
||||
class TrainNumber(Struct):
|
||||
TrainNumberRef: list[Value]
|
||||
|
||||
|
||||
class MonitoredCall(Struct, kw_only=True):
|
||||
Order: int | None = None
|
||||
StopPointName: list[Value]
|
||||
VehicleAtStop: bool
|
||||
DestinationDisplay: list[Value]
|
||||
AimedArrivalTime: datetime | None = None
|
||||
ExpectedArrivalTime: datetime | None = None
|
||||
ArrivalPlatformName: Value | None = None
|
||||
AimedDepartureTime: datetime | None = None
|
||||
ExpectedDepartureTime: datetime | None = None
|
||||
ArrivalStatus: TrainStatus | None = None
|
||||
DepartureStatus: TrainStatus | None = None
|
||||
|
||||
|
||||
class MonitoredVehicleJourney(Struct, kw_only=True):
|
||||
LineRef: Value
|
||||
OperatorRef: Value
|
||||
FramedVehicleJourneyRef: FramedVehicleJourney
|
||||
DestinationRef: Value
|
||||
DestinationName: list[Value] | None = None
|
||||
JourneyNote: list[Value] | None = None
|
||||
TrainNumbers: TrainNumber | None = None
|
||||
MonitoredCall: MonitoredCall
|
||||
|
||||
|
||||
class StopDelivery(Struct):
|
||||
RecordedAtTime: datetime
|
||||
ItemIdentifier: str
|
||||
MonitoringRef: Value
|
||||
MonitoredVehicleJourney: MonitoredVehicleJourney
|
||||
|
||||
|
||||
class StopMonitoringDelivery(Struct):
|
||||
ResponseTimestamp: datetime
|
||||
Version: str
|
||||
Status: IdfmState
|
||||
MonitoredStopVisit: list[StopDelivery]
|
||||
|
||||
|
||||
class ServiceDelivery(Struct):
|
||||
ResponseTimestamp: datetime
|
||||
ProducerRef: str
|
||||
ResponseMessageIdentifier: str
|
||||
StopMonitoringDelivery: list[StopMonitoringDelivery]
|
||||
|
||||
|
||||
class Siri(Struct):
|
||||
ServiceDelivery: ServiceDelivery
|
||||
|
||||
|
||||
class IdfmOperator(Enum):
|
||||
SNCF = "SNCF"
|
||||
|
||||
|
||||
class IdfmResponse(Struct):
|
||||
Siri: Siri
|
15
backend/api/idfm_interface/ratp_types.py
Normal file
15
backend/api/idfm_interface/ratp_types.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from msgspec import Struct
|
||||
|
||||
|
||||
class PictoFieldsFile(Struct, rename={"id_": "id"}):
|
||||
id_: str
|
||||
height: int
|
||||
width: int
|
||||
filename: str
|
||||
thumbnail: bool
|
||||
format: str
|
||||
|
||||
|
||||
class Picto(Struct):
|
||||
indices_commerciaux: str
|
||||
noms_des_fichiers: PictoFieldsFile | None = None
|
89
backend/api/main.py
Executable file
89
backend/api/main.py
Executable file
@@ -0,0 +1,89 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import uvicorn
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi_cache import FastAPICache
|
||||
from opentelemetry import trace
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
|
||||
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
|
||||
from opentelemetry.sdk.resources import Resource, SERVICE_NAME
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||
|
||||
from db import db
|
||||
from dependencies import idfm_interface, redis_backend, settings
|
||||
from routers import line, stop
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
FastAPICache.init(redis_backend, prefix="api", enable=settings.cache.enable)
|
||||
|
||||
await db.connect(settings.db, settings.clear_static_data)
|
||||
if settings.clear_static_data:
|
||||
await idfm_interface.startup()
|
||||
|
||||
yield
|
||||
|
||||
await db.disconnect()
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["http://carrramba.adrien.run", "https://carrramba.adrien.run"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["OPTIONS", "GET"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# The cache-control header entry is not managed properly by fastapi-cache:
|
||||
# For now, a request with a cache-control set to no-cache
|
||||
# is interpreted as disabling the use of the server cache.
|
||||
# Cf. Improve Cache-Control header parsing and handling
|
||||
# https://github.com/long2ice/fastapi-cache/issues/144 workaround
|
||||
@app.middleware("http")
|
||||
async def fastapi_cache_issue_144_workaround(request: Request, call_next):
|
||||
entries = request.headers.__dict__["_list"]
|
||||
new_entries = [
|
||||
entry for entry in entries if entry[0].decode().lower() != "cache-control"
|
||||
]
|
||||
|
||||
request.headers.__dict__["_list"] = new_entries
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
app.include_router(line.router)
|
||||
app.include_router(stop.router)
|
||||
|
||||
|
||||
if settings.tracing.enable:
|
||||
FastAPIInstrumentor.instrument_app(app)
|
||||
|
||||
trace.set_tracer_provider(
|
||||
TracerProvider(resource=Resource.create({SERVICE_NAME: settings.app_name}))
|
||||
)
|
||||
trace.get_tracer_provider().add_span_processor(
|
||||
BatchSpanProcessor(OTLPSpanExporter())
|
||||
)
|
||||
tracer = trace.get_tracer(settings.app_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
http_settings = settings.http
|
||||
|
||||
config = uvicorn.Config(
|
||||
app=app,
|
||||
host=http_settings.host,
|
||||
port=http_settings.port,
|
||||
ssl_certfile=http_settings.cert,
|
||||
proxy_headers=True,
|
||||
)
|
||||
|
||||
server = uvicorn.Server(config)
|
||||
|
||||
server.run()
|
14
backend/api/models/__init__.py
Normal file
14
backend/api/models/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from .line import Line, LinePicto
|
||||
from .stop import ConnectionArea, Stop, StopArea, StopShape
|
||||
from .user import UserLastStopSearchResults
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ConnectionArea",
|
||||
"Line",
|
||||
"LinePicto",
|
||||
"Stop",
|
||||
"StopArea",
|
||||
"StopShape",
|
||||
"UserLastStopSearchResults",
|
||||
]
|
196
backend/api/models/line.py
Normal file
196
backend/api/models/line.py
Normal file
@@ -0,0 +1,196 @@
|
||||
from asyncio import gather as asyncio_gather
|
||||
from collections import defaultdict
|
||||
from typing import Iterable, Self, Sequence
|
||||
|
||||
from sqlalchemy import (
|
||||
BigInteger,
|
||||
Boolean,
|
||||
Enum,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
select,
|
||||
String,
|
||||
)
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship, selectinload
|
||||
from sqlalchemy.sql.expression import tuple_
|
||||
|
||||
from db import Base, db
|
||||
from idfm_interface.idfm_types import (
|
||||
IdfmState,
|
||||
IdfmLineState,
|
||||
TransportMode,
|
||||
TransportSubMode,
|
||||
)
|
||||
from .stop import _Stop
|
||||
|
||||
|
||||
class LineStopAssociations(Base):
|
||||
|
||||
id = mapped_column(BigInteger, primary_key=True)
|
||||
line_id = mapped_column(BigInteger, ForeignKey("lines.id"))
|
||||
stop_id = mapped_column(BigInteger, ForeignKey("_stops.id"))
|
||||
|
||||
__tablename__ = "line_stop_associations"
|
||||
|
||||
|
||||
class LinePicto(Base):
|
||||
|
||||
db = db
|
||||
|
||||
id = mapped_column(String, primary_key=True)
|
||||
mime_type = mapped_column(String, nullable=False)
|
||||
height_px = mapped_column(Integer, nullable=False)
|
||||
width_px = mapped_column(Integer, nullable=False)
|
||||
filename = mapped_column(String, nullable=False)
|
||||
url = mapped_column(String, nullable=False)
|
||||
thumbnail = mapped_column(Boolean, nullable=False)
|
||||
format = mapped_column(String, nullable=False)
|
||||
|
||||
__tablename__ = "line_pictos"
|
||||
|
||||
|
||||
class Line(Base):
|
||||
|
||||
db = db
|
||||
|
||||
id = mapped_column(BigInteger, primary_key=True)
|
||||
|
||||
short_name = mapped_column(String)
|
||||
name = mapped_column(String, nullable=False)
|
||||
status = mapped_column(Enum(IdfmLineState), nullable=False)
|
||||
transport_mode = mapped_column(Enum(TransportMode), nullable=False)
|
||||
transport_submode = mapped_column(Enum(TransportSubMode), nullable=False)
|
||||
|
||||
network_name = mapped_column(String)
|
||||
group_of_lines_id = mapped_column(String)
|
||||
group_of_lines_shortname = mapped_column(String)
|
||||
|
||||
colour_web_hexa = mapped_column(String, nullable=False)
|
||||
text_colour_hexa = mapped_column(String, nullable=False)
|
||||
|
||||
operator_id = mapped_column(Integer)
|
||||
operator_name = mapped_column(String)
|
||||
|
||||
accessibility = mapped_column(Enum(IdfmState), nullable=False)
|
||||
visual_signs_available = mapped_column(Enum(IdfmState), nullable=False)
|
||||
audible_signs_available = mapped_column(Enum(IdfmState), nullable=False)
|
||||
|
||||
picto_id = mapped_column(String, ForeignKey("line_pictos.id"))
|
||||
picto: Mapped[LinePicto] = relationship(LinePicto, lazy="selectin")
|
||||
|
||||
record_id = mapped_column(String, nullable=False)
|
||||
record_ts = mapped_column(BigInteger, nullable=False)
|
||||
|
||||
stops: Mapped[list[_Stop]] = relationship(
|
||||
"_Stop",
|
||||
secondary="line_stop_associations",
|
||||
back_populates="lines",
|
||||
lazy="selectin",
|
||||
)
|
||||
|
||||
__tablename__ = "lines"
|
||||
|
||||
@classmethod
|
||||
async def get_by_name(
|
||||
cls, name: str, operator_name: None | str = None
|
||||
) -> Sequence[Self] | None:
|
||||
if (session := await cls.db.get_session()) is not None:
|
||||
|
||||
async with session.begin():
|
||||
filters = {"name": name}
|
||||
if operator_name is not None:
|
||||
filters["operator_name"] = operator_name
|
||||
|
||||
stmt = (
|
||||
select(cls)
|
||||
.filter_by(**filters)
|
||||
.options(selectinload(cls.stops), selectinload(cls.picto))
|
||||
)
|
||||
res = await session.execute(stmt)
|
||||
lines = res.scalars().all()
|
||||
|
||||
return lines
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
async def _add_picto_to_line(cls, line: str | Self, picto: LinePicto) -> None:
|
||||
formatted_line: Self | None = None
|
||||
if isinstance(line, str):
|
||||
if (lines := await cls.get_by_name(line)) is not None:
|
||||
if len(lines) == 1:
|
||||
formatted_line = lines[0]
|
||||
else:
|
||||
for candidate_line in lines:
|
||||
if candidate_line.operator_name == "RATP":
|
||||
formatted_line = candidate_line
|
||||
break
|
||||
else:
|
||||
formatted_line = line
|
||||
|
||||
if isinstance(formatted_line, Line) and formatted_line.picto is None:
|
||||
formatted_line.picto = picto
|
||||
formatted_line.picto_id = picto.id
|
||||
|
||||
@classmethod
|
||||
async def add_pictos(cls, line_to_pictos: Iterable[tuple[str, LinePicto]]) -> bool:
|
||||
if (session := await cls.db.get_session()) is not None:
|
||||
|
||||
async with session.begin():
|
||||
await asyncio_gather(
|
||||
*[
|
||||
cls._add_picto_to_line(line, picto)
|
||||
for line, picto in line_to_pictos
|
||||
]
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
async def add_stops(cls, line_to_stop_ids: Iterable[tuple[str, str, int]]) -> int:
|
||||
if (session := await cls.db.get_session()) is not None:
|
||||
|
||||
async with session.begin():
|
||||
|
||||
line_names_ops, stop_ids = set(), set()
|
||||
for line_name, operator_name, stop_id in line_to_stop_ids:
|
||||
line_names_ops.add((line_name, operator_name))
|
||||
stop_ids.add(stop_id)
|
||||
|
||||
lines_res = await session.execute(
|
||||
select(Line).where(
|
||||
tuple_(Line.name, Line.operator_name).in_(line_names_ops)
|
||||
)
|
||||
)
|
||||
|
||||
lines = defaultdict(list)
|
||||
for line in lines_res.scalars():
|
||||
lines[(line.name, line.operator_name)].append(line)
|
||||
|
||||
stops_res = await session.execute(
|
||||
select(_Stop).where(_Stop.id.in_(stop_ids))
|
||||
)
|
||||
stops = {stop.id: stop for stop in stops_res.scalars()}
|
||||
|
||||
found = 0
|
||||
for line_name, operator_name, stop_id in line_to_stop_ids:
|
||||
if (stop := stops.get(stop_id)) is not None:
|
||||
if (
|
||||
stop_lines := lines.get((line_name, operator_name))
|
||||
) is not None:
|
||||
for stop_line in stop_lines:
|
||||
stop_line.stops.append(stop)
|
||||
found += 1
|
||||
else:
|
||||
print(f"No line found for {line_name}/{operator_name}")
|
||||
else:
|
||||
print(
|
||||
f"No stop found for {stop_id} id"
|
||||
f"(used by {line_name}/{operator_name})"
|
||||
)
|
||||
|
||||
return found
|
||||
|
||||
return 0
|
275
backend/api/models/stop.py
Normal file
275
backend/api/models/stop.py
Normal file
@@ -0,0 +1,275 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from logging import getLogger
|
||||
from typing import Iterable, Sequence, TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import (
|
||||
BigInteger,
|
||||
Computed,
|
||||
desc,
|
||||
Enum,
|
||||
Float,
|
||||
ForeignKey,
|
||||
func,
|
||||
Integer,
|
||||
JSON,
|
||||
select,
|
||||
String,
|
||||
)
|
||||
from sqlalchemy.orm import (
|
||||
mapped_column,
|
||||
Mapped,
|
||||
relationship,
|
||||
selectinload,
|
||||
with_polymorphic,
|
||||
)
|
||||
from sqlalchemy.schema import Index
|
||||
from sqlalchemy_utils.types.ts_vector import TSVectorType
|
||||
|
||||
from db import Base, db
|
||||
from idfm_interface.idfm_types import TransportMode, IdfmState, StopAreaType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .line import Line
|
||||
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
class StopAreaStopAssociations(Base):
|
||||
|
||||
id = mapped_column(BigInteger, primary_key=True)
|
||||
stop_id = mapped_column(BigInteger, ForeignKey("_stops.id"))
|
||||
stop_area_id = mapped_column(BigInteger, ForeignKey("stop_areas.id"))
|
||||
|
||||
__tablename__ = "stop_area_stop_associations"
|
||||
|
||||
|
||||
class _Stop(Base):
|
||||
|
||||
db = db
|
||||
|
||||
id = mapped_column(BigInteger, primary_key=True)
|
||||
kind = mapped_column(String)
|
||||
|
||||
name = mapped_column(String, nullable=False, index=True)
|
||||
town_name = mapped_column(String, nullable=False)
|
||||
postal_region = mapped_column(Integer, nullable=False)
|
||||
epsg3857_x = mapped_column(Float, nullable=False)
|
||||
epsg3857_y = mapped_column(Float, nullable=False)
|
||||
|
||||
version = mapped_column(String, nullable=False)
|
||||
created_ts = mapped_column(BigInteger)
|
||||
changed_ts = mapped_column(BigInteger, nullable=False)
|
||||
|
||||
lines: Mapped[list[Line]] = relationship(
|
||||
"Line",
|
||||
secondary="line_stop_associations",
|
||||
back_populates="stops",
|
||||
lazy="selectin",
|
||||
)
|
||||
areas: Mapped[list["StopArea"]] = relationship(
|
||||
"StopArea",
|
||||
secondary="stop_area_stop_associations",
|
||||
back_populates="stops",
|
||||
)
|
||||
connection_area_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("connection_areas.id"), nullable=True
|
||||
)
|
||||
connection_area: Mapped["ConnectionArea"] = relationship(
|
||||
back_populates="stops", lazy="selectin"
|
||||
)
|
||||
|
||||
names_tsv = mapped_column(
|
||||
TSVectorType("name", "town_name", regconfig="french"),
|
||||
Computed("to_tsvector('french', name || ' ' || town_name)", persisted=True),
|
||||
)
|
||||
|
||||
__tablename__ = "_stops"
|
||||
__mapper_args__ = {"polymorphic_identity": "_stops", "polymorphic_on": kind}
|
||||
__table_args__ = (
|
||||
Index(
|
||||
"names_tsv_idx",
|
||||
names_tsv,
|
||||
postgresql_ops={"name": "gin_trgm_ops"},
|
||||
postgresql_using="gin",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def get_by_name(cls, name: str) -> Sequence[_Stop] | None:
|
||||
if (session := await cls.db.get_session()) is not None:
|
||||
|
||||
async with session.begin():
|
||||
descendants = with_polymorphic(_Stop, "*")
|
||||
|
||||
match_stmt = descendants.names_tsv.match(
|
||||
name, postgresql_regconfig="french"
|
||||
)
|
||||
ranking_stmt = func.ts_rank_cd(
|
||||
descendants.names_tsv, func.plainto_tsquery("french", name)
|
||||
)
|
||||
stmt = (
|
||||
select(descendants).filter(match_stmt).order_by(desc(ranking_stmt))
|
||||
)
|
||||
|
||||
res = await session.execute(stmt)
|
||||
stops = res.scalars().all()
|
||||
|
||||
return stops
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class Stop(_Stop):
|
||||
|
||||
id = mapped_column(BigInteger, ForeignKey("_stops.id"), primary_key=True)
|
||||
|
||||
transport_mode = mapped_column(Enum(TransportMode), nullable=False)
|
||||
accessibility = mapped_column(Enum(IdfmState), nullable=False)
|
||||
visual_signs_available = mapped_column(Enum(IdfmState), nullable=False)
|
||||
audible_signs_available = mapped_column(Enum(IdfmState), nullable=False)
|
||||
|
||||
record_id = mapped_column(String, nullable=False)
|
||||
record_ts = mapped_column(BigInteger, nullable=False)
|
||||
|
||||
__tablename__ = "stops"
|
||||
__mapper_args__ = {"polymorphic_identity": "stops", "polymorphic_load": "inline"}
|
||||
|
||||
|
||||
class StopArea(_Stop):
|
||||
|
||||
id = mapped_column(BigInteger, ForeignKey("_stops.id"), primary_key=True)
|
||||
|
||||
type = mapped_column(Enum(StopAreaType), nullable=False)
|
||||
|
||||
stops: Mapped[list["Stop"]] = relationship(
|
||||
"Stop",
|
||||
secondary="stop_area_stop_associations",
|
||||
back_populates="areas",
|
||||
lazy="selectin",
|
||||
)
|
||||
|
||||
__tablename__ = "stop_areas"
|
||||
__mapper_args__ = {
|
||||
"polymorphic_identity": "stop_areas",
|
||||
"polymorphic_load": "inline",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
async def add_stops(
|
||||
cls, stop_area_to_stop_ids: Iterable[tuple[int, int]]
|
||||
) -> int | None:
|
||||
if (session := await cls.db.get_session()) is not None:
|
||||
|
||||
async with session.begin():
|
||||
|
||||
stop_area_ids, stop_ids = set(), set()
|
||||
for stop_area_id, stop_id in stop_area_to_stop_ids:
|
||||
stop_area_ids.add(stop_area_id)
|
||||
stop_ids.add(stop_id)
|
||||
|
||||
stop_areas_res = await session.scalars(
|
||||
select(StopArea)
|
||||
.where(StopArea.id.in_(stop_area_ids))
|
||||
.options(selectinload(StopArea.stops))
|
||||
)
|
||||
stop_areas: dict[int, StopArea] = {
|
||||
stop_area.id: stop_area for stop_area in stop_areas_res.all()
|
||||
}
|
||||
|
||||
stop_res = await session.execute(
|
||||
select(Stop).where(Stop.id.in_(stop_ids))
|
||||
)
|
||||
stops: dict[int, Stop] = {stop.id: stop for stop in stop_res.scalars()}
|
||||
|
||||
found = 0
|
||||
for stop_area_id, stop_id in stop_area_to_stop_ids:
|
||||
if (stop_area := stop_areas.get(stop_area_id)) is not None:
|
||||
if (stop := stops.get(stop_id)) is not None:
|
||||
stop_area.stops.append(stop)
|
||||
found += 1
|
||||
else:
|
||||
print(f"No stop found for {stop_id} id")
|
||||
else:
|
||||
print(f"No stop area found for {stop_area_id}")
|
||||
|
||||
return found
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class StopShape(Base):
|
||||
|
||||
db = db
|
||||
|
||||
id = mapped_column(BigInteger, primary_key=True) # Same id than ConnectionArea
|
||||
type = mapped_column(Integer, nullable=False)
|
||||
epsg3857_bbox = mapped_column(JSON)
|
||||
epsg3857_points = mapped_column(JSON)
|
||||
|
||||
__tablename__ = "stop_shapes"
|
||||
|
||||
|
||||
class ConnectionArea(Base):
|
||||
|
||||
db = db
|
||||
|
||||
id = mapped_column(BigInteger, primary_key=True)
|
||||
|
||||
name = mapped_column(String, nullable=False)
|
||||
town_name = mapped_column(String, nullable=False)
|
||||
postal_region = mapped_column(String, nullable=False)
|
||||
epsg3857_x = mapped_column(Float, nullable=False)
|
||||
epsg3857_y = mapped_column(Float, nullable=False)
|
||||
transport_mode = mapped_column(Enum(StopAreaType), nullable=False)
|
||||
|
||||
version = mapped_column(String, nullable=False)
|
||||
created_ts = mapped_column(BigInteger)
|
||||
changed_ts = mapped_column(BigInteger, nullable=False)
|
||||
|
||||
stops: Mapped[list["_Stop"]] = relationship(back_populates="connection_area")
|
||||
|
||||
__tablename__ = "connection_areas"
|
||||
|
||||
# TODO: Merge with StopArea.add_stops
|
||||
@classmethod
|
||||
async def add_stops(
|
||||
cls, conn_area_to_stop_ids: Iterable[tuple[int, int]]
|
||||
) -> int | None:
|
||||
if (session := await cls.db.get_session()) is not None:
|
||||
|
||||
async with session.begin():
|
||||
|
||||
conn_area_ids, stop_ids = set(), set()
|
||||
for conn_area_id, stop_id in conn_area_to_stop_ids:
|
||||
conn_area_ids.add(conn_area_id)
|
||||
stop_ids.add(stop_id)
|
||||
|
||||
conn_area_res = await session.execute(
|
||||
select(ConnectionArea)
|
||||
.where(ConnectionArea.id.in_(conn_area_ids))
|
||||
.options(selectinload(ConnectionArea.stops))
|
||||
)
|
||||
conn_areas: dict[int, ConnectionArea] = {
|
||||
conn.id: conn for conn in conn_area_res.scalars()
|
||||
}
|
||||
|
||||
stop_res = await session.execute(
|
||||
select(Stop).where(Stop.id.in_(stop_ids))
|
||||
)
|
||||
stops: dict[int, Stop] = {stop.id: stop for stop in stop_res.scalars()}
|
||||
|
||||
found = 0
|
||||
for conn_area_id, stop_id in conn_area_to_stop_ids:
|
||||
if (conn_area := conn_areas.get(conn_area_id)) is not None:
|
||||
if (stop := stops.get(stop_id)) is not None:
|
||||
conn_area.stops.append(stop)
|
||||
found += 1
|
||||
else:
|
||||
print(f"No stop found for {stop_id} id")
|
||||
else:
|
||||
print(f"No connection area found for {conn_area_id}")
|
||||
|
||||
return found
|
||||
|
||||
return None
|
27
backend/api/models/user.py
Normal file
27
backend/api/models/user.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from sqlalchemy import BigInteger, ForeignKey, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from db import Base, db
|
||||
from .stop import _Stop
|
||||
|
||||
|
||||
class UserLastStopSearchStopAssociations(Base):
|
||||
id = mapped_column(BigInteger, primary_key=True)
|
||||
user_mxid = mapped_column(
|
||||
String, ForeignKey("user_last_stop_search_results.user_mxid")
|
||||
)
|
||||
stop_id = mapped_column(BigInteger, ForeignKey("_stops.id"))
|
||||
|
||||
__tablename__ = "user_last_stop_search_stop_associations"
|
||||
|
||||
|
||||
class UserLastStopSearchResults(Base):
|
||||
db = db
|
||||
|
||||
user_mxid = mapped_column(String, primary_key=True)
|
||||
request_content = mapped_column(String, nullable=False)
|
||||
stops: Mapped[_Stop] = relationship(
|
||||
_Stop, secondary="user_last_stop_search_stop_associations"
|
||||
)
|
||||
|
||||
__tablename__ = "user_last_stop_search_results"
|
0
backend/api/py.typed
Normal file
0
backend/api/py.typed
Normal file
0
backend/api/routers/__init__.py
Normal file
0
backend/api/routers/__init__.py
Normal file
34
backend/api/routers/line.py
Normal file
34
backend/api/routers/line.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi_cache.decorator import cache
|
||||
|
||||
from models import Line
|
||||
from schemas import Line as LineSchema, TransportMode
|
||||
|
||||
|
||||
router = APIRouter(prefix="/line", tags=["line"])
|
||||
|
||||
|
||||
@router.get("/{line_id}", response_model=LineSchema)
|
||||
@cache(namespace="line")
|
||||
async def get_line(line_id: int) -> LineSchema:
|
||||
line: Line | None = await Line.get_by_id(line_id)
|
||||
|
||||
if line is None:
|
||||
raise HTTPException(status_code=404, detail=f'Line "{line_id}" not found')
|
||||
|
||||
return LineSchema(
|
||||
id=line.id,
|
||||
shortName=line.short_name,
|
||||
name=line.name,
|
||||
status=line.status,
|
||||
transportMode=TransportMode.from_idfm_transport_mode(
|
||||
line.transport_mode, line.transport_submode
|
||||
),
|
||||
backColorHexa=line.colour_web_hexa,
|
||||
foreColorHexa=line.text_colour_hexa,
|
||||
operatorId=line.operator_id,
|
||||
accessibility=line.accessibility,
|
||||
visualSignsAvailable=line.visual_signs_available,
|
||||
audibleSignsAvailable=line.audible_signs_available,
|
||||
stopIds=[stop.id for stop in line.stops],
|
||||
)
|
176
backend/api/routers/stop.py
Normal file
176
backend/api/routers/stop.py
Normal file
@@ -0,0 +1,176 @@
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from typing import Sequence
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi_cache.decorator import cache
|
||||
|
||||
from idfm_interface import Destinations as IdfmDestinations, TrainStatus
|
||||
from models import Stop, StopArea, StopShape
|
||||
from schemas import (
|
||||
NextPassage as NextPassageSchema,
|
||||
NextPassages as NextPassagesSchema,
|
||||
Stop as StopSchema,
|
||||
StopArea as StopAreaSchema,
|
||||
StopShape as StopShapeSchema,
|
||||
)
|
||||
from dependencies import idfm_interface
|
||||
|
||||
|
||||
router = APIRouter(prefix="/stop", tags=["stop"])
|
||||
|
||||
|
||||
def _format_stop(stop: Stop) -> StopSchema:
|
||||
return StopSchema(
|
||||
id=stop.id,
|
||||
name=stop.name,
|
||||
town=stop.town_name,
|
||||
epsg3857_x=stop.epsg3857_x,
|
||||
epsg3857_y=stop.epsg3857_y,
|
||||
lines=[line.id for line in stop.lines],
|
||||
)
|
||||
|
||||
|
||||
def optional_datetime_to_ts(dt: datetime | None) -> int | None:
|
||||
return int(dt.timestamp()) if dt else None
|
||||
|
||||
|
||||
# TODO: Add limit support
|
||||
@router.get("/")
|
||||
@cache(namespace="stop")
|
||||
async def get_stop(
|
||||
name: str = "", limit: int = 10
|
||||
) -> Sequence[StopAreaSchema | StopSchema] | None:
|
||||
|
||||
matching_stops = await Stop.get_by_name(name)
|
||||
if matching_stops is None:
|
||||
return None
|
||||
|
||||
formatted: list[StopAreaSchema | StopSchema] = []
|
||||
stop_areas: dict[int, StopArea] = {}
|
||||
stops: dict[int, Stop] = {}
|
||||
for stop in matching_stops:
|
||||
if isinstance(stop, StopArea):
|
||||
stop_areas[stop.id] = stop
|
||||
elif isinstance(stop, Stop):
|
||||
stops[stop.id] = stop
|
||||
|
||||
for stop_area in stop_areas.values():
|
||||
|
||||
formatted_stops = []
|
||||
for stop in stop_area.stops:
|
||||
formatted_stops.append(_format_stop(stop))
|
||||
try:
|
||||
del stops[stop.id]
|
||||
except KeyError as err:
|
||||
print(err)
|
||||
|
||||
formatted.append(
|
||||
StopAreaSchema(
|
||||
id=stop_area.id,
|
||||
name=stop_area.name,
|
||||
town=stop_area.town_name,
|
||||
type=stop_area.type,
|
||||
lines=[line.id for line in stop_area.lines],
|
||||
stops=formatted_stops,
|
||||
)
|
||||
)
|
||||
|
||||
formatted.extend(_format_stop(stop) for stop in stops.values())
|
||||
|
||||
return formatted
|
||||
|
||||
|
||||
@router.get("/{stop_id}/nextPassages")
|
||||
@cache(namespace="stop-nextPassages", expire=30)
|
||||
async def get_next_passages(stop_id: int) -> NextPassagesSchema | None:
|
||||
res = await idfm_interface.get_next_passages(stop_id)
|
||||
if res is None:
|
||||
return None
|
||||
|
||||
service_delivery = res.Siri.ServiceDelivery
|
||||
stop_monitoring_deliveries = service_delivery.StopMonitoringDelivery
|
||||
|
||||
by_line_by_dst_passages: dict[
|
||||
int, dict[str, list[NextPassageSchema]]
|
||||
] = defaultdict(lambda: defaultdict(list))
|
||||
|
||||
for delivery in stop_monitoring_deliveries:
|
||||
for stop_visit in delivery.MonitoredStopVisit:
|
||||
|
||||
journey = stop_visit.MonitoredVehicleJourney
|
||||
|
||||
# re.match will return None if the given journey.LineRef.value is not valid.
|
||||
try:
|
||||
line_id_match = idfm_interface.LINE_RE.match(journey.LineRef.value)
|
||||
line_id = int(line_id_match.group(1)) # type: ignore
|
||||
except (AttributeError, TypeError, ValueError) as err:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f'Line "{journey.LineRef.value}" not found'
|
||||
) from err
|
||||
|
||||
call = journey.MonitoredCall
|
||||
|
||||
dst_names = call.DestinationDisplay
|
||||
dsts = [dst.value for dst in dst_names] if dst_names else []
|
||||
arrivalPlatformName = (
|
||||
call.ArrivalPlatformName.value if call.ArrivalPlatformName else None
|
||||
)
|
||||
|
||||
next_passage = NextPassageSchema(
|
||||
line=line_id,
|
||||
operator=journey.OperatorRef.value,
|
||||
destinations=dsts,
|
||||
atStop=call.VehicleAtStop,
|
||||
aimedArrivalTs=optional_datetime_to_ts(call.AimedArrivalTime),
|
||||
expectedArrivalTs=optional_datetime_to_ts(call.ExpectedArrivalTime),
|
||||
arrivalPlatformName=arrivalPlatformName,
|
||||
aimedDepartTs=optional_datetime_to_ts(call.AimedDepartureTime),
|
||||
expectedDepartTs=optional_datetime_to_ts(call.ExpectedDepartureTime),
|
||||
arrivalStatus=call.ArrivalStatus
|
||||
if call.ArrivalStatus is not None
|
||||
else TrainStatus.unknown,
|
||||
departStatus=call.DepartureStatus
|
||||
if call.DepartureStatus is not None
|
||||
else TrainStatus.unknown,
|
||||
)
|
||||
|
||||
by_line_passages = by_line_by_dst_passages[line_id]
|
||||
# TODO: by_line_passages[dst].extend(dsts) instead ?
|
||||
for dst in dsts:
|
||||
by_line_passages[dst].append(next_passage)
|
||||
|
||||
return NextPassagesSchema(
|
||||
ts=int(service_delivery.ResponseTimestamp.timestamp()),
|
||||
passages=by_line_by_dst_passages,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{stop_id}/destinations")
|
||||
@cache(namespace="stop-destinations", expire=30)
|
||||
async def get_stop_destinations(
|
||||
stop_id: int,
|
||||
) -> IdfmDestinations | None:
|
||||
destinations = await idfm_interface.get_destinations(stop_id)
|
||||
|
||||
return destinations
|
||||
|
||||
|
||||
@router.get("/{stop_id}/shape")
|
||||
@cache(namespace="stop-shape")
|
||||
async def get_stop_shape(stop_id: int) -> StopShapeSchema | None:
|
||||
if (await Stop.get_by_id(stop_id)) is not None or (
|
||||
await StopArea.get_by_id(stop_id)
|
||||
) is not None:
|
||||
shape_id = stop_id
|
||||
|
||||
if (shape := await StopShape.get_by_id(shape_id)) is not None:
|
||||
return StopShapeSchema(
|
||||
id=shape.id,
|
||||
type=shape.type,
|
||||
epsg3857_bbox=shape.epsg3857_bbox,
|
||||
epsg3857_points=shape.epsg3857_points,
|
||||
)
|
||||
|
||||
msg = f"No shape found for stop {stop_id}"
|
||||
raise HTTPException(status_code=404, detail=msg)
|
14
backend/api/schemas/__init__.py
Normal file
14
backend/api/schemas/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from .line import Line, TransportMode
|
||||
from .next_passage import NextPassage, NextPassages
|
||||
from .stop import Stop, StopArea, StopShape
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Line",
|
||||
"NextPassage",
|
||||
"NextPassages",
|
||||
"Stop",
|
||||
"StopArea",
|
||||
"StopShape",
|
||||
"TransportMode",
|
||||
]
|
60
backend/api/schemas/line.py
Normal file
60
backend/api/schemas/line.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from idfm_interface import (
|
||||
IdfmLineState,
|
||||
IdfmState,
|
||||
TransportMode as IdfmTransportMode,
|
||||
TransportSubMode as IdfmTransportSubMode,
|
||||
)
|
||||
|
||||
|
||||
class TransportMode(StrEnum):
|
||||
"""Computed transport mode from
|
||||
idfm_interface.TransportMode and idfm_interface.TransportSubMode.
|
||||
"""
|
||||
|
||||
bus = "bus"
|
||||
tram = "tram"
|
||||
metro = "metro"
|
||||
funicular = "funicular"
|
||||
# idfm_types.TransportMode.rail + idfm_types.TransportSubMode.regionalRail
|
||||
rail_ter = "ter"
|
||||
# idfm_types.TransportMode.rail + idfm_types.TransportSubMode.local
|
||||
rail_rer = "rer"
|
||||
# idfm_types.TransportMode.rail + idfm_types.TransportSubMode.suburbanRailway
|
||||
rail_transilien = "transilien"
|
||||
# idfm_types.TransportMode.rail + idfm_types.TransportSubMode.railShuttle
|
||||
val = "val"
|
||||
|
||||
# Self return type replaced by "TransportMode" to fix following mypy error:
|
||||
# Incompatible return value type (got "TransportMode", expected "Self")
|
||||
# TODO: Is it the good fix ?
|
||||
@classmethod
|
||||
def from_idfm_transport_mode(cls, mode: str, sub_mode: str) -> "TransportMode":
|
||||
if mode == IdfmTransportMode.rail:
|
||||
if sub_mode == IdfmTransportSubMode.regionalRail:
|
||||
return cls.rail_ter
|
||||
if sub_mode == IdfmTransportSubMode.local:
|
||||
return cls.rail_rer
|
||||
if sub_mode == IdfmTransportSubMode.suburbanRailway:
|
||||
return cls.rail_transilien
|
||||
if sub_mode == IdfmTransportSubMode.railShuttle:
|
||||
return cls.val
|
||||
return cls(mode)
|
||||
|
||||
|
||||
class Line(BaseModel):
|
||||
id: int
|
||||
shortName: str
|
||||
name: str
|
||||
status: IdfmLineState
|
||||
transportMode: TransportMode
|
||||
backColorHexa: str
|
||||
foreColorHexa: str
|
||||
operatorId: int
|
||||
accessibility: IdfmState
|
||||
visualSignsAvailable: IdfmState
|
||||
audibleSignsAvailable: IdfmState
|
||||
stopIds: list[int]
|
22
backend/api/schemas/next_passage.py
Normal file
22
backend/api/schemas/next_passage.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from idfm_interface.idfm_types import TrainStatus
|
||||
|
||||
|
||||
class NextPassage(BaseModel):
|
||||
line: int
|
||||
operator: str
|
||||
destinations: list[str]
|
||||
atStop: bool
|
||||
aimedArrivalTs: int | None
|
||||
expectedArrivalTs: int | None
|
||||
arrivalPlatformName: str | None
|
||||
aimedDepartTs: int | None
|
||||
expectedDepartTs: int | None
|
||||
arrivalStatus: TrainStatus
|
||||
departStatus: TrainStatus
|
||||
|
||||
|
||||
class NextPassages(BaseModel):
|
||||
ts: int
|
||||
passages: dict[int, dict[str, list[NextPassage]]]
|
31
backend/api/schemas/stop.py
Normal file
31
backend/api/schemas/stop.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from idfm_interface import StopAreaType
|
||||
|
||||
|
||||
class Stop(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
town: str
|
||||
epsg3857_x: float
|
||||
epsg3857_y: float
|
||||
lines: list[int]
|
||||
|
||||
|
||||
class StopArea(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
town: str
|
||||
type: StopAreaType
|
||||
lines: list[int] # SNCF lines are linked to stop areas and not stops.
|
||||
stops: list[Stop]
|
||||
|
||||
|
||||
Point = tuple[float, float]
|
||||
|
||||
|
||||
class StopShape(BaseModel):
|
||||
id: int
|
||||
type: int
|
||||
epsg3857_bbox: list[Point]
|
||||
epsg3857_points: list[Point]
|
74
backend/api/settings.py
Normal file
74
backend/api/settings.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import BaseModel, SecretStr
|
||||
from pydantic.functional_validators import model_validator
|
||||
from pydantic_settings import (
|
||||
BaseSettings,
|
||||
PydanticBaseSettingsSource,
|
||||
SettingsConfigDict,
|
||||
)
|
||||
|
||||
|
||||
class HttpSettings(BaseModel):
|
||||
host: str = "127.0.0.1"
|
||||
port: int = 8080
|
||||
cert: str | None = None
|
||||
|
||||
|
||||
class DatabaseSettings(BaseModel):
|
||||
name: str
|
||||
host: str
|
||||
port: int
|
||||
driver: str = "postgresql+psycopg"
|
||||
user: str
|
||||
password: Annotated[SecretStr, check_user_password]
|
||||
|
||||
|
||||
class CacheSettings(BaseModel):
|
||||
enable: bool = False
|
||||
host: str = "127.0.0.1"
|
||||
port: int = 6379
|
||||
user: str | None = None
|
||||
password: Annotated[SecretStr | None, check_user_password] = None
|
||||
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_user_password(self) -> DatabaseSettings | CacheSettings:
|
||||
if self.user is not None and self.password is None:
|
||||
raise ValueError("user is set, password shall be set too.")
|
||||
|
||||
if self.password is not None and self.user is None:
|
||||
raise ValueError("password is set, user shall be set too.")
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class TracingSettings(BaseModel):
|
||||
enable: bool = False
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
app_name: str
|
||||
|
||||
idfm_api_key: SecretStr
|
||||
clear_static_data: bool
|
||||
|
||||
http: HttpSettings
|
||||
db: DatabaseSettings
|
||||
cache: CacheSettings
|
||||
tracing: TracingSettings
|
||||
|
||||
model_config = SettingsConfigDict(env_prefix="CER__", env_nested_delimiter="__")
|
||||
|
||||
@classmethod
|
||||
def settings_customise_sources(
|
||||
cls,
|
||||
settings_cls: type[BaseSettings],
|
||||
init_settings: PydanticBaseSettingsSource,
|
||||
env_settings: PydanticBaseSettingsSource,
|
||||
dotenv_settings: PydanticBaseSettingsSource,
|
||||
file_secret_settings: PydanticBaseSettingsSource,
|
||||
) -> tuple[PydanticBaseSettingsSource, ...]:
|
||||
return env_settings, init_settings, file_secret_settings
|
Reference in New Issue
Block a user