🏷️ Make python linters happy
This commit is contained in:
0
backend/backend/__init__.py
Normal file
0
backend/backend/__init__.py
Normal file
@@ -1,4 +1,6 @@
|
||||
from .db import Database
|
||||
from .base_class import Base
|
||||
|
||||
__all__ = ["Base"]
|
||||
|
||||
db = Database()
|
||||
|
@@ -1,34 +1,39 @@
|
||||
from collections.abc import Iterable
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Iterable, Self, TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import declarative_base
|
||||
from typing import Iterable, Self
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
Base = declarative_base()
|
||||
Base.db = None
|
||||
if TYPE_CHECKING:
|
||||
from .db import Database
|
||||
|
||||
|
||||
async def base_add(cls, stops: Self | Iterable[Self]) -> bool:
|
||||
try:
|
||||
method = (
|
||||
cls.db.session.add_all
|
||||
if isinstance(stops, Iterable)
|
||||
else cls.db.session.add
|
||||
)
|
||||
method(stops)
|
||||
await cls.db.session.commit()
|
||||
except IntegrityError as err:
|
||||
print(err)
|
||||
class Base(DeclarativeBase):
|
||||
db: Database | None = None
|
||||
|
||||
@classmethod
|
||||
async def add(cls, stops: Self | Iterable[Self]) -> bool:
|
||||
try:
|
||||
if isinstance(stops, Iterable):
|
||||
cls.db.session.add_all(stops) # type: ignore
|
||||
else:
|
||||
cls.db.session.add(stops) # type: ignore
|
||||
await cls.db.session.commit() # type: ignore
|
||||
except (AttributeError, IntegrityError) as err:
|
||||
print(err)
|
||||
return False
|
||||
|
||||
Base.add = classmethod(base_add)
|
||||
return True
|
||||
|
||||
|
||||
async def base_get_by_id(cls, id_: int | str) -> None | Base:
|
||||
res = await cls.db.session.execute(select(cls).where(cls.id == id_))
|
||||
element = res.scalar_one_or_none()
|
||||
return element
|
||||
|
||||
|
||||
Base.get_by_id = classmethod(base_get_by_id)
|
||||
@classmethod
|
||||
async def get_by_id(cls, id_: int | str) -> Self | None:
|
||||
try:
|
||||
stmt = select(cls).where(cls.id == id_) # type: ignore
|
||||
res = await cls.db.session.execute(stmt) # type: ignore
|
||||
element = res.scalar_one_or_none()
|
||||
except AttributeError as err:
|
||||
print(err)
|
||||
element = None
|
||||
return element
|
||||
|
@@ -1,80 +1,48 @@
|
||||
from asyncio import gather as asyncio_gather
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
from time import time
|
||||
from typing import Callable, Iterable, Optional
|
||||
|
||||
from rich import print
|
||||
from sqlalchemy import event, select, tuple_
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import (
|
||||
selectinload,
|
||||
sessionmaker,
|
||||
with_polymorphic,
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
async_sessionmaker,
|
||||
AsyncEngine,
|
||||
AsyncSession,
|
||||
create_async_engine,
|
||||
)
|
||||
from sqlalchemy.orm.attributes import set_committed_value
|
||||
|
||||
from .base_class import Base
|
||||
|
||||
|
||||
# import logging
|
||||
|
||||
# logging.basicConfig()
|
||||
# logger = logging.getLogger("bot.sqltime")
|
||||
# logger.setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
# @event.listens_for(Engine, "before_cursor_execute")
|
||||
# def before_cursor_execute(conn, cursor, statement, parameters, context, executemany):
|
||||
# conn.info.setdefault("query_start_time", []).append(time())
|
||||
# logger.debug("Start Query: %s", statement)
|
||||
|
||||
|
||||
# @event.listens_for(Engine, "after_cursor_execute")
|
||||
# def after_cursor_execute(conn, cursor, statement, parameters, context, executemany):
|
||||
# total = time() - conn.info["query_start_time"].pop(-1)
|
||||
# logger.debug("Query Complete!")
|
||||
# logger.debug("Total Time: %f", total)
|
||||
|
||||
|
||||
class Database:
|
||||
def __init__(self) -> None:
|
||||
self._engine = None
|
||||
self._session_maker = None
|
||||
self._session = None
|
||||
self._engine: AsyncEngine | None = None
|
||||
self._session_maker: async_sessionmaker[AsyncSession] | None = None
|
||||
self._session: AsyncSession | None = None
|
||||
|
||||
@property
|
||||
def session(self) -> None:
|
||||
if self._session is None:
|
||||
self._session = self._session_maker()
|
||||
def session(self) -> AsyncSession | None:
|
||||
if self._session is None and (session_maker := self._session_maker) is not None:
|
||||
self._session = session_maker()
|
||||
return self._session
|
||||
|
||||
def use_session(func: Callable):
|
||||
@wraps(func)
|
||||
async def wrapper(self, *args, **kwargs):
|
||||
if self._check_session() is not None:
|
||||
return await func(self, *args, **kwargs)
|
||||
# TODO: Raise an exception ?
|
||||
async def connect(self, db_path: str, clear_static_data: bool = False) -> bool:
|
||||
|
||||
return wrapper
|
||||
|
||||
async def connect(self, db_path: str, clear_static_data: bool = False) -> None:
|
||||
# TODO: Preserve UserLastStopSearchResults table from drop.
|
||||
self._engine = create_async_engine(db_path)
|
||||
self._session_maker = sessionmaker(
|
||||
self._engine, expire_on_commit=False, class_=AsyncSession
|
||||
)
|
||||
await self.session.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm;")
|
||||
if self._engine is not None:
|
||||
self._session_maker = async_sessionmaker(
|
||||
self._engine, expire_on_commit=False, class_=AsyncSession
|
||||
)
|
||||
if (session := self.session) is not None:
|
||||
await session.execute(text("CREATE EXTENSION IF NOT EXISTS pg_trgm;"))
|
||||
|
||||
async with self._engine.begin() as conn:
|
||||
if clear_static_data:
|
||||
await conn.run_sync(Base.metadata.drop_all)
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
async with self._engine.begin() as conn:
|
||||
if clear_static_data:
|
||||
await conn.run_sync(Base.metadata.drop_all)
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
return True
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
if self._session is not None:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
await self._engine.dispose()
|
||||
|
||||
if self._engine is not None:
|
||||
await self._engine.dispose()
|
||||
|
@@ -1,2 +1,68 @@
|
||||
from .idfm_interface import IdfmInterface
|
||||
from .idfm_types import *
|
||||
|
||||
from .idfm_types import (
|
||||
Coordinate,
|
||||
FramedVehicleJourney,
|
||||
IdfmLineState,
|
||||
IdfmOperator,
|
||||
IdfmResponse,
|
||||
IdfmState,
|
||||
LinePicto,
|
||||
LineFields,
|
||||
Line,
|
||||
MonitoredCall,
|
||||
MonitoredVehicleJourney,
|
||||
Point,
|
||||
Siri,
|
||||
ServiceDelivery,
|
||||
Stop,
|
||||
StopArea,
|
||||
StopAreaFields,
|
||||
StopAreaStopAssociation,
|
||||
StopAreaStopAssociationFields,
|
||||
StopAreaType,
|
||||
StopDelivery,
|
||||
StopFields,
|
||||
StopLineAsso,
|
||||
StopLineAssoFields,
|
||||
StopMonitoringDelivery,
|
||||
TrainNumber,
|
||||
TrainStatus,
|
||||
TransportMode,
|
||||
TransportSubMode,
|
||||
Value,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Coordinate",
|
||||
"FramedVehicleJourney",
|
||||
"IdfmInterface",
|
||||
"IdfmLineState",
|
||||
"IdfmOperator",
|
||||
"IdfmResponse",
|
||||
"IdfmState",
|
||||
"LinePicto",
|
||||
"LineFields",
|
||||
"Line",
|
||||
"MonitoredCall",
|
||||
"MonitoredVehicleJourney",
|
||||
"Point",
|
||||
"Siri",
|
||||
"ServiceDelivery",
|
||||
"Stop",
|
||||
"StopArea",
|
||||
"StopAreaFields",
|
||||
"StopAreaStopAssociation",
|
||||
"StopAreaStopAssociationFields",
|
||||
"StopAreaType",
|
||||
"StopDelivery",
|
||||
"StopFields",
|
||||
"StopLineAsso",
|
||||
"StopLineAssoFields",
|
||||
"StopMonitoringDelivery",
|
||||
"TrainNumber",
|
||||
"TrainStatus",
|
||||
"TransportMode",
|
||||
"TransportSubMode",
|
||||
"Value",
|
||||
]
|
||||
|
@@ -1,7 +1,6 @@
|
||||
from pathlib import Path
|
||||
from re import compile as re_compile
|
||||
from time import time
|
||||
from typing import ByteString, Iterable, List, Optional
|
||||
from typing import AsyncIterator, ByteString, Callable, Iterable, List, Type
|
||||
|
||||
from aiofiles import open as async_open
|
||||
from aiohttp import ClientSession
|
||||
@@ -16,14 +15,14 @@ from .idfm_types import (
|
||||
IdfmLineState,
|
||||
IdfmResponse,
|
||||
Line as IdfmLine,
|
||||
MonitoredVehicleJourney,
|
||||
LinePicto as IdfmPicto,
|
||||
IdfmState,
|
||||
Stop as IdfmStop,
|
||||
StopArea as IdfmStopArea,
|
||||
StopAreaStopAssociation,
|
||||
StopAreaType,
|
||||
StopLineAsso as IdfmStopLineAsso,
|
||||
Stops,
|
||||
TransportMode,
|
||||
)
|
||||
from .ratp_types import Picto as RatpPicto
|
||||
|
||||
@@ -40,7 +39,10 @@ class IdfmInterface:
|
||||
IDFM_PICTO_URL = f"{IDFM_ROOT_URL}/referentiel-des-lignes/files"
|
||||
|
||||
RATP_ROOT_URL = "https://data.ratp.fr/explore/dataset"
|
||||
RATP_PICTO_URL = f"{RATP_ROOT_URL}/pictogrammes-des-lignes-de-metro-rer-tramway-bus-et-noctilien/files"
|
||||
RATP_PICTO_URL = (
|
||||
f"{RATP_ROOT_URL}"
|
||||
"/pictogrammes-des-lignes-de-metro-rer-tramway-bus-et-noctilien/files"
|
||||
)
|
||||
|
||||
OPERATOR_RE = re_compile(r"[^:]+:Operator::([^:]+):")
|
||||
LINE_RE = re_compile(r"[^:]+:Line::([^:]+):")
|
||||
@@ -64,7 +66,7 @@ class IdfmInterface:
|
||||
|
||||
async def startup(self) -> None:
|
||||
BATCH_SIZE = 10000
|
||||
STEPS = (
|
||||
STEPS: tuple[tuple[Type[Stop] | Type[StopArea], Callable, Callable], ...] = (
|
||||
(
|
||||
StopArea,
|
||||
self._request_idfm_stop_areas,
|
||||
@@ -132,13 +134,13 @@ class IdfmInterface:
|
||||
pictos.append(picto)
|
||||
if len(pictos) == batch_size:
|
||||
formatted_pictos = IdfmInterface._format_ratp_pictos(*pictos)
|
||||
await LinePicto.add(formatted_pictos.values())
|
||||
await LinePicto.add(map(lambda picto: picto[1], formatted_pictos))
|
||||
await Line.add_pictos(formatted_pictos)
|
||||
pictos.clear()
|
||||
|
||||
if pictos:
|
||||
formatted_pictos = IdfmInterface._format_ratp_pictos(*pictos)
|
||||
await LinePicto.add(formatted_pictos.values())
|
||||
await LinePicto.add(map(lambda picto: picto[1], formatted_pictos))
|
||||
await Line.add_pictos(formatted_pictos)
|
||||
|
||||
async def _load_lines_stops_assos(self, batch_size: int = 5000) -> None:
|
||||
@@ -174,16 +176,18 @@ class IdfmInterface:
|
||||
assos.append((int(fields.zdaid), int(fields.arrid)))
|
||||
if len(assos) == batch_size:
|
||||
total_assos_nb += batch_size
|
||||
total_found_nb += await StopArea.add_stops(assos)
|
||||
if (found_nb := await StopArea.add_stops(assos)) is not None:
|
||||
total_found_nb += found_nb
|
||||
assos.clear()
|
||||
|
||||
if assos:
|
||||
total_assos_nb += len(assos)
|
||||
total_found_nb += await StopArea.add_stops(assos)
|
||||
if (found_nb := await StopArea.add_stops(assos)) is not None:
|
||||
total_found_nb += found_nb
|
||||
|
||||
print(f"{total_found_nb} stop area <-> stop ({total_assos_nb = } found)")
|
||||
|
||||
async def _request_idfm_stops(self):
|
||||
async def _request_idfm_stops(self) -> AsyncIterator[IdfmStop]:
|
||||
# headers = {"Accept": "application/json", "apikey": self._api_key}
|
||||
# async with ClientSession(headers=headers) as session:
|
||||
# async with session.get(self.STOPS_URL) as response:
|
||||
@@ -196,19 +200,21 @@ class IdfmInterface:
|
||||
for element in self._json_stops_decoder.decode(await raw.read()):
|
||||
yield element
|
||||
|
||||
async def _request_idfm_stop_areas(self):
|
||||
async def _request_idfm_stop_areas(self) -> AsyncIterator[IdfmStopArea]:
|
||||
# TODO: Use HTTP
|
||||
async with async_open("./tests/datasets/zones-d-arrets.json", "rb") as raw:
|
||||
for element in self._json_stop_areas_decoder.decode(await raw.read()):
|
||||
yield element
|
||||
|
||||
async def _request_idfm_lines(self):
|
||||
async def _request_idfm_lines(self) -> AsyncIterator[IdfmLine]:
|
||||
# TODO: Use HTTP
|
||||
async with async_open("./tests/datasets/lines_dataset.json", "rb") as raw:
|
||||
for element in self._json_lines_decoder.decode(await raw.read()):
|
||||
yield element
|
||||
|
||||
async def _request_idfm_stops_lines_associations(self):
|
||||
async def _request_idfm_stops_lines_associations(
|
||||
self,
|
||||
) -> AsyncIterator[IdfmStopLineAsso]:
|
||||
# TODO: Use HTTP
|
||||
async with async_open("./tests/datasets/arrets-lignes.json", "rb") as raw:
|
||||
for element in self._json_stops_lines_assos_decoder.decode(
|
||||
@@ -216,7 +222,9 @@ class IdfmInterface:
|
||||
):
|
||||
yield element
|
||||
|
||||
async def _request_idfm_stop_area_stop_associations(self):
|
||||
async def _request_idfm_stop_area_stop_associations(
|
||||
self,
|
||||
) -> AsyncIterator[StopAreaStopAssociation]:
|
||||
# TODO: Use HTTP
|
||||
async with async_open("./tests/datasets/relations.json", "rb") as raw:
|
||||
for element in self._json_stop_area_stop_asso_decoder.decode(
|
||||
@@ -224,7 +232,7 @@ class IdfmInterface:
|
||||
):
|
||||
yield element
|
||||
|
||||
async def _request_ratp_pictos(self):
|
||||
async def _request_ratp_pictos(self) -> AsyncIterator[RatpPicto]:
|
||||
# TODO: Use HTTP
|
||||
async with async_open(
|
||||
"./tests/datasets/pictogrammes-des-lignes-de-metro-rer-tramway-bus-et-noctilien.json",
|
||||
@@ -254,20 +262,25 @@ class IdfmInterface:
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def _format_ratp_pictos(cls, *pictos: RatpPicto) -> dict[str, None | LinePicto]:
|
||||
ret = {}
|
||||
def _format_ratp_pictos(cls, *pictos: RatpPicto) -> Iterable[tuple[str, LinePicto]]:
|
||||
ret = []
|
||||
|
||||
for picto in pictos:
|
||||
if (fields := picto.fields.noms_des_fichiers) is not None:
|
||||
ret[picto.fields.indices_commerciaux] = LinePicto(
|
||||
id=fields.id_,
|
||||
mime_type=f"image/{fields.format.lower()}",
|
||||
height_px=fields.height,
|
||||
width_px=fields.width,
|
||||
filename=fields.filename,
|
||||
url=f"{cls.RATP_PICTO_URL}/{fields.id_}/download",
|
||||
thumbnail=fields.thumbnail,
|
||||
format=fields.format,
|
||||
ret.append(
|
||||
(
|
||||
picto.fields.indices_commerciaux,
|
||||
LinePicto(
|
||||
id=fields.id_,
|
||||
mime_type=f"image/{fields.format.lower()}",
|
||||
height_px=fields.height,
|
||||
width_px=fields.width,
|
||||
filename=fields.filename,
|
||||
url=f"{cls.RATP_PICTO_URL}/{fields.id_}/download",
|
||||
thumbnail=fields.thumbnail,
|
||||
format=fields.format,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return ret
|
||||
@@ -289,7 +302,7 @@ class IdfmInterface:
|
||||
short_name=fields.shortname_line,
|
||||
name=fields.name_line,
|
||||
status=IdfmLineState(fields.status.value),
|
||||
transport_mode=fields.transportmode.value,
|
||||
transport_mode=TransportMode(fields.transportmode.value),
|
||||
transport_submode=optional_value(fields.transportsubmode),
|
||||
network_name=optional_value(fields.networkname),
|
||||
group_of_lines_id=optional_value(fields.id_groupoflines),
|
||||
@@ -300,9 +313,13 @@ class IdfmInterface:
|
||||
text_colour_hexa=fields.textcolourprint_hexa,
|
||||
operator_id=optional_value(fields.operatorref),
|
||||
operator_name=optional_value(fields.operatorname),
|
||||
accessibility=fields.accessibility.value,
|
||||
visual_signs_available=fields.visualsigns_available.value,
|
||||
audible_signs_available=fields.audiblesigns_available.value,
|
||||
accessibility=IdfmState(fields.accessibility.value),
|
||||
visual_signs_available=IdfmState(
|
||||
fields.visualsigns_available.value
|
||||
),
|
||||
audible_signs_available=IdfmState(
|
||||
fields.audiblesigns_available.value
|
||||
),
|
||||
picto_id=fields.picto.id_ if fields.picto is not None else None,
|
||||
picto=picto,
|
||||
record_id=line.recordid,
|
||||
@@ -317,7 +334,7 @@ class IdfmInterface:
|
||||
for stop in stops:
|
||||
fields = stop.fields
|
||||
try:
|
||||
created_ts = int(fields.arrcreated.timestamp())
|
||||
created_ts = int(fields.arrcreated.timestamp()) # type: ignore
|
||||
except AttributeError:
|
||||
created_ts = None
|
||||
yield Stop(
|
||||
@@ -329,13 +346,13 @@ class IdfmInterface:
|
||||
postal_region=fields.arrpostalregion,
|
||||
xepsg2154=fields.arrxepsg2154,
|
||||
yepsg2154=fields.arryepsg2154,
|
||||
transport_mode=fields.arrtype.value,
|
||||
transport_mode=TransportMode(fields.arrtype.value),
|
||||
version=fields.arrversion,
|
||||
created_ts=created_ts,
|
||||
changed_ts=int(fields.arrchanged.timestamp()),
|
||||
accessibility=fields.arraccessibility.value,
|
||||
visual_signs_available=fields.arrvisualsigns.value,
|
||||
audible_signs_available=fields.arraudiblesignals.value,
|
||||
accessibility=IdfmState(fields.arraccessibility.value),
|
||||
visual_signs_available=IdfmState(fields.arrvisualsigns.value),
|
||||
audible_signs_available=IdfmState(fields.arraudiblesignals.value),
|
||||
record_id=stop.recordid,
|
||||
record_ts=int(stop.record_timestamp.timestamp()),
|
||||
)
|
||||
@@ -345,7 +362,7 @@ class IdfmInterface:
|
||||
for stop_area in stop_areas:
|
||||
fields = stop_area.fields
|
||||
try:
|
||||
created_ts = int(fields.arrcreated.timestamp())
|
||||
created_ts = int(fields.zdacreated.timestamp()) # type: ignore
|
||||
except AttributeError:
|
||||
created_ts = None
|
||||
yield StopArea(
|
||||
@@ -355,7 +372,7 @@ class IdfmInterface:
|
||||
postal_region=fields.zdapostalregion,
|
||||
xepsg2154=fields.zdaxepsg2154,
|
||||
yepsg2154=fields.zdayepsg2154,
|
||||
type=fields.zdatype.value,
|
||||
type=StopAreaType(fields.zdatype.value),
|
||||
version=fields.zdaversion,
|
||||
created_ts=created_ts,
|
||||
changed_ts=int(fields.zdachanged.timestamp()),
|
||||
@@ -368,22 +385,22 @@ class IdfmInterface:
|
||||
|
||||
picto = line.picto
|
||||
if picto is not None:
|
||||
picto_data = await self._get_line_picto(line)
|
||||
async with async_open(target, "wb") as fd:
|
||||
await fd.write(picto_data)
|
||||
line_picto_path = target
|
||||
line_picto_format = picto.mime_type
|
||||
if (picto_data := await self._get_line_picto(line)) is not None:
|
||||
async with async_open(target, "wb") as fd:
|
||||
await fd.write(bytes(picto_data))
|
||||
line_picto_path = target
|
||||
line_picto_format = picto.mime_type
|
||||
|
||||
print(f"render_line_picto: {time() - begin_ts}")
|
||||
return (line_picto_path, line_picto_format)
|
||||
|
||||
async def _get_line_picto(self, line: Line) -> Optional[ByteString]:
|
||||
async def _get_line_picto(self, line: Line) -> ByteString | None:
|
||||
print("---------------------------------------------------------------------")
|
||||
begin_ts = time()
|
||||
data = None
|
||||
|
||||
picto = line.picto
|
||||
if picto is not None:
|
||||
if picto is not None and picto.url is not None:
|
||||
headers = (
|
||||
self._http_headers if picto.url.startswith(self.IDFM_ROOT_URL) else None
|
||||
)
|
||||
@@ -401,31 +418,18 @@ class IdfmInterface:
|
||||
print("---------------------------------------------------------------------")
|
||||
return data
|
||||
|
||||
async def get_next_passages(self, stop_point_id: str) -> Optional[IdfmResponse]:
|
||||
# print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
|
||||
begin_ts = time()
|
||||
async def get_next_passages(self, stop_point_id: str) -> IdfmResponse | None:
|
||||
ret = None
|
||||
params = {"MonitoringRef": f"STIF:StopPoint:Q:{stop_point_id}:"}
|
||||
session_begin_ts = time()
|
||||
async with ClientSession(headers=self._http_headers) as session:
|
||||
session_creation_ts = time()
|
||||
# print(f"Session creation {session_creation_ts - session_begin_ts}")
|
||||
async with session.get(self.IDFM_STOP_MON_URL, params=params) as response:
|
||||
get_end_ts = time()
|
||||
# print(f"GET {get_end_ts - session_creation_ts}")
|
||||
if response.status == 200:
|
||||
get_end_ts = time()
|
||||
# print(f"GET {get_end_ts - session_creation_ts}")
|
||||
data = await response.read()
|
||||
# print(data)
|
||||
try:
|
||||
ret = self._response_json_decoder.decode(data)
|
||||
except ValidationError as err:
|
||||
print(err)
|
||||
# print(f"read {time() - get_end_ts}")
|
||||
|
||||
# print(f"get_next_passages: {time() - begin_ts}")
|
||||
# print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
|
||||
return ret
|
||||
|
||||
async def get_destinations(self, stop_point_id: str) -> Iterable[str]:
|
||||
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum, StrEnum
|
||||
from typing import Any, Literal, Optional, NamedTuple
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
from msgspec import Struct
|
||||
|
||||
@@ -88,7 +88,7 @@ class Stop(Struct):
|
||||
Stops = dict[str, Stop]
|
||||
|
||||
|
||||
class StopAreaType(Enum):
|
||||
class StopAreaType(StrEnum):
|
||||
metroStation = "metroStation"
|
||||
onstreetBus = "onstreetBus"
|
||||
onstreetTram = "onstreetTram"
|
||||
@@ -101,7 +101,7 @@ class StopAreaFields(Struct, kw_only=True):
|
||||
zdatown: str
|
||||
zdaversion: str
|
||||
zdaid: str
|
||||
zdacreated: Optional[datetime] = None
|
||||
zdacreated: datetime | None = None
|
||||
zdatype: StopAreaType
|
||||
zdayepsg2154: int
|
||||
zdapostalregion: str
|
||||
@@ -118,13 +118,13 @@ class StopArea(Struct):
|
||||
|
||||
class StopAreaStopAssociationFields(Struct, kw_only=True):
|
||||
arrid: str # TODO: use int ?
|
||||
artid: Optional[str] = None
|
||||
artid: str | None = None
|
||||
arrversion: str
|
||||
zdcid: str
|
||||
version: int
|
||||
zdaid: str
|
||||
zdaversion: str
|
||||
artversion: Optional[str] = None
|
||||
artversion: str | None = None
|
||||
|
||||
|
||||
class StopAreaStopAssociation(Struct):
|
||||
@@ -153,20 +153,20 @@ class LineFields(Struct, kw_only=True):
|
||||
name_line: str
|
||||
status: IdfmLineState
|
||||
accessibility: IdfmState
|
||||
shortname_groupoflines: Optional[str] = None
|
||||
shortname_groupoflines: str | None = None
|
||||
transportmode: TransportMode
|
||||
colourweb_hexa: str
|
||||
textcolourprint_hexa: str
|
||||
transportsubmode: Optional[TransportSubMode] = TransportSubMode.unknown
|
||||
operatorref: Optional[str] = None
|
||||
transportsubmode: TransportSubMode | None = TransportSubMode.unknown
|
||||
operatorref: str | None = None
|
||||
visualsigns_available: IdfmState
|
||||
networkname: Optional[str] = None
|
||||
networkname: str | None = None
|
||||
id_line: str
|
||||
id_groupoflines: Optional[str] = None
|
||||
operatorname: Optional[str] = None
|
||||
id_groupoflines: str | None = None
|
||||
operatorname: str | None = None
|
||||
audiblesigns_available: IdfmState
|
||||
shortname_line: str
|
||||
picto: Optional[LinePicto] = None
|
||||
picto: LinePicto | None = None
|
||||
|
||||
|
||||
class Line(Struct):
|
||||
@@ -220,17 +220,17 @@ class TrainNumber(Struct):
|
||||
|
||||
|
||||
class MonitoredCall(Struct, kw_only=True):
|
||||
Order: Optional[int] = None
|
||||
Order: int | None = None
|
||||
StopPointName: list[Value]
|
||||
VehicleAtStop: bool
|
||||
DestinationDisplay: list[Value]
|
||||
AimedArrivalTime: Optional[datetime] = None
|
||||
ExpectedArrivalTime: Optional[datetime] = None
|
||||
ArrivalPlatformName: Optional[Value] = None
|
||||
AimedDepartureTime: Optional[datetime] = None
|
||||
ExpectedDepartureTime: Optional[datetime] = None
|
||||
ArrivalStatus: TrainStatus = None
|
||||
DepartureStatus: TrainStatus = None
|
||||
AimedArrivalTime: datetime | None = None
|
||||
ExpectedArrivalTime: datetime | None = None
|
||||
ArrivalPlatformName: Value | None = None
|
||||
AimedDepartureTime: datetime | None = None
|
||||
ExpectedDepartureTime: datetime | None = None
|
||||
ArrivalStatus: TrainStatus | None = None
|
||||
DepartureStatus: TrainStatus | None = None
|
||||
|
||||
|
||||
class MonitoredVehicleJourney(Struct, kw_only=True):
|
||||
@@ -240,7 +240,7 @@ class MonitoredVehicleJourney(Struct, kw_only=True):
|
||||
DestinationRef: Value
|
||||
DestinationName: list[Value] | None = None
|
||||
JourneyNote: list[Value] | None = None
|
||||
TrainNumbers: Optional[TrainNumber] = None
|
||||
TrainNumbers: TrainNumber | None = None
|
||||
MonitoredCall: MonitoredCall
|
||||
|
||||
|
||||
|
@@ -1,3 +1,6 @@
|
||||
from .line import Line, LinePicto
|
||||
from .stop import Stop, StopArea
|
||||
from .user import UserLastStopSearchResults
|
||||
|
||||
|
||||
__all__ = ["Line", "LinePicto", "Stop", "StopArea", "UserLastStopSearchResults"]
|
||||
|
@@ -1,6 +1,6 @@
|
||||
from asyncio import gather as asyncio_gather
|
||||
from collections import defaultdict
|
||||
from typing import Iterable, Self
|
||||
from typing import Iterable, Self, Sequence
|
||||
|
||||
from sqlalchemy import (
|
||||
BigInteger,
|
||||
@@ -13,8 +13,7 @@ from sqlalchemy import (
|
||||
String,
|
||||
Table,
|
||||
)
|
||||
from sqlalchemy.orm import Mapped, relationship, selectinload
|
||||
from sqlalchemy.orm.attributes import set_committed_value
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship, selectinload
|
||||
from sqlalchemy.sql.expression import tuple_
|
||||
|
||||
from ..db import Base, db
|
||||
@@ -38,14 +37,14 @@ class LinePicto(Base):
|
||||
|
||||
db = db
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
mime_type = Column(String, nullable=False)
|
||||
height_px = Column(Integer, nullable=False)
|
||||
width_px = Column(Integer, nullable=False)
|
||||
filename = Column(String, nullable=False)
|
||||
url = Column(String, nullable=False)
|
||||
thumbnail = Column(Boolean, nullable=False)
|
||||
format = Column(String, nullable=False)
|
||||
id = mapped_column(String, primary_key=True)
|
||||
mime_type = mapped_column(String, nullable=False)
|
||||
height_px = mapped_column(Integer, nullable=False)
|
||||
width_px = mapped_column(Integer, nullable=False)
|
||||
filename = mapped_column(String, nullable=False)
|
||||
url = mapped_column(String, nullable=False)
|
||||
thumbnail = mapped_column(Boolean, nullable=False)
|
||||
format = mapped_column(String, nullable=False)
|
||||
|
||||
__tablename__ = "line_pictos"
|
||||
|
||||
@@ -54,35 +53,35 @@ class Line(Base):
|
||||
|
||||
db = db
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
id = mapped_column(String, primary_key=True)
|
||||
|
||||
short_name = Column(String)
|
||||
name = Column(String, nullable=False)
|
||||
status = Column(Enum(IdfmLineState), nullable=False)
|
||||
transport_mode = Column(Enum(TransportMode), nullable=False)
|
||||
transport_submode = Column(Enum(TransportSubMode), nullable=False)
|
||||
short_name = mapped_column(String)
|
||||
name = mapped_column(String, nullable=False)
|
||||
status = mapped_column(Enum(IdfmLineState), nullable=False)
|
||||
transport_mode = mapped_column(Enum(TransportMode), nullable=False)
|
||||
transport_submode = mapped_column(Enum(TransportSubMode), nullable=False)
|
||||
|
||||
network_name = Column(String)
|
||||
group_of_lines_id = Column(String)
|
||||
group_of_lines_shortname = Column(String)
|
||||
network_name = mapped_column(String)
|
||||
group_of_lines_id = mapped_column(String)
|
||||
group_of_lines_shortname = mapped_column(String)
|
||||
|
||||
colour_web_hexa = Column(String, nullable=False)
|
||||
text_colour_hexa = Column(String, nullable=False)
|
||||
colour_web_hexa = mapped_column(String, nullable=False)
|
||||
text_colour_hexa = mapped_column(String, nullable=False)
|
||||
|
||||
operator_id = Column(String)
|
||||
operator_name = Column(String)
|
||||
operator_id = mapped_column(String)
|
||||
operator_name = mapped_column(String)
|
||||
|
||||
accessibility = Column(Enum(IdfmState), nullable=False)
|
||||
visual_signs_available = Column(Enum(IdfmState), nullable=False)
|
||||
audible_signs_available = Column(Enum(IdfmState), nullable=False)
|
||||
accessibility = mapped_column(Enum(IdfmState), nullable=False)
|
||||
visual_signs_available = mapped_column(Enum(IdfmState), nullable=False)
|
||||
audible_signs_available = mapped_column(Enum(IdfmState), nullable=False)
|
||||
|
||||
picto_id = Column(String, ForeignKey("line_pictos.id"))
|
||||
picto_id = mapped_column(String, ForeignKey("line_pictos.id"))
|
||||
picto: Mapped[LinePicto] = relationship(LinePicto, lazy="selectin")
|
||||
|
||||
record_id = Column(String, nullable=False)
|
||||
record_ts = Column(BigInteger, nullable=False)
|
||||
record_id = mapped_column(String, nullable=False)
|
||||
record_ts = mapped_column(BigInteger, nullable=False)
|
||||
|
||||
stops: Mapped[list["_Stop"]] = relationship(
|
||||
stops: Mapped[list[_Stop]] = relationship(
|
||||
"_Stop",
|
||||
secondary=line_stop_association_table,
|
||||
back_populates="lines",
|
||||
@@ -94,67 +93,81 @@ class Line(Base):
|
||||
@classmethod
|
||||
async def get_by_name(
|
||||
cls, name: str, operator_name: None | str = None
|
||||
) -> list[Self]:
|
||||
) -> Sequence[Self] | None:
|
||||
session = cls.db.session
|
||||
if session is None:
|
||||
return None
|
||||
|
||||
filters = {"name": name}
|
||||
if operator_name is not None:
|
||||
filters["operator_name"] = operator_name
|
||||
|
||||
lines = None
|
||||
stmt = (
|
||||
select(Line)
|
||||
select(cls)
|
||||
.filter_by(**filters)
|
||||
.options(selectinload(Line.stops), selectinload(Line.picto))
|
||||
.options(selectinload(cls.stops), selectinload(cls.picto))
|
||||
)
|
||||
res = await cls.db.session.execute(stmt)
|
||||
res = await session.execute(stmt)
|
||||
lines = res.scalars().all()
|
||||
|
||||
return lines
|
||||
|
||||
@classmethod
|
||||
async def _add_picto_to_line(cls, line: str | Self, picto: LinePicto) -> None:
|
||||
formatted_line: Self | None = None
|
||||
if isinstance(line, str):
|
||||
if (lines := await cls.get_by_name(line)) is not None:
|
||||
if len(lines) == 1:
|
||||
line = lines[0]
|
||||
formatted_line = lines[0]
|
||||
else:
|
||||
for candidate_line in lines:
|
||||
if candidate_line.operator_name == "RATP":
|
||||
line = candidate_line
|
||||
formatted_line = candidate_line
|
||||
break
|
||||
else:
|
||||
formatted_line = line
|
||||
|
||||
if isinstance(line, Line) and line.picto is None:
|
||||
line.picto = picto
|
||||
line.picto_id = picto.id
|
||||
if isinstance(formatted_line, Line) and formatted_line.picto is None:
|
||||
formatted_line.picto = picto
|
||||
formatted_line.picto_id = picto.id
|
||||
|
||||
@classmethod
|
||||
async def add_pictos(cls, line_to_pictos: dict[str | Self, LinePicto]) -> None:
|
||||
async def add_pictos(cls, line_to_pictos: Iterable[tuple[str, LinePicto]]) -> bool:
|
||||
session = cls.db.session
|
||||
if session is None:
|
||||
return False
|
||||
|
||||
await asyncio_gather(
|
||||
*[
|
||||
cls._add_picto_to_line(line, picto)
|
||||
for line, picto in line_to_pictos.items()
|
||||
]
|
||||
*[cls._add_picto_to_line(line, picto) for line, picto in line_to_pictos]
|
||||
)
|
||||
|
||||
await cls.db.session.commit()
|
||||
await session.commit()
|
||||
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
async def add_stops(cls, line_to_stop_ids: Iterable[tuple[str, str, str]]) -> int:
|
||||
async def add_stops(cls, line_to_stop_ids: Iterable[tuple[str, str, int]]) -> int:
|
||||
session = cls.db.session
|
||||
if session is None:
|
||||
return 0
|
||||
|
||||
line_names_ops, stop_ids = set(), set()
|
||||
for line_name, operator_name, stop_id in line_to_stop_ids:
|
||||
line_names_ops.add((line_name, operator_name))
|
||||
stop_ids.add(stop_id)
|
||||
|
||||
res = await cls.db.session.execute(
|
||||
lines_res = await session.execute(
|
||||
select(Line).where(
|
||||
tuple_(Line.name, Line.operator_name).in_(line_names_ops)
|
||||
)
|
||||
)
|
||||
|
||||
lines = defaultdict(list)
|
||||
for line in res.scalars():
|
||||
for line in lines_res.scalars():
|
||||
lines[(line.name, line.operator_name)].append(line)
|
||||
|
||||
res = await cls.db.session.execute(select(_Stop).where(_Stop.id.in_(stop_ids)))
|
||||
stops = {stop.id: stop for stop in res.scalars()}
|
||||
stops_res = await session.execute(select(_Stop).where(_Stop.id.in_(stop_ids)))
|
||||
stops = {stop.id: stop for stop in stops_res.scalars()}
|
||||
|
||||
found = 0
|
||||
for line_name, operator_name, stop_id in line_to_stop_ids:
|
||||
@@ -167,8 +180,10 @@ class Line(Base):
|
||||
print(f"No line found for {line_name}/{operator_name}")
|
||||
else:
|
||||
print(
|
||||
f"No stop found for {stop_id} id (used by {line_name}/{operator_name})"
|
||||
f"No stop found for {stop_id} id"
|
||||
f"(used by {line_name}/{operator_name})"
|
||||
)
|
||||
|
||||
await cls.db.session.commit()
|
||||
await session.commit()
|
||||
|
||||
return found
|
||||
|
@@ -1,4 +1,6 @@
|
||||
from typing import Iterable, Self
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Iterable, Self, Sequence, TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import (
|
||||
BigInteger,
|
||||
@@ -10,12 +12,22 @@ from sqlalchemy import (
|
||||
String,
|
||||
Table,
|
||||
)
|
||||
from sqlalchemy.orm import Mapped, relationship, selectinload, with_polymorphic
|
||||
from sqlalchemy.orm import (
|
||||
mapped_column,
|
||||
Mapped,
|
||||
relationship,
|
||||
selectinload,
|
||||
with_polymorphic,
|
||||
)
|
||||
from sqlalchemy.schema import Index
|
||||
|
||||
from ..db import Base, db
|
||||
from ..idfm_interface.idfm_types import TransportMode, IdfmState, StopAreaType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .line import Line
|
||||
|
||||
|
||||
stop_area_stop_association_table = Table(
|
||||
"stop_area_stop_association_table",
|
||||
Base.metadata,
|
||||
@@ -28,18 +40,18 @@ class _Stop(Base):
|
||||
|
||||
db = db
|
||||
|
||||
id = Column(BigInteger, primary_key=True)
|
||||
kind = Column(String)
|
||||
id = mapped_column(BigInteger, primary_key=True)
|
||||
kind = mapped_column(String)
|
||||
|
||||
name = Column(String, nullable=False, index=True)
|
||||
town_name = Column(String, nullable=False)
|
||||
postal_region = Column(String, nullable=False)
|
||||
xepsg2154 = Column(BigInteger, nullable=False)
|
||||
yepsg2154 = Column(BigInteger, nullable=False)
|
||||
version = Column(String, nullable=False)
|
||||
created_ts = Column(BigInteger)
|
||||
changed_ts = Column(BigInteger, nullable=False)
|
||||
lines: Mapped[list["Line"]] = relationship(
|
||||
name = mapped_column(String, nullable=False, index=True)
|
||||
town_name = mapped_column(String, nullable=False)
|
||||
postal_region = mapped_column(String, nullable=False)
|
||||
xepsg2154 = mapped_column(BigInteger, nullable=False)
|
||||
yepsg2154 = mapped_column(BigInteger, nullable=False)
|
||||
version = mapped_column(String, nullable=False)
|
||||
created_ts = mapped_column(BigInteger)
|
||||
changed_ts = mapped_column(BigInteger, nullable=False)
|
||||
lines: Mapped[list[Line]] = relationship(
|
||||
"Line",
|
||||
secondary="line_stop_association_table",
|
||||
back_populates="stops",
|
||||
@@ -65,7 +77,11 @@ class _Stop(Base):
|
||||
# TODO: Test https://www.cybertec-postgresql.com/en/postgresql-more-performance-for-like-and-ilike-statements/
|
||||
# TODO: Should be able to remove with_polymorphic ?
|
||||
@classmethod
|
||||
async def get_by_name(cls, name: str) -> list[Self]:
|
||||
async def get_by_name(cls, name: str) -> Sequence[type[_Stop]] | None:
|
||||
session = cls.db.session
|
||||
if session is None:
|
||||
return None
|
||||
|
||||
stop_stop_area = with_polymorphic(_Stop, [Stop, StopArea])
|
||||
stmt = (
|
||||
select(stop_stop_area)
|
||||
@@ -75,22 +91,25 @@ class _Stop(Base):
|
||||
selectinload(stop_stop_area.lines),
|
||||
)
|
||||
)
|
||||
res = await cls.db.session.execute(stmt)
|
||||
return res.scalars()
|
||||
|
||||
res = await session.execute(stmt)
|
||||
stops = res.scalars().all()
|
||||
|
||||
return stops
|
||||
|
||||
|
||||
class Stop(_Stop):
|
||||
|
||||
id = Column(BigInteger, ForeignKey("_stops.id"), primary_key=True)
|
||||
id = mapped_column(BigInteger, ForeignKey("_stops.id"), primary_key=True)
|
||||
|
||||
latitude = Column(Float, nullable=False)
|
||||
longitude = Column(Float, nullable=False)
|
||||
transport_mode = Column(Enum(TransportMode), nullable=False)
|
||||
accessibility = Column(Enum(IdfmState), nullable=False)
|
||||
visual_signs_available = Column(Enum(IdfmState), nullable=False)
|
||||
audible_signs_available = Column(Enum(IdfmState), nullable=False)
|
||||
record_id = Column(String, nullable=False)
|
||||
record_ts = Column(BigInteger, nullable=False)
|
||||
latitude = mapped_column(Float, nullable=False)
|
||||
longitude = mapped_column(Float, nullable=False)
|
||||
transport_mode = mapped_column(Enum(TransportMode), nullable=False)
|
||||
accessibility = mapped_column(Enum(IdfmState), nullable=False)
|
||||
visual_signs_available = mapped_column(Enum(IdfmState), nullable=False)
|
||||
audible_signs_available = mapped_column(Enum(IdfmState), nullable=False)
|
||||
record_id = mapped_column(String, nullable=False)
|
||||
record_ts = mapped_column(BigInteger, nullable=False)
|
||||
|
||||
__tablename__ = "stops"
|
||||
__mapper_args__ = {"polymorphic_identity": "stops", "polymorphic_load": "inline"}
|
||||
@@ -98,11 +117,11 @@ class Stop(_Stop):
|
||||
|
||||
class StopArea(_Stop):
|
||||
|
||||
id = Column(BigInteger, ForeignKey("_stops.id"), primary_key=True)
|
||||
id = mapped_column(BigInteger, ForeignKey("_stops.id"), primary_key=True)
|
||||
|
||||
type = Column(Enum(StopAreaType), nullable=False)
|
||||
stops: Mapped[list[_Stop]] = relationship(
|
||||
_Stop,
|
||||
type = mapped_column(Enum(StopAreaType), nullable=False)
|
||||
stops: Mapped[list["_Stop"]] = relationship(
|
||||
"_Stop",
|
||||
secondary=stop_area_stop_association_table,
|
||||
back_populates="areas",
|
||||
lazy="selectin",
|
||||
@@ -110,24 +129,35 @@ class StopArea(_Stop):
|
||||
)
|
||||
|
||||
__tablename__ = "stop_areas"
|
||||
__mapper_args__ = {"polymorphic_identity": "stop_areas", "polymorphic_load": "inline"}
|
||||
__mapper_args__ = {
|
||||
"polymorphic_identity": "stop_areas",
|
||||
"polymorphic_load": "inline",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
async def add_stops(cls, stop_area_to_stop_ids: Iterable[tuple[str, str]]) -> int:
|
||||
async def add_stops(
|
||||
cls, stop_area_to_stop_ids: Iterable[tuple[int, int]]
|
||||
) -> int | None:
|
||||
session = cls.db.session
|
||||
if session is None:
|
||||
return None
|
||||
|
||||
stop_area_ids, stop_ids = set(), set()
|
||||
for stop_area_id, stop_id in stop_area_to_stop_ids:
|
||||
stop_area_ids.add(stop_area_id)
|
||||
stop_ids.add(stop_id)
|
||||
|
||||
res = await cls.db.session.execute(
|
||||
stop_areas_res = await session.execute(
|
||||
select(StopArea)
|
||||
.where(StopArea.id.in_(stop_area_ids))
|
||||
.options(selectinload(StopArea.stops))
|
||||
)
|
||||
stop_areas = {stop_area.id: stop_area for stop_area in res.scalars()}
|
||||
stop_areas: dict[int, StopArea] = {
|
||||
stop_area.id: stop_area for stop_area in stop_areas_res.scalars()
|
||||
}
|
||||
|
||||
res = await cls.db.session.execute(select(_Stop).where(_Stop.id.in_(stop_ids)))
|
||||
stops = {stop.id: stop for stop in res.scalars()}
|
||||
stop_res = await session.execute(select(_Stop).where(_Stop.id.in_(stop_ids)))
|
||||
stops: dict[int, _Stop] = {stop.id: stop for stop in stop_res.scalars()}
|
||||
|
||||
found = 0
|
||||
for stop_area_id, stop_id in stop_area_to_stop_ids:
|
||||
@@ -140,5 +170,6 @@ class StopArea(_Stop):
|
||||
else:
|
||||
print(f"No stop area found for {stop_area_id}")
|
||||
|
||||
await cls.db.session.commit()
|
||||
await session.commit()
|
||||
|
||||
return found
|
||||
|
@@ -1,5 +1,5 @@
|
||||
from sqlalchemy import Column, ForeignKey, String, Table
|
||||
from sqlalchemy.orm import Mapped, relationship
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from ..db import Base, db
|
||||
from .stop import _Stop
|
||||
@@ -18,8 +18,8 @@ class UserLastStopSearchResults(Base):
|
||||
|
||||
__tablename__ = "user_last_stop_search_results"
|
||||
|
||||
user_mxid = Column(String, primary_key=True)
|
||||
request_content = Column(String, nullable=False)
|
||||
stops: Mapped[list[_Stop]] = relationship(
|
||||
user_mxid = mapped_column(String, primary_key=True)
|
||||
request_content = mapped_column(String, nullable=False)
|
||||
stops: Mapped[_Stop] = relationship(
|
||||
_Stop, secondary=user_last_stop_search_stops_associations_table
|
||||
)
|
||||
|
0
backend/backend/py.typed
Normal file
0
backend/backend/py.typed
Normal file
@@ -1,3 +1,5 @@
|
||||
from .line import Line, TransportMode
|
||||
from .next_passage import NextPassage, NextPassages
|
||||
from .stop import Stop, StopArea
|
||||
|
||||
__all__ = ["Line", "NextPassage", "NextPassages", "Stop", "StopArea", "TransportMode"]
|
||||
|
@@ -1,5 +1,4 @@
|
||||
from enum import StrEnum
|
||||
from typing import Self
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -29,10 +28,11 @@ class TransportMode(StrEnum):
|
||||
# idfm_types.TransportMode.rail + idfm_types.TransportSubMode.railShuttle
|
||||
val = "val"
|
||||
|
||||
# Self return type replaced by "TransportMode" to fix following mypy error:
|
||||
# Incompatible return value type (got "TransportMode", expected "Self")
|
||||
# TODO: Is it the good fix ?
|
||||
@classmethod
|
||||
def from_idfm_transport_mode(
|
||||
cls, mode: IdfmTransportMode, sub_mode: IdfmTransportSubMode
|
||||
) -> Self:
|
||||
def from_idfm_transport_mode(cls, mode: str, sub_mode: str) -> "TransportMode":
|
||||
if mode == IdfmTransportMode.rail:
|
||||
if sub_mode == IdfmTransportSubMode.regionalRail:
|
||||
return cls.rail_ter
|
||||
@@ -42,7 +42,7 @@ class TransportMode(StrEnum):
|
||||
return cls.rail_transilien
|
||||
if sub_mode == IdfmTransportSubMode.railShuttle:
|
||||
return cls.val
|
||||
return TransportMode(mode)
|
||||
return cls(mode)
|
||||
|
||||
|
||||
class Line(BaseModel):
|
||||
|
@@ -8,11 +8,11 @@ class NextPassage(BaseModel):
|
||||
operator: str
|
||||
destinations: list[str]
|
||||
atStop: bool
|
||||
aimedArrivalTs: None | int
|
||||
expectedArrivalTs: None | int
|
||||
arrivalPlatformName: None | str
|
||||
aimedDepartTs: None | int
|
||||
expectedDepartTs: None | int
|
||||
aimedArrivalTs: int | None
|
||||
expectedArrivalTs: int | None
|
||||
arrivalPlatformName: str | None
|
||||
aimedDepartTs: int | None
|
||||
expectedDepartTs: int | None
|
||||
arrivalStatus: TrainStatus
|
||||
departStatus: TrainStatus
|
||||
|
||||
|
@@ -1,6 +1,6 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..idfm_interface import IdfmLineState, IdfmState, StopAreaType, TransportMode
|
||||
from ..idfm_interface import StopAreaType
|
||||
|
||||
|
||||
class Stop(BaseModel):
|
||||
|
@@ -1,10 +1,10 @@
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from os import environ
|
||||
from os import environ, EX_USAGE
|
||||
from typing import Sequence
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from rich import print
|
||||
|
||||
@@ -21,7 +21,9 @@ from backend.schemas import (
|
||||
)
|
||||
|
||||
API_KEY = environ.get("API_KEY")
|
||||
# TODO: Add error message if no key is given.
|
||||
if API_KEY is None:
|
||||
print('No "API_KEY" environment variable set... abort.')
|
||||
exit(EX_USAGE)
|
||||
|
||||
# TODO: Remove postgresql+asyncpg from environ variable
|
||||
DB_PATH = "postgresql+asyncpg://cer_user:cer_password@127.0.0.1:5438/cer_db"
|
||||
@@ -44,9 +46,9 @@ idfm_interface = IdfmInterface(API_KEY, db)
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup():
|
||||
# await db.connect(DB_PATH, clear_static_data=True)
|
||||
# await idfm_interface.startup()
|
||||
await db.connect(DB_PATH, clear_static_data=False)
|
||||
await db.connect(DB_PATH, clear_static_data=True)
|
||||
await idfm_interface.startup()
|
||||
# await db.connect(DB_PATH, clear_static_data=False)
|
||||
print("Connected")
|
||||
|
||||
|
||||
@@ -61,12 +63,12 @@ STATIC_ROOT = "../frontend/"
|
||||
app.mount("/widget", StaticFiles(directory=STATIC_ROOT, html=True), name="widget")
|
||||
|
||||
|
||||
def optional_datetime_to_ts(dt: datetime) -> int | None:
|
||||
return dt.timestamp() if dt else None
|
||||
def optional_datetime_to_ts(dt: datetime | None) -> int | None:
|
||||
return int(dt.timestamp()) if dt else None
|
||||
|
||||
|
||||
@app.get("/line/{line_id}", response_model=LineSchema)
|
||||
async def get_line(line_id: str) -> JSONResponse:
|
||||
async def get_line(line_id: str) -> LineSchema:
|
||||
line: Line | None = await Line.get_by_id(line_id)
|
||||
|
||||
if line is None:
|
||||
@@ -91,7 +93,7 @@ async def get_line(line_id: str) -> JSONResponse:
|
||||
|
||||
|
||||
def _format_stop(stop: Stop) -> StopSchema:
|
||||
print(stop.__dict__)
|
||||
# print(stop.__dict__)
|
||||
return StopSchema(
|
||||
id=stop.id,
|
||||
name=stop.name,
|
||||
@@ -103,15 +105,17 @@ def _format_stop(stop: Stop) -> StopSchema:
|
||||
lines=[line.id for line in stop.lines],
|
||||
)
|
||||
|
||||
|
||||
# châtelet
|
||||
|
||||
|
||||
@app.get("/stop/")
|
||||
async def get_stop(
|
||||
name: str = "", limit: int = 10
|
||||
) -> list[StopAreaSchema | StopSchema]:
|
||||
) -> Sequence[StopAreaSchema | StopSchema]:
|
||||
# TODO: Add limit support
|
||||
|
||||
formatted = []
|
||||
formatted: list[StopAreaSchema | StopSchema] = []
|
||||
matching_stops = await Stop.get_by_name(name)
|
||||
# print(matching_stops, flush=True)
|
||||
|
||||
@@ -153,15 +157,17 @@ async def get_stop(
|
||||
|
||||
# TODO: Cache response for 30 secs ?
|
||||
@app.get("/stop/nextPassages/{stop_id}")
|
||||
async def get_next_passages(stop_id: str) -> JSONResponse:
|
||||
async def get_next_passages(stop_id: str) -> NextPassagesSchema | None:
|
||||
res = await idfm_interface.get_next_passages(stop_id)
|
||||
|
||||
# print(res)
|
||||
if res is None:
|
||||
return None
|
||||
|
||||
service_delivery = res.Siri.ServiceDelivery
|
||||
stop_monitoring_deliveries = service_delivery.StopMonitoringDelivery
|
||||
|
||||
by_line_by_dst_passages = defaultdict(lambda: defaultdict(list))
|
||||
by_line_by_dst_passages: dict[
|
||||
str, dict[str, list[NextPassageSchema]]
|
||||
] = defaultdict(lambda: defaultdict(list))
|
||||
|
||||
for delivery in stop_monitoring_deliveries:
|
||||
for stop_visit in delivery.MonitoredStopVisit:
|
||||
@@ -190,7 +196,9 @@ async def get_next_passages(stop_id: str) -> JSONResponse:
|
||||
atStop=call.VehicleAtStop,
|
||||
aimedArrivalTs=optional_datetime_to_ts(call.AimedArrivalTime),
|
||||
expectedArrivalTs=optional_datetime_to_ts(call.ExpectedArrivalTime),
|
||||
arrivalPlatformName=call.ArrivalPlatformName.value if call.ArrivalPlatformName else None,
|
||||
arrivalPlatformName=call.ArrivalPlatformName.value
|
||||
if call.ArrivalPlatformName
|
||||
else None,
|
||||
aimedDepartTs=optional_datetime_to_ts(call.AimedDepartureTime),
|
||||
expectedDepartTs=optional_datetime_to_ts(call.ExpectedDepartureTime),
|
||||
arrivalStatus=call.ArrivalStatus.value,
|
||||
|
@@ -11,7 +11,7 @@ python = "^3.11"
|
||||
aiohttp = "^3.8.3"
|
||||
rich = "^12.6.0"
|
||||
aiofiles = "^22.1.0"
|
||||
sqlalchemy = {extras = ["asyncio"], version = "^1.4.46"}
|
||||
sqlalchemy = {extras = ["asyncio"], version = "^2.0.1"}
|
||||
fastapi = "^0.88.0"
|
||||
uvicorn = "^0.20.0"
|
||||
asyncpg = "^0.27.0"
|
||||
@@ -28,7 +28,6 @@ rope = "^1.3.0"
|
||||
python-lsp-black = "^1.2.1"
|
||||
black = "^22.10.0"
|
||||
types-aiofiles = "^22.1.0.2"
|
||||
sqlalchemy-stubs = "^0.4"
|
||||
wrapt = "^1.14.1"
|
||||
pydocstyle = "^6.2.2"
|
||||
dill = "^0.3.6"
|
||||
@@ -38,9 +37,16 @@ autopep8 = "^2.0.1"
|
||||
pyflakes = "^3.0.1"
|
||||
yapf = "^0.32.0"
|
||||
whatthepatch = "^1.0.4"
|
||||
sqlalchemy = {extras = ["mypy"], version = "^2.0.1"}
|
||||
mypy = "^1.0.0"
|
||||
|
||||
[tool.mypy]
|
||||
plugins = "sqlalchemy.ext.mypy.plugin"
|
||||
|
||||
[tool.black]
|
||||
target-version = ['py311']
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 88
|
||||
[too.ruff.per-file-ignores]
|
||||
"__init__.py" = ["E401"]
|
||||
|
Reference in New Issue
Block a user