from asyncio import gather as asyncio_gather from functools import wraps from pathlib import Path from time import time from typing import Callable, Iterable, Optional from rich import print from sqlalchemy import event, select, tuple_ from sqlalchemy.engine import Engine from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.orm import ( selectinload, sessionmaker, with_polymorphic, ) from sqlalchemy.orm.attributes import set_committed_value from .base_class import Base # import logging # logging.basicConfig() # logger = logging.getLogger("bot.sqltime") # logger.setLevel(logging.DEBUG) # @event.listens_for(Engine, "before_cursor_execute") # def before_cursor_execute(conn, cursor, statement, parameters, context, executemany): # conn.info.setdefault("query_start_time", []).append(time()) # logger.debug("Start Query: %s", statement) # @event.listens_for(Engine, "after_cursor_execute") # def after_cursor_execute(conn, cursor, statement, parameters, context, executemany): # total = time() - conn.info["query_start_time"].pop(-1) # logger.debug("Query Complete!") # logger.debug("Total Time: %f", total) class Database: def __init__(self) -> None: self._engine = None self._session_maker = None self._session = None @property def session(self) -> None: if self._session is None: self._session = self._session_maker() return self._session def use_session(func: Callable): @wraps(func) async def wrapper(self, *args, **kwargs): if self._check_session() is not None: return await func(self, *args, **kwargs) # TODO: Raise an exception ? return wrapper async def connect(self, db_path: str, clear_static_data: bool = False) -> None: # TODO: Preserve UserLastStopSearchResults table from drop. self._engine = create_async_engine(db_path) self._session_maker = sessionmaker( self._engine, expire_on_commit=False, class_=AsyncSession ) await self.session.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm;") async with self._engine.begin() as conn: if clear_static_data: await conn.run_sync(Base.metadata.drop_all) await conn.run_sync(Base.metadata.create_all) async def disconnect(self) -> None: if self._session is not None: await self._session.close() self._session = None await self._engine.dispose()