81 lines
2.5 KiB
Python
81 lines
2.5 KiB
Python
from asyncio import gather as asyncio_gather
|
|
from functools import wraps
|
|
from pathlib import Path
|
|
from time import time
|
|
from typing import Callable, Iterable, Optional
|
|
|
|
from rich import print
|
|
from sqlalchemy import event, select, tuple_
|
|
from sqlalchemy.engine import Engine
|
|
from sqlalchemy.exc import IntegrityError
|
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
|
from sqlalchemy.orm import (
|
|
selectinload,
|
|
sessionmaker,
|
|
with_polymorphic,
|
|
)
|
|
from sqlalchemy.orm.attributes import set_committed_value
|
|
|
|
from .base_class import Base
|
|
|
|
|
|
# import logging
|
|
|
|
# logging.basicConfig()
|
|
# logger = logging.getLogger("bot.sqltime")
|
|
# logger.setLevel(logging.DEBUG)
|
|
|
|
|
|
# @event.listens_for(Engine, "before_cursor_execute")
|
|
# def before_cursor_execute(conn, cursor, statement, parameters, context, executemany):
|
|
# conn.info.setdefault("query_start_time", []).append(time())
|
|
# logger.debug("Start Query: %s", statement)
|
|
|
|
|
|
# @event.listens_for(Engine, "after_cursor_execute")
|
|
# def after_cursor_execute(conn, cursor, statement, parameters, context, executemany):
|
|
# total = time() - conn.info["query_start_time"].pop(-1)
|
|
# logger.debug("Query Complete!")
|
|
# logger.debug("Total Time: %f", total)
|
|
|
|
|
|
class Database:
|
|
def __init__(self) -> None:
|
|
self._engine = None
|
|
self._session_maker = None
|
|
self._session = None
|
|
|
|
@property
|
|
def session(self) -> None:
|
|
if self._session is None:
|
|
self._session = self._session_maker()
|
|
return self._session
|
|
|
|
def use_session(func: Callable):
|
|
@wraps(func)
|
|
async def wrapper(self, *args, **kwargs):
|
|
if self._check_session() is not None:
|
|
return await func(self, *args, **kwargs)
|
|
# TODO: Raise an exception ?
|
|
|
|
return wrapper
|
|
|
|
async def connect(self, db_path: str, clear_static_data: bool = False) -> None:
|
|
# TODO: Preserve UserLastStopSearchResults table from drop.
|
|
self._engine = create_async_engine(db_path)
|
|
self._session_maker = sessionmaker(
|
|
self._engine, expire_on_commit=False, class_=AsyncSession
|
|
)
|
|
await self.session.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm;")
|
|
|
|
async with self._engine.begin() as conn:
|
|
if clear_static_data:
|
|
await conn.run_sync(Base.metadata.drop_all)
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
|
|
async def disconnect(self) -> None:
|
|
if self._session is not None:
|
|
await self._session.close()
|
|
self._session = None
|
|
await self._engine.dispose()
|