From 293a1391bcf82c1ed823ee1066df28130773bbab Mon Sep 17 00:00:00 2001 From: Adrien Date: Wed, 12 Apr 2023 23:30:14 +0200 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20Add=20ConnectionArea=20and=20StopSh?= =?UTF-8?q?ape=20models=20+=20Stop-ConnectionArea=20relationship?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/backend/models/__init__.py | 12 +++- backend/backend/models/stop.py | 89 +++++++++++++++++++++++++++++- 2 files changed, 97 insertions(+), 4 deletions(-) diff --git a/backend/backend/models/__init__.py b/backend/backend/models/__init__.py index ef1a352..a7c455a 100644 --- a/backend/backend/models/__init__.py +++ b/backend/backend/models/__init__.py @@ -1,6 +1,14 @@ from .line import Line, LinePicto -from .stop import Stop, StopArea +from .stop import ConnectionArea, Stop, StopArea, StopShape from .user import UserLastStopSearchResults -__all__ = ["Line", "LinePicto", "Stop", "StopArea", "UserLastStopSearchResults"] +__all__ = [ + "ConnectionArea", + "Line", + "LinePicto", + "Stop", + "StopArea", + "StopShape", + "UserLastStopSearchResults", +] diff --git a/backend/backend/models/stop.py b/backend/backend/models/stop.py index bf7daa1..f75e1ea 100644 --- a/backend/backend/models/stop.py +++ b/backend/backend/models/stop.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Iterable, Self, Sequence, TYPE_CHECKING +from typing import Iterable, Sequence, TYPE_CHECKING from sqlalchemy import ( BigInteger, @@ -8,6 +8,8 @@ from sqlalchemy import ( Enum, Float, ForeignKey, + Integer, + JSON, select, String, Table, @@ -48,19 +50,26 @@ class _Stop(Base): postal_region = mapped_column(String, nullable=False) xepsg2154 = mapped_column(BigInteger, nullable=False) yepsg2154 = mapped_column(BigInteger, nullable=False) + version = mapped_column(String, nullable=False) created_ts = mapped_column(BigInteger) changed_ts = mapped_column(BigInteger, nullable=False) + lines: Mapped[list[Line]] = relationship( "Line", secondary="line_stop_association_table", back_populates="stops", - # lazy="joined", lazy="selectin", ) areas: Mapped[list["StopArea"]] = relationship( "StopArea", secondary=stop_area_stop_association_table, back_populates="stops" ) + connection_area_id: Mapped[int] = mapped_column( + ForeignKey("connection_areas.id"), nullable=True + ) + connection_area: Mapped["ConnectionArea"] = relationship( + back_populates="stops", lazy="selectin" + ) __tablename__ = "_stops" __mapper_args__ = {"polymorphic_identity": "_stops", "polymorphic_on": kind} @@ -108,6 +117,7 @@ class Stop(_Stop): accessibility = mapped_column(Enum(IdfmState), nullable=False) visual_signs_available = mapped_column(Enum(IdfmState), nullable=False) audible_signs_available = mapped_column(Enum(IdfmState), nullable=False) + record_id = mapped_column(String, nullable=False) record_ts = mapped_column(BigInteger, nullable=False) @@ -173,3 +183,78 @@ class StopArea(_Stop): await session.commit() return found + + +class StopShape(Base): + + db = db + + id = mapped_column(BigInteger, primary_key=True) # Same id than ConnectionArea + type = mapped_column(Integer, nullable=False) + bounding_box = mapped_column(JSON) + points = mapped_column(JSON) + + __tablename__ = "stop_shapes" + + +class ConnectionArea(Base): + + db = db + + id = mapped_column(BigInteger, primary_key=True) + + name = mapped_column(String, nullable=False) + town_name = mapped_column(String, nullable=False) + postal_region = mapped_column(String, nullable=False) + xepsg2154 = mapped_column(BigInteger, nullable=False) + yepsg2154 = mapped_column(BigInteger, nullable=False) + transport_mode = mapped_column(Enum(StopAreaType), nullable=False) + + version = mapped_column(String, nullable=False) + created_ts = mapped_column(BigInteger) + changed_ts = mapped_column(BigInteger, nullable=False) + + stops: Mapped[list["_Stop"]] = relationship(back_populates="connection_area") + + __tablename__ = "connection_areas" + + # TODO: Merge with StopArea.add_stops + @classmethod + async def add_stops( + cls, conn_area_to_stop_ids: Iterable[tuple[int, int]] + ) -> int | None: + session = cls.db.session + if session is None: + return None + + conn_area_ids, stop_ids = set(), set() + for conn_area_id, stop_id in conn_area_to_stop_ids: + conn_area_ids.add(conn_area_id) + stop_ids.add(stop_id) + + conn_area_res = await session.execute( + select(ConnectionArea) + .where(ConnectionArea.id.in_(conn_area_ids)) + .options(selectinload(ConnectionArea.stops)) + ) + conn_areas: dict[int, ConnectionArea] = { + conn.id: conn for conn in conn_area_res.scalars() + } + + stop_res = await session.execute(select(_Stop).where(_Stop.id.in_(stop_ids))) + stops: dict[int, _Stop] = {stop.id: stop for stop in stop_res.scalars()} + + found = 0 + for conn_area_id, stop_id in conn_area_to_stop_ids: + if (conn_area := conn_areas.get(conn_area_id)) is not None: + if (stop := stops.get(stop_id)) is not None: + conn_area.stops.append(stop) + found += 1 + else: + print(f"No stop found for {stop_id} id") + else: + print(f"No connection area found for {conn_area_id}") + + await session.commit() + + return found