296 lines
9.0 KiB
Python
296 lines
9.0 KiB
Python
from __future__ import annotations
|
|
|
|
from logging import getLogger
|
|
from typing import Iterable, Sequence, TYPE_CHECKING
|
|
|
|
from sqlalchemy import (
|
|
BigInteger,
|
|
Computed,
|
|
desc,
|
|
Enum,
|
|
Float,
|
|
ForeignKey,
|
|
func,
|
|
Integer,
|
|
JSON,
|
|
select,
|
|
String,
|
|
)
|
|
from sqlalchemy.orm import (
|
|
mapped_column,
|
|
Mapped,
|
|
relationship,
|
|
selectinload,
|
|
with_polymorphic,
|
|
)
|
|
from sqlalchemy.schema import Index
|
|
from sqlalchemy_utils.types.ts_vector import TSVectorType
|
|
|
|
from ..db import Base, db
|
|
from ..idfm_interface.idfm_types import TransportMode, IdfmState, StopAreaType
|
|
|
|
if TYPE_CHECKING:
|
|
from .line import Line
|
|
|
|
|
|
logger = getLogger(__name__)
|
|
|
|
# import cProfile
|
|
# import io
|
|
# import pstats
|
|
# import contextlib
|
|
|
|
|
|
# @contextlib.contextmanager
|
|
# def profiled():
|
|
# pr = cProfile.Profile()
|
|
# pr.enable()
|
|
# yield
|
|
# pr.disable()
|
|
# s = io.StringIO()
|
|
# ps = pstats.Stats(pr, stream=s).sort_stats("cumulative")
|
|
# ps.print_stats()
|
|
# # uncomment this to see who's calling what
|
|
# # ps.print_callers()
|
|
# print(s.getvalue())
|
|
|
|
|
|
class StopAreaStopAssociations(Base):
|
|
|
|
id = mapped_column(BigInteger, primary_key=True)
|
|
stop_id = mapped_column(BigInteger, ForeignKey("_stops.id"))
|
|
stop_area_id = mapped_column(BigInteger, ForeignKey("stop_areas.id"))
|
|
|
|
__tablename__ = "stop_area_stop_associations"
|
|
|
|
|
|
class _Stop(Base):
|
|
|
|
db = db
|
|
|
|
id = mapped_column(BigInteger, primary_key=True)
|
|
kind = mapped_column(String)
|
|
|
|
name = mapped_column(String, nullable=False, index=True)
|
|
town_name = mapped_column(String, nullable=False)
|
|
postal_region = mapped_column(Integer, nullable=False)
|
|
epsg3857_x = mapped_column(Float, nullable=False)
|
|
epsg3857_y = mapped_column(Float, nullable=False)
|
|
|
|
version = mapped_column(String, nullable=False)
|
|
created_ts = mapped_column(BigInteger)
|
|
changed_ts = mapped_column(BigInteger, nullable=False)
|
|
|
|
lines: Mapped[list[Line]] = relationship(
|
|
"Line",
|
|
secondary="line_stop_associations",
|
|
back_populates="stops",
|
|
lazy="selectin",
|
|
)
|
|
areas: Mapped[list["StopArea"]] = relationship(
|
|
"StopArea",
|
|
secondary="stop_area_stop_associations",
|
|
back_populates="stops",
|
|
)
|
|
connection_area_id: Mapped[int] = mapped_column(
|
|
ForeignKey("connection_areas.id"), nullable=True
|
|
)
|
|
connection_area: Mapped["ConnectionArea"] = relationship(
|
|
back_populates="stops", lazy="selectin"
|
|
)
|
|
|
|
names_tsv = mapped_column(
|
|
TSVectorType("name", "town_name", regconfig="french"),
|
|
Computed("to_tsvector('french', name || ' ' || town_name)", persisted=True),
|
|
)
|
|
|
|
__tablename__ = "_stops"
|
|
__mapper_args__ = {"polymorphic_identity": "_stops", "polymorphic_on": kind}
|
|
__table_args__ = (
|
|
Index(
|
|
"names_tsv_idx",
|
|
names_tsv,
|
|
postgresql_ops={"name": "gin_trgm_ops"},
|
|
postgresql_using="gin",
|
|
),
|
|
)
|
|
|
|
@classmethod
|
|
async def get_by_name(cls, name: str) -> Sequence[_Stop] | None:
|
|
if (session := await cls.db.get_session()) is not None:
|
|
|
|
async with session.begin():
|
|
descendants = with_polymorphic(_Stop, "*")
|
|
|
|
match_stmt = descendants.names_tsv.match(
|
|
name, postgresql_regconfig="french"
|
|
)
|
|
ranking_stmt = func.ts_rank_cd(
|
|
descendants.names_tsv, func.plainto_tsquery("french", name)
|
|
)
|
|
stmt = (
|
|
select(descendants).filter(match_stmt).order_by(desc(ranking_stmt))
|
|
)
|
|
|
|
res = await session.execute(stmt)
|
|
stops = res.scalars().all()
|
|
|
|
return stops
|
|
|
|
return None
|
|
|
|
|
|
class Stop(_Stop):
|
|
|
|
id = mapped_column(BigInteger, ForeignKey("_stops.id"), primary_key=True)
|
|
|
|
transport_mode = mapped_column(Enum(TransportMode), nullable=False)
|
|
accessibility = mapped_column(Enum(IdfmState), nullable=False)
|
|
visual_signs_available = mapped_column(Enum(IdfmState), nullable=False)
|
|
audible_signs_available = mapped_column(Enum(IdfmState), nullable=False)
|
|
|
|
record_id = mapped_column(String, nullable=False)
|
|
record_ts = mapped_column(BigInteger, nullable=False)
|
|
|
|
__tablename__ = "stops"
|
|
__mapper_args__ = {"polymorphic_identity": "stops", "polymorphic_load": "inline"}
|
|
|
|
|
|
class StopArea(_Stop):
|
|
|
|
id = mapped_column(BigInteger, ForeignKey("_stops.id"), primary_key=True)
|
|
|
|
type = mapped_column(Enum(StopAreaType), nullable=False)
|
|
|
|
stops: Mapped[list["Stop"]] = relationship(
|
|
"Stop",
|
|
secondary="stop_area_stop_associations",
|
|
back_populates="areas",
|
|
lazy="selectin",
|
|
)
|
|
|
|
__tablename__ = "stop_areas"
|
|
__mapper_args__ = {
|
|
"polymorphic_identity": "stop_areas",
|
|
"polymorphic_load": "inline",
|
|
}
|
|
|
|
@classmethod
|
|
async def add_stops(
|
|
cls, stop_area_to_stop_ids: Iterable[tuple[int, int]]
|
|
) -> int | None:
|
|
if (session := await cls.db.get_session()) is not None:
|
|
|
|
async with session.begin():
|
|
|
|
stop_area_ids, stop_ids = set(), set()
|
|
for stop_area_id, stop_id in stop_area_to_stop_ids:
|
|
stop_area_ids.add(stop_area_id)
|
|
stop_ids.add(stop_id)
|
|
|
|
stop_areas_res = await session.scalars(
|
|
select(StopArea)
|
|
.where(StopArea.id.in_(stop_area_ids))
|
|
.options(selectinload(StopArea.stops))
|
|
)
|
|
stop_areas: dict[int, StopArea] = {
|
|
stop_area.id: stop_area for stop_area in stop_areas_res.all()
|
|
}
|
|
|
|
stop_res = await session.execute(
|
|
select(Stop).where(Stop.id.in_(stop_ids))
|
|
)
|
|
stops: dict[int, Stop] = {stop.id: stop for stop in stop_res.scalars()}
|
|
|
|
found = 0
|
|
for stop_area_id, stop_id in stop_area_to_stop_ids:
|
|
if (stop_area := stop_areas.get(stop_area_id)) is not None:
|
|
if (stop := stops.get(stop_id)) is not None:
|
|
stop_area.stops.append(stop)
|
|
found += 1
|
|
else:
|
|
print(f"No stop found for {stop_id} id")
|
|
else:
|
|
print(f"No stop area found for {stop_area_id}")
|
|
|
|
return found
|
|
|
|
return None
|
|
|
|
|
|
class StopShape(Base):
|
|
|
|
db = db
|
|
|
|
id = mapped_column(BigInteger, primary_key=True) # Same id than ConnectionArea
|
|
type = mapped_column(Integer, nullable=False)
|
|
epsg3857_bbox = mapped_column(JSON)
|
|
epsg3857_points = mapped_column(JSON)
|
|
|
|
__tablename__ = "stop_shapes"
|
|
|
|
|
|
class ConnectionArea(Base):
|
|
|
|
db = db
|
|
|
|
id = mapped_column(BigInteger, primary_key=True)
|
|
|
|
name = mapped_column(String, nullable=False)
|
|
town_name = mapped_column(String, nullable=False)
|
|
postal_region = mapped_column(String, nullable=False)
|
|
epsg3857_x = mapped_column(Float, nullable=False)
|
|
epsg3857_y = mapped_column(Float, nullable=False)
|
|
transport_mode = mapped_column(Enum(StopAreaType), nullable=False)
|
|
|
|
version = mapped_column(String, nullable=False)
|
|
created_ts = mapped_column(BigInteger)
|
|
changed_ts = mapped_column(BigInteger, nullable=False)
|
|
|
|
stops: Mapped[list["_Stop"]] = relationship(back_populates="connection_area")
|
|
|
|
__tablename__ = "connection_areas"
|
|
|
|
# TODO: Merge with StopArea.add_stops
|
|
@classmethod
|
|
async def add_stops(
|
|
cls, conn_area_to_stop_ids: Iterable[tuple[int, int]]
|
|
) -> int | None:
|
|
if (session := await cls.db.get_session()) is not None:
|
|
|
|
async with session.begin():
|
|
|
|
conn_area_ids, stop_ids = set(), set()
|
|
for conn_area_id, stop_id in conn_area_to_stop_ids:
|
|
conn_area_ids.add(conn_area_id)
|
|
stop_ids.add(stop_id)
|
|
|
|
conn_area_res = await session.execute(
|
|
select(ConnectionArea)
|
|
.where(ConnectionArea.id.in_(conn_area_ids))
|
|
.options(selectinload(ConnectionArea.stops))
|
|
)
|
|
conn_areas: dict[int, ConnectionArea] = {
|
|
conn.id: conn for conn in conn_area_res.scalars()
|
|
}
|
|
|
|
stop_res = await session.execute(
|
|
select(Stop).where(Stop.id.in_(stop_ids))
|
|
)
|
|
stops: dict[int, Stop] = {stop.id: stop for stop in stop_res.scalars()}
|
|
|
|
found = 0
|
|
for conn_area_id, stop_id in conn_area_to_stop_ids:
|
|
if (conn_area := conn_areas.get(conn_area_id)) is not None:
|
|
if (stop := stops.get(stop_id)) is not None:
|
|
conn_area.stops.append(stop)
|
|
found += 1
|
|
else:
|
|
print(f"No stop found for {stop_id} id")
|
|
else:
|
|
print(f"No connection area found for {conn_area_id}")
|
|
|
|
return found
|
|
|
|
return None
|