Files
carrramba-encore-rate/backend/api/db/base_class.py
2023-09-20 22:08:32 +02:00

59 lines
1.6 KiB
Python

from __future__ import annotations
from logging import getLogger
from typing import Self, Sequence, TYPE_CHECKING
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import DeclarativeBase
if TYPE_CHECKING:
from .db import Database
logger = getLogger(__name__)
class Base(DeclarativeBase):
db: Database | None = None
@classmethod
async def add(cls, objs: Sequence[Self]) -> bool:
if cls.db is not None and (session := await cls.db.get_session()) is not None:
try:
async with session.begin():
session.add_all(objs)
except IntegrityError as err:
logger.warning(err)
return await cls.merge(objs)
except AttributeError as err:
logger.error(err)
return False
return True
@classmethod
async def merge(cls, objs: Sequence[Self]) -> bool:
if cls.db is not None and (session := await cls.db.get_session()) is not None:
async with session.begin():
for obj in objs:
await session.merge(obj)
return True
return False
@classmethod
async def get_by_id(cls, id_: int | str) -> Self | None:
if cls.db is not None and (session := await cls.db.get_session()) is not None:
async with session.begin():
stmt = select(cls).where(cls.id == id_)
res = await session.execute(stmt)
return res.scalar_one_or_none()
return None