diff --git a/backend/backend/db/base_class.py b/backend/backend/db/base_class.py index 4025d82..71b9703 100644 --- a/backend/backend/db/base_class.py +++ b/backend/backend/db/base_class.py @@ -1,5 +1,6 @@ from __future__ import annotations +from logging import getLogger from typing import Iterable, Self, TYPE_CHECKING from sqlalchemy import select @@ -9,31 +10,36 @@ from sqlalchemy.orm import DeclarativeBase if TYPE_CHECKING: from .db import Database +logger = getLogger(__name__) + class Base(DeclarativeBase): db: Database | None = None @classmethod - async def add(cls, stops: Self | Iterable[Self]) -> bool: - try: - if isinstance(stops, Iterable): - cls.db.session.add_all(stops) # type: ignore - else: - cls.db.session.add(stops) # type: ignore - await cls.db.session.commit() # type: ignore - except (AttributeError, IntegrityError) as err: - print(err) - return False + async def add(cls, objs: Self | Iterable[Self]) -> bool: + if cls.db is not None and (session := await cls.db.get_session()) is not None: + + async with session.begin(): + try: + if isinstance(objs, Iterable): + session.add_all(objs) + else: + session.add(objs) + + except (AttributeError, IntegrityError) as err: + logger.error(err) + return False return True @classmethod async def get_by_id(cls, id_: int | str) -> Self | None: - try: - stmt = select(cls).where(cls.id == id_) # type: ignore - res = await cls.db.session.execute(stmt) # type: ignore - element = res.scalar_one_or_none() - except AttributeError as err: - print(err) - element = None - return element + if cls.db is not None and (session := await cls.db.get_session()) is not None: + + async with session.begin(): + stmt = select(cls).where(cls.id == id_) + res = await session.execute(stmt) + return res.scalar_one_or_none() + + return None diff --git a/backend/backend/db/db.py b/backend/backend/db/db.py index 1d8a0f9..fdf02ce 100644 --- a/backend/backend/db/db.py +++ b/backend/backend/db/db.py @@ -1,5 +1,10 @@ +from logging import getLogger +from typing import Annotated, AsyncIterator + +from fastapi import Depends from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor from sqlalchemy import text +from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.ext.asyncio import ( async_sessionmaker, AsyncEngine, @@ -10,42 +15,47 @@ from sqlalchemy.ext.asyncio import ( from .base_class import Base +logger = getLogger(__name__) + + class Database: def __init__(self) -> None: - self._engine: AsyncEngine | None = None - self._session_maker: async_sessionmaker[AsyncSession] | None = None - self._session: AsyncSession | None = None + self._async_engine: AsyncEngine | None = None + self._async_session_local: async_sessionmaker[AsyncSession] | None = None - @property - def session(self) -> AsyncSession | None: - if self._session is None and (session_maker := self._session_maker) is not None: - self._session = session_maker() - return self._session + async def get_session(self) -> AsyncSession | None: + try: + return self._async_session_local() # type: ignore + except (SQLAlchemyError, AttributeError) as e: + logger.exception(e) + + return None + + # TODO: Preserve UserLastStopSearchResults table from drop. async def connect(self, db_path: str, clear_static_data: bool = False) -> bool: + self._async_engine = create_async_engine( + db_path, pool_pre_ping=True, pool_size=10, max_overflow=20 + ) - # TODO: Preserve UserLastStopSearchResults table from drop. - self._engine = create_async_engine(db_path) - if self._engine is not None: - SQLAlchemyInstrumentor().instrument(engine=self._engine.sync_engine) + if self._async_engine is not None: + SQLAlchemyInstrumentor().instrument(engine=self._async_engine.sync_engine) - self._session_maker = async_sessionmaker( - self._engine, expire_on_commit=False, class_=AsyncSession + self._async_session_local = async_sessionmaker( + bind=self._async_engine, + # autoflush=False, + expire_on_commit=False, + class_=AsyncSession, ) - if (session := self.session) is not None: - await session.execute(text("CREATE EXTENSION IF NOT EXISTS pg_trgm;")) - async with self._engine.begin() as conn: + async with self._async_engine.begin() as session: + await session.execute(text("CREATE EXTENSION IF NOT EXISTS pg_trgm;")) if clear_static_data: - await conn.run_sync(Base.metadata.drop_all) - await conn.run_sync(Base.metadata.create_all) + await session.run_sync(Base.metadata.drop_all) + await session.run_sync(Base.metadata.create_all) return True async def disconnect(self) -> None: - if self._session is not None: - await self._session.close() - self._session = None - - if self._engine is not None: - await self._engine.dispose() + if self._async_engine is not None: + await self._async_engine.dispose() diff --git a/backend/backend/idfm_interface/idfm_interface.py b/backend/backend/idfm_interface/idfm_interface.py index 2c04e99..20ef89d 100644 --- a/backend/backend/idfm_interface/idfm_interface.py +++ b/backend/backend/idfm_interface/idfm_interface.py @@ -15,7 +15,7 @@ from aiohttp import ClientSession from msgspec import ValidationError from msgspec.json import Decoder from pyproj import Transformer -from shapefile import Reader as ShapeFileReader, ShapeRecord +from shapefile import Reader as ShapeFileReader, ShapeRecord # type: ignore from ..db import Database from ..models import ConnectionArea, Line, LinePicto, Stop, StopArea, StopShape @@ -357,7 +357,6 @@ class IdfmInterface: fields = line.fields picto_id = fields.picto.id_ if fields.picto is not None else None - picto = await LinePicto.get_by_id(picto_id) if picto_id else None ret.append( Line( @@ -384,7 +383,6 @@ class IdfmInterface: fields.audiblesigns_available.value ), picto_id=fields.picto.id_ if fields.picto is not None else None, - picto=picto, record_id=line.recordid, record_ts=int(line.record_timestamp.timestamp()), ) diff --git a/backend/backend/models/line.py b/backend/backend/models/line.py index d3108cd..b2fc4bf 100644 --- a/backend/backend/models/line.py +++ b/backend/backend/models/line.py @@ -94,23 +94,24 @@ class Line(Base): async def get_by_name( cls, name: str, operator_name: None | str = None ) -> Sequence[Self] | None: - session = cls.db.session - if session is None: - return None + if (session := await cls.db.get_session()) is not None: - filters = {"name": name} - if operator_name is not None: - filters["operator_name"] = operator_name + async with session.begin(): + filters = {"name": name} + if operator_name is not None: + filters["operator_name"] = operator_name - stmt = ( - select(cls) - .filter_by(**filters) - .options(selectinload(cls.stops), selectinload(cls.picto)) - ) - res = await session.execute(stmt) - lines = res.scalars().all() + stmt = ( + select(cls) + .filter_by(**filters) + .options(selectinload(cls.stops), selectinload(cls.picto)) + ) + res = await session.execute(stmt) + lines = res.scalars().all() - return lines + return lines + + return None @classmethod async def _add_picto_to_line(cls, line: str | Self, picto: LinePicto) -> None: @@ -133,57 +134,63 @@ class Line(Base): @classmethod async def add_pictos(cls, line_to_pictos: Iterable[tuple[str, LinePicto]]) -> bool: - session = cls.db.session - if session is None: - return False + if (session := await cls.db.get_session()) is not None: - await asyncio_gather( - *[cls._add_picto_to_line(line, picto) for line, picto in line_to_pictos] - ) + async with session.begin(): + await asyncio_gather( + *[ + cls._add_picto_to_line(line, picto) + for line, picto in line_to_pictos + ] + ) - await session.commit() + return True - return True + return False @classmethod async def add_stops(cls, line_to_stop_ids: Iterable[tuple[str, str, int]]) -> int: - session = cls.db.session - if session is None: - return 0 + if (session := await cls.db.get_session()) is not None: - line_names_ops, stop_ids = set(), set() - for line_name, operator_name, stop_id in line_to_stop_ids: - line_names_ops.add((line_name, operator_name)) - stop_ids.add(stop_id) + async with session.begin(): - lines_res = await session.execute( - select(Line).where( - tuple_(Line.name, Line.operator_name).in_(line_names_ops) - ) - ) + line_names_ops, stop_ids = set(), set() + for line_name, operator_name, stop_id in line_to_stop_ids: + line_names_ops.add((line_name, operator_name)) + stop_ids.add(stop_id) - lines = defaultdict(list) - for line in lines_res.scalars(): - lines[(line.name, line.operator_name)].append(line) - - stops_res = await session.execute(select(_Stop).where(_Stop.id.in_(stop_ids))) - stops = {stop.id: stop for stop in stops_res.scalars()} - - found = 0 - for line_name, operator_name, stop_id in line_to_stop_ids: - if (stop := stops.get(stop_id)) is not None: - if (stop_lines := lines.get((line_name, operator_name))) is not None: - for stop_line in stop_lines: - stop_line.stops.append(stop) - found += 1 - else: - print(f"No line found for {line_name}/{operator_name}") - else: - print( - f"No stop found for {stop_id} id" - f"(used by {line_name}/{operator_name})" + lines_res = await session.execute( + select(Line).where( + tuple_(Line.name, Line.operator_name).in_(line_names_ops) + ) ) - await session.commit() + lines = defaultdict(list) + for line in lines_res.scalars(): + lines[(line.name, line.operator_name)].append(line) - return found + stops_res = await session.execute( + select(_Stop).where(_Stop.id.in_(stop_ids)) + ) + stops = {stop.id: stop for stop in stops_res.scalars()} + + found = 0 + for line_name, operator_name, stop_id in line_to_stop_ids: + if (stop := stops.get(stop_id)) is not None: + if ( + stop_lines := lines.get((line_name, operator_name)) + ) is not None: + for stop_line in stop_lines: + stop_line.stops.append(stop) + found += 1 + else: + print(f"No line found for {line_name}/{operator_name}") + else: + print( + f"No stop found for {stop_id} id" + f"(used by {line_name}/{operator_name})" + ) + + return found + + return 0 diff --git a/backend/backend/models/stop.py b/backend/backend/models/stop.py index d70fe9e..49c0f30 100644 --- a/backend/backend/models/stop.py +++ b/backend/backend/models/stop.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import Iterable, Sequence, TYPE_CHECKING +from logging import getLogger +from typing import Annotated, Iterable, Sequence, TYPE_CHECKING from sqlalchemy import ( BigInteger, @@ -22,7 +23,6 @@ from sqlalchemy.orm import ( Mapped, relationship, selectinload, - with_polymorphic, ) from sqlalchemy.schema import Index from sqlalchemy_utils.types.ts_vector import TSVectorType @@ -34,6 +34,8 @@ if TYPE_CHECKING: from .line import Line +logger = getLogger(__name__) + stop_area_stop_association_table = Table( "stop_area_stop_association_table", Base.metadata, @@ -91,34 +93,23 @@ class _Stop(Base): ), ) - # TODO: Test https://www.cybertec-postgresql.com/en/postgresql-more-performance-for-like-and-ilike-statements/ - # TODO: Should be able to remove with_polymorphic ? @classmethod - async def get_by_name(cls, name: str) -> Sequence[type[_Stop]] | None: - session = cls.db.session - if session is None: - return None + async def get_by_name(cls, name: str) -> Sequence[_Stop] | None: + if (session := await cls.db.get_session()) is not None: - stop_stop_area = with_polymorphic(_Stop, [Stop, StopArea]) - match_stmt = stop_stop_area.names_tsv.match(name, postgresql_regconfig="french") - ranking_stmt = func.ts_rank_cd( - stop_stop_area.names_tsv, func.plainto_tsquery("french", name) - ) + async with session.begin(): + match_stmt = cls.names_tsv.match(name, postgresql_regconfig="french") + ranking_stmt = func.ts_rank_cd( + cls.names_tsv, func.plainto_tsquery("french", name) + ) + stmt = select(cls).filter(match_stmt).order_by(desc(ranking_stmt)) - stmt = ( - select(stop_stop_area) - .filter(match_stmt) - .order_by(desc(ranking_stmt)) - .options( - selectinload(stop_stop_area.areas), - selectinload(stop_stop_area.lines), - ) - ) + res = await session.execute(stmt) + stops = res.scalars().all() - res = await session.execute(stmt) - stops = res.scalars().all() + return stops - return stops + return None class Stop(_Stop): @@ -160,41 +151,43 @@ class StopArea(_Stop): async def add_stops( cls, stop_area_to_stop_ids: Iterable[tuple[int, int]] ) -> int | None: - session = cls.db.session - if session is None: - return None + if (session := await cls.db.get_session()) is not None: - stop_area_ids, stop_ids = set(), set() - for stop_area_id, stop_id in stop_area_to_stop_ids: - stop_area_ids.add(stop_area_id) - stop_ids.add(stop_id) + async with session.begin(): - stop_areas_res = await session.scalars( - select(StopArea) - .where(StopArea.id.in_(stop_area_ids)) - .options(selectinload(StopArea.stops)) - ) - stop_areas: dict[int, StopArea] = { - stop_area.id: stop_area for stop_area in stop_areas_res.all() - } + stop_area_ids, stop_ids = set(), set() + for stop_area_id, stop_id in stop_area_to_stop_ids: + stop_area_ids.add(stop_area_id) + stop_ids.add(stop_id) - stop_res = await session.execute(select(Stop).where(Stop.id.in_(stop_ids))) - stops: dict[int, Stop] = {stop.id: stop for stop in stop_res.scalars()} + stop_areas_res = await session.scalars( + select(StopArea) + .where(StopArea.id.in_(stop_area_ids)) + .options(selectinload(StopArea.stops)) + ) + stop_areas: dict[int, StopArea] = { + stop_area.id: stop_area for stop_area in stop_areas_res.all() + } - found = 0 - for stop_area_id, stop_id in stop_area_to_stop_ids: - if (stop_area := stop_areas.get(stop_area_id)) is not None: - if (stop := stops.get(stop_id)) is not None: - stop_area.stops.append(stop) - found += 1 - else: - print(f"No stop found for {stop_id} id") - else: - print(f"No stop area found for {stop_area_id}") + stop_res = await session.execute( + select(Stop).where(Stop.id.in_(stop_ids)) + ) + stops: dict[int, Stop] = {stop.id: stop for stop in stop_res.scalars()} - await session.commit() + found = 0 + for stop_area_id, stop_id in stop_area_to_stop_ids: + if (stop_area := stop_areas.get(stop_area_id)) is not None: + if (stop := stops.get(stop_id)) is not None: + stop_area.stops.append(stop) + found += 1 + else: + print(f"No stop found for {stop_id} id") + else: + print(f"No stop area found for {stop_area_id}") - return found + return found + + return None class StopShape(Base): @@ -235,38 +228,40 @@ class ConnectionArea(Base): async def add_stops( cls, conn_area_to_stop_ids: Iterable[tuple[int, int]] ) -> int | None: - session = cls.db.session - if session is None: - return None + if (session := await cls.db.get_session()) is not None: - conn_area_ids, stop_ids = set(), set() - for conn_area_id, stop_id in conn_area_to_stop_ids: - conn_area_ids.add(conn_area_id) - stop_ids.add(stop_id) + async with session.begin(): - conn_area_res = await session.execute( - select(ConnectionArea) - .where(ConnectionArea.id.in_(conn_area_ids)) - .options(selectinload(ConnectionArea.stops)) - ) - conn_areas: dict[int, ConnectionArea] = { - conn.id: conn for conn in conn_area_res.scalars() - } + conn_area_ids, stop_ids = set(), set() + for conn_area_id, stop_id in conn_area_to_stop_ids: + conn_area_ids.add(conn_area_id) + stop_ids.add(stop_id) - stop_res = await session.execute(select(_Stop).where(_Stop.id.in_(stop_ids))) - stops: dict[int, _Stop] = {stop.id: stop for stop in stop_res.scalars()} + conn_area_res = await session.execute( + select(ConnectionArea) + .where(ConnectionArea.id.in_(conn_area_ids)) + .options(selectinload(ConnectionArea.stops)) + ) + conn_areas: dict[int, ConnectionArea] = { + conn.id: conn for conn in conn_area_res.scalars() + } - found = 0 - for conn_area_id, stop_id in conn_area_to_stop_ids: - if (conn_area := conn_areas.get(conn_area_id)) is not None: - if (stop := stops.get(stop_id)) is not None: - conn_area.stops.append(stop) - found += 1 - else: - print(f"No stop found for {stop_id} id") - else: - print(f"No connection area found for {conn_area_id}") + stop_res = await session.execute( + select(Stop).where(Stop.id.in_(stop_ids)) + ) + stops: dict[int, Stop] = {stop.id: stop for stop in stop_res.scalars()} - await session.commit() + found = 0 + for conn_area_id, stop_id in conn_area_to_stop_ids: + if (conn_area := conn_areas.get(conn_area_id)) is not None: + if (stop := stops.get(stop_id)) is not None: + conn_area.stops.append(stop) + found += 1 + else: + print(f"No stop found for {stop_id} id") + else: + print(f"No connection area found for {conn_area_id}") - return found + return found + + return None