197 lines
6.4 KiB
Python
197 lines
6.4 KiB
Python
from asyncio import gather as asyncio_gather
|
|
from collections import defaultdict
|
|
from typing import Iterable, Self, Sequence
|
|
|
|
from sqlalchemy import (
|
|
BigInteger,
|
|
Boolean,
|
|
Enum,
|
|
ForeignKey,
|
|
Integer,
|
|
select,
|
|
String,
|
|
)
|
|
from sqlalchemy.orm import Mapped, mapped_column, relationship, selectinload
|
|
from sqlalchemy.sql.expression import tuple_
|
|
|
|
from db import Base, db
|
|
from idfm_interface.idfm_types import (
|
|
IdfmState,
|
|
IdfmLineState,
|
|
TransportMode,
|
|
TransportSubMode,
|
|
)
|
|
from .stop import _Stop
|
|
|
|
|
|
class LineStopAssociations(Base):
|
|
|
|
id = mapped_column(BigInteger, primary_key=True)
|
|
line_id = mapped_column(BigInteger, ForeignKey("lines.id"))
|
|
stop_id = mapped_column(BigInteger, ForeignKey("_stops.id"))
|
|
|
|
__tablename__ = "line_stop_associations"
|
|
|
|
|
|
class LinePicto(Base):
|
|
|
|
db = db
|
|
|
|
id = mapped_column(String, primary_key=True)
|
|
mime_type = mapped_column(String, nullable=False)
|
|
height_px = mapped_column(Integer, nullable=False)
|
|
width_px = mapped_column(Integer, nullable=False)
|
|
filename = mapped_column(String, nullable=False)
|
|
url = mapped_column(String, nullable=False)
|
|
thumbnail = mapped_column(Boolean, nullable=False)
|
|
format = mapped_column(String, nullable=False)
|
|
|
|
__tablename__ = "line_pictos"
|
|
|
|
|
|
class Line(Base):
|
|
|
|
db = db
|
|
|
|
id = mapped_column(BigInteger, primary_key=True)
|
|
|
|
short_name = mapped_column(String)
|
|
name = mapped_column(String, nullable=False)
|
|
status = mapped_column(Enum(IdfmLineState), nullable=False)
|
|
transport_mode = mapped_column(Enum(TransportMode), nullable=False)
|
|
transport_submode = mapped_column(Enum(TransportSubMode), nullable=False)
|
|
|
|
network_name = mapped_column(String)
|
|
group_of_lines_id = mapped_column(String)
|
|
group_of_lines_shortname = mapped_column(String)
|
|
|
|
colour_web_hexa = mapped_column(String, nullable=False)
|
|
text_colour_hexa = mapped_column(String, nullable=False)
|
|
|
|
operator_id = mapped_column(Integer)
|
|
operator_name = mapped_column(String)
|
|
|
|
accessibility = mapped_column(Enum(IdfmState), nullable=False)
|
|
visual_signs_available = mapped_column(Enum(IdfmState), nullable=False)
|
|
audible_signs_available = mapped_column(Enum(IdfmState), nullable=False)
|
|
|
|
picto_id = mapped_column(String, ForeignKey("line_pictos.id"))
|
|
picto: Mapped[LinePicto] = relationship(LinePicto, lazy="selectin")
|
|
|
|
record_id = mapped_column(String, nullable=False)
|
|
record_ts = mapped_column(BigInteger, nullable=False)
|
|
|
|
stops: Mapped[list[_Stop]] = relationship(
|
|
"_Stop",
|
|
secondary="line_stop_associations",
|
|
back_populates="lines",
|
|
lazy="selectin",
|
|
)
|
|
|
|
__tablename__ = "lines"
|
|
|
|
@classmethod
|
|
async def get_by_name(
|
|
cls, name: str, operator_name: None | str = None
|
|
) -> Sequence[Self] | None:
|
|
if (session := await cls.db.get_session()) is not None:
|
|
|
|
async with session.begin():
|
|
filters = {"name": name}
|
|
if operator_name is not None:
|
|
filters["operator_name"] = operator_name
|
|
|
|
stmt = (
|
|
select(cls)
|
|
.filter_by(**filters)
|
|
.options(selectinload(cls.stops), selectinload(cls.picto))
|
|
)
|
|
res = await session.execute(stmt)
|
|
lines = res.scalars().all()
|
|
|
|
return lines
|
|
|
|
return None
|
|
|
|
@classmethod
|
|
async def _add_picto_to_line(cls, line: str | Self, picto: LinePicto) -> None:
|
|
formatted_line: Self | None = None
|
|
if isinstance(line, str):
|
|
if (lines := await cls.get_by_name(line)) is not None:
|
|
if len(lines) == 1:
|
|
formatted_line = lines[0]
|
|
else:
|
|
for candidate_line in lines:
|
|
if candidate_line.operator_name == "RATP":
|
|
formatted_line = candidate_line
|
|
break
|
|
else:
|
|
formatted_line = line
|
|
|
|
if isinstance(formatted_line, Line) and formatted_line.picto is None:
|
|
formatted_line.picto = picto
|
|
formatted_line.picto_id = picto.id
|
|
|
|
@classmethod
|
|
async def add_pictos(cls, line_to_pictos: Iterable[tuple[str, LinePicto]]) -> bool:
|
|
if (session := await cls.db.get_session()) is not None:
|
|
|
|
async with session.begin():
|
|
await asyncio_gather(
|
|
*[
|
|
cls._add_picto_to_line(line, picto)
|
|
for line, picto in line_to_pictos
|
|
]
|
|
)
|
|
|
|
return True
|
|
|
|
return False
|
|
|
|
@classmethod
|
|
async def add_stops(cls, line_to_stop_ids: Iterable[tuple[str, str, int]]) -> int:
|
|
if (session := await cls.db.get_session()) is not None:
|
|
|
|
async with session.begin():
|
|
|
|
line_names_ops, stop_ids = set(), set()
|
|
for line_name, operator_name, stop_id in line_to_stop_ids:
|
|
line_names_ops.add((line_name, operator_name))
|
|
stop_ids.add(stop_id)
|
|
|
|
lines_res = await session.execute(
|
|
select(Line).where(
|
|
tuple_(Line.name, Line.operator_name).in_(line_names_ops)
|
|
)
|
|
)
|
|
|
|
lines = defaultdict(list)
|
|
for line in lines_res.scalars():
|
|
lines[(line.name, line.operator_name)].append(line)
|
|
|
|
stops_res = await session.execute(
|
|
select(_Stop).where(_Stop.id.in_(stop_ids))
|
|
)
|
|
stops = {stop.id: stop for stop in stops_res.scalars()}
|
|
|
|
found = 0
|
|
for line_name, operator_name, stop_id in line_to_stop_ids:
|
|
if (stop := stops.get(stop_id)) is not None:
|
|
if (
|
|
stop_lines := lines.get((line_name, operator_name))
|
|
) is not None:
|
|
for stop_line in stop_lines:
|
|
stop_line.stops.append(stop)
|
|
found += 1
|
|
else:
|
|
print(f"No line found for {line_name}/{operator_name}")
|
|
else:
|
|
print(
|
|
f"No stop found for {stop_id} id"
|
|
f"(used by {line_name}/{operator_name})"
|
|
)
|
|
|
|
return found
|
|
|
|
return 0
|