🎨 Reorganize back-end code
This commit is contained in:
6
backend/api/db/__init__.py
Normal file
6
backend/api/db/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from .db import Database
|
||||
from .base_class import Base
|
||||
|
||||
__all__ = ["Base"]
|
||||
|
||||
db = Database()
|
58
backend/api/db/base_class.py
Normal file
58
backend/api/db/base_class.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from logging import getLogger
|
||||
from typing import Self, Sequence, TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .db import Database
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
db: Database | None = None
|
||||
|
||||
@classmethod
|
||||
async def add(cls, objs: Sequence[Self]) -> bool:
|
||||
if cls.db is not None and (session := await cls.db.get_session()) is not None:
|
||||
|
||||
try:
|
||||
async with session.begin():
|
||||
session.add_all(objs)
|
||||
|
||||
except IntegrityError as err:
|
||||
logger.warning(err)
|
||||
return await cls.merge(objs)
|
||||
|
||||
except AttributeError as err:
|
||||
logger.error(err)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
async def merge(cls, objs: Sequence[Self]) -> bool:
|
||||
if cls.db is not None and (session := await cls.db.get_session()) is not None:
|
||||
|
||||
async with session.begin():
|
||||
for obj in objs:
|
||||
await session.merge(obj)
|
||||
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
async def get_by_id(cls, id_: int | str) -> Self | None:
|
||||
if cls.db is not None and (session := await cls.db.get_session()) is not None:
|
||||
|
||||
async with session.begin():
|
||||
stmt = select(cls).where(cls.id == id_)
|
||||
res = await session.execute(stmt)
|
||||
return res.scalar_one_or_none()
|
||||
|
||||
return None
|
76
backend/api/db/db.py
Normal file
76
backend/api/db/db.py
Normal file
@@ -0,0 +1,76 @@
|
||||
from asyncio import sleep
|
||||
from logging import getLogger
|
||||
from typing import Annotated, AsyncIterator
|
||||
|
||||
from fastapi import Depends
|
||||
from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.exc import OperationalError, SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
async_sessionmaker,
|
||||
AsyncEngine,
|
||||
AsyncSession,
|
||||
create_async_engine,
|
||||
)
|
||||
|
||||
from .base_class import Base
|
||||
from settings import DatabaseSettings
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
class Database:
|
||||
def __init__(self) -> None:
|
||||
self._async_engine: AsyncEngine | None = None
|
||||
self._async_session_local: async_sessionmaker[AsyncSession] | None = None
|
||||
|
||||
async def get_session(self) -> AsyncSession | None:
|
||||
try:
|
||||
return self._async_session_local() # type: ignore
|
||||
|
||||
except (SQLAlchemyError, AttributeError) as e:
|
||||
logger.exception(e)
|
||||
|
||||
return None
|
||||
|
||||
# TODO: Preserve UserLastStopSearchResults table from drop.
|
||||
async def connect(
|
||||
self, settings: DatabaseSettings, clear_static_data: bool = False
|
||||
) -> bool:
|
||||
password = settings.password
|
||||
path = (
|
||||
f"{settings.driver}://{settings.user}:"
|
||||
f"{password.get_secret_value() if password is not None else ''}"
|
||||
f"@{settings.host}:{settings.port}/{settings.name}"
|
||||
)
|
||||
self._async_engine = create_async_engine(
|
||||
path, pool_pre_ping=True, pool_size=10, max_overflow=20
|
||||
)
|
||||
|
||||
if self._async_engine is not None:
|
||||
SQLAlchemyInstrumentor().instrument(engine=self._async_engine.sync_engine)
|
||||
|
||||
self._async_session_local = async_sessionmaker(
|
||||
bind=self._async_engine,
|
||||
# autoflush=False,
|
||||
expire_on_commit=False,
|
||||
class_=AsyncSession,
|
||||
)
|
||||
|
||||
ret = False
|
||||
while not ret:
|
||||
try:
|
||||
async with self._async_engine.begin() as session:
|
||||
if clear_static_data:
|
||||
await session.run_sync(Base.metadata.drop_all)
|
||||
await session.run_sync(Base.metadata.create_all)
|
||||
ret = True
|
||||
except OperationalError as err:
|
||||
logger.error(err)
|
||||
await sleep(1)
|
||||
|
||||
return True
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
if self._async_engine is not None:
|
||||
await self._async_engine.dispose()
|
Reference in New Issue
Block a user