from __future__ import annotations from logging import getLogger from typing import Self, Sequence, TYPE_CHECKING from sqlalchemy import select from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import DeclarativeBase if TYPE_CHECKING: from .db import Database logger = getLogger(__name__) class Base(DeclarativeBase): db: Database | None = None @classmethod async def add(cls, objs: Sequence[Self]) -> bool: if cls.db is not None and (session := await cls.db.get_session()) is not None: try: async with session.begin(): session.add_all(objs) except IntegrityError as err: logger.warning(err) return await cls.merge(objs) except AttributeError as err: logger.error(err) return False return True @classmethod async def merge(cls, objs: Sequence[Self]) -> bool: if cls.db is not None and (session := await cls.db.get_session()) is not None: async with session.begin(): for obj in objs: await session.merge(obj) return True return False @classmethod async def get_by_id(cls, id_: int | str) -> Self | None: if cls.db is not None and (session := await cls.db.get_session()) is not None: async with session.begin(): stmt = select(cls).where(cls.id == id_) res = await session.execute(stmt) return res.scalar_one_or_none() return None