diff --git a/backend/backend/db/base_class.py b/backend/backend/db/base_class.py index 71b9703..69b1d62 100644 --- a/backend/backend/db/base_class.py +++ b/backend/backend/db/base_class.py @@ -1,7 +1,7 @@ from __future__ import annotations from logging import getLogger -from typing import Iterable, Self, TYPE_CHECKING +from typing import Self, Sequence, TYPE_CHECKING from sqlalchemy import select from sqlalchemy.exc import IntegrityError @@ -17,21 +17,34 @@ class Base(DeclarativeBase): db: Database | None = None @classmethod - async def add(cls, objs: Self | Iterable[Self]) -> bool: + async def add(cls, objs: Sequence[Self]) -> bool: + if cls.db is not None and (session := await cls.db.get_session()) is not None: + + try: + async with session.begin(): + session.add_all(objs) + + except IntegrityError as err: + logger.warning(err) + return await cls.merge(objs) + + except AttributeError as err: + logger.error(err) + return False + + return True + + @classmethod + async def merge(cls, objs: Sequence[Self]) -> bool: if cls.db is not None and (session := await cls.db.get_session()) is not None: async with session.begin(): - try: - if isinstance(objs, Iterable): - session.add_all(objs) - else: - session.add(objs) + for obj in objs: + await session.merge(obj) - except (AttributeError, IntegrityError) as err: - logger.error(err) - return False + return True - return True + return False @classmethod async def get_by_id(cls, id_: int | str) -> Self | None: