Use of db merge when adds fails due to single key violations

This commit is contained in:
2023-06-11 22:28:15 +02:00
parent c42b687870
commit 2eaf0f4ed5

View File

@@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from logging import getLogger from logging import getLogger
from typing import Iterable, Self, TYPE_CHECKING from typing import Self, Sequence, TYPE_CHECKING
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
@@ -17,21 +17,34 @@ class Base(DeclarativeBase):
db: Database | None = None db: Database | None = None
@classmethod @classmethod
async def add(cls, objs: Self | Iterable[Self]) -> bool: async def add(cls, objs: Sequence[Self]) -> bool:
if cls.db is not None and (session := await cls.db.get_session()) is not None:
try:
async with session.begin():
session.add_all(objs)
except IntegrityError as err:
logger.warning(err)
return await cls.merge(objs)
except AttributeError as err:
logger.error(err)
return False
return True
@classmethod
async def merge(cls, objs: Sequence[Self]) -> bool:
if cls.db is not None and (session := await cls.db.get_session()) is not None: if cls.db is not None and (session := await cls.db.get_session()) is not None:
async with session.begin(): async with session.begin():
try: for obj in objs:
if isinstance(objs, Iterable): await session.merge(obj)
session.add_all(objs)
else:
session.add(objs)
except (AttributeError, IntegrityError) as err: return True
logger.error(err)
return False
return True return False
@classmethod @classmethod
async def get_by_id(cls, id_: int | str) -> Self | None: async def get_by_id(cls, id_: int | str) -> Self | None: