From b437bbbf708f2d8d5ec1bb6ddfb91efc19e5abc0 Mon Sep 17 00:00:00 2001 From: Adrien Date: Thu, 11 May 2023 21:17:02 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=8E=A8=20Split=20main=20into=20several=20?= =?UTF-8?q?APIRouters?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/dependencies.py | 22 ++++ backend/main.py | 228 ++---------------------------------- backend/routers/__init__.py | 0 backend/routers/line.py | 32 +++++ backend/routers/stop.py | 178 ++++++++++++++++++++++++++++ 5 files changed, 239 insertions(+), 221 deletions(-) create mode 100644 backend/dependencies.py create mode 100644 backend/routers/__init__.py create mode 100644 backend/routers/line.py create mode 100644 backend/routers/stop.py diff --git a/backend/dependencies.py b/backend/dependencies.py new file mode 100644 index 0000000..c2ee1f6 --- /dev/null +++ b/backend/dependencies.py @@ -0,0 +1,22 @@ +from os import environ + +from yaml import safe_load + +from backend.db import db +from backend.idfm_interface import IdfmInterface +from backend.settings import Settings + + +CONFIG_PATH = environ.get("CONFIG_PATH", "./config.sample.yaml") + + +def load_settings(path: str) -> Settings: + with open(path, "r") as config_file: + config = safe_load(config_file) + + return Settings(**config) + + +settings = load_settings(CONFIG_PATH) + +idfm_interface = IdfmInterface(settings.idfm_api_key.get_secret_value(), db) diff --git a/backend/main.py b/backend/main.py index 63c4ece..3ae627f 100755 --- a/backend/main.py +++ b/backend/main.py @@ -1,13 +1,8 @@ #!/usr/bin/env python3 -import logging -from collections import defaultdict -from datetime import datetime -from os import environ, EX_USAGE -from typing import Sequence import uvicorn from contextlib import asynccontextmanager -from fastapi import FastAPI, HTTPException +from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from opentelemetry import trace @@ -16,35 +11,10 @@ from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor from opentelemetry.sdk.resources import Resource, SERVICE_NAME from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor -from rich import print -from yaml import safe_load from backend.db import db -from backend.idfm_interface import Destinations as IdfmDestinations, IdfmInterface -from backend.models import Line, Stop, StopArea, StopShape -from backend.schemas import ( - Line as LineSchema, - TransportMode, - NextPassage as NextPassageSchema, - NextPassages as NextPassagesSchema, - Stop as StopSchema, - StopArea as StopAreaSchema, - StopShape as StopShapeSchema, -) -from backend.settings import Settings - - -CONFIG_PATH = environ.get("CONFIG_PATH", "./config.sample.yaml") - - -def load_settings(path: str) -> Settings: - with open(path, "r") as config_file: - config = safe_load(config_file) - - return Settings(**config) - - -settings = load_settings(CONFIG_PATH) +from dependencies import idfm_interface, settings +from routers import line, stop @asynccontextmanager @@ -70,6 +40,10 @@ app.add_middleware( app.mount("/widget", StaticFiles(directory="../frontend/", html=True), name="widget") +app.include_router(line.router) +app.include_router(stop.router) + + FastAPIInstrumentor.instrument_app(app) trace.set_tracer_provider( @@ -78,194 +52,6 @@ trace.set_tracer_provider( trace.get_tracer_provider().add_span_processor(BatchSpanProcessor(OTLPSpanExporter())) tracer = trace.get_tracer(settings.app_name) -idfm_interface = IdfmInterface(settings.idfm_api_key.get_secret_value(), db) - - -def optional_datetime_to_ts(dt: datetime | None) -> int | None: - return int(dt.timestamp()) if dt else None - - -@app.get("/line/{line_id}", response_model=LineSchema) -async def get_line(line_id: int) -> LineSchema: - line: Line | None = await Line.get_by_id(line_id) - - if line is None: - raise HTTPException(status_code=404, detail=f'Line "{line_id}" not found') - - return LineSchema( - id=line.id, - shortName=line.short_name, - name=line.name, - status=line.status, - transportMode=TransportMode.from_idfm_transport_mode( - line.transport_mode, line.transport_submode - ), - backColorHexa=line.colour_web_hexa, - foreColorHexa=line.text_colour_hexa, - operatorId=line.operator_id, - accessibility=line.accessibility, - visualSignsAvailable=line.visual_signs_available, - audibleSignsAvailable=line.audible_signs_available, - stopIds=[stop.id for stop in line.stops], - ) - - -def _format_stop(stop: Stop) -> StopSchema: - return StopSchema( - id=stop.id, - name=stop.name, - town=stop.town_name, - epsg3857_x=stop.epsg3857_x, - epsg3857_y=stop.epsg3857_y, - lines=[line.id for line in stop.lines], - ) - - -@app.get("/stop/") -async def get_stop( - name: str = "", limit: int = 10 -) -> Sequence[StopAreaSchema | StopSchema]: - # TODO: Add limit support - - formatted: list[StopAreaSchema | StopSchema] = [] - matching_stops = await Stop.get_by_name(name) - # print(matching_stops, flush=True) - - stop_areas: dict[int, StopArea] = {} - stops: dict[int, Stop] = {} - for stop in matching_stops: - # print(f"{stop.__dict__ = }", flush=True) - dst = stop_areas if isinstance(stop, StopArea) else stops - dst[stop.id] = stop - - for stop_area in stop_areas.values(): - - formatted_stops = [] - for stop in stop_area.stops: - formatted_stops.append(_format_stop(stop)) - try: - del stops[stop.id] - except KeyError as err: - print(err) - - formatted.append( - StopAreaSchema( - id=stop_area.id, - name=stop_area.name, - town=stop_area.town_name, - type=stop_area.type, - lines=[line.id for line in stop_area.lines], - stops=formatted_stops, - ) - ) - - formatted.extend(_format_stop(stop) for stop in stops.values()) - - return formatted - - -# TODO: Cache response for 30 secs ? -@app.get("/stop/{stop_id}/nextPassages") -async def get_next_passages(stop_id: int) -> NextPassagesSchema | None: - res = await idfm_interface.get_next_passages(stop_id) - if res is None: - return None - - service_delivery = res.Siri.ServiceDelivery - stop_monitoring_deliveries = service_delivery.StopMonitoringDelivery - - by_line_by_dst_passages: dict[ - str, dict[str, list[NextPassageSchema]] - ] = defaultdict(lambda: defaultdict(list)) - - for delivery in stop_monitoring_deliveries: - for stop_visit in delivery.MonitoredStopVisit: - - journey = stop_visit.MonitoredVehicleJourney - - # re.match will return None if the given journey.LineRef.value is not valid. - try: - line_id_match = IdfmInterface.LINE_RE.match(journey.LineRef.value) - line_id = int(line_id_match.group(1)) # type: ignore - except (AttributeError, TypeError, ValueError) as err: - raise HTTPException( - status_code=404, detail=f'Line "{journey.LineRef.value}" not found' - ) from err - - call = journey.MonitoredCall - - dst_names = call.DestinationDisplay - dsts = [dst.value for dst in dst_names] if dst_names else [] - arrivalPlatformName = ( - call.ArrivalPlatformName.value if call.ArrivalPlatformName else None - ) - - next_passage = NextPassageSchema( - line=line_id, - operator=journey.OperatorRef.value, - destinations=dsts, - atStop=call.VehicleAtStop, - aimedArrivalTs=optional_datetime_to_ts(call.AimedArrivalTime), - expectedArrivalTs=optional_datetime_to_ts(call.ExpectedArrivalTime), - arrivalPlatformName=arrivalPlatformName, - aimedDepartTs=optional_datetime_to_ts(call.AimedDepartureTime), - expectedDepartTs=optional_datetime_to_ts(call.ExpectedDepartureTime), - arrivalStatus=call.ArrivalStatus.value, - departStatus=call.DepartureStatus.value, - ) - - by_line_passages = by_line_by_dst_passages[line_id] - # TODO: by_line_passages[dst].extend(dsts) instead ? - for dst in dsts: - by_line_passages[dst].append(next_passage) - - return NextPassagesSchema( - ts=service_delivery.ResponseTimestamp.timestamp(), - passages=by_line_by_dst_passages, - ) - - -@app.get("/stop/{stop_id}/destinations") -async def get_stop_destinations( - stop_id: int, -) -> IdfmDestinations | None: - destinations = await idfm_interface.get_destinations(stop_id) - - return destinations - - -@app.get("/stop/{stop_id}/shape") -async def get_stop_shape(stop_id: int) -> StopShapeSchema | None: - connection_area = None - - if (stop := await Stop.get_by_id(stop_id)) is not None: - connection_area = stop.connection_area - - elif (stop_area := await StopArea.get_by_id(stop_id)) is not None: - connection_areas = {stop.connection_area for stop in stop_area.stops} - connection_areas_len = len(connection_areas) - if connection_areas_len == 1: - connection_area = connection_areas.pop() - - else: - prefix = "More than one" if connection_areas_len else "No" - msg = f"{prefix} connection area has been found for stop area #{stop_id}" - raise HTTPException(status_code=500, detail=msg) - - if ( - connection_area is not None - and (shape := await StopShape.get_by_id(connection_area.id)) is not None - ): - return StopShapeSchema( - id=shape.id, - type=shape.type, - epsg3857_bbox=shape.epsg3857_bbox, - epsg3857_points=shape.epsg3857_points, - ) - - msg = f"No shape found for stop {stop_id}" - raise HTTPException(status_code=404, detail=msg) - if __name__ == "__main__": http_settings = settings.http diff --git a/backend/routers/__init__.py b/backend/routers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/routers/line.py b/backend/routers/line.py new file mode 100644 index 0000000..1c8463f --- /dev/null +++ b/backend/routers/line.py @@ -0,0 +1,32 @@ +from fastapi import APIRouter, HTTPException + +from backend.models import Line +from backend.schemas import Line as LineSchema, TransportMode + + +router = APIRouter(prefix="/line", tags=["line"]) + + +@router.get("/{line_id}", response_model=LineSchema) +async def get_line(line_id: int) -> LineSchema: + line: Line | None = await Line.get_by_id(line_id) + + if line is None: + raise HTTPException(status_code=404, detail=f'Line "{line_id}" not found') + + return LineSchema( + id=line.id, + shortName=line.short_name, + name=line.name, + status=line.status, + transportMode=TransportMode.from_idfm_transport_mode( + line.transport_mode, line.transport_submode + ), + backColorHexa=line.colour_web_hexa, + foreColorHexa=line.text_colour_hexa, + operatorId=line.operator_id, + accessibility=line.accessibility, + visualSignsAvailable=line.visual_signs_available, + audibleSignsAvailable=line.audible_signs_available, + stopIds=[stop.id for stop in line.stops], + ) diff --git a/backend/routers/stop.py b/backend/routers/stop.py new file mode 100644 index 0000000..cca2260 --- /dev/null +++ b/backend/routers/stop.py @@ -0,0 +1,178 @@ +from collections import defaultdict +from datetime import datetime +from typing import Sequence + +from fastapi import APIRouter, HTTPException + +from backend.idfm_interface import Destinations as IdfmDestinations, IdfmInterface +from backend.models import Stop, StopArea, StopShape +from backend.schemas import ( + NextPassage as NextPassageSchema, + NextPassages as NextPassagesSchema, + Stop as StopSchema, + StopArea as StopAreaSchema, + StopShape as StopShapeSchema, +) +from dependencies import idfm_interface + + +router = APIRouter(prefix="/stop", tags=["stop"]) + + +def _format_stop(stop: Stop) -> StopSchema: + return StopSchema( + id=stop.id, + name=stop.name, + town=stop.town_name, + epsg3857_x=stop.epsg3857_x, + epsg3857_y=stop.epsg3857_y, + lines=[line.id for line in stop.lines], + ) + + +def optional_datetime_to_ts(dt: datetime | None) -> int | None: + return int(dt.timestamp()) if dt else None + + +# TODO: Add limit support +@router.get("/") +async def get_stop( + name: str = "", limit: int = 10 +) -> Sequence[StopAreaSchema | StopSchema]: + + formatted: list[StopAreaSchema | StopSchema] = [] + matching_stops = await Stop.get_by_name(name) + + stop_areas: dict[int, StopArea] = {} + stops: dict[int, Stop] = {} + for stop in matching_stops: + dst = stop_areas if isinstance(stop, StopArea) else stops + dst[stop.id] = stop + + for stop_area in stop_areas.values(): + + formatted_stops = [] + for stop in stop_area.stops: + formatted_stops.append(_format_stop(stop)) + try: + del stops[stop.id] + except KeyError as err: + print(err) + + formatted.append( + StopAreaSchema( + id=stop_area.id, + name=stop_area.name, + town=stop_area.town_name, + type=stop_area.type, + lines=[line.id for line in stop_area.lines], + stops=formatted_stops, + ) + ) + + formatted.extend(_format_stop(stop) for stop in stops.values()) + + return formatted + + +# TODO: Cache response for 30 secs ? +@router.get("/{stop_id}/nextPassages") +async def get_next_passages(stop_id: int) -> NextPassagesSchema | None: + res = await idfm_interface.get_next_passages(stop_id) + if res is None: + return None + + service_delivery = res.Siri.ServiceDelivery + stop_monitoring_deliveries = service_delivery.StopMonitoringDelivery + + by_line_by_dst_passages: dict[ + int, dict[str, list[NextPassageSchema]] + ] = defaultdict(lambda: defaultdict(list)) + + for delivery in stop_monitoring_deliveries: + for stop_visit in delivery.MonitoredStopVisit: + + journey = stop_visit.MonitoredVehicleJourney + + # re.match will return None if the given journey.LineRef.value is not valid. + try: + line_id_match = IdfmInterface.LINE_RE.match(journey.LineRef.value) + line_id = int(line_id_match.group(1)) # type: ignore + except (AttributeError, TypeError, ValueError) as err: + raise HTTPException( + status_code=404, detail=f'Line "{journey.LineRef.value}" not found' + ) from err + + call = journey.MonitoredCall + + dst_names = call.DestinationDisplay + dsts = [dst.value for dst in dst_names] if dst_names else [] + arrivalPlatformName = ( + call.ArrivalPlatformName.value if call.ArrivalPlatformName else None + ) + + next_passage = NextPassageSchema( + line=line_id, + operator=journey.OperatorRef.value, + destinations=dsts, + atStop=call.VehicleAtStop, + aimedArrivalTs=optional_datetime_to_ts(call.AimedArrivalTime), + expectedArrivalTs=optional_datetime_to_ts(call.ExpectedArrivalTime), + arrivalPlatformName=arrivalPlatformName, + aimedDepartTs=optional_datetime_to_ts(call.AimedDepartureTime), + expectedDepartTs=optional_datetime_to_ts(call.ExpectedDepartureTime), + arrivalStatus=call.ArrivalStatus.value, + departStatus=call.DepartureStatus.value, + ) + + by_line_passages = by_line_by_dst_passages[line_id] + # TODO: by_line_passages[dst].extend(dsts) instead ? + for dst in dsts: + by_line_passages[dst].append(next_passage) + + return NextPassagesSchema( + ts=service_delivery.ResponseTimestamp.timestamp(), + passages=by_line_by_dst_passages, + ) + + +@router.get("/{stop_id}/destinations") +async def get_stop_destinations( + stop_id: int, +) -> IdfmDestinations | None: + destinations = await idfm_interface.get_destinations(stop_id) + + return destinations + + +@router.get("/{stop_id}/shape") +async def get_stop_shape(stop_id: int) -> StopShapeSchema | None: + connection_area = None + + if (stop := await Stop.get_by_id(stop_id)) is not None: + connection_area = stop.connection_area + + elif (stop_area := await StopArea.get_by_id(stop_id)) is not None: + connection_areas = {stop.connection_area for stop in stop_area.stops} + connection_areas_len = len(connection_areas) + if connection_areas_len == 1: + connection_area = connection_areas.pop() + + else: + prefix = "More than one" if connection_areas_len else "No" + msg = f"{prefix} connection area has been found for stop area #{stop_id}" + raise HTTPException(status_code=500, detail=msg) + + if ( + connection_area is not None + and (shape := await StopShape.get_by_id(connection_area.id)) is not None + ): + return StopShapeSchema( + id=shape.id, + type=shape.type, + epsg3857_bbox=shape.epsg3857_bbox, + epsg3857_points=shape.epsg3857_points, + ) + + msg = f"No shape found for stop {stop_id}" + raise HTTPException(status_code=404, detail=msg)