🎨 Reorganize back-end code

This commit is contained in:
2023-09-20 22:08:32 +02:00
parent bdbc72ab39
commit 3434802b31
28 changed files with 29 additions and 36 deletions

View File

@@ -0,0 +1,14 @@
from .line import Line, LinePicto
from .stop import ConnectionArea, Stop, StopArea, StopShape
from .user import UserLastStopSearchResults
__all__ = [
"ConnectionArea",
"Line",
"LinePicto",
"Stop",
"StopArea",
"StopShape",
"UserLastStopSearchResults",
]

196
backend/api/models/line.py Normal file
View File

@@ -0,0 +1,196 @@
from asyncio import gather as asyncio_gather
from collections import defaultdict
from typing import Iterable, Self, Sequence
from sqlalchemy import (
BigInteger,
Boolean,
Enum,
ForeignKey,
Integer,
select,
String,
)
from sqlalchemy.orm import Mapped, mapped_column, relationship, selectinload
from sqlalchemy.sql.expression import tuple_
from db import Base, db
from idfm_interface.idfm_types import (
IdfmState,
IdfmLineState,
TransportMode,
TransportSubMode,
)
from .stop import _Stop
class LineStopAssociations(Base):
id = mapped_column(BigInteger, primary_key=True)
line_id = mapped_column(BigInteger, ForeignKey("lines.id"))
stop_id = mapped_column(BigInteger, ForeignKey("_stops.id"))
__tablename__ = "line_stop_associations"
class LinePicto(Base):
db = db
id = mapped_column(String, primary_key=True)
mime_type = mapped_column(String, nullable=False)
height_px = mapped_column(Integer, nullable=False)
width_px = mapped_column(Integer, nullable=False)
filename = mapped_column(String, nullable=False)
url = mapped_column(String, nullable=False)
thumbnail = mapped_column(Boolean, nullable=False)
format = mapped_column(String, nullable=False)
__tablename__ = "line_pictos"
class Line(Base):
db = db
id = mapped_column(BigInteger, primary_key=True)
short_name = mapped_column(String)
name = mapped_column(String, nullable=False)
status = mapped_column(Enum(IdfmLineState), nullable=False)
transport_mode = mapped_column(Enum(TransportMode), nullable=False)
transport_submode = mapped_column(Enum(TransportSubMode), nullable=False)
network_name = mapped_column(String)
group_of_lines_id = mapped_column(String)
group_of_lines_shortname = mapped_column(String)
colour_web_hexa = mapped_column(String, nullable=False)
text_colour_hexa = mapped_column(String, nullable=False)
operator_id = mapped_column(Integer)
operator_name = mapped_column(String)
accessibility = mapped_column(Enum(IdfmState), nullable=False)
visual_signs_available = mapped_column(Enum(IdfmState), nullable=False)
audible_signs_available = mapped_column(Enum(IdfmState), nullable=False)
picto_id = mapped_column(String, ForeignKey("line_pictos.id"))
picto: Mapped[LinePicto] = relationship(LinePicto, lazy="selectin")
record_id = mapped_column(String, nullable=False)
record_ts = mapped_column(BigInteger, nullable=False)
stops: Mapped[list[_Stop]] = relationship(
"_Stop",
secondary="line_stop_associations",
back_populates="lines",
lazy="selectin",
)
__tablename__ = "lines"
@classmethod
async def get_by_name(
cls, name: str, operator_name: None | str = None
) -> Sequence[Self] | None:
if (session := await cls.db.get_session()) is not None:
async with session.begin():
filters = {"name": name}
if operator_name is not None:
filters["operator_name"] = operator_name
stmt = (
select(cls)
.filter_by(**filters)
.options(selectinload(cls.stops), selectinload(cls.picto))
)
res = await session.execute(stmt)
lines = res.scalars().all()
return lines
return None
@classmethod
async def _add_picto_to_line(cls, line: str | Self, picto: LinePicto) -> None:
formatted_line: Self | None = None
if isinstance(line, str):
if (lines := await cls.get_by_name(line)) is not None:
if len(lines) == 1:
formatted_line = lines[0]
else:
for candidate_line in lines:
if candidate_line.operator_name == "RATP":
formatted_line = candidate_line
break
else:
formatted_line = line
if isinstance(formatted_line, Line) and formatted_line.picto is None:
formatted_line.picto = picto
formatted_line.picto_id = picto.id
@classmethod
async def add_pictos(cls, line_to_pictos: Iterable[tuple[str, LinePicto]]) -> bool:
if (session := await cls.db.get_session()) is not None:
async with session.begin():
await asyncio_gather(
*[
cls._add_picto_to_line(line, picto)
for line, picto in line_to_pictos
]
)
return True
return False
@classmethod
async def add_stops(cls, line_to_stop_ids: Iterable[tuple[str, str, int]]) -> int:
if (session := await cls.db.get_session()) is not None:
async with session.begin():
line_names_ops, stop_ids = set(), set()
for line_name, operator_name, stop_id in line_to_stop_ids:
line_names_ops.add((line_name, operator_name))
stop_ids.add(stop_id)
lines_res = await session.execute(
select(Line).where(
tuple_(Line.name, Line.operator_name).in_(line_names_ops)
)
)
lines = defaultdict(list)
for line in lines_res.scalars():
lines[(line.name, line.operator_name)].append(line)
stops_res = await session.execute(
select(_Stop).where(_Stop.id.in_(stop_ids))
)
stops = {stop.id: stop for stop in stops_res.scalars()}
found = 0
for line_name, operator_name, stop_id in line_to_stop_ids:
if (stop := stops.get(stop_id)) is not None:
if (
stop_lines := lines.get((line_name, operator_name))
) is not None:
for stop_line in stop_lines:
stop_line.stops.append(stop)
found += 1
else:
print(f"No line found for {line_name}/{operator_name}")
else:
print(
f"No stop found for {stop_id} id"
f"(used by {line_name}/{operator_name})"
)
return found
return 0

