59 lines
1.6 KiB
Python
59 lines
1.6 KiB
Python
from __future__ import annotations
|
|
|
|
from logging import getLogger
|
|
from typing import Self, Sequence, TYPE_CHECKING
|
|
|
|
from sqlalchemy import select
|
|
from sqlalchemy.exc import IntegrityError
|
|
from sqlalchemy.orm import DeclarativeBase
|
|
|
|
if TYPE_CHECKING:
|
|
from .db import Database
|
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
|
class Base(DeclarativeBase):
|
|
db: Database | None = None
|
|
|
|
@classmethod
|
|
async def add(cls, objs: Sequence[Self]) -> bool:
|
|
if cls.db is not None and (session := await cls.db.get_session()) is not None:
|
|
|
|
try:
|
|
async with session.begin():
|
|
session.add_all(objs)
|
|
|
|
except IntegrityError as err:
|
|
logger.warning(err)
|
|
return await cls.merge(objs)
|
|
|
|
except AttributeError as err:
|
|
logger.error(err)
|
|
return False
|
|
|
|
return True
|
|
|
|
@classmethod
|
|
async def merge(cls, objs: Sequence[Self]) -> bool:
|
|
if cls.db is not None and (session := await cls.db.get_session()) is not None:
|
|
|
|
async with session.begin():
|
|
for obj in objs:
|
|
await session.merge(obj)
|
|
|
|
return True
|
|
|
|
return False
|
|
|
|
@classmethod
|
|
async def get_by_id(cls, id_: int | str) -> Self | None:
|
|
if cls.db is not None and (session := await cls.db.get_session()) is not None:
|
|
|
|
async with session.begin():
|
|
stmt = select(cls).where(cls.id == id_)
|
|
res = await session.execute(stmt)
|
|
return res.scalar_one_or_none()
|
|
|
|
return None
|