2 Commits

Author SHA1 Message Date
b713042359 🗃️ Use of dedicated db sessions 2023-05-07 12:18:12 +02:00
5505209760 ️ Replace asyncpg with psycopg 2023-05-07 11:24:02 +02:00
7 changed files with 213 additions and 190 deletions

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
from logging import getLogger
from typing import Iterable, Self, TYPE_CHECKING
from sqlalchemy import select
@@ -9,31 +10,36 @@ from sqlalchemy.orm import DeclarativeBase
if TYPE_CHECKING:
from .db import Database
logger = getLogger(__name__)
class Base(DeclarativeBase):
db: Database | None = None
@classmethod
async def add(cls, stops: Self | Iterable[Self]) -> bool:
try:
if isinstance(stops, Iterable):
cls.db.session.add_all(stops) # type: ignore
else:
cls.db.session.add(stops) # type: ignore
await cls.db.session.commit() # type: ignore
except (AttributeError, IntegrityError) as err:
print(err)
return False
async def add(cls, objs: Self | Iterable[Self]) -> bool:
if cls.db is not None and (session := await cls.db.get_session()) is not None:
async with session.begin():
try:
if isinstance(objs, Iterable):
session.add_all(objs)
else:
session.add(objs)
except (AttributeError, IntegrityError) as err:
logger.error(err)
return False
return True
@classmethod
async def get_by_id(cls, id_: int | str) -> Self | None:
try:
stmt = select(cls).where(cls.id == id_) # type: ignore
res = await cls.db.session.execute(stmt) # type: ignore
element = res.scalar_one_or_none()
except AttributeError as err:
print(err)
element = None
return element
if cls.db is not None and (session := await cls.db.get_session()) is not None:
async with session.begin():
stmt = select(cls).where(cls.id == id_)
res = await session.execute(stmt)
return res.scalar_one_or_none()
return None

View File

