from __future__ import annotations from logging import getLogger from typing import Iterable, Sequence, TYPE_CHECKING from sqlalchemy import ( BigInteger, Computed, desc, Enum, Float, ForeignKey, func, Integer, JSON, select, String, ) from sqlalchemy.orm import ( mapped_column, Mapped, relationship, selectinload, with_polymorphic, ) from sqlalchemy.schema import Index from sqlalchemy_utils.types.ts_vector import TSVectorType from db import Base, db from idfm_interface.idfm_types import TransportMode, IdfmState, StopAreaType if TYPE_CHECKING: from .line import Line logger = getLogger(__name__) class StopAreaStopAssociations(Base): id = mapped_column(BigInteger, primary_key=True) stop_id = mapped_column(BigInteger, ForeignKey("_stops.id")) stop_area_id = mapped_column(BigInteger, ForeignKey("stop_areas.id")) __tablename__ = "stop_area_stop_associations" class _Stop(Base): db = db id = mapped_column(BigInteger, primary_key=True) kind = mapped_column(String) name = mapped_column(String, nullable=False, index=True) town_name = mapped_column(String, nullable=False) postal_region = mapped_column(Integer, nullable=False) epsg3857_x = mapped_column(Float, nullable=False) epsg3857_y = mapped_column(Float, nullable=False) version = mapped_column(String, nullable=False) created_ts = mapped_column(BigInteger) changed_ts = mapped_column(BigInteger, nullable=False) lines: Mapped[list[Line]] = relationship( "Line", secondary="line_stop_associations", back_populates="stops", lazy="selectin", ) areas: Mapped[list["StopArea"]] = relationship( "StopArea", secondary="stop_area_stop_associations", back_populates="stops", ) connection_area_id: Mapped[int] = mapped_column( ForeignKey("connection_areas.id"), nullable=True ) connection_area: Mapped["ConnectionArea"] = relationship( back_populates="stops", lazy="selectin" ) names_tsv = mapped_column( TSVectorType("name", "town_name", regconfig="french"), Computed("to_tsvector('french', name || ' ' || town_name)", persisted=True), ) __tablename__ = "_stops" __mapper_args__ = {"polymorphic_identity": "_stops", "polymorphic_on": kind} __table_args__ = ( Index( "names_tsv_idx", names_tsv, postgresql_ops={"name": "gin_trgm_ops"}, postgresql_using="gin", ), ) @classmethod async def get_by_name(cls, name: str) -> Sequence[_Stop] | None: if (session := await cls.db.get_session()) is not None: async with session.begin(): descendants = with_polymorphic(_Stop, "*") match_stmt = descendants.names_tsv.match( name, postgresql_regconfig="french" ) ranking_stmt = func.ts_rank_cd( descendants.names_tsv, func.plainto_tsquery("french", name) ) stmt = ( select(descendants).filter(match_stmt).order_by(desc(ranking_stmt)) ) res = await session.execute(stmt) stops = res.scalars().all() return stops return None class Stop(_Stop): id = mapped_column(BigInteger, ForeignKey("_stops.id"), primary_key=True) transport_mode = mapped_column(Enum(TransportMode), nullable=False) accessibility = mapped_column(Enum(IdfmState), nullable=False) visual_signs_available = mapped_column(Enum(IdfmState), nullable=False) audible_signs_available = mapped_column(Enum(IdfmState), nullable=False) record_id = mapped_column(String, nullable=False) record_ts = mapped_column(BigInteger, nullable=False) __tablename__ = "stops" __mapper_args__ = {"polymorphic_identity": "stops", "polymorphic_load": "inline"} class StopArea(_Stop): id = mapped_column(BigInteger, ForeignKey("_stops.id"), primary_key=True) type = mapped_column(Enum(StopAreaType), nullable=False) stops: Mapped[list["Stop"]] = relationship( "Stop", secondary="stop_area_stop_associations", back_populates="areas", lazy="selectin", ) __tablename__ = "stop_areas" __mapper_args__ = { "polymorphic_identity": "stop_areas", "polymorphic_load": "inline", } @classmethod async def add_stops( cls, stop_area_to_stop_ids: Iterable[tuple[int, int]] ) -> int | None: if (session := await cls.db.get_session()) is not None: async with session.begin(): stop_area_ids, stop_ids = set(), set() for stop_area_id, stop_id in stop_area_to_stop_ids: stop_area_ids.add(stop_area_id) stop_ids.add(stop_id) stop_areas_res = await session.scalars( select(StopArea) .where(StopArea.id.in_(stop_area_ids)) .options(selectinload(StopArea.stops)) ) stop_areas: dict[int, StopArea] = { stop_area.id: stop_area for stop_area in stop_areas_res.all() } stop_res = await session.execute( select(Stop).where(Stop.id.in_(stop_ids)) ) stops: dict[int, Stop] = {stop.id: stop for stop in stop_res.scalars()} found = 0 for stop_area_id, stop_id in stop_area_to_stop_ids: if (stop_area := stop_areas.get(stop_area_id)) is not None: if (stop := stops.get(stop_id)) is not None: stop_area.stops.append(stop) found += 1 else: print(f"No stop found for {stop_id} id") else: print(f"No stop area found for {stop_area_id}") return found return None class StopShape(Base): db = db id = mapped_column(BigInteger, primary_key=True) # Same id than ConnectionArea type = mapped_column(Integer, nullable=False) epsg3857_bbox = mapped_column(JSON) epsg3857_points = mapped_column(JSON) __tablename__ = "stop_shapes" class ConnectionArea(Base): db = db id = mapped_column(BigInteger, primary_key=True) name = mapped_column(String, nullable=False) town_name = mapped_column(String, nullable=False) postal_region = mapped_column(String, nullable=False) epsg3857_x = mapped_column(Float, nullable=False) epsg3857_y = mapped_column(Float, nullable=False) transport_mode = mapped_column(Enum(StopAreaType), nullable=False) version = mapped_column(String, nullable=False) created_ts = mapped_column(BigInteger) changed_ts = mapped_column(BigInteger, nullable=False) stops: Mapped[list["_Stop"]] = relationship(back_populates="connection_area") __tablename__ = "connection_areas" # TODO: Merge with StopArea.add_stops @classmethod async def add_stops( cls, conn_area_to_stop_ids: Iterable[tuple[int, int]] ) -> int | None: if (session := await cls.db.get_session()) is not None: async with session.begin(): conn_area_ids, stop_ids = set(), set() for conn_area_id, stop_id in conn_area_to_stop_ids: conn_area_ids.add(conn_area_id) stop_ids.add(stop_id) conn_area_res = await session.execute( select(ConnectionArea) .where(ConnectionArea.id.in_(conn_area_ids)) .options(selectinload(ConnectionArea.stops)) ) conn_areas: dict[int, ConnectionArea] = { conn.id: conn for conn in conn_area_res.scalars() } stop_res = await session.execute( select(Stop).where(Stop.id.in_(stop_ids)) ) stops: dict[int, Stop] = {stop.id: stop for stop in stop_res.scalars()} found = 0 for conn_area_id, stop_id in conn_area_to_stop_ids: if (conn_area := conn_areas.get(conn_area_id)) is not None: if (stop := stops.get(stop_id)) is not None: conn_area.stops.append(stop) found += 1 else: print(f"No stop found for {stop_id} id") else: print(f"No connection area found for {conn_area_id}") return found return None