77 lines
2.5 KiB
Python
77 lines
2.5 KiB
Python
from asyncio import sleep
|
|
from logging import getLogger
|
|
from typing import Annotated, AsyncIterator
|
|
|
|
from fastapi import Depends
|
|
from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor
|
|
from sqlalchemy import text
|
|
from sqlalchemy.exc import OperationalError, SQLAlchemyError
|
|
from sqlalchemy.ext.asyncio import (
|
|
async_sessionmaker,
|
|
AsyncEngine,
|
|
AsyncSession,
|
|
create_async_engine,
|
|
)
|
|
|
|
from .base_class import Base
|
|
from settings import DatabaseSettings
|
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
|
class Database:
|
|
def __init__(self) -> None:
|
|
self._async_engine: AsyncEngine | None = None
|
|
self._async_session_local: async_sessionmaker[AsyncSession] | None = None
|
|
|
|
async def get_session(self) -> AsyncSession | None:
|
|
try:
|
|
return self._async_session_local() # type: ignore
|
|
|
|
except (SQLAlchemyError, AttributeError) as e:
|
|
logger.exception(e)
|
|
|
|
return None
|
|
|
|
# TODO: Preserve UserLastStopSearchResults table from drop.
|
|
async def connect(
|
|
self, settings: DatabaseSettings, clear_static_data: bool = False
|
|
) -> bool:
|
|
password = settings.password
|
|
path = (
|
|
f"{settings.driver}://{settings.user}:"
|
|
f"{password.get_secret_value() if password is not None else ''}"
|
|
f"@{settings.host}:{settings.port}/{settings.name}"
|
|
)
|
|
self._async_engine = create_async_engine(
|
|
path, pool_pre_ping=True, pool_size=10, max_overflow=20
|
|
)
|
|
|
|
if self._async_engine is not None:
|
|
SQLAlchemyInstrumentor().instrument(engine=self._async_engine.sync_engine)
|
|
|
|
self._async_session_local = async_sessionmaker(
|
|
bind=self._async_engine,
|
|
# autoflush=False,
|
|
expire_on_commit=False,
|
|
class_=AsyncSession,
|
|
)
|
|
|
|
ret = False
|
|
while not ret:
|
|
try:
|
|
async with self._async_engine.begin() as session:
|
|
if clear_static_data:
|
|
await session.run_sync(Base.metadata.drop_all)
|
|
await session.run_sync(Base.metadata.create_all)
|
|
ret = True
|
|
except OperationalError as err:
|
|
logger.error(err)
|
|
await sleep(1)
|
|
|
|
return True
|
|
|
|
async def disconnect(self) -> None:
|
|
if self._async_engine is not None:
|
|
await self._async_engine.dispose()
|