🎉 First commit !!!
This commit is contained in:
80
backend/idfm_matrix_backend/db/db.py
Normal file
80
backend/idfm_matrix_backend/db/db.py
Normal file
@@ -0,0 +1,80 @@
|
||||
from asyncio import gather as asyncio_gather
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
from time import time
|
||||
from typing import Callable, Iterable, Optional
|
||||
|
||||
from rich import print
|
||||
from sqlalchemy import event, select, tuple_
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import (
|
||||
selectinload,
|
||||
sessionmaker,
|
||||
with_polymorphic,
|
||||
)
|
||||
from sqlalchemy.orm.attributes import set_committed_value
|
||||
|
||||
from .base_class import Base
|
||||
|
||||
|
||||
# import logging
|
||||
|
||||
# logging.basicConfig()
|
||||
# logger = logging.getLogger("bot.sqltime")
|
||||
# logger.setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
# @event.listens_for(Engine, "before_cursor_execute")
|
||||
# def before_cursor_execute(conn, cursor, statement, parameters, context, executemany):
|
||||
# conn.info.setdefault("query_start_time", []).append(time())
|
||||
# logger.debug("Start Query: %s", statement)
|
||||
|
||||
|
||||
# @event.listens_for(Engine, "after_cursor_execute")
|
||||
# def after_cursor_execute(conn, cursor, statement, parameters, context, executemany):
|
||||
# total = time() - conn.info["query_start_time"].pop(-1)
|
||||
# logger.debug("Query Complete!")
|
||||
# logger.debug("Total Time: %f", total)
|
||||
|
||||
|
||||
class Database:
|
||||
def __init__(self) -> None:
|
||||
self._engine = None
|
||||
self._session_maker = None
|
||||
self._session = None
|
||||
|
||||
@property
|
||||
def session(self) -> None:
|
||||
if self._session is None:
|
||||
self._session = self._session_maker()
|
||||
return self._session
|
||||
|
||||
def use_session(func: Callable):
|
||||
@wraps(func)
|
||||
async def wrapper(self, *args, **kwargs):
|
||||
if self._check_session() is not None:
|
||||
return await func(self, *args, **kwargs)
|
||||
# TODO: Raise an exception ?
|
||||
|
||||
return wrapper
|
||||
|
||||
async def connect(self, db_path: str, clear_static_data: bool = False) -> None:
|
||||
# TODO: Preserve UserLastStopSearchResults table from drop.
|
||||
self._engine = create_async_engine(db_path)
|
||||
self._session_maker = sessionmaker(
|
||||
self._engine, expire_on_commit=False, class_=AsyncSession
|
||||
)
|
||||
await self.session.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm;")
|
||||
|
||||
async with self._engine.begin() as conn:
|
||||
if clear_static_data:
|
||||
await conn.run_sync(Base.metadata.drop_all)
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
if self._session is not None:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
await self._engine.dispose()
|
Reference in New Issue
Block a user