from asyncio import sleep from logging import getLogger from typing import Annotated, AsyncIterator from fastapi import Depends from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor from sqlalchemy import text from sqlalchemy.exc import OperationalError, SQLAlchemyError from sqlalchemy.ext.asyncio import ( async_sessionmaker, AsyncEngine, AsyncSession, create_async_engine, ) from .base_class import Base from ..settings import DatabaseSettings logger = getLogger(__name__) class Database: def __init__(self) -> None: self._async_engine: AsyncEngine | None = None self._async_session_local: async_sessionmaker[AsyncSession] | None = None async def get_session(self) -> AsyncSession | None: try: return self._async_session_local() # type: ignore except (SQLAlchemyError, AttributeError) as e: logger.exception(e) return None # TODO: Preserve UserLastStopSearchResults table from drop. async def connect( self, settings: DatabaseSettings, clear_static_data: bool = False ) -> bool: password = settings.password path = ( f"{settings.driver}://{settings.user}:" f"{password.get_secret_value() if password is not None else ''}" f"@{settings.host}:{settings.port}/{settings.name}" ) self._async_engine = create_async_engine( path, pool_pre_ping=True, pool_size=10, max_overflow=20 ) if self._async_engine is not None: SQLAlchemyInstrumentor().instrument(engine=self._async_engine.sync_engine) self._async_session_local = async_sessionmaker( bind=self._async_engine, # autoflush=False, expire_on_commit=False, class_=AsyncSession, ) ret = False while not ret: try: async with self._async_engine.begin() as session: if clear_static_data: await session.run_sync(Base.metadata.drop_all) await session.run_sync(Base.metadata.create_all) ret = True except OperationalError as err: logger.error(err) await sleep(1) return True async def disconnect(self) -> None: if self._async_engine is not None: await self._async_engine.dispose()