Files
carrramba-encore-rate/backend/main.py

226 lines
7.0 KiB
Python

from collections import defaultdict
from datetime import datetime
from os import environ, EX_USAGE
from typing import Sequence
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from rich import print
from backend.db import db
from backend.models import Line, Stop, StopArea
from backend.idfm_interface import Destinations as IdfmDestinations, IdfmInterface
from backend.schemas import (
Line as LineSchema,
TransportMode,
NextPassage as NextPassageSchema,
NextPassages as NextPassagesSchema,
Stop as StopSchema,
StopArea as StopAreaSchema,
)
API_KEY = environ.get("API_KEY")
if API_KEY is None:
print('No "API_KEY" environment variable set... abort.')
exit(EX_USAGE)
# TODO: Remove postgresql+asyncpg from environ variable
DB_PATH = "postgresql+asyncpg://cer_user:cer_password@127.0.0.1:5438/cer_db"
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=[
"https://localhost:4443",
"https://localhost:3000",
],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
idfm_interface = IdfmInterface(API_KEY, db)
@app.on_event("startup")
async def startup():
await db.connect(DB_PATH, clear_static_data=True)
await idfm_interface.startup()
# await db.connect(DB_PATH, clear_static_data=False)
print("Connected")
@app.on_event("shutdown")
async def shutdown():
await db.disconnect()
# /addwidget https://localhost:4443/static/#?widgetId=$matrix_widget_id&userId=$matrix_user_id
# /addwidget https://localhost:3000/widget?widgetId=$matrix_widget_id&userId=$matrix_user_id
STATIC_ROOT = "../frontend/"
app.mount("/widget", StaticFiles(directory=STATIC_ROOT, html=True), name="widget")
def optional_datetime_to_ts(dt: datetime | None) -> int | None:
return int(dt.timestamp()) if dt else None
@app.get("/line/{line_id}", response_model=LineSchema)
async def get_line(line_id: str) -> LineSchema:
line: Line | None = await Line.get_by_id(line_id)
if line is None:
raise HTTPException(status_code=404, detail=f'Line "{line_id}" not found')
return LineSchema(
id=line.id,
shortName=line.short_name,
name=line.name,
status=line.status,
transportMode=TransportMode.from_idfm_transport_mode(
line.transport_mode, line.transport_submode
),
backColorHexa=line.colour_web_hexa,
foreColorHexa=line.text_colour_hexa,
operatorId=line.operator_id,
accessibility=line.accessibility,
visualSignsAvailable=line.visual_signs_available,
audibleSignsAvailable=line.audible_signs_available,
stopIds=[stop.id for stop in line.stops],
)
def _format_stop(stop: Stop) -> StopSchema:
# print(stop.__dict__)
return StopSchema(
id=stop.id,
name=stop.name,
town=stop.town_name,
# xepsg2154=stop.xepsg2154,
# yepsg2154=stop.yepsg2154,
lat=stop.latitude,
lon=stop.longitude,
lines=[line.id for line in stop.lines],
)
# châtelet
@app.get("/stop/")
async def get_stop(
name: str = "", limit: int = 10
) -> Sequence[StopAreaSchema | StopSchema]:
# TODO: Add limit support
formatted: list[StopAreaSchema | StopSchema] = []
matching_stops = await Stop.get_by_name(name)
# print(matching_stops, flush=True)
stop_areas: dict[int, StopArea] = {}
stops: dict[int, Stop] = {}
for stop in matching_stops:
# print(f"{stop.__dict__ = }", flush=True)
dst = stop_areas if isinstance(stop, StopArea) else stops
dst[stop.id] = stop
for stop_area in stop_areas.values():
formatted_stops = []
for stop in stop_area.stops:
formatted_stops.append(_format_stop(stop))
try:
del stops[stop.id]
except KeyError as err:
print(err)
formatted.append(
StopAreaSchema(
id=stop_area.id,
name=stop_area.name,
town=stop_area.town_name,
# xepsg2154=stop_area.xepsg2154,
# yepsg2154=stop_area.yepsg2154,
type=stop_area.type,
lines=[line.id for line in stop_area.lines],
stops=formatted_stops,
)
)
# print(f"{stops = }", flush=True)
formatted.extend(_format_stop(stop) for stop in stops.values())
return formatted
# TODO: Cache response for 30 secs ?
@app.get("/stop/nextPassages/{stop_id}")
async def get_next_passages(stop_id: str) -> NextPassagesSchema | None:
res = await idfm_interface.get_next_passages(stop_id)
if res is None:
return None
service_delivery = res.Siri.ServiceDelivery
stop_monitoring_deliveries = service_delivery.StopMonitoringDelivery
by_line_by_dst_passages: dict[
str, dict[str, list[NextPassageSchema]]
] = defaultdict(lambda: defaultdict(list))
for delivery in stop_monitoring_deliveries:
for stop_visit in delivery.MonitoredStopVisit:
journey = stop_visit.MonitoredVehicleJourney
# re.match will return None if the given journey.LineRef.value is not valid.
try:
line_id = IdfmInterface.LINE_RE.match(journey.LineRef.value).group(1)
except AttributeError as exc:
raise HTTPException(
status_code=404, detail=f'Line "{journey.LineRef.value}" not found'
) from exc
call = journey.MonitoredCall
dst_names = call.DestinationDisplay
dsts = [dst.value for dst in dst_names] if dst_names else []
print(f"{call.ArrivalPlatformName = }")
next_passage = NextPassageSchema(
line=line_id,
operator=journey.OperatorRef.value,
destinations=dsts,
atStop=call.VehicleAtStop,
aimedArrivalTs=optional_datetime_to_ts(call.AimedArrivalTime),
expectedArrivalTs=optional_datetime_to_ts(call.ExpectedArrivalTime),
arrivalPlatformName=call.ArrivalPlatformName.value
if call.ArrivalPlatformName
else None,
aimedDepartTs=optional_datetime_to_ts(call.AimedDepartureTime),
expectedDepartTs=optional_datetime_to_ts(call.ExpectedDepartureTime),
arrivalStatus=call.ArrivalStatus.value,
departStatus=call.DepartureStatus.value,
)
by_line_passages = by_line_by_dst_passages[line_id]
# TODO: by_line_passages[dst].extend(dsts) instead ?
for dst in dsts:
by_line_passages[dst].append(next_passage)
return NextPassagesSchema(
ts=service_delivery.ResponseTimestamp.timestamp(),
passages=by_line_by_dst_passages,
)
@app.get("/stop/{stop_id}/destinations")
async def get_stop_destinations(
stop_id: int,
) -> IdfmDestinations | None:
destinations = await idfm_interface.get_destinations(stop_id)
return destinations