from asyncio import gather as asyncio_gather from collections import defaultdict from typing import Iterable, Self, Sequence from sqlalchemy import ( BigInteger, Boolean, Enum, ForeignKey, Integer, select, String, ) from sqlalchemy.orm import Mapped, mapped_column, relationship, selectinload from sqlalchemy.sql.expression import tuple_ from ..db import Base, db from ..idfm_interface.idfm_types import ( IdfmState, IdfmLineState, TransportMode, TransportSubMode, ) from .stop import _Stop class LineStopAssociations(Base): id = mapped_column(BigInteger, primary_key=True) line_id = mapped_column(BigInteger, ForeignKey("lines.id")) stop_id = mapped_column(BigInteger, ForeignKey("_stops.id")) __tablename__ = "line_stop_associations" class LinePicto(Base): db = db id = mapped_column(String, primary_key=True) mime_type = mapped_column(String, nullable=False) height_px = mapped_column(Integer, nullable=False) width_px = mapped_column(Integer, nullable=False) filename = mapped_column(String, nullable=False) url = mapped_column(String, nullable=False) thumbnail = mapped_column(Boolean, nullable=False) format = mapped_column(String, nullable=False) __tablename__ = "line_pictos" class Line(Base): db = db id = mapped_column(BigInteger, primary_key=True) short_name = mapped_column(String) name = mapped_column(String, nullable=False) status = mapped_column(Enum(IdfmLineState), nullable=False) transport_mode = mapped_column(Enum(TransportMode), nullable=False) transport_submode = mapped_column(Enum(TransportSubMode), nullable=False) network_name = mapped_column(String) group_of_lines_id = mapped_column(String) group_of_lines_shortname = mapped_column(String) colour_web_hexa = mapped_column(String, nullable=False) text_colour_hexa = mapped_column(String, nullable=False) operator_id = mapped_column(Integer) operator_name = mapped_column(String) accessibility = mapped_column(Enum(IdfmState), nullable=False) visual_signs_available = mapped_column(Enum(IdfmState), nullable=False) audible_signs_available = mapped_column(Enum(IdfmState), nullable=False) picto_id = mapped_column(String, ForeignKey("line_pictos.id")) picto: Mapped[LinePicto] = relationship(LinePicto, lazy="selectin") record_id = mapped_column(String, nullable=False) record_ts = mapped_column(BigInteger, nullable=False) stops: Mapped[list[_Stop]] = relationship( "_Stop", secondary="line_stop_associations", back_populates="lines", lazy="selectin", ) __tablename__ = "lines" @classmethod async def get_by_name( cls, name: str, operator_name: None | str = None ) -> Sequence[Self] | None: if (session := await cls.db.get_session()) is not None: async with session.begin(): filters = {"name": name} if operator_name is not None: filters["operator_name"] = operator_name stmt = ( select(cls) .filter_by(**filters) .options(selectinload(cls.stops), selectinload(cls.picto)) ) res = await session.execute(stmt) lines = res.scalars().all() return lines return None @classmethod async def _add_picto_to_line(cls, line: str | Self, picto: LinePicto) -> None: formatted_line: Self | None = None if isinstance(line, str): if (lines := await cls.get_by_name(line)) is not None: if len(lines) == 1: formatted_line = lines[0] else: for candidate_line in lines: if candidate_line.operator_name == "RATP": formatted_line = candidate_line break else: formatted_line = line if isinstance(formatted_line, Line) and formatted_line.picto is None: formatted_line.picto = picto formatted_line.picto_id = picto.id @classmethod async def add_pictos(cls, line_to_pictos: Iterable[tuple[str, LinePicto]]) -> bool: if (session := await cls.db.get_session()) is not None: async with session.begin(): await asyncio_gather( *[ cls._add_picto_to_line(line, picto) for line, picto in line_to_pictos ] ) return True return False @classmethod async def add_stops(cls, line_to_stop_ids: Iterable[tuple[str, str, int]]) -> int: if (session := await cls.db.get_session()) is not None: async with session.begin(): line_names_ops, stop_ids = set(), set() for line_name, operator_name, stop_id in line_to_stop_ids: line_names_ops.add((line_name, operator_name)) stop_ids.add(stop_id) lines_res = await session.execute( select(Line).where( tuple_(Line.name, Line.operator_name).in_(line_names_ops) ) ) lines = defaultdict(list) for line in lines_res.scalars(): lines[(line.name, line.operator_name)].append(line) stops_res = await session.execute( select(_Stop).where(_Stop.id.in_(stop_ids)) ) stops = {stop.id: stop for stop in stops_res.scalars()} found = 0 for line_name, operator_name, stop_id in line_to_stop_ids: if (stop := stops.get(stop_id)) is not None: if ( stop_lines := lines.get((line_name, operator_name)) ) is not None: for stop_line in stop_lines: stop_line.stops.append(stop) found += 1 else: print(f"No line found for {line_name}/{operator_name}") else: print( f"No stop found for {stop_id} id" f"(used by {line_name}/{operator_name})" ) return found return 0