@@ -1,5 +1,10 @@
from logging import getLogger
from typing import Annotated, AsyncIterator
from fastapi import Depends
from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor
from sqlalchemy import text
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import (
async_sessionmaker,
AsyncEngine,
@@ -10,42 +15,47 @@ from sqlalchemy.ext.asyncio import (
from .base_class import Base
logger = getLogger(__name__)
class Database:
def __init__(self) -> None:
self._engine: AsyncEngine | None = None
self._session_maker: async_sessionmaker[AsyncSession] | None = None
self._session: AsyncSession | None = None
self._async_engine: AsyncEngine | None = None
self._async_session_local: async_sessionmaker[AsyncSession] | None = None
@property
def session(self) -> AsyncSession | None:
if self._session is None and (session_maker := self._session_maker) is not None:
self._session = session_maker()
return self._session
async def get_session(self) -> AsyncSession | None:
try:
return self._async_session_local() # type: ignore
except (SQLAlchemyError, AttributeError) as e:
logger.exception(e)
return None
# TODO: Preserve UserLastStopSearchResults table from drop.
async def connect(self, db_path: str, clear_static_data: bool = False) -> bool:
self._async_engine = create_async_engine(
db_path, pool_pre_ping=True, pool_size=10, max_overflow=20
)
# TODO: Preserve UserLastStopSearchResults table from drop.
self._engine = create_async_engine(db_path)
if self._engine is not None:
SQLAlchemyInstrumentor().instrument(engine=self._engine.sync_engine)
if self._async_engine is not None:
SQLAlchemyInstrumentor().instrument(engine=self._async_engine.sync_engine)
self._session_maker = async_sessionmaker(
self._engine, expire_on_commit=False, class_=AsyncSession
self._async_session_local = async_sessionmaker(
bind=self._async_engine,
# autoflush=False,
expire_on_commit=False,
class_=AsyncSession,
)
if (session := self.session) is not None:
await session.execute(text("CREATE EXTENSION IF NOT EXISTS pg_trgm;"))
async with self._engine.begin() as conn:
async with self._async_engine.begin() as session:
await session.execute(text("CREATE EXTENSION IF NOT EXISTS pg_trgm;"))
if clear_static_data:
await conn.run_sync(Base.metadata.drop_all)
await conn.run_sync(Base.metadata.create_all)
await session.run_sync(Base.metadata.drop_all)
await session.run_sync(Base.metadata.create_all)
return True
async def disconnect(self) -> None:
if self._session is not None:
await self._session.close()
self._session = None
if self._engine is not None:
await self._engine.dispose()
if self._async_engine is not None:
await self._async_engine.dispose()

View File

@@ -15,7 +15,7 @@ from aiohttp import ClientSession
from msgspec import ValidationError
from msgspec.json import Decoder
from pyproj import Transformer
from shapefile import Reader as ShapeFileReader, ShapeRecord
from shapefile import Reader as ShapeFileReader, ShapeRecord # type: ignore
from ..db import Database
from ..models import ConnectionArea, Line, LinePicto, Stop, StopArea, StopShape
@@ -357,7 +357,6 @@ class IdfmInterface:
fields = line.fields
picto_id = fields.picto.id_ if fields.picto is not None else None
picto = await LinePicto.get_by_id(picto_id) if picto_id else None
ret.append(
Line(
@@ -384,7 +383,6 @@ class IdfmInterface:
fields.audiblesigns_available.value
),
picto_id=fields.picto.id_ if fields.picto is not None else None,
picto=picto,
record_id=line.recordid,
record_ts=int(line.record_timestamp.timestamp()),
)

View File

@@ -94,23 +94,24 @@ class Line(Base):
async def get_by_name(
cls, name: str, operator_name: None | str = None
) -> Sequence[Self] | None:
session = cls.db.session
if session is None:
return None
if (session := await cls.db.get_session()) is not None:
filters = {"name": name}
if operator_name is not None:
filters["operator_name"] = operator_name
async with session.begin():
filters = {"name": name}
if operator_name is not None:
filters["operator_name"] = operator_name
stmt = (
select(cls)
.filter_by(**filters)
.options(selectinload(cls.stops), selectinload(cls.picto))
)
res = await session.execute(stmt)
lines = res.scalars().all()
stmt = (
select(cls)
.filter_by(**filters)
.options(selectinload(cls.stops), selectinload(cls.picto))
)
res = await session.execute(stmt)
lines = res.scalars().all()
return lines
return lines
return None
@classmethod
async def _add_picto_to_line(cls, line: str | Self, picto: LinePicto) -> None:
@@ -133,57 +134,63 @@ class Line(Base):
@classmethod
async def add_pictos(cls, line_to_pictos: Iterable[tuple[str, LinePicto]]) -> bool:
session = cls.db.session
if session is None:
return False
if (session := await cls.db.get_session()) is not None:
await asyncio_gather(
*[cls._add_picto_to_line(line, picto) for line, picto in line_to_pictos]
)
async with session.begin():
await asyncio_gather(
*[
cls._add_picto_to_line(line, picto)
for line, picto in line_to_pictos
]
)
await session.commit()
return True
return True
return False
@classmethod
async def add_stops(cls, line_to_stop_ids: Iterable[tuple[str, str, int]]) -> int:
session = cls.db.session
if session is None:
return 0
if (session := await cls.db.get_session()) is not None:
line_names_ops, stop_ids = set(), set()
for line_name, operator_name, stop_id in line_to_stop_ids:
line_names_ops.add((line_name, operator_name))
stop_ids.add(stop_id)
async with session.begin():
lines_res = await session.execute(
select(Line).where(
tuple_(Line.name, Line.operator_name).in_(line_names_ops)
)
)
line_names_ops, stop_ids = set(), set()
for line_name, operator_name, stop_id in line_to_stop_ids:
line_names_ops.add((line_name, operator_name))
stop_ids.add(stop_id)
lines = defaultdict(list)
for line in lines_res.scalars():
lines[(line.name, line.operator_name)].append(line)
stops_res = await session.execute(select(_Stop).where(_Stop.id.in_(stop_ids)))
stops = {stop.id: stop for stop in stops_res.scalars()}
found = 0
for line_name, operator_name, stop_id in line_to_stop_ids:
if (stop := stops.get(stop_id)) is not None:
if (stop_lines := lines.get((line_name, operator_name))) is not None:
for stop_line in stop_lines:
stop_line.stops.append(stop)
found += 1
else:
print(f"No line found for {line_name}/{operator_name}")
else:
print(
f"No stop found for {stop_id} id"
f"(used by {line_name}/{operator_name})"
lines_res = await session.execute(
select(Line).where(
tuple_(Line.name, Line.operator_name).in_(line_names_ops)
)
)
await session.commit()
lines = defaultdict(list)
for line in lines_res.scalars():
lines[(line.name, line.operator_name)].append(line)
return found
stops_res = await session.execute(
select(_Stop).where(_Stop.id.in_(stop_ids))
)
stops = {stop.id: stop for stop in stops_res.scalars()}
found = 0
for line_name, operator_name, stop_id in line_to_stop_ids:
if (stop := stops.get(stop_id)) is not None:
if (
stop_lines := lines.get((line_name, operator_name))
) is not None:
for stop_line in stop_lines:
stop_line.stops.append(stop)
found += 1
else:
print(f"No line found for {line_name}/{operator_name}")
else:
print(
f"No stop found for {stop_id} id"
f"(used by {line_name}/{operator_name})"
)
return found
return 0

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
from typing import Iterable, Sequence, TYPE_CHECKING
from logging import getLogger
from typing import Annotated, Iterable, Sequence, TYPE_CHECKING
from sqlalchemy import (
BigInteger,
@@ -22,7 +23,6 @@ from sqlalchemy.orm import (
Mapped,
relationship,
selectinload,
with_polymorphic,
)
from sqlalchemy.schema import Index
from sqlalchemy_utils.types.ts_vector import TSVectorType
@@ -34,6 +34,8 @@ if TYPE_CHECKING:
from .line import Line
logger = getLogger(__name__)
stop_area_stop_association_table = Table(
"stop_area_stop_association_table",
Base.metadata,
@@ -91,34 +93,23 @@ class _Stop(Base):
),
)
# TODO: Test https://www.cybertec-postgresql.com/en/postgresql-more-performance-for-like-and-ilike-statements/
# TODO: Should be able to remove with_polymorphic ?
@classmethod
async def get_by_name(cls, name: str) -> Sequence[type[_Stop]] | None:
session = cls.db.session
if session is None:
return None
async def get_by_name(cls, name: str) -> Sequence[_Stop] | None:
if (session := await cls.db.get_session()) is not None:
stop_stop_area = with_polymorphic(_Stop, [Stop, StopArea])
match_stmt = stop_stop_area.names_tsv.match(name, postgresql_regconfig="french")
ranking_stmt = func.ts_rank_cd(
stop_stop_area.names_tsv, func.plainto_tsquery("french", name)
)
async with session.begin():
match_stmt = cls.names_tsv.match(name, postgresql_regconfig="french")
ranking_stmt = func.ts_rank_cd(
cls.names_tsv, func.plainto_tsquery("french", name)
)
stmt = select(cls).filter(match_stmt).order_by(desc(ranking_stmt))
stmt = (
select(stop_stop_area)
.filter(match_stmt)
.order_by(desc(ranking_stmt))
.options(
selectinload(stop_stop_area.areas),
selectinload(stop_stop_area.lines),
)
)
res = await session.execute(stmt)
stops = res.scalars().all()
res = await session.execute(stmt)
stops = res.scalars().all()
return stops
return stops
return None
class Stop(_Stop):
@@ -160,41 +151,43 @@ class StopArea(_Stop):
async def add_stops(
cls, stop_area_to_stop_ids: Iterable[tuple[int, int]]
) -> int | None:
session = cls.db.session
if session is None:
return None
if (session := await cls.db.get_session()) is not None:
stop_area_ids, stop_ids = set(), set()
for stop_area_id, stop_id in stop_area_to_stop_ids:
stop_area_ids.add(stop_area_id)
stop_ids.add(stop_id)
async with session.begin():
stop_areas_res = await session.scalars(
select(StopArea)
.where(StopArea.id.in_(stop_area_ids))
.options(selectinload(StopArea.stops))
)
stop_areas: dict[int, StopArea] = {
stop_area.id: stop_area for stop_area in stop_areas_res.all()
}
stop_area_ids, stop_ids = set(), set()
for stop_area_id, stop_id in stop_area_to_stop_ids:
stop_area_ids.add(stop_area_id)
stop_ids.add(stop_id)
stop_res = await session.execute(select(Stop).where(Stop.id.in_(stop_ids)))
stops: dict[int, Stop] = {stop.id: stop for stop in stop_res.scalars()}
stop_areas_res = await session.scalars(
select(StopArea)
.where(StopArea.id.in_(stop_area_ids))
.options(selectinload(StopArea.stops))
)
stop_areas: dict[int, StopArea] = {
stop_area.id: stop_area for stop_area in stop_areas_res.all()
}
found = 0
for stop_area_id, stop_id in stop_area_to_stop_ids:
if (stop_area := stop_areas.get(stop_area_id)) is not None:
if (stop := stops.get(stop_id)) is not None:
stop_area.stops.append(stop)
found += 1
else:
print(f"No stop found for {stop_id} id")
else:
print(f"No stop area found for {stop_area_id}")
stop_res = await session.execute(
select(Stop).where(Stop.id.in_(stop_ids))
)
stops: dict[int, Stop] = {stop.id: stop for stop in stop_res.scalars()}
await session.commit()
found = 0
for stop_area_id, stop_id in stop_area_to_stop_ids:
if (stop_area := stop_areas.get(stop_area_id)) is not None:
if (stop := stops.get(stop_id)) is not None:
stop_area.stops.append(stop)
found += 1
else:
print(f"No stop found for {stop_id} id")
else:
print(f"No stop area found for {stop_area_id}")
return found
return found
return None
class StopShape(Base):
@@ -235,38 +228,40 @@ class ConnectionArea(Base):
async def add_stops(
cls, conn_area_to_stop_ids: Iterable[tuple[int, int]]
) -> int | None:
session = cls.db.session
if session is None:
return None
if (session := await cls.db.get_session()) is not None:
conn_area_ids, stop_ids = set(), set()
for conn_area_id, stop_id in conn_area_to_stop_ids:
conn_area_ids.add(conn_area_id)
stop_ids.add(stop_id)
async with session.begin():
conn_area_res = await session.execute(
select(ConnectionArea)
.where(ConnectionArea.id.in_(conn_area_ids))
.options(selectinload(ConnectionArea.stops))
)
conn_areas: dict[int, ConnectionArea] = {
conn.id: conn for conn in conn_area_res.scalars()
}
conn_area_ids, stop_ids = set(), set()
for conn_area_id, stop_id in conn_area_to_stop_ids:
conn_area_ids.add(conn_area_id)
stop_ids.add(stop_id)
stop_res = await session.execute(select(_Stop).where(_Stop.id.in_(stop_ids)))
stops: dict[int, _Stop] = {stop.id: stop for stop in stop_res.scalars()}
conn_area_res = await session.execute(
select(ConnectionArea)
.where(ConnectionArea.id.in_(conn_area_ids))
.options(selectinload(ConnectionArea.stops))
)
conn_areas: dict[int, ConnectionArea] = {
conn.id: conn for conn in conn_area_res.scalars()
}
found = 0
for conn_area_id, stop_id in conn_area_to_stop_ids:
if (conn_area := conn_areas.get(conn_area_id)) is not None:
if (stop := stops.get(stop_id)) is not None:
conn_area.stops.append(stop)
found += 1
else:
print(f"No stop found for {stop_id} id")
else:
print(f"No connection area found for {conn_area_id}")
stop_res = await session.execute(
select(Stop).where(Stop.id.in_(stop_ids))
)
stops: dict[int, Stop] = {stop.id: stop for stop in stop_res.scalars()}
await session.commit()
found = 0
for conn_area_id, stop_id in conn_area_to_stop_ids:
if (conn_area := conn_areas.get(conn_area_id)) is not None:
if (stop := stops.get(stop_id)) is not None:
conn_area.stops.append(stop)
found += 1
else:
print(f"No stop found for {stop_id} id")
else:
print(f"No connection area found for {conn_area_id}")
return found
return found
return None

View File

@@ -45,8 +45,16 @@ MODE = environ.get("MODE", "grpc")
COLLECTOR_ENDPOINT_GRPC_ENDPOINT = environ.get(
"COLLECTOR_ENDPOINT_GRPC_ENDPOINT", "127.0.0.1:14250" # "jaeger-collector:14250"
)
# TODO: Remove postgresql+asyncpg from environ variable
DB_PATH = "postgresql+asyncpg://cer_user:cer_password@127.0.0.1:5438/cer_db"
# CREATE DATABASE "carrramba-encore-rate";
# CREATE USER cer WITH ENCRYPTED PASSWORD 'cer_password';
# GRANT ALL PRIVILEGES ON DATABASE "carrramba-encore-rate" TO cer;
# \c "carrramba-encore-rate";
# GRANT ALL ON schema public TO cer;
# CREATE EXTENSION IF NOT EXISTS pg_trgm;
# TODO: Remove postgresql+psycopg from environ variable
DB_PATH = "postgresql+psycopg://cer:cer_password@127.0.0.1:5432/carrramba-encore-rate"
app = FastAPI()

View File

@@ -11,10 +11,8 @@ python = "^3.11"
aiohttp = "^3.8.3"
rich = "^12.6.0"
aiofiles = "^22.1.0"
sqlalchemy = {extras = ["asyncio"], version = "^2.0.1"}
fastapi = "^0.88.0"
uvicorn = "^0.20.0"
asyncpg = "^0.27.0"
msgspec = "^0.12.0"
pyshp = "^2.3.1"
pyproj = "^3.5.0"
@@ -25,6 +23,8 @@ opentelemetry-sdk = "^1.17.0"
opentelemetry-api = "^1.17.0"
opentelemetry-exporter-otlp-proto-http = "^1.17.0"
opentelemetry-instrumentation-sqlalchemy = "^0.38b0"
sqlalchemy = "^2.0.12"
psycopg = "^3.1.9"
[build-system]
requires = ["poetry-core"]
@@ -46,7 +46,6 @@ autopep8 = "^2.0.1"
pyflakes = "^3.0.1"
yapf = "^0.32.0"
whatthepatch = "^1.0.4"
sqlalchemy = {extras = ["mypy"], version = "^2.0.1"}
mypy = "^1.0.0"
[tool.mypy]