#!/usr/bin/env python3 import logging from collections import defaultdict from datetime import datetime from os import environ, EX_USAGE from typing import Sequence import uvicorn from contextlib import asynccontextmanager from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from opentelemetry import trace from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor from opentelemetry.sdk.resources import Resource, SERVICE_NAME from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor from rich import print from yaml import safe_load from backend.db import db from backend.idfm_interface import Destinations as IdfmDestinations, IdfmInterface from backend.models import Line, Stop, StopArea, StopShape from backend.schemas import ( Line as LineSchema, TransportMode, NextPassage as NextPassageSchema, NextPassages as NextPassagesSchema, Stop as StopSchema, StopArea as StopAreaSchema, StopShape as StopShapeSchema, ) from backend.settings import Settings CONFIG_PATH = environ.get("CONFIG_PATH", "./config.sample.yaml") def load_settings(path: str) -> Settings: with open(path, "r") as config_file: config = safe_load(config_file) return Settings(**config) settings = load_settings(CONFIG_PATH) @asynccontextmanager async def lifespan(app: FastAPI): await db.connect(settings.db, settings.clear_static_data) if settings.clear_static_data: await idfm_interface.startup() yield await db.disconnect() app = FastAPI(lifespan=lifespan) app.add_middleware( CORSMiddleware, allow_origins=["https://localhost:4443", "https://localhost:3000"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) app.mount("/widget", StaticFiles(directory="../frontend/", html=True), name="widget") FastAPIInstrumentor.instrument_app(app) trace.set_tracer_provider( TracerProvider(resource=Resource.create({SERVICE_NAME: settings.app_name})) ) trace.get_tracer_provider().add_span_processor(BatchSpanProcessor(OTLPSpanExporter())) tracer = trace.get_tracer(settings.app_name) idfm_interface = IdfmInterface(settings.idfm_api_key.get_secret_value(), db) def optional_datetime_to_ts(dt: datetime | None) -> int | None: return int(dt.timestamp()) if dt else None @app.get("/line/{line_id}", response_model=LineSchema) async def get_line(line_id: int) -> LineSchema: line: Line | None = await Line.get_by_id(line_id) if line is None: raise HTTPException(status_code=404, detail=f'Line "{line_id}" not found') return LineSchema( id=line.id, shortName=line.short_name, name=line.name, status=line.status, transportMode=TransportMode.from_idfm_transport_mode( line.transport_mode, line.transport_submode ), backColorHexa=line.colour_web_hexa, foreColorHexa=line.text_colour_hexa, operatorId=line.operator_id, accessibility=line.accessibility, visualSignsAvailable=line.visual_signs_available, audibleSignsAvailable=line.audible_signs_available, stopIds=[stop.id for stop in line.stops], ) def _format_stop(stop: Stop) -> StopSchema: return StopSchema( id=stop.id, name=stop.name, town=stop.town_name, epsg3857_x=stop.epsg3857_x, epsg3857_y=stop.epsg3857_y, lines=[line.id for line in stop.lines], ) @app.get("/stop/") async def get_stop( name: str = "", limit: int = 10 ) -> Sequence[StopAreaSchema | StopSchema]: # TODO: Add limit support formatted: list[StopAreaSchema | StopSchema] = [] matching_stops = await Stop.get_by_name(name) # print(matching_stops, flush=True) stop_areas: dict[int, StopArea] = {} stops: dict[int, Stop] = {} for stop in matching_stops: # print(f"{stop.__dict__ = }", flush=True) dst = stop_areas if isinstance(stop, StopArea) else stops dst[stop.id] = stop for stop_area in stop_areas.values(): formatted_stops = [] for stop in stop_area.stops: formatted_stops.append(_format_stop(stop)) try: del stops[stop.id] except KeyError as err: print(err) formatted.append( StopAreaSchema( id=stop_area.id, name=stop_area.name, town=stop_area.town_name, type=stop_area.type, lines=[line.id for line in stop_area.lines], stops=formatted_stops, ) ) formatted.extend(_format_stop(stop) for stop in stops.values()) return formatted # TODO: Cache response for 30 secs ? @app.get("/stop/{stop_id}/nextPassages") async def get_next_passages(stop_id: int) -> NextPassagesSchema | None: res = await idfm_interface.get_next_passages(stop_id) if res is None: return None service_delivery = res.Siri.ServiceDelivery stop_monitoring_deliveries = service_delivery.StopMonitoringDelivery by_line_by_dst_passages: dict[ str, dict[str, list[NextPassageSchema]] ] = defaultdict(lambda: defaultdict(list)) for delivery in stop_monitoring_deliveries: for stop_visit in delivery.MonitoredStopVisit: journey = stop_visit.MonitoredVehicleJourney # re.match will return None if the given journey.LineRef.value is not valid. try: line_id_match = IdfmInterface.LINE_RE.match(journey.LineRef.value) line_id = int(line_id_match.group(1)) # type: ignore except (AttributeError, TypeError, ValueError) as err: raise HTTPException( status_code=404, detail=f'Line "{journey.LineRef.value}" not found' ) from err call = journey.MonitoredCall dst_names = call.DestinationDisplay dsts = [dst.value for dst in dst_names] if dst_names else [] arrivalPlatformName = ( call.ArrivalPlatformName.value if call.ArrivalPlatformName else None ) next_passage = NextPassageSchema( line=line_id, operator=journey.OperatorRef.value, destinations=dsts, atStop=call.VehicleAtStop, aimedArrivalTs=optional_datetime_to_ts(call.AimedArrivalTime), expectedArrivalTs=optional_datetime_to_ts(call.ExpectedArrivalTime), arrivalPlatformName=arrivalPlatformName, aimedDepartTs=optional_datetime_to_ts(call.AimedDepartureTime), expectedDepartTs=optional_datetime_to_ts(call.ExpectedDepartureTime), arrivalStatus=call.ArrivalStatus.value, departStatus=call.DepartureStatus.value, ) by_line_passages = by_line_by_dst_passages[line_id] # TODO: by_line_passages[dst].extend(dsts) instead ? for dst in dsts: by_line_passages[dst].append(next_passage) return NextPassagesSchema( ts=service_delivery.ResponseTimestamp.timestamp(), passages=by_line_by_dst_passages, ) @app.get("/stop/{stop_id}/destinations") async def get_stop_destinations( stop_id: int, ) -> IdfmDestinations | None: destinations = await idfm_interface.get_destinations(stop_id) return destinations @app.get("/stop/{stop_id}/shape") async def get_stop_shape(stop_id: int) -> StopShapeSchema | None: connection_area = None if (stop := await Stop.get_by_id(stop_id)) is not None: connection_area = stop.connection_area elif (stop_area := await StopArea.get_by_id(stop_id)) is not None: connection_areas = {stop.connection_area for stop in stop_area.stops} connection_areas_len = len(connection_areas) if connection_areas_len == 1: connection_area = connection_areas.pop() else: prefix = "More than one" if connection_areas_len else "No" msg = f"{prefix} connection area has been found for stop area #{stop_id}" raise HTTPException(status_code=500, detail=msg) if ( connection_area is not None and (shape := await StopShape.get_by_id(connection_area.id)) is not None ): return StopShapeSchema( id=shape.id, type=shape.type, epsg3857_bbox=shape.epsg3857_bbox, epsg3857_points=shape.epsg3857_points, ) msg = f"No shape found for stop {stop_id}" raise HTTPException(status_code=404, detail=msg) if __name__ == "__main__": http_settings = settings.http uvicorn.run( app, host=http_settings.host, port=http_settings.port, ssl_certfile=http_settings.cert, )