🏷️ Fix some type issues (mypy)

This commit is contained in:
2023-05-11 21:40:38 +02:00
parent b437bbbf70
commit 5e0d7b174c
2 changed files with 23 additions and 11 deletions

View File

@@ -4,7 +4,7 @@ from ..idfm_interface.idfm_types import TrainStatus
class NextPassage(BaseModel): class NextPassage(BaseModel):
line: str line: int
operator: str operator: str
destinations: list[str] destinations: list[str]
atStop: bool atStop: bool
@@ -19,4 +19,4 @@ class NextPassage(BaseModel):
class NextPassages(BaseModel): class NextPassages(BaseModel):
ts: int ts: int
passages: dict[str, dict[str, list[NextPassage]]] passages: dict[int, dict[str, list[NextPassage]]]

View File

@@ -4,7 +4,11 @@ from typing import Sequence
from fastapi import APIRouter, HTTPException from fastapi import APIRouter, HTTPException
from backend.idfm_interface import Destinations as IdfmDestinations, IdfmInterface from backend.idfm_interface import (
Destinations as IdfmDestinations,
IdfmInterface,
TrainStatus,
)
from backend.models import Stop, StopArea, StopShape from backend.models import Stop, StopArea, StopShape
from backend.schemas import ( from backend.schemas import (
NextPassage as NextPassageSchema, NextPassage as NextPassageSchema,
@@ -38,16 +42,20 @@ def optional_datetime_to_ts(dt: datetime | None) -> int | None:
@router.get("/") @router.get("/")
async def get_stop( async def get_stop(
name: str = "", limit: int = 10 name: str = "", limit: int = 10
) -> Sequence[StopAreaSchema | StopSchema]: ) -> Sequence[StopAreaSchema | StopSchema] | None:
matching_stops = await Stop.get_by_name(name)
if matching_stops is None:
return None
formatted: list[StopAreaSchema | StopSchema] = [] formatted: list[StopAreaSchema | StopSchema] = []
matching_stops = await Stop.get_by_name(name)
stop_areas: dict[int, StopArea] = {} stop_areas: dict[int, StopArea] = {}
stops: dict[int, Stop] = {} stops: dict[int, Stop] = {}
for stop in matching_stops: for stop in matching_stops:
dst = stop_areas if isinstance(stop, StopArea) else stops if isinstance(stop, StopArea):
dst[stop.id] = stop stop_areas[stop.id] = stop
elif isinstance(stop, Stop):
stops[stop.id] = stop
for stop_area in stop_areas.values(): for stop_area in stop_areas.values():
@@ -121,8 +129,12 @@ async def get_next_passages(stop_id: int) -> NextPassagesSchema | None:
arrivalPlatformName=arrivalPlatformName, arrivalPlatformName=arrivalPlatformName,
aimedDepartTs=optional_datetime_to_ts(call.AimedDepartureTime), aimedDepartTs=optional_datetime_to_ts(call.AimedDepartureTime),
expectedDepartTs=optional_datetime_to_ts(call.ExpectedDepartureTime), expectedDepartTs=optional_datetime_to_ts(call.ExpectedDepartureTime),
arrivalStatus=call.ArrivalStatus.value, arrivalStatus=call.ArrivalStatus
departStatus=call.DepartureStatus.value, if call.ArrivalStatus is not None
else TrainStatus.unknown,
departStatus=call.DepartureStatus
if call.DepartureStatus is not None
else TrainStatus.unknown,
) )
by_line_passages = by_line_by_dst_passages[line_id] by_line_passages = by_line_by_dst_passages[line_id]
@@ -131,7 +143,7 @@ async def get_next_passages(stop_id: int) -> NextPassagesSchema | None:
by_line_passages[dst].append(next_passage) by_line_passages[dst].append(next_passage)
return NextPassagesSchema( return NextPassagesSchema(
ts=service_delivery.ResponseTimestamp.timestamp(), ts=int(service_delivery.ResponseTimestamp.timestamp()),
passages=by_line_by_dst_passages, passages=by_line_by_dst_passages,
) )