🏷️ Make python linters happy

This commit is contained in:
2023-02-08 22:10:21 +01:00
parent d1db97554c
commit e34355e8be
18 changed files with 400 additions and 290 deletions

View File

@@ -1,10 +1,10 @@
from collections import defaultdict
from datetime import datetime
from os import environ
from os import environ, EX_USAGE
from typing import Sequence
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from fastapi.staticfiles import StaticFiles
from rich import print
@@ -21,7 +21,9 @@ from backend.schemas import (
)
API_KEY = environ.get("API_KEY")
# TODO: Add error message if no key is given.
if API_KEY is None:
print('No "API_KEY" environment variable set... abort.')
exit(EX_USAGE)
# TODO: Remove postgresql+asyncpg from environ variable
DB_PATH = "postgresql+asyncpg://cer_user:cer_password@127.0.0.1:5438/cer_db"
@@ -44,9 +46,9 @@ idfm_interface = IdfmInterface(API_KEY, db)
@app.on_event("startup")
async def startup():
# await db.connect(DB_PATH, clear_static_data=True)
# await idfm_interface.startup()
await db.connect(DB_PATH, clear_static_data=False)
await db.connect(DB_PATH, clear_static_data=True)
await idfm_interface.startup()
# await db.connect(DB_PATH, clear_static_data=False)
print("Connected")
@@ -61,12 +63,12 @@ STATIC_ROOT = "../frontend/"
app.mount("/widget", StaticFiles(directory=STATIC_ROOT, html=True), name="widget")
def optional_datetime_to_ts(dt: datetime) -> int | None:
return dt.timestamp() if dt else None
def optional_datetime_to_ts(dt: datetime | None) -> int | None:
return int(dt.timestamp()) if dt else None
@app.get("/line/{line_id}", response_model=LineSchema)
async def get_line(line_id: str) -> JSONResponse:
async def get_line(line_id: str) -> LineSchema:
line: Line | None = await Line.get_by_id(line_id)
if line is None:
@@ -91,7 +93,7 @@ async def get_line(line_id: str) -> JSONResponse:
def _format_stop(stop: Stop) -> StopSchema:
print(stop.__dict__)
# print(stop.__dict__)
return StopSchema(
id=stop.id,
name=stop.name,
@@ -103,15 +105,17 @@ def _format_stop(stop: Stop) -> StopSchema:
lines=[line.id for line in stop.lines],
)
# châtelet
@app.get("/stop/")
async def get_stop(
name: str = "", limit: int = 10
) -> list[StopAreaSchema | StopSchema]:
) -> Sequence[StopAreaSchema | StopSchema]:
# TODO: Add limit support
formatted = []
formatted: list[StopAreaSchema | StopSchema] = []
matching_stops = await Stop.get_by_name(name)
# print(matching_stops, flush=True)
@@ -153,15 +157,17 @@ async def get_stop(
# TODO: Cache response for 30 secs ?
@app.get("/stop/nextPassages/{stop_id}")
async def get_next_passages(stop_id: str) -> JSONResponse:
async def get_next_passages(stop_id: str) -> NextPassagesSchema | None:
res = await idfm_interface.get_next_passages(stop_id)
# print(res)
if res is None:
return None
service_delivery = res.Siri.ServiceDelivery
stop_monitoring_deliveries = service_delivery.StopMonitoringDelivery
by_line_by_dst_passages = defaultdict(lambda: defaultdict(list))
by_line_by_dst_passages: dict[
str, dict[str, list[NextPassageSchema]]
] = defaultdict(lambda: defaultdict(list))
for delivery in stop_monitoring_deliveries:
for stop_visit in delivery.MonitoredStopVisit:
@@ -190,7 +196,9 @@ async def get_next_passages(stop_id: str) -> JSONResponse:
atStop=call.VehicleAtStop,
aimedArrivalTs=optional_datetime_to_ts(call.AimedArrivalTime),
expectedArrivalTs=optional_datetime_to_ts(call.ExpectedArrivalTime),
arrivalPlatformName=call.ArrivalPlatformName.value if call.ArrivalPlatformName else None,
arrivalPlatformName=call.ArrivalPlatformName.value
if call.ArrivalPlatformName
else None,
aimedDepartTs=optional_datetime_to_ts(call.AimedDepartureTime),
expectedDepartTs=optional_datetime_to_ts(call.ExpectedDepartureTime),
arrivalStatus=call.ArrivalStatus.value,