from collections import defaultdict from datetime import datetime from os import environ, EX_USAGE from typing import Sequence from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from rich import print from backend.db import db from backend.idfm_interface import Destinations as IdfmDestinations, IdfmInterface from backend.models import Line, Stop, StopArea, StopShape from backend.schemas import ( Line as LineSchema, TransportMode, NextPassage as NextPassageSchema, NextPassages as NextPassagesSchema, Stop as StopSchema, StopArea as StopAreaSchema, StopShape as StopShapeSchema, ) API_KEY = environ.get("API_KEY") if API_KEY is None: print('No "API_KEY" environment variable set... abort.') exit(EX_USAGE) # TODO: Remove postgresql+asyncpg from environ variable DB_PATH = "postgresql+asyncpg://cer_user:cer_password@127.0.0.1:5438/cer_db" app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=[ "https://localhost:4443", "https://localhost:3000", ], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) idfm_interface = IdfmInterface(API_KEY, db) @app.on_event("startup") async def startup(): await db.connect(DB_PATH, clear_static_data=True) await idfm_interface.startup() # await db.connect(DB_PATH, clear_static_data=False) print("Connected") @app.on_event("shutdown") async def shutdown(): await db.disconnect() # /addwidget https://localhost:4443/static/#?widgetId=$matrix_widget_id&userId=$matrix_user_id # /addwidget https://localhost:3000/widget?widgetId=$matrix_widget_id&userId=$matrix_user_id STATIC_ROOT = "../frontend/" app.mount("/widget", StaticFiles(directory=STATIC_ROOT, html=True), name="widget") def optional_datetime_to_ts(dt: datetime | None) -> int | None: return int(dt.timestamp()) if dt else None @app.get("/line/{line_id}", response_model=LineSchema) async def get_line(line_id: str) -> LineSchema: line: Line | None = await Line.get_by_id(line_id) if line is None: raise HTTPException(status_code=404, detail=f'Line "{line_id}" not found') return LineSchema( id=line.id, shortName=line.short_name, name=line.name, status=line.status, transportMode=TransportMode.from_idfm_transport_mode( line.transport_mode, line.transport_submode ), backColorHexa=line.colour_web_hexa, foreColorHexa=line.text_colour_hexa, operatorId=line.operator_id, accessibility=line.accessibility, visualSignsAvailable=line.visual_signs_available, audibleSignsAvailable=line.audible_signs_available, stopIds=[stop.id for stop in line.stops], ) def _format_stop(stop: Stop) -> StopSchema: # print(stop.__dict__) return StopSchema( id=stop.id, name=stop.name, town=stop.town_name, # xepsg2154=stop.xepsg2154, # yepsg2154=stop.yepsg2154, lat=stop.latitude, lon=stop.longitude, lines=[line.id for line in stop.lines], ) # châtelet @app.get("/stop/") async def get_stop( name: str = "", limit: int = 10 ) -> Sequence[StopAreaSchema | StopSchema]: # TODO: Add limit support formatted: list[StopAreaSchema | StopSchema] = [] matching_stops = await Stop.get_by_name(name) # print(matching_stops, flush=True) stop_areas: dict[int, StopArea] = {} stops: dict[int, Stop] = {} for stop in matching_stops: # print(f"{stop.__dict__ = }", flush=True) dst = stop_areas if isinstance(stop, StopArea) else stops dst[stop.id] = stop for stop_area in stop_areas.values(): formatted_stops = [] for stop in stop_area.stops: formatted_stops.append(_format_stop(stop)) try: del stops[stop.id] except KeyError as err: print(err) formatted.append( StopAreaSchema( id=stop_area.id, name=stop_area.name, town=stop_area.town_name, # xepsg2154=stop_area.xepsg2154, # yepsg2154=stop_area.yepsg2154, type=stop_area.type, lines=[line.id for line in stop_area.lines], stops=formatted_stops, ) ) # print(f"{stops = }", flush=True) formatted.extend(_format_stop(stop) for stop in stops.values()) return formatted # TODO: Cache response for 30 secs ? @app.get("/stop/nextPassages/{stop_id}") async def get_next_passages(stop_id: str) -> NextPassagesSchema | None: res = await idfm_interface.get_next_passages(stop_id) if res is None: return None service_delivery = res.Siri.ServiceDelivery stop_monitoring_deliveries = service_delivery.StopMonitoringDelivery by_line_by_dst_passages: dict[ str, dict[str, list[NextPassageSchema]] ] = defaultdict(lambda: defaultdict(list)) for delivery in stop_monitoring_deliveries: for stop_visit in delivery.MonitoredStopVisit: journey = stop_visit.MonitoredVehicleJourney # re.match will return None if the given journey.LineRef.value is not valid. try: line_id = IdfmInterface.LINE_RE.match(journey.LineRef.value).group(1) except AttributeError as exc: raise HTTPException( status_code=404, detail=f'Line "{journey.LineRef.value}" not found' ) from exc call = journey.MonitoredCall dst_names = call.DestinationDisplay dsts = [dst.value for dst in dst_names] if dst_names else [] print(f"{call.ArrivalPlatformName = }") next_passage = NextPassageSchema( line=line_id, operator=journey.OperatorRef.value, destinations=dsts, atStop=call.VehicleAtStop, aimedArrivalTs=optional_datetime_to_ts(call.AimedArrivalTime), expectedArrivalTs=optional_datetime_to_ts(call.ExpectedArrivalTime), arrivalPlatformName=call.ArrivalPlatformName.value if call.ArrivalPlatformName else None, aimedDepartTs=optional_datetime_to_ts(call.AimedDepartureTime), expectedDepartTs=optional_datetime_to_ts(call.ExpectedDepartureTime), arrivalStatus=call.ArrivalStatus.value, departStatus=call.DepartureStatus.value, ) by_line_passages = by_line_by_dst_passages[line_id] # TODO: by_line_passages[dst].extend(dsts) instead ? for dst in dsts: by_line_passages[dst].append(next_passage) return NextPassagesSchema( ts=service_delivery.ResponseTimestamp.timestamp(), passages=by_line_by_dst_passages, ) @app.get("/stop/{stop_id}/destinations") async def get_stop_destinations( stop_id: int, ) -> IdfmDestinations | None: destinations = await idfm_interface.get_destinations(stop_id) return destinations # TODO: Rename endpoint -> /stop/{stop_id}/shape @app.get("/stop_shape/{stop_id}") async def get_stop_shape(stop_id: int) -> StopShapeSchema | None: connection_area = None if (stop := await Stop.get_by_id(stop_id)) is not None: connection_area = stop.connection_area elif (stop_area := await StopArea.get_by_id(stop_id)) is not None: connection_areas = {stop.connection_area for stop in stop_area.stops} connection_areas_len = len(connection_areas) if connection_areas_len == 1: connection_area = connection_areas.pop() else: prefix = "More than one" if connection_areas_len else "No" msg = f"{prefix} connection area has been found for stop area #{stop_id}" raise HTTPException(status_code=500, detail=msg) if ( connection_area is not None and (shape := await StopShape.get_by_id(connection_area.id)) is not None ): return StopShapeSchema( id=shape.id, type=shape.type, bbox=shape.bounding_box, points=shape.points ) msg = f"No shape found for stop {stop_id}" raise HTTPException(status_code=404, detail=msg)