From 2eaf0f4ed585497469867389f388ab93c54c7dfb Mon Sep 17 00:00:00 2001 From: Adrien Date: Sun, 11 Jun 2023 22:28:15 +0200 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20Use=20of=20db=20merge=20when=20adds?= =?UTF-8?q?=20fails=20due=20to=20single=20key=20violations?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/backend/db/base_class.py | 35 ++++++++++++++++++++++---------- 1 file changed, 24 insertions(+), 11 deletions(-) diff --git a/backend/backend/db/base_class.py b/backend/backend/db/base_class.py index 71b9703..69b1d62 100644 --- a/backend/backend/db/base_class.py +++ b/backend/backend/db/base_class.py @@ -1,7 +1,7 @@ from __future__ import annotations from logging import getLogger -from typing import Iterable, Self, TYPE_CHECKING +from typing import Self, Sequence, TYPE_CHECKING from sqlalchemy import select from sqlalchemy.exc import IntegrityError @@ -17,21 +17,34 @@ class Base(DeclarativeBase): db: Database | None = None @classmethod - async def add(cls, objs: Self | Iterable[Self]) -> bool: + async def add(cls, objs: Sequence[Self]) -> bool: + if cls.db is not None and (session := await cls.db.get_session()) is not None: + + try: + async with session.begin(): + session.add_all(objs) + + except IntegrityError as err: + logger.warning(err) + return await cls.merge(objs) + + except AttributeError as err: + logger.error(err) + return False + + return True + + @classmethod + async def merge(cls, objs: Sequence[Self]) -> bool: if cls.db is not None and (session := await cls.db.get_session()) is not None: async with session.begin(): - try: - if isinstance(objs, Iterable): - session.add_all(objs) - else: - session.add(objs) + for obj in objs: + await session.merge(obj) - except (AttributeError, IntegrityError) as err: - logger.error(err) - return False + return True - return True + return False @classmethod async def get_by_id(cls, id_: int | str) -> Self | None: