🏷️ Make python linters happy

This commit is contained in:
2023-02-08 22:10:21 +01:00
parent d1db97554c
commit e34355e8be
18 changed files with 400 additions and 290 deletions

View File

View File

@@ -1,4 +1,6 @@
from .db import Database
from .base_class import Base
__all__ = ["Base"]
db = Database()

View File

@@ -1,34 +1,39 @@
from collections.abc import Iterable
from __future__ import annotations
from typing import Iterable, Self, TYPE_CHECKING
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import declarative_base
from typing import Iterable, Self
from sqlalchemy.orm import DeclarativeBase
Base = declarative_base()
Base.db = None
if TYPE_CHECKING:
from .db import Database
async def base_add(cls, stops: Self | Iterable[Self]) -> bool:
class Base(DeclarativeBase):
db: Database | None = None
@classmethod
async def add(cls, stops: Self | Iterable[Self]) -> bool:
try:
method = (
cls.db.session.add_all
if isinstance(stops, Iterable)
else cls.db.session.add
)
method(stops)
await cls.db.session.commit()
except IntegrityError as err:
if isinstance(stops, Iterable):
cls.db.session.add_all(stops) # type: ignore
else:
cls.db.session.add(stops) # type: ignore
await cls.db.session.commit() # type: ignore
except (AttributeError, IntegrityError) as err:
print(err)
return False
return True
Base.add = classmethod(base_add)
async def base_get_by_id(cls, id_: int | str) -> None | Base:
res = await cls.db.session.execute(select(cls).where(cls.id == id_))
@classmethod
async def get_by_id(cls, id_: int | str) -> Self | None:
try:
stmt = select(cls).where(cls.id == id_) # type: ignore
res = await cls.db.session.execute(stmt) # type: ignore
element = res.scalar_one_or_none()
except AttributeError as err:
print(err)
element = None
return element
Base.get_by_id = classmethod(base_get_by_id)

View File

