#!/usr/bin/env python3 import logging from collections import defaultdict from datetime import datetime from os import environ, EX_USAGE from typing import Sequence import uvicorn from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from opentelemetry import trace from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor from opentelemetry.instrumentation.logging import LoggingInstrumentor from opentelemetry.sdk.resources import Resource as OtResource from opentelemetry.sdk.trace import TracerProvider as OtTracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor from rich import print from starlette.types import ASGIApp from backend.db import db from backend.idfm_interface import Destinations as IdfmDestinations, IdfmInterface from backend.models import Line, Stop, StopArea, StopShape from backend.schemas import ( Line as LineSchema, TransportMode, NextPassage as NextPassageSchema, NextPassages as NextPassagesSchema, Stop as StopSchema, StopArea as StopAreaSchema, StopShape as StopShapeSchema, ) API_KEY = environ.get("API_KEY") if API_KEY is None: print('No "API_KEY" environment variable set... abort.') exit(EX_USAGE) APP_NAME = environ.get("APP_NAME", "app") MODE = environ.get("MODE", "grpc") COLLECTOR_ENDPOINT_GRPC_ENDPOINT = environ.get( "COLLECTOR_ENDPOINT_GRPC_ENDPOINT", "127.0.0.1:14250" # "jaeger-collector:14250" ) # CREATE DATABASE "carrramba-encore-rate"; # CREATE USER cer WITH ENCRYPTED PASSWORD 'cer_password'; # GRANT ALL PRIVILEGES ON DATABASE "carrramba-encore-rate" TO cer; # \c "carrramba-encore-rate"; # GRANT ALL ON schema public TO cer; # CREATE EXTENSION IF NOT EXISTS pg_trgm; # TODO: Remove postgresql+psycopg from environ variable DB_PATH = "postgresql+psycopg://cer:cer_password@127.0.0.1:5432/carrramba-encore-rate" app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=[ "https://localhost:4443", "https://localhost:3000", ], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) trace.set_tracer_provider(TracerProvider()) trace.get_tracer_provider().add_span_processor(BatchSpanProcessor(OTLPSpanExporter())) tracer = trace.get_tracer(APP_NAME) with tracer.start_as_current_span("foo"): print("Hello world!") idfm_interface = IdfmInterface(API_KEY, db) @app.on_event("startup") async def startup(): await db.connect(DB_PATH, clear_static_data=True) await idfm_interface.startup() # await db.connect(DB_PATH, clear_static_data=False) print("Connected") @app.on_event("shutdown") async def shutdown(): await db.disconnect() STATIC_ROOT = "../frontend/" app.mount("/widget", StaticFiles(directory=STATIC_ROOT, html=True), name="widget") def optional_datetime_to_ts(dt: datetime | None) -> int | None: return int(dt.timestamp()) if dt else None @app.get("/line/{line_id}", response_model=LineSchema) async def get_line(line_id: str) -> LineSchema: line: Line | None = await Line.get_by_id(line_id) if line is None: raise HTTPException(status_code=404, detail=f'Line "{line_id}" not found') return LineSchema( id=line.id, shortName=line.short_name, name=line.name, status=line.status, transportMode=TransportMode.from_idfm_transport_mode( line.transport_mode, line.transport_submode ), backColorHexa=line.colour_web_hexa, foreColorHexa=line.text_colour_hexa, operatorId=line.operator_id, accessibility=line.accessibility, visualSignsAvailable=line.visual_signs_available, audibleSignsAvailable=line.audible_signs_available, stopIds=[stop.id for stop in line.stops], ) def _format_stop(stop: Stop) -> StopSchema: return StopSchema( id=stop.id, name=stop.name, town=stop.town_name, epsg3857_x=stop.epsg3857_x, epsg3857_y=stop.epsg3857_y, lines=[line.id for line in stop.lines], ) @app.get("/stop/") async def get_stop( name: str = "", limit: int = 10 ) -> Sequence[StopAreaSchema | StopSchema]: # TODO: Add limit support formatted: list[StopAreaSchema | StopSchema] = [] matching_stops = await Stop.get_by_name(name) # print(matching_stops, flush=True) stop_areas: dict[int, StopArea] = {} stops: dict[int, Stop] = {} for stop in matching_stops: # print(f"{stop.__dict__ = }", flush=True) dst = stop_areas if isinstance(stop, StopArea) else stops dst[stop.id] = stop for stop_area in stop_areas.values(): formatted_stops = [] for stop in stop_area.stops: formatted_stops.append(_format_stop(stop)) try: del stops[stop.id] except KeyError as err: print(err) formatted.append( StopAreaSchema( id=stop_area.id, name=stop_area.name, town=stop_area.town_name, type=stop_area.type, lines=[line.id for line in stop_area.lines], stops=formatted_stops, ) ) formatted.extend(_format_stop(stop) for stop in stops.values()) return formatted # TODO: Cache response for 30 secs ? @app.get("/stop/{stop_id}/nextPassages") async def get_next_passages(stop_id: int) -> NextPassagesSchema | None: res = await idfm_interface.get_next_passages(stop_id) if res is None: return None service_delivery = res.Siri.ServiceDelivery stop_monitoring_deliveries = service_delivery.StopMonitoringDelivery by_line_by_dst_passages: dict[ str, dict[str, list[NextPassageSchema]] ] = defaultdict(lambda: defaultdict(list)) for delivery in stop_monitoring_deliveries: for stop_visit in delivery.MonitoredStopVisit: journey = stop_visit.MonitoredVehicleJourney # re.match will return None if the given journey.LineRef.value is not valid. try: line_id = IdfmInterface.LINE_RE.match(journey.LineRef.value).group(1) except AttributeError as exc: raise HTTPException( status_code=404, detail=f'Line "{journey.LineRef.value}" not found' ) from exc call = journey.MonitoredCall dst_names = call.DestinationDisplay dsts = [dst.value for dst in dst_names] if dst_names else [] arrivalPlatformName = ( call.ArrivalPlatformName.value if call.ArrivalPlatformName else None ) next_passage = NextPassageSchema( line=line_id, operator=journey.OperatorRef.value, destinations=dsts, atStop=call.VehicleAtStop, aimedArrivalTs=optional_datetime_to_ts(call.AimedArrivalTime), expectedArrivalTs=optional_datetime_to_ts(call.ExpectedArrivalTime), arrivalPlatformName=arrivalPlatformName, aimedDepartTs=optional_datetime_to_ts(call.AimedDepartureTime), expectedDepartTs=optional_datetime_to_ts(call.ExpectedDepartureTime), arrivalStatus=call.ArrivalStatus.value, departStatus=call.DepartureStatus.value, ) by_line_passages = by_line_by_dst_passages[line_id] # TODO: by_line_passages[dst].extend(dsts) instead ? for dst in dsts: by_line_passages[dst].append(next_passage) return NextPassagesSchema( ts=service_delivery.ResponseTimestamp.timestamp(), passages=by_line_by_dst_passages, ) @app.get("/stop/{stop_id}/destinations") async def get_stop_destinations( stop_id: int, ) -> IdfmDestinations | None: destinations = await idfm_interface.get_destinations(stop_id) return destinations @app.get("/stop/{stop_id}/shape") async def get_stop_shape(stop_id: int) -> StopShapeSchema | None: connection_area = None if (stop := await Stop.get_by_id(stop_id)) is not None: connection_area = stop.connection_area elif (stop_area := await StopArea.get_by_id(stop_id)) is not None: connection_areas = {stop.connection_area for stop in stop_area.stops} connection_areas_len = len(connection_areas) if connection_areas_len == 1: connection_area = connection_areas.pop() else: prefix = "More than one" if connection_areas_len else "No" msg = f"{prefix} connection area has been found for stop area #{stop_id}" raise HTTPException(status_code=500, detail=msg) if ( connection_area is not None and (shape := await StopShape.get_by_id(connection_area.id)) is not None ): return StopShapeSchema( id=shape.id, type=shape.type, epsg3857_bbox=shape.epsg3857_bbox, epsg3857_points=shape.epsg3857_points, ) msg = f"No shape found for stop {stop_id}" raise HTTPException(status_code=404, detail=msg) FastAPIInstrumentor.instrument_app(app) if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=4443, ssl_certfile="./config/cert.pem")