275
backend/api/models/stop.py Normal file
View File

@@ -0,0 +1,275 @@
from __future__ import annotations
from logging import getLogger
from typing import Iterable, Sequence, TYPE_CHECKING
from sqlalchemy import (
BigInteger,
Computed,
desc,
Enum,
Float,
ForeignKey,
func,
Integer,
JSON,
select,
String,
)
from sqlalchemy.orm import (
mapped_column,
Mapped,
relationship,
selectinload,
with_polymorphic,
)
from sqlalchemy.schema import Index
from sqlalchemy_utils.types.ts_vector import TSVectorType
from db import Base, db
from idfm_interface.idfm_types import TransportMode, IdfmState, StopAreaType
if TYPE_CHECKING:
from .line import Line
logger = getLogger(__name__)
class StopAreaStopAssociations(Base):
id = mapped_column(BigInteger, primary_key=True)
stop_id = mapped_column(BigInteger, ForeignKey("_stops.id"))
stop_area_id = mapped_column(BigInteger, ForeignKey("stop_areas.id"))
__tablename__ = "stop_area_stop_associations"
class _Stop(Base):
db = db
id = mapped_column(BigInteger, primary_key=True)
kind = mapped_column(String)
name = mapped_column(String, nullable=False, index=True)
town_name = mapped_column(String, nullable=False)
postal_region = mapped_column(Integer, nullable=False)
epsg3857_x = mapped_column(Float, nullable=False)
epsg3857_y = mapped_column(Float, nullable=False)
version = mapped_column(String, nullable=False)
created_ts = mapped_column(BigInteger)
changed_ts = mapped_column(BigInteger, nullable=False)
lines: Mapped[list[Line]] = relationship(
"Line",
secondary="line_stop_associations",
back_populates="stops",
lazy="selectin",
)
areas: Mapped[list["StopArea"]] = relationship(
"StopArea",
secondary="stop_area_stop_associations",
back_populates="stops",
)
connection_area_id: Mapped[int] = mapped_column(
ForeignKey("connection_areas.id"), nullable=True
)
connection_area: Mapped["ConnectionArea"] = relationship(
back_populates="stops", lazy="selectin"
)
names_tsv = mapped_column(
TSVectorType("name", "town_name", regconfig="french"),
Computed("to_tsvector('french', name || ' ' || town_name)", persisted=True),
)
__tablename__ = "_stops"
__mapper_args__ = {"polymorphic_identity": "_stops", "polymorphic_on": kind}
__table_args__ = (
Index(
"names_tsv_idx",
names_tsv,
postgresql_ops={"name": "gin_trgm_ops"},
postgresql_using="gin",
),
)
@classmethod
async def get_by_name(cls, name: str) -> Sequence[_Stop] | None:
if (session := await cls.db.get_session()) is not None:
async with session.begin():
descendants = with_polymorphic(_Stop, "*")
match_stmt = descendants.names_tsv.match(
name, postgresql_regconfig="french"
)
ranking_stmt = func.ts_rank_cd(
descendants.names_tsv, func.plainto_tsquery("french", name)
)
stmt = (
select(descendants).filter(match_stmt).order_by(desc(ranking_stmt))
)
res = await session.execute(stmt)
stops = res.scalars().all()
return stops
return None
class Stop(_Stop):
id = mapped_column(BigInteger, ForeignKey("_stops.id"), primary_key=True)
transport_mode = mapped_column(Enum(TransportMode), nullable=False)
accessibility = mapped_column(Enum(IdfmState), nullable=False)
visual_signs_available = mapped_column(Enum(IdfmState), nullable=False)
audible_signs_available = mapped_column(Enum(IdfmState), nullable=False)
record_id = mapped_column(String, nullable=False)
record_ts = mapped_column(BigInteger, nullable=False)
__tablename__ = "stops"
__mapper_args__ = {"polymorphic_identity": "stops", "polymorphic_load": "inline"}
class StopArea(_Stop):
id = mapped_column(BigInteger, ForeignKey("_stops.id"), primary_key=True)
type = mapped_column(Enum(StopAreaType), nullable=False)
stops: Mapped[list["Stop"]] = relationship(
"Stop",
secondary="stop_area_stop_associations",
back_populates="areas",
lazy="selectin",
)
__tablename__ = "stop_areas"
__mapper_args__ = {
"polymorphic_identity": "stop_areas",
"polymorphic_load": "inline",
}
@classmethod
async def add_stops(
cls, stop_area_to_stop_ids: Iterable[tuple[int, int]]
) -> int | None:
if (session := await cls.db.get_session()) is not None:
async with session.begin():
stop_area_ids, stop_ids = set(), set()
for stop_area_id, stop_id in stop_area_to_stop_ids:
stop_area_ids.add(stop_area_id)
stop_ids.add(stop_id)
stop_areas_res = await session.scalars(
select(StopArea)
.where(StopArea.id.in_(stop_area_ids))
.options(selectinload(StopArea.stops))
)
stop_areas: dict[int, StopArea] = {
stop_area.id: stop_area for stop_area in stop_areas_res.all()
}
stop_res = await session.execute(
select(Stop).where(Stop.id.in_(stop_ids))
)
stops: dict[int, Stop] = {stop.id: stop for stop in stop_res.scalars()}
found = 0
for stop_area_id, stop_id in stop_area_to_stop_ids:
if (stop_area := stop_areas.get(stop_area_id)) is not None:
if (stop := stops.get(stop_id)) is not None:
stop_area.stops.append(stop)
found += 1
else:
print(f"No stop found for {stop_id} id")
else:
print(f"No stop area found for {stop_area_id}")
return found
return None
class StopShape(Base):
db = db
id = mapped_column(BigInteger, primary_key=True) # Same id than ConnectionArea
type = mapped_column(Integer, nullable=False)
epsg3857_bbox = mapped_column(JSON)
epsg3857_points = mapped_column(JSON)
__tablename__ = "stop_shapes"
class ConnectionArea(Base):
db = db
id = mapped_column(BigInteger, primary_key=True)
name = mapped_column(String, nullable=False)
town_name = mapped_column(String, nullable=False)
postal_region = mapped_column(String, nullable=False)
epsg3857_x = mapped_column(Float, nullable=False)
epsg3857_y = mapped_column(Float, nullable=False)
transport_mode = mapped_column(Enum(StopAreaType), nullable=False)
version = mapped_column(String, nullable=False)
created_ts = mapped_column(BigInteger)
changed_ts = mapped_column(BigInteger, nullable=False)
stops: Mapped[list["_Stop"]] = relationship(back_populates="connection_area")
__tablename__ = "connection_areas"
# TODO: Merge with StopArea.add_stops
@classmethod
async def add_stops(
cls, conn_area_to_stop_ids: Iterable[tuple[int, int]]
) -> int | None:
if (session := await cls.db.get_session()) is not None:
async with session.begin():
conn_area_ids, stop_ids = set(), set()
for conn_area_id, stop_id in conn_area_to_stop_ids:
conn_area_ids.add(conn_area_id)
stop_ids.add(stop_id)
conn_area_res = await session.execute(
select(ConnectionArea)
.where(ConnectionArea.id.in_(conn_area_ids))
.options(selectinload(ConnectionArea.stops))
)
conn_areas: dict[int, ConnectionArea] = {
conn.id: conn for conn in conn_area_res.scalars()
}
stop_res = await session.execute(
select(Stop).where(Stop.id.in_(stop_ids))
)
stops: dict[int, Stop] = {stop.id: stop for stop in stop_res.scalars()}
found = 0
for conn_area_id, stop_id in conn_area_to_stop_ids:
if (conn_area := conn_areas.get(conn_area_id)) is not None:
if (stop := stops.get(stop_id)) is not None:
conn_area.stops.append(stop)
found += 1
else:
print(f"No stop found for {stop_id} id")
else:
print(f"No connection area found for {conn_area_id}")
return found
return None

View File

@@ -0,0 +1,27 @@
from sqlalchemy import BigInteger, ForeignKey, String
from sqlalchemy.orm import Mapped, mapped_column, relationship
from db import Base, db
from .stop import _Stop
class UserLastStopSearchStopAssociations(Base):
id = mapped_column(BigInteger, primary_key=True)
user_mxid = mapped_column(
String, ForeignKey("user_last_stop_search_results.user_mxid")
)
stop_id = mapped_column(BigInteger, ForeignKey("_stops.id"))
__tablename__ = "user_last_stop_search_stop_associations"
class UserLastStopSearchResults(Base):
db = db
user_mxid = mapped_column(String, primary_key=True)
request_content = mapped_column(String, nullable=False)
stops: Mapped[_Stop] = relationship(
_Stop, secondary="user_last_stop_search_stop_associations"
)
__tablename__ = "user_last_stop_search_results"