🗃️ Use of dedicated db sessions

This commit is contained in:
2023-05-07 12:18:12 +02:00
parent 5505209760
commit b713042359
5 changed files with 201 additions and 185 deletions

View File

@@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
from logging import getLogger
from typing import Iterable, Self, TYPE_CHECKING from typing import Iterable, Self, TYPE_CHECKING
from sqlalchemy import select from sqlalchemy import select
@@ -9,31 +10,36 @@ from sqlalchemy.orm import DeclarativeBase
if TYPE_CHECKING: if TYPE_CHECKING:
from .db import Database from .db import Database
logger = getLogger(__name__)
class Base(DeclarativeBase): class Base(DeclarativeBase):
db: Database | None = None db: Database | None = None
@classmethod @classmethod
async def add(cls, stops: Self | Iterable[Self]) -> bool: async def add(cls, objs: Self | Iterable[Self]) -> bool:
try: if cls.db is not None and (session := await cls.db.get_session()) is not None:
if isinstance(stops, Iterable):
cls.db.session.add_all(stops) # type: ignore async with session.begin():
else: try:
cls.db.session.add(stops) # type: ignore if isinstance(objs, Iterable):
await cls.db.session.commit() # type: ignore session.add_all(objs)
except (AttributeError, IntegrityError) as err: else:
print(err) session.add(objs)
return False
except (AttributeError, IntegrityError) as err:
logger.error(err)
return False
return True return True
@classmethod @classmethod
async def get_by_id(cls, id_: int | str) -> Self | None: async def get_by_id(cls, id_: int | str) -> Self | None:
try: if cls.db is not None and (session := await cls.db.get_session()) is not None:
stmt = select(cls).where(cls.id == id_) # type: ignore
res = await cls.db.session.execute(stmt) # type: ignore async with session.begin():
element = res.scalar_one_or_none() stmt = select(cls).where(cls.id == id_)
except AttributeError as err: res = await session.execute(stmt)
print(err) return res.scalar_one_or_none()
element = None
return element return None

View File