@@ -1,80 +1,48 @@
from asyncio import gather as asyncio_gather
from functools import wraps
from pathlib import Path
from time import time
from typing import Callable, Iterable, Optional
from rich import print
from sqlalchemy import event, select, tuple_
from sqlalchemy.engine import Engine
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import (
selectinload,
sessionmaker,
with_polymorphic,
from sqlalchemy import text
from sqlalchemy.ext.asyncio import (
async_sessionmaker,
AsyncEngine,
AsyncSession,
create_async_engine,
)
from sqlalchemy.orm.attributes import set_committed_value
from .base_class import Base
# import logging
# logging.basicConfig()
# logger = logging.getLogger("bot.sqltime")
# logger.setLevel(logging.DEBUG)
# @event.listens_for(Engine, "before_cursor_execute")
# def before_cursor_execute(conn, cursor, statement, parameters, context, executemany):
# conn.info.setdefault("query_start_time", []).append(time())
# logger.debug("Start Query: %s", statement)
# @event.listens_for(Engine, "after_cursor_execute")
# def after_cursor_execute(conn, cursor, statement, parameters, context, executemany):
# total = time() - conn.info["query_start_time"].pop(-1)
# logger.debug("Query Complete!")
# logger.debug("Total Time: %f", total)
class Database:
def __init__(self) -> None:
self._engine = None
self._session_maker = None
self._session = None
self._engine: AsyncEngine | None = None
self._session_maker: async_sessionmaker[AsyncSession] | None = None
self._session: AsyncSession | None = None
@property
def session(self) -> None:
if self._session is None:
self._session = self._session_maker()
def session(self) -> AsyncSession | None:
if self._session is None and (session_maker := self._session_maker) is not None:
self._session = session_maker()
return self._session
def use_session(func: Callable):
@wraps(func)
async def wrapper(self, *args, **kwargs):
if self._check_session() is not None:
return await func(self, *args, **kwargs)
# TODO: Raise an exception ?
async def connect(self, db_path: str, clear_static_data: bool = False) -> bool:
return wrapper
async def connect(self, db_path: str, clear_static_data: bool = False) -> None:
# TODO: Preserve UserLastStopSearchResults table from drop.
self._engine = create_async_engine(db_path)
self._session_maker = sessionmaker(
if self._engine is not None:
self._session_maker = async_sessionmaker(
self._engine, expire_on_commit=False, class_=AsyncSession
)
await self.session.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm;")
if (session := self.session) is not None:
await session.execute(text("CREATE EXTENSION IF NOT EXISTS pg_trgm;"))
async with self._engine.begin() as conn:
if clear_static_data:
await conn.run_sync(Base.metadata.drop_all)
await conn.run_sync(Base.metadata.create_all)
return True
async def disconnect(self) -> None:
if self._session is not None:
await self._session.close()
self._session = None
if self._engine is not None:
await self._engine.dispose()

View File

@@ -1,2 +1,68 @@
from .idfm_interface import IdfmInterface
from .idfm_types import *
from .idfm_types import (
Coordinate,
FramedVehicleJourney,
IdfmLineState,
IdfmOperator,
IdfmResponse,
IdfmState,
LinePicto,
LineFields,
Line,
MonitoredCall,
MonitoredVehicleJourney,
Point,
Siri,
ServiceDelivery,
Stop,
StopArea,
StopAreaFields,
StopAreaStopAssociation,
StopAreaStopAssociationFields,
StopAreaType,
StopDelivery,
StopFields,
StopLineAsso,
StopLineAssoFields,
StopMonitoringDelivery,
TrainNumber,
TrainStatus,
TransportMode,
TransportSubMode,
Value,
)
__all__ = [
"Coordinate",
"FramedVehicleJourney",
"IdfmInterface",
"IdfmLineState",
"IdfmOperator",
"IdfmResponse",
"IdfmState",
"LinePicto",
"LineFields",
"Line",
"MonitoredCall",
"MonitoredVehicleJourney",
"Point",
"Siri",
"ServiceDelivery",
"Stop",
"StopArea",
"StopAreaFields",
"StopAreaStopAssociation",
"StopAreaStopAssociationFields",
"StopAreaType",
"StopDelivery",
"StopFields",
"StopLineAsso",
"StopLineAssoFields",
"StopMonitoringDelivery",
"TrainNumber",
"TrainStatus",
"TransportMode",
"TransportSubMode",
"Value",
]

View File

@@ -1,7 +1,6 @@
from pathlib import Path
from re import compile as re_compile
from time import time
from typing import ByteString, Iterable, List, Optional
from typing import AsyncIterator, ByteString, Callable, Iterable, List, Type
from aiofiles import open as async_open
from aiohttp import ClientSession
@@ -16,14 +15,14 @@ from .idfm_types import (
IdfmLineState,
IdfmResponse,
Line as IdfmLine,
MonitoredVehicleJourney,
LinePicto as IdfmPicto,
IdfmState,
Stop as IdfmStop,
StopArea as IdfmStopArea,
StopAreaStopAssociation,
StopAreaType,
StopLineAsso as IdfmStopLineAsso,
Stops,
TransportMode,
)
from .ratp_types import Picto as RatpPicto
@@ -40,7 +39,10 @@ class IdfmInterface:
IDFM_PICTO_URL = f"{IDFM_ROOT_URL}/referentiel-des-lignes/files"
RATP_ROOT_URL = "https://data.ratp.fr/explore/dataset"
RATP_PICTO_URL = f"{RATP_ROOT_URL}/pictogrammes-des-lignes-de-metro-rer-tramway-bus-et-noctilien/files"
RATP_PICTO_URL = (
f"{RATP_ROOT_URL}"
"/pictogrammes-des-lignes-de-metro-rer-tramway-bus-et-noctilien/files"
)
OPERATOR_RE = re_compile(r"[^:]+:Operator::([^:]+):")
LINE_RE = re_compile(r"[^:]+:Line::([^:]+):")
@@ -64,7 +66,7 @@ class IdfmInterface:
async def startup(self) -> None:
BATCH_SIZE = 10000
STEPS = (
STEPS: tuple[tuple[Type[Stop] | Type[StopArea], Callable, Callable], ...] = (
(
StopArea,
self._request_idfm_stop_areas,
@@ -132,13 +134,13 @@ class IdfmInterface:
pictos.append(picto)
if len(pictos) == batch_size:
formatted_pictos = IdfmInterface._format_ratp_pictos(*pictos)
await LinePicto.add(formatted_pictos.values())
await LinePicto.add(map(lambda picto: picto[1], formatted_pictos))
await Line.add_pictos(formatted_pictos)
pictos.clear()
if pictos:
formatted_pictos = IdfmInterface._format_ratp_pictos(*pictos)
await LinePicto.add(formatted_pictos.values())
await LinePicto.add(map(lambda picto: picto[1], formatted_pictos))
await Line.add_pictos(formatted_pictos)
async def _load_lines_stops_assos(self, batch_size: int = 5000) -> None:
@@ -174,16 +176,18 @@ class IdfmInterface:
assos.append((int(fields.zdaid), int(fields.arrid)))
if len(assos) == batch_size:
total_assos_nb += batch_size
total_found_nb += await StopArea.add_stops(assos)
if (found_nb := await StopArea.add_stops(assos)) is not None:
total_found_nb += found_nb
assos.clear()
if assos:
total_assos_nb += len(assos)
total_found_nb += await StopArea.add_stops(assos)
if (found_nb := await StopArea.add_stops(assos)) is not None:
total_found_nb += found_nb
print(f"{total_found_nb} stop area <-> stop ({total_assos_nb = } found)")
async def _request_idfm_stops(self):
async def _request_idfm_stops(self) -> AsyncIterator[IdfmStop]:
# headers = {"Accept": "application/json", "apikey": self._api_key}
# async with ClientSession(headers=headers) as session:
# async with session.get(self.STOPS_URL) as response:
@@ -196,19 +200,21 @@ class IdfmInterface:
for element in self._json_stops_decoder.decode(await raw.read()):
yield element
async def _request_idfm_stop_areas(self):
async def _request_idfm_stop_areas(self) -> AsyncIterator[IdfmStopArea]:
# TODO: Use HTTP
async with async_open("./tests/datasets/zones-d-arrets.json", "rb") as raw:
for element in self._json_stop_areas_decoder.decode(await raw.read()):
yield element
async def _request_idfm_lines(self):
async def _request_idfm_lines(self) -> AsyncIterator[IdfmLine]:
# TODO: Use HTTP
async with async_open("./tests/datasets/lines_dataset.json", "rb") as raw:
for element in self._json_lines_decoder.decode(await raw.read()):
yield element
async def _request_idfm_stops_lines_associations(self):
async def _request_idfm_stops_lines_associations(
self,
) -> AsyncIterator[IdfmStopLineAsso]:
# TODO: Use HTTP
async with async_open("./tests/datasets/arrets-lignes.json", "rb") as raw:
for element in self._json_stops_lines_assos_decoder.decode(
@@ -216,7 +222,9 @@ class IdfmInterface:
):
yield element
async def _request_idfm_stop_area_stop_associations(self):
async def _request_idfm_stop_area_stop_associations(
self,
) -> AsyncIterator[StopAreaStopAssociation]:
# TODO: Use HTTP
async with async_open("./tests/datasets/relations.json", "rb") as raw:
for element in self._json_stop_area_stop_asso_decoder.decode(
@@ -224,7 +232,7 @@ class IdfmInterface:
):
yield element
async def _request_ratp_pictos(self):
async def _request_ratp_pictos(self) -> AsyncIterator[RatpPicto]:
# TODO: Use HTTP
async with async_open(
"./tests/datasets/pictogrammes-des-lignes-de-metro-rer-tramway-bus-et-noctilien.json",
@@ -254,12 +262,15 @@ class IdfmInterface:
return ret
@classmethod
def _format_ratp_pictos(cls, *pictos: RatpPicto) -> dict[str, None | LinePicto]:
ret = {}
def _format_ratp_pictos(cls, *pictos: RatpPicto) -> Iterable[tuple[str, LinePicto]]:
ret = []
for picto in pictos:
if (fields := picto.fields.noms_des_fichiers) is not None:
ret[picto.fields.indices_commerciaux] = LinePicto(
ret.append(
(
picto.fields.indices_commerciaux,
LinePicto(
id=fields.id_,
mime_type=f"image/{fields.format.lower()}",
height_px=fields.height,
@@ -268,6 +279,8 @@ class IdfmInterface:
url=f"{cls.RATP_PICTO_URL}/{fields.id_}/download",
thumbnail=fields.thumbnail,
format=fields.format,
),
)
)
return ret
@@ -289,7 +302,7 @@ class IdfmInterface:
short_name=fields.shortname_line,
name=fields.name_line,
status=IdfmLineState(fields.status.value),
transport_mode=fields.transportmode.value,
transport_mode=TransportMode(fields.transportmode.value),
transport_submode=optional_value(fields.transportsubmode),
network_name=optional_value(fields.networkname),
group_of_lines_id=optional_value(fields.id_groupoflines),
@@ -300,9 +313,13 @@ class IdfmInterface:
text_colour_hexa=fields.textcolourprint_hexa,
operator_id=optional_value(fields.operatorref),
operator_name=optional_value(fields.operatorname),
accessibility=fields.accessibility.value,
visual_signs_available=fields.visualsigns_available.value,
audible_signs_available=fields.audiblesigns_available.value,
accessibility=IdfmState(fields.accessibility.value),
visual_signs_available=IdfmState(
fields.visualsigns_available.value
),
audible_signs_available=IdfmState(
fields.audiblesigns_available.value
),
picto_id=fields.picto.id_ if fields.picto is not None else None,
picto=picto,
record_id=line.recordid,
@@ -317,7 +334,7 @@ class IdfmInterface:
for stop in stops:
fields = stop.fields
try:
created_ts = int(fields.arrcreated.timestamp())
created_ts = int(fields.arrcreated.timestamp()) # type: ignore
except AttributeError:
created_ts = None
yield Stop(
@@ -329,13 +346,13 @@ class IdfmInterface:
postal_region=fields.arrpostalregion,
xepsg2154=fields.arrxepsg2154,
yepsg2154=fields.arryepsg2154,
transport_mode=fields.arrtype.value,
transport_mode=TransportMode(fields.arrtype.value),
version=fields.arrversion,
created_ts=created_ts,
changed_ts=int(fields.arrchanged.timestamp()),
accessibility=fields.arraccessibility.value,
visual_signs_available=fields.arrvisualsigns.value,
audible_signs_available=fields.arraudiblesignals.value,
accessibility=IdfmState(fields.arraccessibility.value),
visual_signs_available=IdfmState(fields.arrvisualsigns.value),
audible_signs_available=IdfmState(fields.arraudiblesignals.value),
record_id=stop.recordid,
record_ts=int(stop.record_timestamp.timestamp()),
)
@@ -345,7 +362,7 @@ class IdfmInterface:
for stop_area in stop_areas:
fields = stop_area.fields
try:
created_ts = int(fields.arrcreated.timestamp())
created_ts = int(fields.zdacreated.timestamp()) # type: ignore
except AttributeError:
created_ts = None
yield StopArea(
@@ -355,7 +372,7 @@ class IdfmInterface:
postal_region=fields.zdapostalregion,
xepsg2154=fields.zdaxepsg2154,
yepsg2154=fields.zdayepsg2154,
type=fields.zdatype.value,
type=StopAreaType(fields.zdatype.value),
version=fields.zdaversion,
created_ts=created_ts,
changed_ts=int(fields.zdachanged.timestamp()),
@@ -368,22 +385,22 @@ class IdfmInterface:
picto = line.picto
if picto is not None:
picto_data = await self._get_line_picto(line)
if (picto_data := await self._get_line_picto(line)) is not None:
async with async_open(target, "wb") as fd:
await fd.write(picto_data)
await fd.write(bytes(picto_data))
line_picto_path = target
line_picto_format = picto.mime_type
print(f"render_line_picto: {time() - begin_ts}")
return (line_picto_path, line_picto_format)
async def _get_line_picto(self, line: Line) -> Optional[ByteString]:
async def _get_line_picto(self, line: Line) -> ByteString | None:
print("---------------------------------------------------------------------")
begin_ts = time()
data = None
picto = line.picto
if picto is not None:
if picto is not None and picto.url is not None:
headers = (
self._http_headers if picto.url.startswith(self.IDFM_ROOT_URL) else None
)
@@ -401,31 +418,18 @@ class IdfmInterface:
print("---------------------------------------------------------------------")
return data
async def get_next_passages(self, stop_point_id: str) -> Optional[IdfmResponse]:
# print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
begin_ts = time()
async def get_next_passages(self, stop_point_id: str) -> IdfmResponse | None:
ret = None
params = {"MonitoringRef": f"STIF:StopPoint:Q:{stop_point_id}:"}
session_begin_ts = time()
async with ClientSession(headers=self._http_headers) as session:
session_creation_ts = time()
# print(f"Session creation {session_creation_ts - session_begin_ts}")
async with session.get(self.IDFM_STOP_MON_URL, params=params) as response:
get_end_ts = time()
# print(f"GET {get_end_ts - session_creation_ts}")
if response.status == 200:
get_end_ts = time()
# print(f"GET {get_end_ts - session_creation_ts}")
data = await response.read()
# print(data)
try:
ret = self._response_json_decoder.decode(data)
except ValidationError as err:
print(err)
# print(f"read {time() - get_end_ts}")
# print(f"get_next_passages: {time() - begin_ts}")
# print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
return ret
async def get_destinations(self, stop_point_id: str) -> Iterable[str]:

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
from datetime import datetime
from enum import Enum, StrEnum
from typing import Any, Literal, Optional, NamedTuple
from typing import Any, NamedTuple
from msgspec import Struct
@@ -88,7 +88,7 @@ class Stop(Struct):
Stops = dict[str, Stop]
class StopAreaType(Enum):
class StopAreaType(StrEnum):
metroStation = "metroStation"
onstreetBus = "onstreetBus"
onstreetTram = "onstreetTram"
@@ -101,7 +101,7 @@ class StopAreaFields(Struct, kw_only=True):
zdatown: str
zdaversion: str
zdaid: str
zdacreated: Optional[datetime] = None
zdacreated: datetime | None = None
zdatype: StopAreaType
zdayepsg2154: int
zdapostalregion: str
@@ -118,13 +118,13 @@ class StopArea(Struct):
class StopAreaStopAssociationFields(Struct, kw_only=True):
arrid: str # TODO: use int ?
artid: Optional[str] = None
artid: str | None = None
arrversion: str
zdcid: str
version: int
zdaid: str
zdaversion: str
artversion: Optional[str] = None
artversion: str | None = None
class StopAreaStopAssociation(Struct):
@@ -153,20 +153,20 @@ class LineFields(Struct, kw_only=True):
name_line: str
status: IdfmLineState
accessibility: IdfmState
shortname_groupoflines: Optional[str] = None
shortname_groupoflines: str | None = None
transportmode: TransportMode
colourweb_hexa: str
textcolourprint_hexa: str
transportsubmode: Optional[TransportSubMode] = TransportSubMode.unknown
operatorref: Optional[str] = None
transportsubmode: TransportSubMode | None = TransportSubMode.unknown
operatorref: str | None = None
visualsigns_available: IdfmState
networkname: Optional[str] = None
networkname: str | None = None
id_line: str
id_groupoflines: Optional[str] = None
operatorname: Optional[str] = None
id_groupoflines: str | None = None
operatorname: str | None = None
audiblesigns_available: IdfmState
shortname_line: str
picto: Optional[LinePicto] = None
picto: LinePicto | None = None
class Line(Struct):
@@ -220,17 +220,17 @@ class TrainNumber(Struct):
class MonitoredCall(Struct, kw_only=True):
Order: Optional[int] = None
Order: int | None = None
StopPointName: list[Value]
VehicleAtStop: bool
DestinationDisplay: list[Value]
AimedArrivalTime: Optional[datetime] = None
ExpectedArrivalTime: Optional[datetime] = None
ArrivalPlatformName: Optional[Value] = None
AimedDepartureTime: Optional[datetime] = None
ExpectedDepartureTime: Optional[datetime] = None
ArrivalStatus: TrainStatus = None
DepartureStatus: TrainStatus = None
AimedArrivalTime: datetime | None = None
ExpectedArrivalTime: datetime | None = None
ArrivalPlatformName: Value | None = None
AimedDepartureTime: datetime | None = None
ExpectedDepartureTime: datetime | None = None
ArrivalStatus: TrainStatus | None = None
DepartureStatus: TrainStatus | None = None
class MonitoredVehicleJourney(Struct, kw_only=True):
@@ -240,7 +240,7 @@ class MonitoredVehicleJourney(Struct, kw_only=True):
DestinationRef: Value
DestinationName: list[Value] | None = None
JourneyNote: list[Value] | None = None
TrainNumbers: Optional[TrainNumber] = None
TrainNumbers: TrainNumber | None = None
MonitoredCall: MonitoredCall

View File

@@ -1,3 +1,6 @@
from .line import Line, LinePicto
from .stop import Stop, StopArea
from .user import UserLastStopSearchResults
__all__ = ["Line", "LinePicto", "Stop", "StopArea", "UserLastStopSearchResults"]

View File

@@ -1,6 +1,6 @@
from asyncio import gather as asyncio_gather
from collections import defaultdict
from typing import Iterable, Self
from typing import Iterable, Self, Sequence
from sqlalchemy import (
BigInteger,
@@ -13,8 +13,7 @@ from sqlalchemy import (
String,
Table,
)
from sqlalchemy.orm import Mapped, relationship, selectinload
from sqlalchemy.orm.attributes import set_committed_value
from sqlalchemy.orm import Mapped, mapped_column, relationship, selectinload
from sqlalchemy.sql.expression import tuple_
from ..db import Base, db
@@ -38,14 +37,14 @@ class LinePicto(Base):
db = db
id = Column(String, primary_key=True)
mime_type = Column(String, nullable=False)
height_px = Column(Integer, nullable=False)
width_px = Column(Integer, nullable=False)
filename = Column(String, nullable=False)
url = Column(String, nullable=False)
thumbnail = Column(Boolean, nullable=False)
format = Column(String, nullable=False)
id = mapped_column(String, primary_key=True)
mime_type = mapped_column(String, nullable=False)
height_px = mapped_column(Integer, nullable=False)
width_px = mapped_column(Integer, nullable=False)
filename = mapped_column(String, nullable=False)
url = mapped_column(String, nullable=False)
thumbnail = mapped_column(Boolean, nullable=False)
format = mapped_column(String, nullable=False)
__tablename__ = "line_pictos"
@@ -54,35 +53,35 @@ class Line(Base):
db = db
id = Column(String, primary_key=True)
id = mapped_column(String, primary_key=True)
short_name = Column(String)
name = Column(String, nullable=False)
status = Column(Enum(IdfmLineState), nullable=False)
transport_mode = Column(Enum(TransportMode), nullable=False)
transport_submode = Column(Enum(TransportSubMode), nullable=False)
short_name = mapped_column(String)
name = mapped_column(String, nullable=False)
status = mapped_column(Enum(IdfmLineState), nullable=False)
transport_mode = mapped_column(Enum(TransportMode), nullable=False)
transport_submode = mapped_column(Enum(TransportSubMode), nullable=False)
network_name = Column(String)
group_of_lines_id = Column(String)
group_of_lines_shortname = Column(String)
network_name = mapped_column(String)
group_of_lines_id = mapped_column(String)
group_of_lines_shortname = mapped_column(String)
colour_web_hexa = Column(String, nullable=False)
text_colour_hexa = Column(String, nullable=False)
colour_web_hexa = mapped_column(String, nullable=False)
text_colour_hexa = mapped_column(String, nullable=False)
operator_id = Column(String)
operator_name = Column(String)
operator_id = mapped_column(String)
operator_name = mapped_column(String)
accessibility = Column(Enum(IdfmState), nullable=False)
visual_signs_available = Column(Enum(IdfmState), nullable=False)
audible_signs_available = Column(Enum(IdfmState), nullable=False)
accessibility = mapped_column(Enum(IdfmState), nullable=False)
visual_signs_available = mapped_column(Enum(IdfmState), nullable=False)
audible_signs_available = mapped_column(Enum(IdfmState), nullable=False)
picto_id = Column(String, ForeignKey("line_pictos.id"))
picto_id = mapped_column(String, ForeignKey("line_pictos.id"))
picto: Mapped[LinePicto] = relationship(LinePicto, lazy="selectin")
record_id = Column(String, nullable=False)
record_ts = Column(BigInteger, nullable=False)
record_id = mapped_column(String, nullable=False)
record_ts = mapped_column(BigInteger, nullable=False)
stops: Mapped[list["_Stop"]] = relationship(
stops: Mapped[list[_Stop]] = relationship(
"_Stop",
secondary=line_stop_association_table,
back_populates="lines",
@@ -94,67 +93,81 @@ class Line(Base):
@classmethod
async def get_by_name(
cls, name: str, operator_name: None | str = None
) -> list[Self]:
) -> Sequence[Self] | None:
session = cls.db.session
if session is None:
return None
filters = {"name": name}
if operator_name is not None:
filters["operator_name"] = operator_name
lines = None
stmt = (
select(Line)
select(cls)
.filter_by(**filters)
.options(selectinload(Line.stops), selectinload(Line.picto))
.options(selectinload(cls.stops), selectinload(cls.picto))
)
res = await cls.db.session.execute(stmt)
res = await session.execute(stmt)
lines = res.scalars().all()
return lines
@classmethod
async def _add_picto_to_line(cls, line: str | Self, picto: LinePicto) -> None:
formatted_line: Self | None = None
if isinstance(line, str):
if (lines := await cls.get_by_name(line)) is not None:
if len(lines) == 1:
line = lines[0]
formatted_line = lines[0]
else:
for candidate_line in lines:
if candidate_line.operator_name == "RATP":
line = candidate_line
formatted_line = candidate_line
break
else:
formatted_line = line
if isinstance(line, Line) and line.picto is None:
line.picto = picto
line.picto_id = picto.id
if isinstance(formatted_line, Line) and formatted_line.picto is None:
formatted_line.picto = picto
formatted_line.picto_id = picto.id
@classmethod
async def add_pictos(cls, line_to_pictos: dict[str | Self, LinePicto]) -> None:
async def add_pictos(cls, line_to_pictos: Iterable[tuple[str, LinePicto]]) -> bool:
session = cls.db.session
if session is None:
return False
await asyncio_gather(
*[
cls._add_picto_to_line(line, picto)
for line, picto in line_to_pictos.items()
]
*[cls._add_picto_to_line(line, picto) for line, picto in line_to_pictos]
)
await cls.db.session.commit()
await session.commit()
return True
@classmethod
async def add_stops(cls, line_to_stop_ids: Iterable[tuple[str, str, str]]) -> int:
async def add_stops(cls, line_to_stop_ids: Iterable[tuple[str, str, int]]) -> int:
session = cls.db.session
if session is None:
return 0
line_names_ops, stop_ids = set(), set()
for line_name, operator_name, stop_id in line_to_stop_ids:
line_names_ops.add((line_name, operator_name))
stop_ids.add(stop_id)
res = await cls.db.session.execute(
lines_res = await session.execute(
select(Line).where(
tuple_(Line.name, Line.operator_name).in_(line_names_ops)
)
)
lines = defaultdict(list)
for line in res.scalars():
for line in lines_res.scalars():
lines[(line.name, line.operator_name)].append(line)
res = await cls.db.session.execute(select(_Stop).where(_Stop.id.in_(stop_ids)))
stops = {stop.id: stop for stop in res.scalars()}
stops_res = await session.execute(select(_Stop).where(_Stop.id.in_(stop_ids)))
stops = {stop.id: stop for stop in stops_res.scalars()}
found = 0
for line_name, operator_name, stop_id in line_to_stop_ids:
@@ -167,8 +180,10 @@ class Line(Base):
print(f"No line found for {line_name}/{operator_name}")
else:
print(
f"No stop found for {stop_id} id (used by {line_name}/{operator_name})"
f"No stop found for {stop_id} id"
f"(used by {line_name}/{operator_name})"
)
await cls.db.session.commit()
await session.commit()
return found

View File

@@ -1,4 +1,6 @@
from typing import Iterable, Self
from __future__ import annotations
from typing import Iterable, Self, Sequence, TYPE_CHECKING
from sqlalchemy import (
BigInteger,
@@ -10,12 +12,22 @@ from sqlalchemy import (
String,
Table,
)
from sqlalchemy.orm import Mapped, relationship, selectinload, with_polymorphic
from sqlalchemy.orm import (
mapped_column,
Mapped,
relationship,
selectinload,
with_polymorphic,
)
from sqlalchemy.schema import Index
from ..db import Base, db
from ..idfm_interface.idfm_types import TransportMode, IdfmState, StopAreaType
if TYPE_CHECKING:
from .line import Line
stop_area_stop_association_table = Table(
"stop_area_stop_association_table",
Base.metadata,
@@ -28,18 +40,18 @@ class _Stop(Base):
db = db
id = Column(BigInteger, primary_key=True)
kind = Column(String)
id = mapped_column(BigInteger, primary_key=True)
kind = mapped_column(String)
name = Column(String, nullable=False, index=True)
town_name = Column(String, nullable=False)
postal_region = Column(String, nullable=False)
xepsg2154 = Column(BigInteger, nullable=False)
yepsg2154 = Column(BigInteger, nullable=False)
version = Column(String, nullable=False)
created_ts = Column(BigInteger)
changed_ts = Column(BigInteger, nullable=False)
lines: Mapped[list["Line"]] = relationship(
name = mapped_column(String, nullable=False, index=True)
town_name = mapped_column(String, nullable=False)
postal_region = mapped_column(String, nullable=False)
xepsg2154 = mapped_column(BigInteger, nullable=False)
yepsg2154 = mapped_column(BigInteger, nullable=False)
version = mapped_column(String, nullable=False)
created_ts = mapped_column(BigInteger)
changed_ts = mapped_column(BigInteger, nullable=False)
lines: Mapped[list[Line]] = relationship(
"Line",
secondary="line_stop_association_table",
back_populates="stops",
@@ -65,7 +77,11 @@ class _Stop(Base):
# TODO: Test https://www.cybertec-postgresql.com/en/postgresql-more-performance-for-like-and-ilike-statements/
# TODO: Should be able to remove with_polymorphic ?
@classmethod
async def get_by_name(cls, name: str) -> list[Self]:
async def get_by_name(cls, name: str) -> Sequence[type[_Stop]] | None:
session = cls.db.session
if session is None:
return None
stop_stop_area = with_polymorphic(_Stop, [Stop, StopArea])
stmt = (
select(stop_stop_area)
@@ -75,22 +91,25 @@ class _Stop(Base):
selectinload(stop_stop_area.lines),
)
)
res = await cls.db.session.execute(stmt)
return res.scalars()
res = await session.execute(stmt)
stops = res.scalars().all()
return stops
class Stop(_Stop):
id = Column(BigInteger, ForeignKey("_stops.id"), primary_key=True)
id = mapped_column(BigInteger, ForeignKey("_stops.id"), primary_key=True)
latitude = Column(Float, nullable=False)
longitude = Column(Float, nullable=False)
transport_mode = Column(Enum(TransportMode), nullable=False)
accessibility = Column(Enum(IdfmState), nullable=False)
visual_signs_available = Column(Enum(IdfmState), nullable=False)
audible_signs_available = Column(Enum(IdfmState), nullable=False)
record_id = Column(String, nullable=False)
record_ts = Column(BigInteger, nullable=False)
latitude = mapped_column(Float, nullable=False)
longitude = mapped_column(Float, nullable=False)
transport_mode = mapped_column(Enum(TransportMode), nullable=False)
accessibility = mapped_column(Enum(IdfmState), nullable=False)
visual_signs_available = mapped_column(Enum(IdfmState), nullable=False)
audible_signs_available = mapped_column(Enum(IdfmState), nullable=False)
record_id = mapped_column(String, nullable=False)
record_ts = mapped_column(BigInteger, nullable=False)
__tablename__ = "stops"
__mapper_args__ = {"polymorphic_identity": "stops", "polymorphic_load": "inline"}
@@ -98,11 +117,11 @@ class Stop(_Stop):
class StopArea(_Stop):
id = Column(BigInteger, ForeignKey("_stops.id"), primary_key=True)
id = mapped_column(BigInteger, ForeignKey("_stops.id"), primary_key=True)
type = Column(Enum(StopAreaType), nullable=False)
stops: Mapped[list[_Stop]] = relationship(
_Stop,
type = mapped_column(Enum(StopAreaType), nullable=False)
stops: Mapped[list["_Stop"]] = relationship(
"_Stop",
secondary=stop_area_stop_association_table,
back_populates="areas",
lazy="selectin",
@@ -110,24 +129,35 @@ class StopArea(_Stop):
)
__tablename__ = "stop_areas"
__mapper_args__ = {"polymorphic_identity": "stop_areas", "polymorphic_load": "inline"}
__mapper_args__ = {
"polymorphic_identity": "stop_areas",
"polymorphic_load": "inline",
}
@classmethod
async def add_stops(cls, stop_area_to_stop_ids: Iterable[tuple[str, str]]) -> int:
async def add_stops(
cls, stop_area_to_stop_ids: Iterable[tuple[int, int]]
) -> int | None:
session = cls.db.session
if session is None:
return None
stop_area_ids, stop_ids = set(), set()
for stop_area_id, stop_id in stop_area_to_stop_ids:
stop_area_ids.add(stop_area_id)
stop_ids.add(stop_id)
res = await cls.db.session.execute(
stop_areas_res = await session.execute(
select(StopArea)
.where(StopArea.id.in_(stop_area_ids))
.options(selectinload(StopArea.stops))
)
stop_areas = {stop_area.id: stop_area for stop_area in res.scalars()}
stop_areas: dict[int, StopArea] = {
stop_area.id: stop_area for stop_area in stop_areas_res.scalars()
}
res = await cls.db.session.execute(select(_Stop).where(_Stop.id.in_(stop_ids)))
stops = {stop.id: stop for stop in res.scalars()}
stop_res = await session.execute(select(_Stop).where(_Stop.id.in_(stop_ids)))
stops: dict[int, _Stop] = {stop.id: stop for stop in stop_res.scalars()}
found = 0
for stop_area_id, stop_id in stop_area_to_stop_ids:
@@ -140,5 +170,6 @@ class StopArea(_Stop):
else:
print(f"No stop area found for {stop_area_id}")
await cls.db.session.commit()
await session.commit()
return found

View File

@@ -1,5 +1,5 @@
from sqlalchemy import Column, ForeignKey, String, Table
from sqlalchemy.orm import Mapped, relationship
from sqlalchemy.orm import Mapped, mapped_column, relationship
from ..db import Base, db
from .stop import _Stop
@@ -18,8 +18,8 @@ class UserLastStopSearchResults(Base):
__tablename__ = "user_last_stop_search_results"
user_mxid = Column(String, primary_key=True)
request_content = Column(String, nullable=False)
stops: Mapped[list[_Stop]] = relationship(
user_mxid = mapped_column(String, primary_key=True)
request_content = mapped_column(String, nullable=False)
stops: Mapped[_Stop] = relationship(
_Stop, secondary=user_last_stop_search_stops_associations_table
)

0
backend/backend/py.typed Normal file
View File

View File

@@ -1,3 +1,5 @@
from .line import Line, TransportMode
from .next_passage import NextPassage, NextPassages
from .stop import Stop, StopArea
__all__ = ["Line", "NextPassage", "NextPassages", "Stop", "StopArea", "TransportMode"]

View File

@@ -1,5 +1,4 @@
from enum import StrEnum
from typing import Self
from pydantic import BaseModel
@@ -29,10 +28,11 @@ class TransportMode(StrEnum):
# idfm_types.TransportMode.rail + idfm_types.TransportSubMode.railShuttle
val = "val"
# Self return type replaced by "TransportMode" to fix following mypy error:
# Incompatible return value type (got "TransportMode", expected "Self")
# TODO: Is it the good fix ?
@classmethod
def from_idfm_transport_mode(
cls, mode: IdfmTransportMode, sub_mode: IdfmTransportSubMode
) -> Self:
def from_idfm_transport_mode(cls, mode: str, sub_mode: str) -> "TransportMode":
if mode == IdfmTransportMode.rail:
if sub_mode == IdfmTransportSubMode.regionalRail:
return cls.rail_ter
@@ -42,7 +42,7 @@ class TransportMode(StrEnum):
return cls.rail_transilien
if sub_mode == IdfmTransportSubMode.railShuttle:
return cls.val
return TransportMode(mode)
return cls(mode)
class Line(BaseModel):

View File

@@ -8,11 +8,11 @@ class NextPassage(BaseModel):
operator: str
destinations: list[str]
atStop: bool
aimedArrivalTs: None | int
expectedArrivalTs: None | int
arrivalPlatformName: None | str
aimedDepartTs: None | int
expectedDepartTs: None | int
aimedArrivalTs: int | None
expectedArrivalTs: int | None
arrivalPlatformName: str | None
aimedDepartTs: int | None
expectedDepartTs: int | None
arrivalStatus: TrainStatus
departStatus: TrainStatus

View File

@@ -1,6 +1,6 @@
from pydantic import BaseModel
from ..idfm_interface import IdfmLineState, IdfmState, StopAreaType, TransportMode
from ..idfm_interface import StopAreaType
class Stop(BaseModel):

View File

@@ -1,10 +1,10 @@
from collections import defaultdict
from datetime import datetime
from os import environ
from os import environ, EX_USAGE
from typing import Sequence
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from fastapi.staticfiles import StaticFiles
from rich import print
@@ -21,7 +21,9 @@ from backend.schemas import (
)
API_KEY = environ.get("API_KEY")
# TODO: Add error message if no key is given.
if API_KEY is None:
print('No "API_KEY" environment variable set... abort.')
exit(EX_USAGE)
# TODO: Remove postgresql+asyncpg from environ variable
DB_PATH = "postgresql+asyncpg://cer_user:cer_password@127.0.0.1:5438/cer_db"
@@ -44,9 +46,9 @@ idfm_interface = IdfmInterface(API_KEY, db)
@app.on_event("startup")
async def startup():
# await db.connect(DB_PATH, clear_static_data=True)
# await idfm_interface.startup()
await db.connect(DB_PATH, clear_static_data=False)
await db.connect(DB_PATH, clear_static_data=True)
await idfm_interface.startup()
# await db.connect(DB_PATH, clear_static_data=False)
print("Connected")
@@ -61,12 +63,12 @@ STATIC_ROOT = "../frontend/"
app.mount("/widget", StaticFiles(directory=STATIC_ROOT, html=True), name="widget")
def optional_datetime_to_ts(dt: datetime) -> int | None:
return dt.timestamp() if dt else None
def optional_datetime_to_ts(dt: datetime | None) -> int | None:
return int(dt.timestamp()) if dt else None
@app.get("/line/{line_id}", response_model=LineSchema)
async def get_line(line_id: str) -> JSONResponse:
async def get_line(line_id: str) -> LineSchema:
line: Line | None = await Line.get_by_id(line_id)
if line is None:
@@ -91,7 +93,7 @@ async def get_line(line_id: str) -> JSONResponse:
def _format_stop(stop: Stop) -> StopSchema:
print(stop.__dict__)
# print(stop.__dict__)
return StopSchema(
id=stop.id,
name=stop.name,
@@ -103,15 +105,17 @@ def _format_stop(stop: Stop) -> StopSchema:
lines=[line.id for line in stop.lines],
)
# châtelet
@app.get("/stop/")
async def get_stop(
name: str = "", limit: int = 10
) -> list[StopAreaSchema | StopSchema]:
) -> Sequence[StopAreaSchema | StopSchema]:
# TODO: Add limit support
formatted = []
formatted: list[StopAreaSchema | StopSchema] = []
matching_stops = await Stop.get_by_name(name)
# print(matching_stops, flush=True)
@@ -153,15 +157,17 @@ async def get_stop(
# TODO: Cache response for 30 secs ?
@app.get("/stop/nextPassages/{stop_id}")
async def get_next_passages(stop_id: str) -> JSONResponse:
async def get_next_passages(stop_id: str) -> NextPassagesSchema | None:
res = await idfm_interface.get_next_passages(stop_id)
# print(res)
if res is None:
return None
service_delivery = res.Siri.ServiceDelivery
stop_monitoring_deliveries = service_delivery.StopMonitoringDelivery
by_line_by_dst_passages = defaultdict(lambda: defaultdict(list))
by_line_by_dst_passages: dict[
str, dict[str, list[NextPassageSchema]]
] = defaultdict(lambda: defaultdict(list))
for delivery in stop_monitoring_deliveries:
for stop_visit in delivery.MonitoredStopVisit:
@@ -190,7 +196,9 @@ async def get_next_passages(stop_id: str) -> JSONResponse:
atStop=call.VehicleAtStop,
aimedArrivalTs=optional_datetime_to_ts(call.AimedArrivalTime),
expectedArrivalTs=optional_datetime_to_ts(call.ExpectedArrivalTime),
arrivalPlatformName=call.ArrivalPlatformName.value if call.ArrivalPlatformName else None,
arrivalPlatformName=call.ArrivalPlatformName.value
if call.ArrivalPlatformName
else None,
aimedDepartTs=optional_datetime_to_ts(call.AimedDepartureTime),
expectedDepartTs=optional_datetime_to_ts(call.ExpectedDepartureTime),
arrivalStatus=call.ArrivalStatus.value,

View File

@@ -11,7 +11,7 @@ python = "^3.11"
aiohttp = "^3.8.3"
rich = "^12.6.0"
aiofiles = "^22.1.0"
sqlalchemy = {extras = ["asyncio"], version = "^1.4.46"}
sqlalchemy = {extras = ["asyncio"], version = "^2.0.1"}
fastapi = "^0.88.0"
uvicorn = "^0.20.0"
asyncpg = "^0.27.0"
@@ -28,7 +28,6 @@ rope = "^1.3.0"
python-lsp-black = "^1.2.1"
black = "^22.10.0"
types-aiofiles = "^22.1.0.2"
sqlalchemy-stubs = "^0.4"
wrapt = "^1.14.1"
pydocstyle = "^6.2.2"
dill = "^0.3.6"
@@ -38,9 +37,16 @@ autopep8 = "^2.0.1"
pyflakes = "^3.0.1"
yapf = "^0.32.0"
whatthepatch = "^1.0.4"
sqlalchemy = {extras = ["mypy"], version = "^2.0.1"}
mypy = "^1.0.0"
[tool.mypy]
plugins = "sqlalchemy.ext.mypy.plugin"
[tool.black]
target-version = ['py311']
[tool.ruff]
line-length = 88
[too.ruff.per-file-ignores]
"__init__.py" = ["E401"]