from asyncio import sleep from logging import getLogger from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor from sqlalchemy.exc import OperationalError, SQLAlchemyError from sqlalchemy.ext.asyncio import ( async_sessionmaker, AsyncEngine, AsyncSession, create_async_engine, ) from .base_class import Base from ..settings import DatabaseSettings logger = getLogger(__name__) class Database: def __init__(self) -> None: self._async_engine: AsyncEngine | None = None self._async_session_local: async_sessionmaker[AsyncSession] | None = None async def get_session(self) -> AsyncSession | None: try: return self._async_session_local() # type: ignore except (SQLAlchemyError, AttributeError) as e: logger.exception(e) raise # TODO: Preserve UserLastStopSearchResults table from drop. async def connect( self, settings: DatabaseSettings, clear_static_data: bool = False ) -> bool: password = settings.password path = ( f"{settings.driver}://{settings.user}:" f"{password.get_secret_value() if password is not None else ''}" f"@{settings.host}:{settings.port}/{settings.name}" ) self._async_engine = create_async_engine( path, pool_pre_ping=True, pool_size=10, max_overflow=20 ) if self._async_engine is not None: SQLAlchemyInstrumentor().instrument(engine=self._async_engine.sync_engine) self._async_session_local = async_sessionmaker( bind=self._async_engine, # autoflush=False, expire_on_commit=False, class_=AsyncSession, ) ret = False while not ret: try: async with self._async_engine.begin() as session: if clear_static_data: await session.run_sync(Base.metadata.drop_all) await session.run_sync(Base.metadata.create_all) ret = True except OperationalError as err: logger.error(err) await sleep(1) return True async def disconnect(self) -> None: if self._async_engine is not None: await self._async_engine.dispose()