From e34355e8beaf8da1fb9d41af4ba449491d2b3a44 Mon Sep 17 00:00:00 2001 From: Adrien Date: Wed, 8 Feb 2023 22:10:21 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=8F=B7=EF=B8=8F=20Make=20python=20linters?= =?UTF-8?q?=20happy?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/backend/__init__.py | 0 backend/backend/db/__init__.py | 2 + backend/backend/db/base_class.py | 55 ++++---- backend/backend/db/db.py | 88 ++++--------- backend/backend/idfm_interface/__init__.py | 68 +++++++++- .../backend/idfm_interface/idfm_interface.py | 122 ++++++++--------- backend/backend/idfm_interface/idfm_types.py | 42 +++--- backend/backend/models/__init__.py | 3 + backend/backend/models/line.py | 123 ++++++++++-------- backend/backend/models/stop.py | 103 ++++++++++----- backend/backend/models/user.py | 8 +- backend/backend/py.typed | 0 backend/backend/schemas/__init__.py | 2 + backend/backend/schemas/line.py | 10 +- backend/backend/schemas/next_passage.py | 10 +- backend/backend/schemas/stop.py | 2 +- backend/main.py | 42 +++--- backend/pyproject.toml | 10 +- 18 files changed, 400 insertions(+), 290 deletions(-) create mode 100644 backend/backend/__init__.py create mode 100644 backend/backend/py.typed diff --git a/backend/backend/__init__.py b/backend/backend/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/backend/db/__init__.py b/backend/backend/db/__init__.py index 8d9bf84..eb5ec72 100644 --- a/backend/backend/db/__init__.py +++ b/backend/backend/db/__init__.py @@ -1,4 +1,6 @@ from .db import Database from .base_class import Base +__all__ = ["Base"] + db = Database() diff --git a/backend/backend/db/base_class.py b/backend/backend/db/base_class.py index 2334a75..4025d82 100644 --- a/backend/backend/db/base_class.py +++ b/backend/backend/db/base_class.py @@ -1,34 +1,39 @@ -from collections.abc import Iterable +from __future__ import annotations + +from typing import Iterable, Self, TYPE_CHECKING from sqlalchemy import select from sqlalchemy.exc import IntegrityError -from sqlalchemy.orm import declarative_base -from typing import Iterable, Self +from sqlalchemy.orm import DeclarativeBase -Base = declarative_base() -Base.db = None +if TYPE_CHECKING: + from .db import Database -async def base_add(cls, stops: Self | Iterable[Self]) -> bool: - try: - method = ( - cls.db.session.add_all - if isinstance(stops, Iterable) - else cls.db.session.add - ) - method(stops) - await cls.db.session.commit() - except IntegrityError as err: - print(err) +class Base(DeclarativeBase): + db: Database | None = None + @classmethod + async def add(cls, stops: Self | Iterable[Self]) -> bool: + try: + if isinstance(stops, Iterable): + cls.db.session.add_all(stops) # type: ignore + else: + cls.db.session.add(stops) # type: ignore + await cls.db.session.commit() # type: ignore + except (AttributeError, IntegrityError) as err: + print(err) + return False -Base.add = classmethod(base_add) + return True - -async def base_get_by_id(cls, id_: int | str) -> None | Base: - res = await cls.db.session.execute(select(cls).where(cls.id == id_)) - element = res.scalar_one_or_none() - return element - - -Base.get_by_id = classmethod(base_get_by_id) + @classmethod + async def get_by_id(cls, id_: int | str) -> Self | None: + try: + stmt = select(cls).where(cls.id == id_) # type: ignore + res = await cls.db.session.execute(stmt) # type: ignore + element = res.scalar_one_or_none() + except AttributeError as err: + print(err) + element = None + return element diff --git a/backend/backend/db/db.py b/backend/backend/db/db.py index e050258..e1f8772 100644 --- a/backend/backend/db/db.py +++ b/backend/backend/db/db.py @@ -1,80 +1,48 @@ -from asyncio import gather as asyncio_gather -from functools import wraps -from pathlib import Path -from time import time -from typing import Callable, Iterable, Optional - -from rich import print -from sqlalchemy import event, select, tuple_ -from sqlalchemy.engine import Engine -from sqlalchemy.exc import IntegrityError -from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine -from sqlalchemy.orm import ( - selectinload, - sessionmaker, - with_polymorphic, +from sqlalchemy import text +from sqlalchemy.ext.asyncio import ( + async_sessionmaker, + AsyncEngine, + AsyncSession, + create_async_engine, ) -from sqlalchemy.orm.attributes import set_committed_value from .base_class import Base -# import logging - -# logging.basicConfig() -# logger = logging.getLogger("bot.sqltime") -# logger.setLevel(logging.DEBUG) - - -# @event.listens_for(Engine, "before_cursor_execute") -# def before_cursor_execute(conn, cursor, statement, parameters, context, executemany): -# conn.info.setdefault("query_start_time", []).append(time()) -# logger.debug("Start Query: %s", statement) - - -# @event.listens_for(Engine, "after_cursor_execute") -# def after_cursor_execute(conn, cursor, statement, parameters, context, executemany): -# total = time() - conn.info["query_start_time"].pop(-1) -# logger.debug("Query Complete!") -# logger.debug("Total Time: %f", total) - - class Database: def __init__(self) -> None: - self._engine = None - self._session_maker = None - self._session = None + self._engine: AsyncEngine | None = None + self._session_maker: async_sessionmaker[AsyncSession] | None = None + self._session: AsyncSession | None = None @property - def session(self) -> None: - if self._session is None: - self._session = self._session_maker() + def session(self) -> AsyncSession | None: + if self._session is None and (session_maker := self._session_maker) is not None: + self._session = session_maker() return self._session - def use_session(func: Callable): - @wraps(func) - async def wrapper(self, *args, **kwargs): - if self._check_session() is not None: - return await func(self, *args, **kwargs) - # TODO: Raise an exception ? + async def connect(self, db_path: str, clear_static_data: bool = False) -> bool: - return wrapper - - async def connect(self, db_path: str, clear_static_data: bool = False) -> None: # TODO: Preserve UserLastStopSearchResults table from drop. self._engine = create_async_engine(db_path) - self._session_maker = sessionmaker( - self._engine, expire_on_commit=False, class_=AsyncSession - ) - await self.session.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm;") + if self._engine is not None: + self._session_maker = async_sessionmaker( + self._engine, expire_on_commit=False, class_=AsyncSession + ) + if (session := self.session) is not None: + await session.execute(text("CREATE EXTENSION IF NOT EXISTS pg_trgm;")) - async with self._engine.begin() as conn: - if clear_static_data: - await conn.run_sync(Base.metadata.drop_all) - await conn.run_sync(Base.metadata.create_all) + async with self._engine.begin() as conn: + if clear_static_data: + await conn.run_sync(Base.metadata.drop_all) + await conn.run_sync(Base.metadata.create_all) + + return True async def disconnect(self) -> None: if self._session is not None: await self._session.close() self._session = None - await self._engine.dispose() + + if self._engine is not None: + await self._engine.dispose() diff --git a/backend/backend/idfm_interface/__init__.py b/backend/backend/idfm_interface/__init__.py index c97a846..070099a 100644 --- a/backend/backend/idfm_interface/__init__.py +++ b/backend/backend/idfm_interface/__init__.py @@ -1,2 +1,68 @@ from .idfm_interface import IdfmInterface -from .idfm_types import * + +from .idfm_types import ( + Coordinate, + FramedVehicleJourney, + IdfmLineState, + IdfmOperator, + IdfmResponse, + IdfmState, + LinePicto, + LineFields, + Line, + MonitoredCall, + MonitoredVehicleJourney, + Point, + Siri, + ServiceDelivery, + Stop, + StopArea, + StopAreaFields, + StopAreaStopAssociation, + StopAreaStopAssociationFields, + StopAreaType, + StopDelivery, + StopFields, + StopLineAsso, + StopLineAssoFields, + StopMonitoringDelivery, + TrainNumber, + TrainStatus, + TransportMode, + TransportSubMode, + Value, +) + +__all__ = [ + "Coordinate", + "FramedVehicleJourney", + "IdfmInterface", + "IdfmLineState", + "IdfmOperator", + "IdfmResponse", + "IdfmState", + "LinePicto", + "LineFields", + "Line", + "MonitoredCall", + "MonitoredVehicleJourney", + "Point", + "Siri", + "ServiceDelivery", + "Stop", + "StopArea", + "StopAreaFields", + "StopAreaStopAssociation", + "StopAreaStopAssociationFields", + "StopAreaType", + "StopDelivery", + "StopFields", + "StopLineAsso", + "StopLineAssoFields", + "StopMonitoringDelivery", + "TrainNumber", + "TrainStatus", + "TransportMode", + "TransportSubMode", + "Value", +] diff --git a/backend/backend/idfm_interface/idfm_interface.py b/backend/backend/idfm_interface/idfm_interface.py index 5744be9..faf2769 100644 --- a/backend/backend/idfm_interface/idfm_interface.py +++ b/backend/backend/idfm_interface/idfm_interface.py @@ -1,7 +1,6 @@ -from pathlib import Path from re import compile as re_compile from time import time -from typing import ByteString, Iterable, List, Optional +from typing import AsyncIterator, ByteString, Callable, Iterable, List, Type from aiofiles import open as async_open from aiohttp import ClientSession @@ -16,14 +15,14 @@ from .idfm_types import ( IdfmLineState, IdfmResponse, Line as IdfmLine, - MonitoredVehicleJourney, LinePicto as IdfmPicto, IdfmState, Stop as IdfmStop, StopArea as IdfmStopArea, StopAreaStopAssociation, + StopAreaType, StopLineAsso as IdfmStopLineAsso, - Stops, + TransportMode, ) from .ratp_types import Picto as RatpPicto @@ -40,7 +39,10 @@ class IdfmInterface: IDFM_PICTO_URL = f"{IDFM_ROOT_URL}/referentiel-des-lignes/files" RATP_ROOT_URL = "https://data.ratp.fr/explore/dataset" - RATP_PICTO_URL = f"{RATP_ROOT_URL}/pictogrammes-des-lignes-de-metro-rer-tramway-bus-et-noctilien/files" + RATP_PICTO_URL = ( + f"{RATP_ROOT_URL}" + "/pictogrammes-des-lignes-de-metro-rer-tramway-bus-et-noctilien/files" + ) OPERATOR_RE = re_compile(r"[^:]+:Operator::([^:]+):") LINE_RE = re_compile(r"[^:]+:Line::([^:]+):") @@ -64,7 +66,7 @@ class IdfmInterface: async def startup(self) -> None: BATCH_SIZE = 10000 - STEPS = ( + STEPS: tuple[tuple[Type[Stop] | Type[StopArea], Callable, Callable], ...] = ( ( StopArea, self._request_idfm_stop_areas, @@ -132,13 +134,13 @@ class IdfmInterface: pictos.append(picto) if len(pictos) == batch_size: formatted_pictos = IdfmInterface._format_ratp_pictos(*pictos) - await LinePicto.add(formatted_pictos.values()) + await LinePicto.add(map(lambda picto: picto[1], formatted_pictos)) await Line.add_pictos(formatted_pictos) pictos.clear() if pictos: formatted_pictos = IdfmInterface._format_ratp_pictos(*pictos) - await LinePicto.add(formatted_pictos.values()) + await LinePicto.add(map(lambda picto: picto[1], formatted_pictos)) await Line.add_pictos(formatted_pictos) async def _load_lines_stops_assos(self, batch_size: int = 5000) -> None: @@ -174,16 +176,18 @@ class IdfmInterface: assos.append((int(fields.zdaid), int(fields.arrid))) if len(assos) == batch_size: total_assos_nb += batch_size - total_found_nb += await StopArea.add_stops(assos) + if (found_nb := await StopArea.add_stops(assos)) is not None: + total_found_nb += found_nb assos.clear() if assos: total_assos_nb += len(assos) - total_found_nb += await StopArea.add_stops(assos) + if (found_nb := await StopArea.add_stops(assos)) is not None: + total_found_nb += found_nb print(f"{total_found_nb} stop area <-> stop ({total_assos_nb = } found)") - async def _request_idfm_stops(self): + async def _request_idfm_stops(self) -> AsyncIterator[IdfmStop]: # headers = {"Accept": "application/json", "apikey": self._api_key} # async with ClientSession(headers=headers) as session: # async with session.get(self.STOPS_URL) as response: @@ -196,19 +200,21 @@ class IdfmInterface: for element in self._json_stops_decoder.decode(await raw.read()): yield element - async def _request_idfm_stop_areas(self): + async def _request_idfm_stop_areas(self) -> AsyncIterator[IdfmStopArea]: # TODO: Use HTTP async with async_open("./tests/datasets/zones-d-arrets.json", "rb") as raw: for element in self._json_stop_areas_decoder.decode(await raw.read()): yield element - async def _request_idfm_lines(self): + async def _request_idfm_lines(self) -> AsyncIterator[IdfmLine]: # TODO: Use HTTP async with async_open("./tests/datasets/lines_dataset.json", "rb") as raw: for element in self._json_lines_decoder.decode(await raw.read()): yield element - async def _request_idfm_stops_lines_associations(self): + async def _request_idfm_stops_lines_associations( + self, + ) -> AsyncIterator[IdfmStopLineAsso]: # TODO: Use HTTP async with async_open("./tests/datasets/arrets-lignes.json", "rb") as raw: for element in self._json_stops_lines_assos_decoder.decode( @@ -216,7 +222,9 @@ class IdfmInterface: ): yield element - async def _request_idfm_stop_area_stop_associations(self): + async def _request_idfm_stop_area_stop_associations( + self, + ) -> AsyncIterator[StopAreaStopAssociation]: # TODO: Use HTTP async with async_open("./tests/datasets/relations.json", "rb") as raw: for element in self._json_stop_area_stop_asso_decoder.decode( @@ -224,7 +232,7 @@ class IdfmInterface: ): yield element - async def _request_ratp_pictos(self): + async def _request_ratp_pictos(self) -> AsyncIterator[RatpPicto]: # TODO: Use HTTP async with async_open( "./tests/datasets/pictogrammes-des-lignes-de-metro-rer-tramway-bus-et-noctilien.json", @@ -254,20 +262,25 @@ class IdfmInterface: return ret @classmethod - def _format_ratp_pictos(cls, *pictos: RatpPicto) -> dict[str, None | LinePicto]: - ret = {} + def _format_ratp_pictos(cls, *pictos: RatpPicto) -> Iterable[tuple[str, LinePicto]]: + ret = [] for picto in pictos: if (fields := picto.fields.noms_des_fichiers) is not None: - ret[picto.fields.indices_commerciaux] = LinePicto( - id=fields.id_, - mime_type=f"image/{fields.format.lower()}", - height_px=fields.height, - width_px=fields.width, - filename=fields.filename, - url=f"{cls.RATP_PICTO_URL}/{fields.id_}/download", - thumbnail=fields.thumbnail, - format=fields.format, + ret.append( + ( + picto.fields.indices_commerciaux, + LinePicto( + id=fields.id_, + mime_type=f"image/{fields.format.lower()}", + height_px=fields.height, + width_px=fields.width, + filename=fields.filename, + url=f"{cls.RATP_PICTO_URL}/{fields.id_}/download", + thumbnail=fields.thumbnail, + format=fields.format, + ), + ) ) return ret @@ -289,7 +302,7 @@ class IdfmInterface: short_name=fields.shortname_line, name=fields.name_line, status=IdfmLineState(fields.status.value), - transport_mode=fields.transportmode.value, + transport_mode=TransportMode(fields.transportmode.value), transport_submode=optional_value(fields.transportsubmode), network_name=optional_value(fields.networkname), group_of_lines_id=optional_value(fields.id_groupoflines), @@ -300,9 +313,13 @@ class IdfmInterface: text_colour_hexa=fields.textcolourprint_hexa, operator_id=optional_value(fields.operatorref), operator_name=optional_value(fields.operatorname), - accessibility=fields.accessibility.value, - visual_signs_available=fields.visualsigns_available.value, - audible_signs_available=fields.audiblesigns_available.value, + accessibility=IdfmState(fields.accessibility.value), + visual_signs_available=IdfmState( + fields.visualsigns_available.value + ), + audible_signs_available=IdfmState( + fields.audiblesigns_available.value + ), picto_id=fields.picto.id_ if fields.picto is not None else None, picto=picto, record_id=line.recordid, @@ -317,7 +334,7 @@ class IdfmInterface: for stop in stops: fields = stop.fields try: - created_ts = int(fields.arrcreated.timestamp()) + created_ts = int(fields.arrcreated.timestamp()) # type: ignore except AttributeError: created_ts = None yield Stop( @@ -329,13 +346,13 @@ class IdfmInterface: postal_region=fields.arrpostalregion, xepsg2154=fields.arrxepsg2154, yepsg2154=fields.arryepsg2154, - transport_mode=fields.arrtype.value, + transport_mode=TransportMode(fields.arrtype.value), version=fields.arrversion, created_ts=created_ts, changed_ts=int(fields.arrchanged.timestamp()), - accessibility=fields.arraccessibility.value, - visual_signs_available=fields.arrvisualsigns.value, - audible_signs_available=fields.arraudiblesignals.value, + accessibility=IdfmState(fields.arraccessibility.value), + visual_signs_available=IdfmState(fields.arrvisualsigns.value), + audible_signs_available=IdfmState(fields.arraudiblesignals.value), record_id=stop.recordid, record_ts=int(stop.record_timestamp.timestamp()), ) @@ -345,7 +362,7 @@ class IdfmInterface: for stop_area in stop_areas: fields = stop_area.fields try: - created_ts = int(fields.arrcreated.timestamp()) + created_ts = int(fields.zdacreated.timestamp()) # type: ignore except AttributeError: created_ts = None yield StopArea( @@ -355,7 +372,7 @@ class IdfmInterface: postal_region=fields.zdapostalregion, xepsg2154=fields.zdaxepsg2154, yepsg2154=fields.zdayepsg2154, - type=fields.zdatype.value, + type=StopAreaType(fields.zdatype.value), version=fields.zdaversion, created_ts=created_ts, changed_ts=int(fields.zdachanged.timestamp()), @@ -368,22 +385,22 @@ class IdfmInterface: picto = line.picto if picto is not None: - picto_data = await self._get_line_picto(line) - async with async_open(target, "wb") as fd: - await fd.write(picto_data) - line_picto_path = target - line_picto_format = picto.mime_type + if (picto_data := await self._get_line_picto(line)) is not None: + async with async_open(target, "wb") as fd: + await fd.write(bytes(picto_data)) + line_picto_path = target + line_picto_format = picto.mime_type print(f"render_line_picto: {time() - begin_ts}") return (line_picto_path, line_picto_format) - async def _get_line_picto(self, line: Line) -> Optional[ByteString]: + async def _get_line_picto(self, line: Line) -> ByteString | None: print("---------------------------------------------------------------------") begin_ts = time() data = None picto = line.picto - if picto is not None: + if picto is not None and picto.url is not None: headers = ( self._http_headers if picto.url.startswith(self.IDFM_ROOT_URL) else None ) @@ -401,31 +418,18 @@ class IdfmInterface: print("---------------------------------------------------------------------") return data - async def get_next_passages(self, stop_point_id: str) -> Optional[IdfmResponse]: - # print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") - begin_ts = time() + async def get_next_passages(self, stop_point_id: str) -> IdfmResponse | None: ret = None params = {"MonitoringRef": f"STIF:StopPoint:Q:{stop_point_id}:"} - session_begin_ts = time() async with ClientSession(headers=self._http_headers) as session: - session_creation_ts = time() - # print(f"Session creation {session_creation_ts - session_begin_ts}") async with session.get(self.IDFM_STOP_MON_URL, params=params) as response: - get_end_ts = time() - # print(f"GET {get_end_ts - session_creation_ts}") if response.status == 200: - get_end_ts = time() - # print(f"GET {get_end_ts - session_creation_ts}") data = await response.read() - # print(data) try: ret = self._response_json_decoder.decode(data) except ValidationError as err: print(err) - # print(f"read {time() - get_end_ts}") - # print(f"get_next_passages: {time() - begin_ts}") - # print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") return ret async def get_destinations(self, stop_point_id: str) -> Iterable[str]: diff --git a/backend/backend/idfm_interface/idfm_types.py b/backend/backend/idfm_interface/idfm_types.py index b703e1f..82b9cbb 100644 --- a/backend/backend/idfm_interface/idfm_types.py +++ b/backend/backend/idfm_interface/idfm_types.py @@ -2,7 +2,7 @@ from __future__ import annotations from datetime import datetime from enum import Enum, StrEnum -from typing import Any, Literal, Optional, NamedTuple +from typing import Any, NamedTuple from msgspec import Struct @@ -88,7 +88,7 @@ class Stop(Struct): Stops = dict[str, Stop] -class StopAreaType(Enum): +class StopAreaType(StrEnum): metroStation = "metroStation" onstreetBus = "onstreetBus" onstreetTram = "onstreetTram" @@ -101,7 +101,7 @@ class StopAreaFields(Struct, kw_only=True): zdatown: str zdaversion: str zdaid: str - zdacreated: Optional[datetime] = None + zdacreated: datetime | None = None zdatype: StopAreaType zdayepsg2154: int zdapostalregion: str @@ -118,13 +118,13 @@ class StopArea(Struct): class StopAreaStopAssociationFields(Struct, kw_only=True): arrid: str # TODO: use int ? - artid: Optional[str] = None + artid: str | None = None arrversion: str zdcid: str version: int zdaid: str zdaversion: str - artversion: Optional[str] = None + artversion: str | None = None class StopAreaStopAssociation(Struct): @@ -153,20 +153,20 @@ class LineFields(Struct, kw_only=True): name_line: str status: IdfmLineState accessibility: IdfmState - shortname_groupoflines: Optional[str] = None + shortname_groupoflines: str | None = None transportmode: TransportMode colourweb_hexa: str textcolourprint_hexa: str - transportsubmode: Optional[TransportSubMode] = TransportSubMode.unknown - operatorref: Optional[str] = None + transportsubmode: TransportSubMode | None = TransportSubMode.unknown + operatorref: str | None = None visualsigns_available: IdfmState - networkname: Optional[str] = None + networkname: str | None = None id_line: str - id_groupoflines: Optional[str] = None - operatorname: Optional[str] = None + id_groupoflines: str | None = None + operatorname: str | None = None audiblesigns_available: IdfmState shortname_line: str - picto: Optional[LinePicto] = None + picto: LinePicto | None = None class Line(Struct): @@ -220,17 +220,17 @@ class TrainNumber(Struct): class MonitoredCall(Struct, kw_only=True): - Order: Optional[int] = None + Order: int | None = None StopPointName: list[Value] VehicleAtStop: bool DestinationDisplay: list[Value] - AimedArrivalTime: Optional[datetime] = None - ExpectedArrivalTime: Optional[datetime] = None - ArrivalPlatformName: Optional[Value] = None - AimedDepartureTime: Optional[datetime] = None - ExpectedDepartureTime: Optional[datetime] = None - ArrivalStatus: TrainStatus = None - DepartureStatus: TrainStatus = None + AimedArrivalTime: datetime | None = None + ExpectedArrivalTime: datetime | None = None + ArrivalPlatformName: Value | None = None + AimedDepartureTime: datetime | None = None + ExpectedDepartureTime: datetime | None = None + ArrivalStatus: TrainStatus | None = None + DepartureStatus: TrainStatus | None = None class MonitoredVehicleJourney(Struct, kw_only=True): @@ -240,7 +240,7 @@ class MonitoredVehicleJourney(Struct, kw_only=True): DestinationRef: Value DestinationName: list[Value] | None = None JourneyNote: list[Value] | None = None - TrainNumbers: Optional[TrainNumber] = None + TrainNumbers: TrainNumber | None = None MonitoredCall: MonitoredCall diff --git a/backend/backend/models/__init__.py b/backend/backend/models/__init__.py index c6060f7..ef1a352 100644 --- a/backend/backend/models/__init__.py +++ b/backend/backend/models/__init__.py @@ -1,3 +1,6 @@ from .line import Line, LinePicto from .stop import Stop, StopArea from .user import UserLastStopSearchResults + + +__all__ = ["Line", "LinePicto", "Stop", "StopArea", "UserLastStopSearchResults"] diff --git a/backend/backend/models/line.py b/backend/backend/models/line.py index 7527188..d3108cd 100644 --- a/backend/backend/models/line.py +++ b/backend/backend/models/line.py @@ -1,6 +1,6 @@ from asyncio import gather as asyncio_gather from collections import defaultdict -from typing import Iterable, Self +from typing import Iterable, Self, Sequence from sqlalchemy import ( BigInteger, @@ -13,8 +13,7 @@ from sqlalchemy import ( String, Table, ) -from sqlalchemy.orm import Mapped, relationship, selectinload -from sqlalchemy.orm.attributes import set_committed_value +from sqlalchemy.orm import Mapped, mapped_column, relationship, selectinload from sqlalchemy.sql.expression import tuple_ from ..db import Base, db @@ -38,14 +37,14 @@ class LinePicto(Base): db = db - id = Column(String, primary_key=True) - mime_type = Column(String, nullable=False) - height_px = Column(Integer, nullable=False) - width_px = Column(Integer, nullable=False) - filename = Column(String, nullable=False) - url = Column(String, nullable=False) - thumbnail = Column(Boolean, nullable=False) - format = Column(String, nullable=False) + id = mapped_column(String, primary_key=True) + mime_type = mapped_column(String, nullable=False) + height_px = mapped_column(Integer, nullable=False) + width_px = mapped_column(Integer, nullable=False) + filename = mapped_column(String, nullable=False) + url = mapped_column(String, nullable=False) + thumbnail = mapped_column(Boolean, nullable=False) + format = mapped_column(String, nullable=False) __tablename__ = "line_pictos" @@ -54,35 +53,35 @@ class Line(Base): db = db - id = Column(String, primary_key=True) + id = mapped_column(String, primary_key=True) - short_name = Column(String) - name = Column(String, nullable=False) - status = Column(Enum(IdfmLineState), nullable=False) - transport_mode = Column(Enum(TransportMode), nullable=False) - transport_submode = Column(Enum(TransportSubMode), nullable=False) + short_name = mapped_column(String) + name = mapped_column(String, nullable=False) + status = mapped_column(Enum(IdfmLineState), nullable=False) + transport_mode = mapped_column(Enum(TransportMode), nullable=False) + transport_submode = mapped_column(Enum(TransportSubMode), nullable=False) - network_name = Column(String) - group_of_lines_id = Column(String) - group_of_lines_shortname = Column(String) + network_name = mapped_column(String) + group_of_lines_id = mapped_column(String) + group_of_lines_shortname = mapped_column(String) - colour_web_hexa = Column(String, nullable=False) - text_colour_hexa = Column(String, nullable=False) + colour_web_hexa = mapped_column(String, nullable=False) + text_colour_hexa = mapped_column(String, nullable=False) - operator_id = Column(String) - operator_name = Column(String) + operator_id = mapped_column(String) + operator_name = mapped_column(String) - accessibility = Column(Enum(IdfmState), nullable=False) - visual_signs_available = Column(Enum(IdfmState), nullable=False) - audible_signs_available = Column(Enum(IdfmState), nullable=False) + accessibility = mapped_column(Enum(IdfmState), nullable=False) + visual_signs_available = mapped_column(Enum(IdfmState), nullable=False) + audible_signs_available = mapped_column(Enum(IdfmState), nullable=False) - picto_id = Column(String, ForeignKey("line_pictos.id")) + picto_id = mapped_column(String, ForeignKey("line_pictos.id")) picto: Mapped[LinePicto] = relationship(LinePicto, lazy="selectin") - record_id = Column(String, nullable=False) - record_ts = Column(BigInteger, nullable=False) + record_id = mapped_column(String, nullable=False) + record_ts = mapped_column(BigInteger, nullable=False) - stops: Mapped[list["_Stop"]] = relationship( + stops: Mapped[list[_Stop]] = relationship( "_Stop", secondary=line_stop_association_table, back_populates="lines", @@ -94,67 +93,81 @@ class Line(Base): @classmethod async def get_by_name( cls, name: str, operator_name: None | str = None - ) -> list[Self]: + ) -> Sequence[Self] | None: + session = cls.db.session + if session is None: + return None + filters = {"name": name} if operator_name is not None: filters["operator_name"] = operator_name - lines = None stmt = ( - select(Line) + select(cls) .filter_by(**filters) - .options(selectinload(Line.stops), selectinload(Line.picto)) + .options(selectinload(cls.stops), selectinload(cls.picto)) ) - res = await cls.db.session.execute(stmt) + res = await session.execute(stmt) lines = res.scalars().all() + return lines @classmethod async def _add_picto_to_line(cls, line: str | Self, picto: LinePicto) -> None: + formatted_line: Self | None = None if isinstance(line, str): if (lines := await cls.get_by_name(line)) is not None: if len(lines) == 1: - line = lines[0] + formatted_line = lines[0] else: for candidate_line in lines: if candidate_line.operator_name == "RATP": - line = candidate_line + formatted_line = candidate_line break + else: + formatted_line = line - if isinstance(line, Line) and line.picto is None: - line.picto = picto - line.picto_id = picto.id + if isinstance(formatted_line, Line) and formatted_line.picto is None: + formatted_line.picto = picto + formatted_line.picto_id = picto.id @classmethod - async def add_pictos(cls, line_to_pictos: dict[str | Self, LinePicto]) -> None: + async def add_pictos(cls, line_to_pictos: Iterable[tuple[str, LinePicto]]) -> bool: + session = cls.db.session + if session is None: + return False + await asyncio_gather( - *[ - cls._add_picto_to_line(line, picto) - for line, picto in line_to_pictos.items() - ] + *[cls._add_picto_to_line(line, picto) for line, picto in line_to_pictos] ) - await cls.db.session.commit() + await session.commit() + + return True @classmethod - async def add_stops(cls, line_to_stop_ids: Iterable[tuple[str, str, str]]) -> int: + async def add_stops(cls, line_to_stop_ids: Iterable[tuple[str, str, int]]) -> int: + session = cls.db.session + if session is None: + return 0 + line_names_ops, stop_ids = set(), set() for line_name, operator_name, stop_id in line_to_stop_ids: line_names_ops.add((line_name, operator_name)) stop_ids.add(stop_id) - res = await cls.db.session.execute( + lines_res = await session.execute( select(Line).where( tuple_(Line.name, Line.operator_name).in_(line_names_ops) ) ) lines = defaultdict(list) - for line in res.scalars(): + for line in lines_res.scalars(): lines[(line.name, line.operator_name)].append(line) - res = await cls.db.session.execute(select(_Stop).where(_Stop.id.in_(stop_ids))) - stops = {stop.id: stop for stop in res.scalars()} + stops_res = await session.execute(select(_Stop).where(_Stop.id.in_(stop_ids))) + stops = {stop.id: stop for stop in stops_res.scalars()} found = 0 for line_name, operator_name, stop_id in line_to_stop_ids: @@ -167,8 +180,10 @@ class Line(Base): print(f"No line found for {line_name}/{operator_name}") else: print( - f"No stop found for {stop_id} id (used by {line_name}/{operator_name})" + f"No stop found for {stop_id} id" + f"(used by {line_name}/{operator_name})" ) - await cls.db.session.commit() + await session.commit() + return found diff --git a/backend/backend/models/stop.py b/backend/backend/models/stop.py index 6bb9543..bf7daa1 100644 --- a/backend/backend/models/stop.py +++ b/backend/backend/models/stop.py @@ -1,4 +1,6 @@ -from typing import Iterable, Self +from __future__ import annotations + +from typing import Iterable, Self, Sequence, TYPE_CHECKING from sqlalchemy import ( BigInteger, @@ -10,12 +12,22 @@ from sqlalchemy import ( String, Table, ) -from sqlalchemy.orm import Mapped, relationship, selectinload, with_polymorphic +from sqlalchemy.orm import ( + mapped_column, + Mapped, + relationship, + selectinload, + with_polymorphic, +) from sqlalchemy.schema import Index from ..db import Base, db from ..idfm_interface.idfm_types import TransportMode, IdfmState, StopAreaType +if TYPE_CHECKING: + from .line import Line + + stop_area_stop_association_table = Table( "stop_area_stop_association_table", Base.metadata, @@ -28,18 +40,18 @@ class _Stop(Base): db = db - id = Column(BigInteger, primary_key=True) - kind = Column(String) + id = mapped_column(BigInteger, primary_key=True) + kind = mapped_column(String) - name = Column(String, nullable=False, index=True) - town_name = Column(String, nullable=False) - postal_region = Column(String, nullable=False) - xepsg2154 = Column(BigInteger, nullable=False) - yepsg2154 = Column(BigInteger, nullable=False) - version = Column(String, nullable=False) - created_ts = Column(BigInteger) - changed_ts = Column(BigInteger, nullable=False) - lines: Mapped[list["Line"]] = relationship( + name = mapped_column(String, nullable=False, index=True) + town_name = mapped_column(String, nullable=False) + postal_region = mapped_column(String, nullable=False) + xepsg2154 = mapped_column(BigInteger, nullable=False) + yepsg2154 = mapped_column(BigInteger, nullable=False) + version = mapped_column(String, nullable=False) + created_ts = mapped_column(BigInteger) + changed_ts = mapped_column(BigInteger, nullable=False) + lines: Mapped[list[Line]] = relationship( "Line", secondary="line_stop_association_table", back_populates="stops", @@ -65,7 +77,11 @@ class _Stop(Base): # TODO: Test https://www.cybertec-postgresql.com/en/postgresql-more-performance-for-like-and-ilike-statements/ # TODO: Should be able to remove with_polymorphic ? @classmethod - async def get_by_name(cls, name: str) -> list[Self]: + async def get_by_name(cls, name: str) -> Sequence[type[_Stop]] | None: + session = cls.db.session + if session is None: + return None + stop_stop_area = with_polymorphic(_Stop, [Stop, StopArea]) stmt = ( select(stop_stop_area) @@ -75,22 +91,25 @@ class _Stop(Base): selectinload(stop_stop_area.lines), ) ) - res = await cls.db.session.execute(stmt) - return res.scalars() + + res = await session.execute(stmt) + stops = res.scalars().all() + + return stops class Stop(_Stop): - id = Column(BigInteger, ForeignKey("_stops.id"), primary_key=True) + id = mapped_column(BigInteger, ForeignKey("_stops.id"), primary_key=True) - latitude = Column(Float, nullable=False) - longitude = Column(Float, nullable=False) - transport_mode = Column(Enum(TransportMode), nullable=False) - accessibility = Column(Enum(IdfmState), nullable=False) - visual_signs_available = Column(Enum(IdfmState), nullable=False) - audible_signs_available = Column(Enum(IdfmState), nullable=False) - record_id = Column(String, nullable=False) - record_ts = Column(BigInteger, nullable=False) + latitude = mapped_column(Float, nullable=False) + longitude = mapped_column(Float, nullable=False) + transport_mode = mapped_column(Enum(TransportMode), nullable=False) + accessibility = mapped_column(Enum(IdfmState), nullable=False) + visual_signs_available = mapped_column(Enum(IdfmState), nullable=False) + audible_signs_available = mapped_column(Enum(IdfmState), nullable=False) + record_id = mapped_column(String, nullable=False) + record_ts = mapped_column(BigInteger, nullable=False) __tablename__ = "stops" __mapper_args__ = {"polymorphic_identity": "stops", "polymorphic_load": "inline"} @@ -98,11 +117,11 @@ class Stop(_Stop): class StopArea(_Stop): - id = Column(BigInteger, ForeignKey("_stops.id"), primary_key=True) + id = mapped_column(BigInteger, ForeignKey("_stops.id"), primary_key=True) - type = Column(Enum(StopAreaType), nullable=False) - stops: Mapped[list[_Stop]] = relationship( - _Stop, + type = mapped_column(Enum(StopAreaType), nullable=False) + stops: Mapped[list["_Stop"]] = relationship( + "_Stop", secondary=stop_area_stop_association_table, back_populates="areas", lazy="selectin", @@ -110,24 +129,35 @@ class StopArea(_Stop): ) __tablename__ = "stop_areas" - __mapper_args__ = {"polymorphic_identity": "stop_areas", "polymorphic_load": "inline"} + __mapper_args__ = { + "polymorphic_identity": "stop_areas", + "polymorphic_load": "inline", + } @classmethod - async def add_stops(cls, stop_area_to_stop_ids: Iterable[tuple[str, str]]) -> int: + async def add_stops( + cls, stop_area_to_stop_ids: Iterable[tuple[int, int]] + ) -> int | None: + session = cls.db.session + if session is None: + return None + stop_area_ids, stop_ids = set(), set() for stop_area_id, stop_id in stop_area_to_stop_ids: stop_area_ids.add(stop_area_id) stop_ids.add(stop_id) - res = await cls.db.session.execute( + stop_areas_res = await session.execute( select(StopArea) .where(StopArea.id.in_(stop_area_ids)) .options(selectinload(StopArea.stops)) ) - stop_areas = {stop_area.id: stop_area for stop_area in res.scalars()} + stop_areas: dict[int, StopArea] = { + stop_area.id: stop_area for stop_area in stop_areas_res.scalars() + } - res = await cls.db.session.execute(select(_Stop).where(_Stop.id.in_(stop_ids))) - stops = {stop.id: stop for stop in res.scalars()} + stop_res = await session.execute(select(_Stop).where(_Stop.id.in_(stop_ids))) + stops: dict[int, _Stop] = {stop.id: stop for stop in stop_res.scalars()} found = 0 for stop_area_id, stop_id in stop_area_to_stop_ids: @@ -140,5 +170,6 @@ class StopArea(_Stop): else: print(f"No stop area found for {stop_area_id}") - await cls.db.session.commit() + await session.commit() + return found diff --git a/backend/backend/models/user.py b/backend/backend/models/user.py index 4705d8e..5dedb8f 100644 --- a/backend/backend/models/user.py +++ b/backend/backend/models/user.py @@ -1,5 +1,5 @@ from sqlalchemy import Column, ForeignKey, String, Table -from sqlalchemy.orm import Mapped, relationship +from sqlalchemy.orm import Mapped, mapped_column, relationship from ..db import Base, db from .stop import _Stop @@ -18,8 +18,8 @@ class UserLastStopSearchResults(Base): __tablename__ = "user_last_stop_search_results" - user_mxid = Column(String, primary_key=True) - request_content = Column(String, nullable=False) - stops: Mapped[list[_Stop]] = relationship( + user_mxid = mapped_column(String, primary_key=True) + request_content = mapped_column(String, nullable=False) + stops: Mapped[_Stop] = relationship( _Stop, secondary=user_last_stop_search_stops_associations_table ) diff --git a/backend/backend/py.typed b/backend/backend/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/backend/backend/schemas/__init__.py b/backend/backend/schemas/__init__.py index 232658b..010c439 100644 --- a/backend/backend/schemas/__init__.py +++ b/backend/backend/schemas/__init__.py @@ -1,3 +1,5 @@ from .line import Line, TransportMode from .next_passage import NextPassage, NextPassages from .stop import Stop, StopArea + +__all__ = ["Line", "NextPassage", "NextPassages", "Stop", "StopArea", "TransportMode"] diff --git a/backend/backend/schemas/line.py b/backend/backend/schemas/line.py index 74f6369..8a65841 100644 --- a/backend/backend/schemas/line.py +++ b/backend/backend/schemas/line.py @@ -1,5 +1,4 @@ from enum import StrEnum -from typing import Self from pydantic import BaseModel @@ -29,10 +28,11 @@ class TransportMode(StrEnum): # idfm_types.TransportMode.rail + idfm_types.TransportSubMode.railShuttle val = "val" + # Self return type replaced by "TransportMode" to fix following mypy error: + # Incompatible return value type (got "TransportMode", expected "Self") + # TODO: Is it the good fix ? @classmethod - def from_idfm_transport_mode( - cls, mode: IdfmTransportMode, sub_mode: IdfmTransportSubMode - ) -> Self: + def from_idfm_transport_mode(cls, mode: str, sub_mode: str) -> "TransportMode": if mode == IdfmTransportMode.rail: if sub_mode == IdfmTransportSubMode.regionalRail: return cls.rail_ter @@ -42,7 +42,7 @@ class TransportMode(StrEnum): return cls.rail_transilien if sub_mode == IdfmTransportSubMode.railShuttle: return cls.val - return TransportMode(mode) + return cls(mode) class Line(BaseModel): diff --git a/backend/backend/schemas/next_passage.py b/backend/backend/schemas/next_passage.py index 4bf32a9..68d0298 100644 --- a/backend/backend/schemas/next_passage.py +++ b/backend/backend/schemas/next_passage.py @@ -8,11 +8,11 @@ class NextPassage(BaseModel): operator: str destinations: list[str] atStop: bool - aimedArrivalTs: None | int - expectedArrivalTs: None | int - arrivalPlatformName: None | str - aimedDepartTs: None | int - expectedDepartTs: None | int + aimedArrivalTs: int | None + expectedArrivalTs: int | None + arrivalPlatformName: str | None + aimedDepartTs: int | None + expectedDepartTs: int | None arrivalStatus: TrainStatus departStatus: TrainStatus diff --git a/backend/backend/schemas/stop.py b/backend/backend/schemas/stop.py index f239223..f4e9277 100644 --- a/backend/backend/schemas/stop.py +++ b/backend/backend/schemas/stop.py @@ -1,6 +1,6 @@ from pydantic import BaseModel -from ..idfm_interface import IdfmLineState, IdfmState, StopAreaType, TransportMode +from ..idfm_interface import StopAreaType class Stop(BaseModel): diff --git a/backend/main.py b/backend/main.py index d5a6f14..fa9b01e 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1,10 +1,10 @@ from collections import defaultdict from datetime import datetime -from os import environ +from os import environ, EX_USAGE +from typing import Sequence from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import JSONResponse from fastapi.staticfiles import StaticFiles from rich import print @@ -21,7 +21,9 @@ from backend.schemas import ( ) API_KEY = environ.get("API_KEY") -# TODO: Add error message if no key is given. +if API_KEY is None: + print('No "API_KEY" environment variable set... abort.') + exit(EX_USAGE) # TODO: Remove postgresql+asyncpg from environ variable DB_PATH = "postgresql+asyncpg://cer_user:cer_password@127.0.0.1:5438/cer_db" @@ -44,9 +46,9 @@ idfm_interface = IdfmInterface(API_KEY, db) @app.on_event("startup") async def startup(): - # await db.connect(DB_PATH, clear_static_data=True) - # await idfm_interface.startup() - await db.connect(DB_PATH, clear_static_data=False) + await db.connect(DB_PATH, clear_static_data=True) + await idfm_interface.startup() + # await db.connect(DB_PATH, clear_static_data=False) print("Connected") @@ -61,12 +63,12 @@ STATIC_ROOT = "../frontend/" app.mount("/widget", StaticFiles(directory=STATIC_ROOT, html=True), name="widget") -def optional_datetime_to_ts(dt: datetime) -> int | None: - return dt.timestamp() if dt else None +def optional_datetime_to_ts(dt: datetime | None) -> int | None: + return int(dt.timestamp()) if dt else None @app.get("/line/{line_id}", response_model=LineSchema) -async def get_line(line_id: str) -> JSONResponse: +async def get_line(line_id: str) -> LineSchema: line: Line | None = await Line.get_by_id(line_id) if line is None: @@ -91,7 +93,7 @@ async def get_line(line_id: str) -> JSONResponse: def _format_stop(stop: Stop) -> StopSchema: - print(stop.__dict__) + # print(stop.__dict__) return StopSchema( id=stop.id, name=stop.name, @@ -103,15 +105,17 @@ def _format_stop(stop: Stop) -> StopSchema: lines=[line.id for line in stop.lines], ) + # châtelet + @app.get("/stop/") async def get_stop( name: str = "", limit: int = 10 -) -> list[StopAreaSchema | StopSchema]: +) -> Sequence[StopAreaSchema | StopSchema]: # TODO: Add limit support - formatted = [] + formatted: list[StopAreaSchema | StopSchema] = [] matching_stops = await Stop.get_by_name(name) # print(matching_stops, flush=True) @@ -153,15 +157,17 @@ async def get_stop( # TODO: Cache response for 30 secs ? @app.get("/stop/nextPassages/{stop_id}") -async def get_next_passages(stop_id: str) -> JSONResponse: +async def get_next_passages(stop_id: str) -> NextPassagesSchema | None: res = await idfm_interface.get_next_passages(stop_id) - - # print(res) + if res is None: + return None service_delivery = res.Siri.ServiceDelivery stop_monitoring_deliveries = service_delivery.StopMonitoringDelivery - by_line_by_dst_passages = defaultdict(lambda: defaultdict(list)) + by_line_by_dst_passages: dict[ + str, dict[str, list[NextPassageSchema]] + ] = defaultdict(lambda: defaultdict(list)) for delivery in stop_monitoring_deliveries: for stop_visit in delivery.MonitoredStopVisit: @@ -190,7 +196,9 @@ async def get_next_passages(stop_id: str) -> JSONResponse: atStop=call.VehicleAtStop, aimedArrivalTs=optional_datetime_to_ts(call.AimedArrivalTime), expectedArrivalTs=optional_datetime_to_ts(call.ExpectedArrivalTime), - arrivalPlatformName=call.ArrivalPlatformName.value if call.ArrivalPlatformName else None, + arrivalPlatformName=call.ArrivalPlatformName.value + if call.ArrivalPlatformName + else None, aimedDepartTs=optional_datetime_to_ts(call.AimedDepartureTime), expectedDepartTs=optional_datetime_to_ts(call.ExpectedDepartureTime), arrivalStatus=call.ArrivalStatus.value, diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 074ede0..fbd10f6 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -11,7 +11,7 @@ python = "^3.11" aiohttp = "^3.8.3" rich = "^12.6.0" aiofiles = "^22.1.0" -sqlalchemy = {extras = ["asyncio"], version = "^1.4.46"} +sqlalchemy = {extras = ["asyncio"], version = "^2.0.1"} fastapi = "^0.88.0" uvicorn = "^0.20.0" asyncpg = "^0.27.0" @@ -28,7 +28,6 @@ rope = "^1.3.0" python-lsp-black = "^1.2.1" black = "^22.10.0" types-aiofiles = "^22.1.0.2" -sqlalchemy-stubs = "^0.4" wrapt = "^1.14.1" pydocstyle = "^6.2.2" dill = "^0.3.6" @@ -38,9 +37,16 @@ autopep8 = "^2.0.1" pyflakes = "^3.0.1" yapf = "^0.32.0" whatthepatch = "^1.0.4" +sqlalchemy = {extras = ["mypy"], version = "^2.0.1"} +mypy = "^1.0.0" + +[tool.mypy] +plugins = "sqlalchemy.ext.mypy.plugin" [tool.black] target-version = ['py311'] [tool.ruff] line-length = 88 +[too.ruff.per-file-ignores] +"__init__.py" = ["E401"]