from collections import defaultdict from datetime import datetime from os import environ from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from fastapi.staticfiles import StaticFiles from rich import print from backend.db import db from backend.idfm_interface import IdfmInterface from backend.models import Line, Stop, StopArea from backend.schemas import ( Line as LineSchema, TransportMode, NextPassage as NextPassageSchema, NextPassages as NextPassagesSchema, Stop as StopSchema, StopArea as StopAreaSchema, ) API_KEY = environ.get("API_KEY") # TODO: Add error message if no key is given. # TODO: Remove postgresql+asyncpg from environ variable DB_PATH = "postgresql+asyncpg://cer_user:cer_password@127.0.0.1:5438/cer_db" app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=[ "https://localhost:4443", "https://localhost:3000", ], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) idfm_interface = IdfmInterface(API_KEY, db) @app.on_event("startup") async def startup(): # await db.connect(DB_PATH, clear_static_data=True) # await idfm_interface.startup() await db.connect(DB_PATH, clear_static_data=False) print("Connected") @app.on_event("shutdown") async def shutdown(): await db.disconnect() # /addwidget https://localhost:4443/static/#?widgetId=$matrix_widget_id&userId=$matrix_user_id # /addwidget https://localhost:3000/widget?widgetId=$matrix_widget_id&userId=$matrix_user_id STATIC_ROOT = "../frontend/" app.mount("/widget", StaticFiles(directory=STATIC_ROOT, html=True), name="widget") def optional_datetime_to_ts(dt: datetime) -> int | None: return dt.timestamp() if dt else None @app.get("/line/{line_id}", response_model=LineSchema) async def get_line(line_id: str) -> JSONResponse: line: Line | None = await Line.get_by_id(line_id) if line is None: raise HTTPException(status_code=404, detail=f'Line "{line_id}" not found') return LineSchema( id=line.id, shortName=line.short_name, name=line.name, status=line.status, transportMode=TransportMode.from_idfm_transport_mode( line.transport_mode, line.transport_submode ), backColorHexa=line.colour_web_hexa, foreColorHexa=line.text_colour_hexa, operatorId=line.operator_id, accessibility=line.accessibility, visualSignsAvailable=line.visual_signs_available, audibleSignsAvailable=line.audible_signs_available, stopIds=[stop.id for stop in line.stops], ) def _format_stop(stop: Stop) -> StopSchema: print(stop.__dict__) return StopSchema( id=stop.id, name=stop.name, town=stop.town_name, # xepsg2154=stop.xepsg2154, # yepsg2154=stop.yepsg2154, lat=stop.latitude, lon=stop.longitude, lines=[line.id for line in stop.lines], ) # châtelet @app.get("/stop/") async def get_stop( name: str = "", limit: int = 10 ) -> list[StopAreaSchema | StopSchema]: # TODO: Add limit support formatted = [] matching_stops = await Stop.get_by_name(name) # print(matching_stops, flush=True) stop_areas: dict[int, StopArea] = {} stops: dict[int, Stop] = {} for stop in matching_stops: # print(f"{stop.__dict__ = }", flush=True) dst = stop_areas if isinstance(stop, StopArea) else stops dst[stop.id] = stop for stop_area in stop_areas.values(): formatted_stops = [] for stop in stop_area.stops: formatted_stops.append(_format_stop(stop)) try: del stops[stop.id] except KeyError as err: print(err) formatted.append( StopAreaSchema( id=stop_area.id, name=stop_area.name, town=stop_area.town_name, # xepsg2154=stop_area.xepsg2154, # yepsg2154=stop_area.yepsg2154, type=stop_area.type, lines=[line.id for line in stop_area.lines], stops=formatted_stops, ) ) # print(f"{stops = }", flush=True) formatted.extend(_format_stop(stop) for stop in stops.values()) return formatted # TODO: Cache response for 30 secs ? @app.get("/stop/nextPassages/{stop_id}") async def get_next_passages(stop_id: str) -> JSONResponse: res = await idfm_interface.get_next_passages(stop_id) # print(res) service_delivery = res.Siri.ServiceDelivery stop_monitoring_deliveries = service_delivery.StopMonitoringDelivery by_line_by_dst_passages = defaultdict(lambda: defaultdict(list)) for delivery in stop_monitoring_deliveries: for stop_visit in delivery.MonitoredStopVisit: journey = stop_visit.MonitoredVehicleJourney # re.match will return None if the given journey.LineRef.value is not valid. try: line_id = IdfmInterface.LINE_RE.match(journey.LineRef.value).group(1) except AttributeError as exc: raise HTTPException( status_code=404, detail=f'Line "{journey.LineRef.value}" not found' ) from exc call = journey.MonitoredCall dst_names = call.DestinationDisplay dsts = [dst.value for dst in dst_names] if dst_names else [] print(f"{call.ArrivalPlatformName = }") next_passage = NextPassageSchema( line=line_id, operator=journey.OperatorRef.value, destinations=dsts, atStop=call.VehicleAtStop, aimedArrivalTs=optional_datetime_to_ts(call.AimedArrivalTime), expectedArrivalTs=optional_datetime_to_ts(call.ExpectedArrivalTime), arrivalPlatformName=call.ArrivalPlatformName.value if call.ArrivalPlatformName else None, aimedDepartTs=optional_datetime_to_ts(call.AimedDepartureTime), expectedDepartTs=optional_datetime_to_ts(call.ExpectedDepartureTime), arrivalStatus=call.ArrivalStatus.value, departStatus=call.DepartureStatus.value, ) by_line_passages = by_line_by_dst_passages[line_id] # TODO: by_line_passages[dst].extend(dsts) instead ? for dst in dsts: by_line_passages[dst].append(next_passage) return NextPassagesSchema( ts=service_delivery.ResponseTimestamp.timestamp(), passages=by_line_by_dst_passages, )