@@ -1,5 +1,10 @@
from logging import getLogger
from typing import Annotated, AsyncIterator
from fastapi import Depends
from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor
from sqlalchemy import text from sqlalchemy import text
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import ( from sqlalchemy.ext.asyncio import (
async_sessionmaker, async_sessionmaker,
AsyncEngine, AsyncEngine,
@@ -10,42 +15,47 @@ from sqlalchemy.ext.asyncio import (
from .base_class import Base from .base_class import Base
logger = getLogger(__name__)
class Database: class Database:
def __init__(self) -> None: def __init__(self) -> None:
self._engine: AsyncEngine | None = None self._async_engine: AsyncEngine | None = None
self._session_maker: async_sessionmaker[AsyncSession] | None = None self._async_session_local: async_sessionmaker[AsyncSession] | None = None
self._session: AsyncSession | None = None
@property async def get_session(self) -> AsyncSession | None:
def session(self) -> AsyncSession | None: try:
if self._session is None and (session_maker := self._session_maker) is not None: return self._async_session_local() # type: ignore
self._session = session_maker()
return self._session
except (SQLAlchemyError, AttributeError) as e:
logger.exception(e)
return None
# TODO: Preserve UserLastStopSearchResults table from drop.
async def connect(self, db_path: str, clear_static_data: bool = False) -> bool: async def connect(self, db_path: str, clear_static_data: bool = False) -> bool:
self._async_engine = create_async_engine(
db_path, pool_pre_ping=True, pool_size=10, max_overflow=20
)
# TODO: Preserve UserLastStopSearchResults table from drop. if self._async_engine is not None:
self._engine = create_async_engine(db_path) SQLAlchemyInstrumentor().instrument(engine=self._async_engine.sync_engine)
if self._engine is not None:
SQLAlchemyInstrumentor().instrument(engine=self._engine.sync_engine)
self._session_maker = async_sessionmaker( self._async_session_local = async_sessionmaker(
self._engine, expire_on_commit=False, class_=AsyncSession bind=self._async_engine,
# autoflush=False,
expire_on_commit=False,
class_=AsyncSession,
) )
if (session := self.session) is not None:
await session.execute(text("CREATE EXTENSION IF NOT EXISTS pg_trgm;"))
async with self._engine.begin() as conn: async with self._async_engine.begin() as session:
await session.execute(text("CREATE EXTENSION IF NOT EXISTS pg_trgm;"))
if clear_static_data: if clear_static_data:
await conn.run_sync(Base.metadata.drop_all) await session.run_sync(Base.metadata.drop_all)
await conn.run_sync(Base.metadata.create_all) await session.run_sync(Base.metadata.create_all)
return True return True
async def disconnect(self) -> None: async def disconnect(self) -> None:
if self._session is not None: if self._async_engine is not None:
await self._session.close() await self._async_engine.dispose()
self._session = None
if self._engine is not None:
await self._engine.dispose()

View File

@@ -15,7 +15,7 @@ from aiohttp import ClientSession
from msgspec import ValidationError from msgspec import ValidationError
from msgspec.json import Decoder from msgspec.json import Decoder
from pyproj import Transformer from pyproj import Transformer
from shapefile import Reader as ShapeFileReader, ShapeRecord from shapefile import Reader as ShapeFileReader, ShapeRecord # type: ignore
from ..db import Database from ..db import Database
from ..models import ConnectionArea, Line, LinePicto, Stop, StopArea, StopShape from ..models import ConnectionArea, Line, LinePicto, Stop, StopArea, StopShape
@@ -357,7 +357,6 @@ class IdfmInterface:
fields = line.fields fields = line.fields
picto_id = fields.picto.id_ if fields.picto is not None else None picto_id = fields.picto.id_ if fields.picto is not None else None
picto = await LinePicto.get_by_id(picto_id) if picto_id else None
ret.append( ret.append(
Line( Line(
@@ -384,7 +383,6 @@ class IdfmInterface:
fields.audiblesigns_available.value fields.audiblesigns_available.value
), ),
picto_id=fields.picto.id_ if fields.picto is not None else None, picto_id=fields.picto.id_ if fields.picto is not None else None,
picto=picto,
record_id=line.recordid, record_id=line.recordid,
record_ts=int(line.record_timestamp.timestamp()), record_ts=int(line.record_timestamp.timestamp()),
) )

View File

@@ -94,23 +94,24 @@ class Line(Base):
async def get_by_name( async def get_by_name(
cls, name: str, operator_name: None | str = None cls, name: str, operator_name: None | str = None
) -> Sequence[Self] | None: ) -> Sequence[Self] | None:
session = cls.db.session if (session := await cls.db.get_session()) is not None:
if session is None:
return None
filters = {"name": name} async with session.begin():
if operator_name is not None: filters = {"name": name}
filters["operator_name"] = operator_name if operator_name is not None:
filters["operator_name"] = operator_name
stmt = ( stmt = (
select(cls) select(cls)
.filter_by(**filters) .filter_by(**filters)
.options(selectinload(cls.stops), selectinload(cls.picto)) .options(selectinload(cls.stops), selectinload(cls.picto))
) )
res = await session.execute(stmt) res = await session.execute(stmt)
lines = res.scalars().all() lines = res.scalars().all()
return lines return lines
return None
@classmethod @classmethod
async def _add_picto_to_line(cls, line: str | Self, picto: LinePicto) -> None: async def _add_picto_to_line(cls, line: str | Self, picto: LinePicto) -> None:
@@ -133,57 +134,63 @@ class Line(Base):
@classmethod @classmethod
async def add_pictos(cls, line_to_pictos: Iterable[tuple[str, LinePicto]]) -> bool: async def add_pictos(cls, line_to_pictos: Iterable[tuple[str, LinePicto]]) -> bool:
session = cls.db.session if (session := await cls.db.get_session()) is not None:
if session is None:
return False
await asyncio_gather( async with session.begin():
*[cls._add_picto_to_line(line, picto) for line, picto in line_to_pictos] await asyncio_gather(
) *[
cls._add_picto_to_line(line, picto)
for line, picto in line_to_pictos
]
)
await session.commit() return True
return True return False
@classmethod @classmethod
async def add_stops(cls, line_to_stop_ids: Iterable[tuple[str, str, int]]) -> int: async def add_stops(cls, line_to_stop_ids: Iterable[tuple[str, str, int]]) -> int:
session = cls.db.session if (session := await cls.db.get_session()) is not None:
if session is None:
return 0
line_names_ops, stop_ids = set(), set() async with session.begin():
for line_name, operator_name, stop_id in line_to_stop_ids:
line_names_ops.add((line_name, operator_name))
stop_ids.add(stop_id)
lines_res = await session.execute( line_names_ops, stop_ids = set(), set()
select(Line).where( for line_name, operator_name, stop_id in line_to_stop_ids:
tuple_(Line.name, Line.operator_name).in_(line_names_ops) line_names_ops.add((line_name, operator_name))
) stop_ids.add(stop_id)
)
lines = defaultdict(list) lines_res = await session.execute(
for line in lines_res.scalars(): select(Line).where(
lines[(line.name, line.operator_name)].append(line) tuple_(Line.name, Line.operator_name).in_(line_names_ops)
)
stops_res = await session.execute(select(_Stop).where(_Stop.id.in_(stop_ids)))
stops = {stop.id: stop for stop in stops_res.scalars()}
found = 0
for line_name, operator_name, stop_id in line_to_stop_ids:
if (stop := stops.get(stop_id)) is not None:
if (stop_lines := lines.get((line_name, operator_name))) is not None:
for stop_line in stop_lines:
stop_line.stops.append(stop)
found += 1
else:
print(f"No line found for {line_name}/{operator_name}")
else:
print(
f"No stop found for {stop_id} id"
f"(used by {line_name}/{operator_name})"
) )
await session.commit() lines = defaultdict(list)
for line in lines_res.scalars():
lines[(line.name, line.operator_name)].append(line)
return found stops_res = await session.execute(
select(_Stop).where(_Stop.id.in_(stop_ids))
)
stops = {stop.id: stop for stop in stops_res.scalars()}
found = 0
for line_name, operator_name, stop_id in line_to_stop_ids:
if (stop := stops.get(stop_id)) is not None:
if (
stop_lines := lines.get((line_name, operator_name))
) is not None:
for stop_line in stop_lines:
stop_line.stops.append(stop)
found += 1
else:
print(f"No line found for {line_name}/{operator_name}")
else:
print(
f"No stop found for {stop_id} id"
f"(used by {line_name}/{operator_name})"
)
return found
return 0

View File

@@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
from typing import Iterable, Sequence, TYPE_CHECKING from logging import getLogger
from typing import Annotated, Iterable, Sequence, TYPE_CHECKING
from sqlalchemy import ( from sqlalchemy import (
BigInteger, BigInteger,
@@ -22,7 +23,6 @@ from sqlalchemy.orm import (
Mapped, Mapped,
relationship, relationship,
selectinload, selectinload,
with_polymorphic,
) )
from sqlalchemy.schema import Index from sqlalchemy.schema import Index
from sqlalchemy_utils.types.ts_vector import TSVectorType from sqlalchemy_utils.types.ts_vector import TSVectorType
@@ -34,6 +34,8 @@ if TYPE_CHECKING:
from .line import Line from .line import Line
logger = getLogger(__name__)
stop_area_stop_association_table = Table( stop_area_stop_association_table = Table(
"stop_area_stop_association_table", "stop_area_stop_association_table",
Base.metadata, Base.metadata,
@@ -91,34 +93,23 @@ class _Stop(Base):
), ),
) )
# TODO: Test https://www.cybertec-postgresql.com/en/postgresql-more-performance-for-like-and-ilike-statements/
# TODO: Should be able to remove with_polymorphic ?
@classmethod @classmethod
async def get_by_name(cls, name: str) -> Sequence[type[_Stop]] | None: async def get_by_name(cls, name: str) -> Sequence[_Stop] | None:
session = cls.db.session if (session := await cls.db.get_session()) is not None:
if session is None:
return None
stop_stop_area = with_polymorphic(_Stop, [Stop, StopArea]) async with session.begin():
match_stmt = stop_stop_area.names_tsv.match(name, postgresql_regconfig="french") match_stmt = cls.names_tsv.match(name, postgresql_regconfig="french")
ranking_stmt = func.ts_rank_cd( ranking_stmt = func.ts_rank_cd(
stop_stop_area.names_tsv, func.plainto_tsquery("french", name) cls.names_tsv, func.plainto_tsquery("french", name)
) )
stmt = select(cls).filter(match_stmt).order_by(desc(ranking_stmt))
stmt = ( res = await session.execute(stmt)
select(stop_stop_area) stops = res.scalars().all()
.filter(match_stmt)
.order_by(desc(ranking_stmt))
.options(
selectinload(stop_stop_area.areas),
selectinload(stop_stop_area.lines),
)
)
res = await session.execute(stmt) return stops
stops = res.scalars().all()
return stops return None
class Stop(_Stop): class Stop(_Stop):
@@ -160,41 +151,43 @@ class StopArea(_Stop):
async def add_stops( async def add_stops(
cls, stop_area_to_stop_ids: Iterable[tuple[int, int]] cls, stop_area_to_stop_ids: Iterable[tuple[int, int]]
) -> int | None: ) -> int | None:
session = cls.db.session if (session := await cls.db.get_session()) is not None:
if session is None:
return None
stop_area_ids, stop_ids = set(), set() async with session.begin():
for stop_area_id, stop_id in stop_area_to_stop_ids:
stop_area_ids.add(stop_area_id)
stop_ids.add(stop_id)
stop_areas_res = await session.scalars( stop_area_ids, stop_ids = set(), set()
select(StopArea) for stop_area_id, stop_id in stop_area_to_stop_ids:
.where(StopArea.id.in_(stop_area_ids)) stop_area_ids.add(stop_area_id)
.options(selectinload(StopArea.stops)) stop_ids.add(stop_id)
)
stop_areas: dict[int, StopArea] = {
stop_area.id: stop_area for stop_area in stop_areas_res.all()
}
stop_res = await session.execute(select(Stop).where(Stop.id.in_(stop_ids))) stop_areas_res = await session.scalars(
stops: dict[int, Stop] = {stop.id: stop for stop in stop_res.scalars()} select(StopArea)
.where(StopArea.id.in_(stop_area_ids))
.options(selectinload(StopArea.stops))
)
stop_areas: dict[int, StopArea] = {
stop_area.id: stop_area for stop_area in stop_areas_res.all()
}
found = 0 stop_res = await session.execute(
for stop_area_id, stop_id in stop_area_to_stop_ids: select(Stop).where(Stop.id.in_(stop_ids))
if (stop_area := stop_areas.get(stop_area_id)) is not None: )
if (stop := stops.get(stop_id)) is not None: stops: dict[int, Stop] = {stop.id: stop for stop in stop_res.scalars()}
stop_area.stops.append(stop)
found += 1
else:
print(f"No stop found for {stop_id} id")
else:
print(f"No stop area found for {stop_area_id}")
await session.commit() found = 0
for stop_area_id, stop_id in stop_area_to_stop_ids:
if (stop_area := stop_areas.get(stop_area_id)) is not None:
if (stop := stops.get(stop_id)) is not None:
stop_area.stops.append(stop)
found += 1
else:
print(f"No stop found for {stop_id} id")
else:
print(f"No stop area found for {stop_area_id}")
return found return found
return None
class StopShape(Base): class StopShape(Base):
@@ -235,38 +228,40 @@ class ConnectionArea(Base):
async def add_stops( async def add_stops(
cls, conn_area_to_stop_ids: Iterable[tuple[int, int]] cls, conn_area_to_stop_ids: Iterable[tuple[int, int]]
) -> int | None: ) -> int | None:
session = cls.db.session if (session := await cls.db.get_session()) is not None:
if session is None:
return None
conn_area_ids, stop_ids = set(), set() async with session.begin():
for conn_area_id, stop_id in conn_area_to_stop_ids:
conn_area_ids.add(conn_area_id)
stop_ids.add(stop_id)
conn_area_res = await session.execute( conn_area_ids, stop_ids = set(), set()
select(ConnectionArea) for conn_area_id, stop_id in conn_area_to_stop_ids:
.where(ConnectionArea.id.in_(conn_area_ids)) conn_area_ids.add(conn_area_id)
.options(selectinload(ConnectionArea.stops)) stop_ids.add(stop_id)
)
conn_areas: dict[int, ConnectionArea] = {
conn.id: conn for conn in conn_area_res.scalars()
}
stop_res = await session.execute(select(_Stop).where(_Stop.id.in_(stop_ids))) conn_area_res = await session.execute(
stops: dict[int, _Stop] = {stop.id: stop for stop in stop_res.scalars()} select(ConnectionArea)
.where(ConnectionArea.id.in_(conn_area_ids))
.options(selectinload(ConnectionArea.stops))
)
conn_areas: dict[int, ConnectionArea] = {
conn.id: conn for conn in conn_area_res.scalars()
}
found = 0 stop_res = await session.execute(
for conn_area_id, stop_id in conn_area_to_stop_ids: select(Stop).where(Stop.id.in_(stop_ids))
if (conn_area := conn_areas.get(conn_area_id)) is not None: )
if (stop := stops.get(stop_id)) is not None: stops: dict[int, Stop] = {stop.id: stop for stop in stop_res.scalars()}
conn_area.stops.append(stop)
found += 1
else:
print(f"No stop found for {stop_id} id")
else:
print(f"No connection area found for {conn_area_id}")
await session.commit() found = 0
for conn_area_id, stop_id in conn_area_to_stop_ids:
if (conn_area := conn_areas.get(conn_area_id)) is not None:
if (stop := stops.get(stop_id)) is not None:
conn_area.stops.append(stop)
found += 1
else:
print(f"No stop found for {stop_id} id")
else:
print(f"No connection area found for {conn_area_id}")
return found return found
return None