Files
carrramba-encore-rate/backend/backend/models/line.py
2023-02-08 22:10:21 +01:00

190 lines
5.8 KiB
Python

from asyncio import gather as asyncio_gather
from collections import defaultdict
from typing import Iterable, Self, Sequence
from sqlalchemy import (
BigInteger,
Boolean,
Column,
Enum,
ForeignKey,
Integer,
select,
String,
Table,
)
from sqlalchemy.orm import Mapped, mapped_column, relationship, selectinload
from sqlalchemy.sql.expression import tuple_
from ..db import Base, db
from ..idfm_interface.idfm_types import (
IdfmState,
IdfmLineState,
TransportMode,
TransportSubMode,
)
from .stop import _Stop
line_stop_association_table = Table(
"line_stop_association_table",
Base.metadata,
Column("line_id", ForeignKey("lines.id")),
Column("stop_id", ForeignKey("_stops.id")),
)
class LinePicto(Base):
db = db
id = mapped_column(String, primary_key=True)
mime_type = mapped_column(String, nullable=False)
height_px = mapped_column(Integer, nullable=False)
width_px = mapped_column(Integer, nullable=False)
filename = mapped_column(String, nullable=False)
url = mapped_column(String, nullable=False)
thumbnail = mapped_column(Boolean, nullable=False)
format = mapped_column(String, nullable=False)
__tablename__ = "line_pictos"
class Line(Base):
db = db
id = mapped_column(String, primary_key=True)
short_name = mapped_column(String)
name = mapped_column(String, nullable=False)
status = mapped_column(Enum(IdfmLineState), nullable=False)
transport_mode = mapped_column(Enum(TransportMode), nullable=False)
transport_submode = mapped_column(Enum(TransportSubMode), nullable=False)
network_name = mapped_column(String)
group_of_lines_id = mapped_column(String)
group_of_lines_shortname = mapped_column(String)
colour_web_hexa = mapped_column(String, nullable=False)
text_colour_hexa = mapped_column(String, nullable=False)
operator_id = mapped_column(String)
operator_name = mapped_column(String)
accessibility = mapped_column(Enum(IdfmState), nullable=False)
visual_signs_available = mapped_column(Enum(IdfmState), nullable=False)
audible_signs_available = mapped_column(Enum(IdfmState), nullable=False)
picto_id = mapped_column(String, ForeignKey("line_pictos.id"))
picto: Mapped[LinePicto] = relationship(LinePicto, lazy="selectin")
record_id = mapped_column(String, nullable=False)
record_ts = mapped_column(BigInteger, nullable=False)
stops: Mapped[list[_Stop]] = relationship(
"_Stop",
secondary=line_stop_association_table,
back_populates="lines",
lazy="selectin",
)
__tablename__ = "lines"
@classmethod
async def get_by_name(
cls, name: str, operator_name: None | str = None
) -> Sequence[Self] | None:
session = cls.db.session
if session is None:
return None
filters = {"name": name}
if operator_name is not None:
filters["operator_name"] = operator_name
stmt = (
select(cls)
.filter_by(**filters)
.options(selectinload(cls.stops), selectinload(cls.picto))
)
res = await session.execute(stmt)
lines = res.scalars().all()
return lines
@classmethod
async def _add_picto_to_line(cls, line: str | Self, picto: LinePicto) -> None:
formatted_line: Self | None = None
if isinstance(line, str):
if (lines := await cls.get_by_name(line)) is not None:
if len(lines) == 1:
formatted_line = lines[0]
else:
for candidate_line in lines:
if candidate_line.operator_name == "RATP":
formatted_line = candidate_line
break
else:
formatted_line = line
if isinstance(formatted_line, Line) and formatted_line.picto is None:
formatted_line.picto = picto
formatted_line.picto_id = picto.id
@classmethod
async def add_pictos(cls, line_to_pictos: Iterable[tuple[str, LinePicto]]) -> bool:
session = cls.db.session
if session is None:
return False
await asyncio_gather(
*[cls._add_picto_to_line(line, picto) for line, picto in line_to_pictos]
)
await session.commit()
return True
@classmethod
async def add_stops(cls, line_to_stop_ids: Iterable[tuple[str, str, int]]) -> int:
session = cls.db.session
if session is None:
return 0
line_names_ops, stop_ids = set(), set()
for line_name, operator_name, stop_id in line_to_stop_ids:
line_names_ops.add((line_name, operator_name))
stop_ids.add(stop_id)
lines_res = await session.execute(
select(Line).where(
tuple_(Line.name, Line.operator_name).in_(line_names_ops)
)
)
lines = defaultdict(list)
for line in lines_res.scalars():
lines[(line.name, line.operator_name)].append(line)
stops_res = await session.execute(select(_Stop).where(_Stop.id.in_(stop_ids)))
stops = {stop.id: stop for stop in stops_res.scalars()}
found = 0
for line_name, operator_name, stop_id in line_to_stop_ids:
if (stop := stops.get(stop_id)) is not None:
if (stop_lines := lines.get((line_name, operator_name))) is not None:
for stop_line in stop_lines:
stop_line.stops.append(stop)
found += 1
else:
print(f"No line found for {line_name}/{operator_name}")
else:
print(
f"No stop found for {stop_id} id"
f"(used by {line_name}/{operator_name})"
)
await session.commit()
return found