40 lines
1.1 KiB
Python
40 lines
1.1 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import Iterable, Self, TYPE_CHECKING
|
|
|
|
from sqlalchemy import select
|
|
from sqlalchemy.exc import IntegrityError
|
|
from sqlalchemy.orm import DeclarativeBase
|
|
|
|
if TYPE_CHECKING:
|
|
from .db import Database
|
|
|
|
|
|
class Base(DeclarativeBase):
|
|
db: Database | None = None
|
|
|
|
@classmethod
|
|
async def add(cls, stops: Self | Iterable[Self]) -> bool:
|
|
try:
|
|
if isinstance(stops, Iterable):
|
|
cls.db.session.add_all(stops) # type: ignore
|
|
else:
|
|
cls.db.session.add(stops) # type: ignore
|
|
await cls.db.session.commit() # type: ignore
|
|
except (AttributeError, IntegrityError) as err:
|
|
print(err)
|
|
return False
|
|
|
|
return True
|
|
|
|
@classmethod
|
|
async def get_by_id(cls, id_: int | str) -> Self | None:
|
|
try:
|
|
stmt = select(cls).where(cls.id == id_) # type: ignore
|
|
res = await cls.db.session.execute(stmt) # type: ignore
|
|
element = res.scalar_one_or_none()
|
|
except AttributeError as err:
|
|
print(err)
|
|
element = None
|
|
return element
|