🎉 First commit !!!

This commit is contained in:
2023-01-22 16:53:45 +01:00
commit dde835760a
68 changed files with 3250 additions and 0 deletions

View File

@@ -0,0 +1,4 @@
from .db import Database
from .base_class import Base
db = Database()

View File

@@ -0,0 +1,34 @@
from collections.abc import Iterable
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import declarative_base
from typing import Iterable, Self
Base = declarative_base()
Base.db = None
async def base_add(cls, stops: Self | Iterable[Self]) -> bool:
try:
method = (
cls.db.session.add_all
if isinstance(stops, Iterable)
else cls.db.session.add
)
method(stops)
await cls.db.session.commit()
except IntegrityError as err:
print(err)
Base.add = classmethod(base_add)
async def base_get_by_id(cls, id_: int | str) -> None | Base:
res = await cls.db.session.execute(select(cls).where(cls.id == id_))
element = res.scalar_one_or_none()
return element
Base.get_by_id = classmethod(base_get_by_id)

View File

@@ -0,0 +1,80 @@
from asyncio import gather as asyncio_gather
from functools import wraps
from pathlib import Path
from time import time
from typing import Callable, Iterable, Optional
from rich import print
from sqlalchemy import event, select, tuple_
from sqlalchemy.engine import Engine
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import (
selectinload,
sessionmaker,
with_polymorphic,
)
from sqlalchemy.orm.attributes import set_committed_value
from .base_class import Base
# import logging
# logging.basicConfig()
# logger = logging.getLogger("bot.sqltime")
# logger.setLevel(logging.DEBUG)
# @event.listens_for(Engine, "before_cursor_execute")
# def before_cursor_execute(conn, cursor, statement, parameters, context, executemany):
# conn.info.setdefault("query_start_time", []).append(time())
# logger.debug("Start Query: %s", statement)
# @event.listens_for(Engine, "after_cursor_execute")
# def after_cursor_execute(conn, cursor, statement, parameters, context, executemany):
# total = time() - conn.info["query_start_time"].pop(-1)
# logger.debug("Query Complete!")
# logger.debug("Total Time: %f", total)
class Database:
def __init__(self) -> None:
self._engine = None
self._session_maker = None
self._session = None
@property
def session(self) -> None:
if self._session is None:
self._session = self._session_maker()
return self._session
def use_session(func: Callable):
@wraps(func)
async def wrapper(self, *args, **kwargs):
if self._check_session() is not None:
return await func(self, *args, **kwargs)
# TODO: Raise an exception ?
return wrapper
async def connect(self, db_path: str, clear_static_data: bool = False) -> None:
# TODO: Preserve UserLastStopSearchResults table from drop.
self._engine = create_async_engine(db_path)
self._session_maker = sessionmaker(
self._engine, expire_on_commit=False, class_=AsyncSession
)
await self.session.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm;")
async with self._engine.begin() as conn:
if clear_static_data:
await conn.run_sync(Base.metadata.drop_all)
await conn.run_sync(Base.metadata.create_all)
async def disconnect(self) -> None:
if self._session is not None:
await self._session.close()
self._session = None
await self._engine.dispose()