Files
2023-01-22 16:53:45 +01:00

81 lines
2.5 KiB
Python

from asyncio import gather as asyncio_gather
from functools import wraps
from pathlib import Path
from time import time
from typing import Callable, Iterable, Optional
from rich import print
from sqlalchemy import event, select, tuple_
from sqlalchemy.engine import Engine
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import (
selectinload,
sessionmaker,
with_polymorphic,
)
from sqlalchemy.orm.attributes import set_committed_value
from .base_class import Base
# import logging
# logging.basicConfig()
# logger = logging.getLogger("bot.sqltime")
# logger.setLevel(logging.DEBUG)
# @event.listens_for(Engine, "before_cursor_execute")
# def before_cursor_execute(conn, cursor, statement, parameters, context, executemany):
# conn.info.setdefault("query_start_time", []).append(time())
# logger.debug("Start Query: %s", statement)
# @event.listens_for(Engine, "after_cursor_execute")
# def after_cursor_execute(conn, cursor, statement, parameters, context, executemany):
# total = time() - conn.info["query_start_time"].pop(-1)
# logger.debug("Query Complete!")
# logger.debug("Total Time: %f", total)
class Database:
def __init__(self) -> None:
self._engine = None
self._session_maker = None
self._session = None
@property
def session(self) -> None:
if self._session is None:
self._session = self._session_maker()
return self._session
def use_session(func: Callable):
@wraps(func)
async def wrapper(self, *args, **kwargs):
if self._check_session() is not None:
return await func(self, *args, **kwargs)
# TODO: Raise an exception ?
return wrapper
async def connect(self, db_path: str, clear_static_data: bool = False) -> None:
# TODO: Preserve UserLastStopSearchResults table from drop.
self._engine = create_async_engine(db_path)
self._session_maker = sessionmaker(
self._engine, expire_on_commit=False, class_=AsyncSession
)
await self.session.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm;")
async with self._engine.begin() as conn:
if clear_static_data:
await conn.run_sync(Base.metadata.drop_all)
await conn.run_sync(Base.metadata.create_all)
async def disconnect(self) -> None:
if self._session is not None:
await self._session.close()
self._session = None
await self._engine.dispose()