Files
carrramba-encore-rate/backend/db_updater.py

576 lines
18 KiB
Python
Executable File

#!/usr/bin/env python3
from asyncio import run, gather
from logging import getLogger, INFO, Handler as LoggingHandler, NOTSET
from itertools import islice
from time import time
from os import environ
from typing import Callable, Iterable, List, Type
from aiofiles.tempfile import NamedTemporaryFile
from aiohttp import ClientSession
from msgspec import ValidationError
from msgspec.json import Decoder
from pyproj import Transformer
from shapefile import Reader as ShapeFileReader, ShapeRecord # type: ignore
from tqdm import tqdm
from yaml import safe_load
from api.db import Base, db, Database
from api.models import ConnectionArea, Line, LinePicto, Stop, StopArea, StopShape
from api.idfm_interface.idfm_types import (
ConnectionArea as IdfmConnectionArea,
IdfmLineState,
Line as IdfmLine,
LinePicto as IdfmPicto,
IdfmState,
Stop as IdfmStop,
StopArea as IdfmStopArea,
StopAreaStopAssociation,
StopAreaType,
StopLineAsso as IdfmStopLineAsso,
TransportMode,
)
from api.idfm_interface.ratp_types import Picto as RatpPicto
from api.settings import Settings
CONFIG_PATH = environ.get("CONFIG_PATH", "./config.sample.yaml")
BATCH_SIZE = 1000
IDFM_ROOT_URL = "https://data.iledefrance-mobilites.fr/explore/dataset"
IDFM_CONNECTION_AREAS_URL = (
f"{IDFM_ROOT_URL}/zones-de-correspondance/download/?format=json"
)
IDFM_LINES_URL = f"{IDFM_ROOT_URL}/referentiel-des-lignes/download/?format=json"
IDFM_PICTO_URL = f"{IDFM_ROOT_URL}/referentiel-des-lignes/files"
IDFM_STOP_AREAS_URL = f"{IDFM_ROOT_URL}/zones-d-arrets/download/?format=json"
IDFM_STOP_SHAPES_URL = "https://eu.ftp.opendatasoft.com/stif/Reflex/REF_ArR.zip"
IDFM_STOP_AREA_SHAPES_URL = "https://eu.ftp.opendatasoft.com/stif/Reflex/REF_ZdA.zip"
IDFM_STOP_STOP_AREAS_ASSOS_URL = f"{IDFM_ROOT_URL}/relations/download/?format=json"
IDFM_STOPS_LINES_ASSOS_URL = f"{IDFM_ROOT_URL}/arrets-lignes/download/?format=json"
IDFM_STOPS_URL = f"{IDFM_ROOT_URL}/arrets/download/?format=json"
RATP_ROOT_URL = "https://data.ratp.fr/api/explore/v2.1/catalog/datasets"
RATP_PICTOS_URL = (
f"{RATP_ROOT_URL}"
"/pictogrammes-des-lignes-de-metro-rer-tramway-bus-et-noctilien/exports/json?lang=fr"
)
# From https://stackoverflow.com/a/38739634
class TqdmLoggingHandler(LoggingHandler):
def __init__(self, level=NOTSET):
super().__init__(level)
def emit(self, record):
try:
msg = self.format(record)
tqdm.write(msg)
self.flush()
except Exception:
self.handleError(record)
logger = getLogger(__name__)
logger.setLevel(INFO)
logger.addHandler(TqdmLoggingHandler())
epsg2154_epsg3857_transformer = Transformer.from_crs(2154, 3857)
json_stops_decoder = Decoder(type=List[IdfmStop])
json_stop_areas_decoder = Decoder(type=List[IdfmStopArea])
json_connection_areas_decoder = Decoder(type=List[IdfmConnectionArea])
json_lines_decoder = Decoder(type=List[IdfmLine])
json_stops_lines_assos_decoder = Decoder(type=List[IdfmStopLineAsso])
json_ratp_pictos_decoder = Decoder(type=List[RatpPicto])
json_stop_area_stop_asso_decoder = Decoder(type=List[StopAreaStopAssociation])
def format_idfm_pictos(*pictos: IdfmPicto) -> Iterable[LinePicto]:
ret = []
for picto in pictos:
ret.append(
LinePicto(
id=picto.id_,
mime_type=picto.mimetype,
height_px=picto.height,
width_px=picto.width,
filename=picto.filename,
url=f"{IDFM_PICTO_URL}/{picto.id_}/download",
thumbnail=picto.thumbnail,
format=picto.format,
)
)
return ret
def format_ratp_pictos(*pictos: RatpPicto) -> Iterable[tuple[str, LinePicto]]:
ret = []
for picto in pictos:
if (fields := picto.noms_des_fichiers) is not None:
ret.append(
(
picto.indices_commerciaux,
LinePicto(
id=fields.id_,
mime_type=f"image/{fields.format.lower()}",
height_px=fields.height,
width_px=fields.width,
filename=fields.filename,
url=f"{RATP_PICTOS_URL}/{fields.id_}/download",
thumbnail=fields.thumbnail,
format=fields.format,
),
)
)
return ret
def format_idfm_lines(*lines: IdfmLine) -> Iterable[Line]:
ret = []
optional_value = IdfmLine.optional_value
for line in lines:
fields = line.fields
line_id = fields.id_line
try:
formatted_line_id = int(line_id[1:] if line_id[0] == "C" else line_id)
except ValueError:
logger.warning("Unable to format %s line id.", line_id)
continue
try:
operator_id = int(fields.operatorref) # type: ignore
except (ValueError, TypeError):
logger.warning("Unable to format %s operator id.", fields.operatorref)
operator_id = 0
ret.append(
Line(
id=formatted_line_id,
short_name=fields.shortname_line,
name=fields.name_line,
status=IdfmLineState(fields.status.value),
transport_mode=TransportMode(fields.transportmode.value),
transport_submode=optional_value(fields.transportsubmode),
network_name=optional_value(fields.networkname),
group_of_lines_id=optional_value(fields.id_groupoflines),
group_of_lines_shortname=optional_value(fields.shortname_groupoflines),
colour_web_hexa=fields.colourweb_hexa,
text_colour_hexa=fields.textcolourprint_hexa,
operator_id=operator_id,
operator_name=optional_value(fields.operatorname),
accessibility=IdfmState(fields.accessibility.value),
visual_signs_available=IdfmState(fields.visualsigns_available.value),
audible_signs_available=IdfmState(fields.audiblesigns_available.value),
picto_id=fields.picto.id_ if fields.picto is not None else None,
record_id=line.recordid,
record_ts=int(line.record_timestamp.timestamp()),
)
)
return ret
def format_idfm_stops(*stops: IdfmStop) -> Iterable[Stop]:
for stop in stops:
fields = stop.fields
try:
created_ts = int(fields.arrcreated.timestamp()) # type: ignore
except AttributeError:
created_ts = None
epsg3857_point = epsg2154_epsg3857_transformer.transform(
fields.arrxepsg2154, fields.arryepsg2154
)
try:
postal_region = int(fields.arrpostalregion)
except ValueError:
logger.warning("Unable to format %s postal region.", fields.arrpostalregion)
continue
yield Stop(
id=int(fields.arrid),
name=fields.arrname,
epsg3857_x=epsg3857_point[0],
epsg3857_y=epsg3857_point[1],
town_name=fields.arrtown,
postal_region=postal_region,
transport_mode=TransportMode(fields.arrtype.value),
version=fields.arrversion,
created_ts=created_ts,
changed_ts=int(fields.arrchanged.timestamp()),
accessibility=IdfmState(fields.arraccessibility.value),
visual_signs_available=IdfmState(fields.arrvisualsigns.value),
audible_signs_available=IdfmState(fields.arraudiblesignals.value),
record_id=stop.recordid,
record_ts=int(stop.record_timestamp.timestamp()),
)
def format_idfm_stop_areas(*stop_areas: IdfmStopArea) -> Iterable[StopArea]:
for stop_area in stop_areas:
fields = stop_area.fields
try:
created_ts = int(fields.zdacreated.timestamp()) # type: ignore
except AttributeError:
created_ts = None
epsg3857_point = epsg2154_epsg3857_transformer.transform(
fields.zdaxepsg2154, fields.zdayepsg2154
)
yield StopArea(
id=int(fields.zdaid),
name=fields.zdaname,
town_name=fields.zdatown,
postal_region=fields.zdapostalregion,
epsg3857_x=epsg3857_point[0],
epsg3857_y=epsg3857_point[1],
type=StopAreaType(fields.zdatype.value),
version=fields.zdaversion,
created_ts=created_ts,
changed_ts=int(fields.zdachanged.timestamp()),
)
def format_idfm_connection_areas(
*connection_areas: IdfmConnectionArea,
) -> Iterable[ConnectionArea]:
for connection_area in connection_areas:
fields = connection_area.fields
epsg3857_point = epsg2154_epsg3857_transformer.transform(
fields.zdcxepsg2154, fields.zdcyepsg2154
)
yield ConnectionArea(
id=int(fields.zdcid),
name=fields.zdcname,
town_name=fields.zdctown,
postal_region=fields.zdcpostalregion,
epsg3857_x=epsg3857_point[0],
epsg3857_y=epsg3857_point[1],
transport_mode=StopAreaType(fields.zdctype.value),
version=fields.zdcversion,
created_ts=int(fields.zdccreated.timestamp()),
changed_ts=int(fields.zdcchanged.timestamp()),
)
def format_idfm_stop_shapes(*shape_records: ShapeRecord) -> Iterable[StopShape]:
for shape_record in shape_records:
epsg3857_points = [
epsg2154_epsg3857_transformer.transform(*point)
for point in shape_record.shape.points
]
try:
bbox_it = iter(shape_record.shape.bbox)
epsg3857_bbox = [
epsg2154_epsg3857_transformer.transform(*point)
for point in zip(bbox_it, bbox_it)
]
except AttributeError:
# Handle stop shapes for which no bbox is provided
epsg3857_bbox = []
yield StopShape(
id=shape_record.record[1],
type=shape_record.shape.shapeType,
epsg3857_bbox=epsg3857_bbox,
epsg3857_points=epsg3857_points,
)
async def http_get(url: str) -> str | None:
chunks = []
headers = {"Accept": "application/json"}
async with ClientSession(headers=headers) as session:
async with session.get(url) as response:
size = int(response.headers.get("content-length", 0)) or None
progress_bar = tqdm(desc=f"Downloading {url}", total=size)
if response.status == 200:
async for chunk in response.content.iter_chunked(1024 * 1024):
chunks.append(chunk.decode())
progress_bar.update(len(chunk))
else:
return None
return "".join(chunks)
async def http_request(
url: str, decode: Callable, format_method: Callable, model: Type[Base]
) -> bool:
elements = []
data = await http_get(url)
if data is None:
return False
try:
for element in decode(data):
elements.append(element)
if len(elements) == BATCH_SIZE:
await model.add(format_method(*elements))
elements.clear()
if elements:
await model.add(format_method(*elements))
except ValidationError as err:
logger.warning(err)
return False
return True
async def load_idfm_stops() -> bool:
return await http_request(
IDFM_STOPS_URL, json_stops_decoder.decode, format_idfm_stops, Stop
)
async def load_idfm_stop_areas() -> bool:
return await http_request(
IDFM_STOP_AREAS_URL,
json_stop_areas_decoder.decode,
format_idfm_stop_areas,
StopArea,
)
async def load_idfm_connection_areas() -> bool:
return await http_request(
IDFM_CONNECTION_AREAS_URL,
json_connection_areas_decoder.decode,
format_idfm_connection_areas,
ConnectionArea,
)
async def load_idfm_stop_shapes(url: str) -> None:
async with ClientSession(headers={"Accept": "application/zip"}) as session:
async with session.get(url) as response:
size = int(response.headers.get("content-length", 0)) or None
dl_progress_bar = tqdm(desc=f"Downloading {url}", total=size)
if response.status == 200:
async with NamedTemporaryFile(suffix=".zip") as tmp_file:
async for chunk in response.content.iter_chunked(1024 * 1024):
await tmp_file.write(chunk)
dl_progress_bar.update(len(chunk))
with ShapeFileReader(tmp_file.name) as reader:
step_begin_ts = time()
shapes = reader.shapeRecords()
shapes_len = len(shapes)
db_progress_bar = tqdm(
desc=f"Filling db with {shapes_len} StopShapes",
total=shapes_len,
)
begin, end, finished = 0, BATCH_SIZE, False
while not finished:
elements = islice(shapes, begin, end)
formatteds = list(format_idfm_stop_shapes(*elements))
await StopShape.add(formatteds)
begin = end
end = begin + BATCH_SIZE
finished = begin > len(shapes)
db_progress_bar.update(BATCH_SIZE)
logger.info(
f"Add {StopShape.__name__}s: {time() - step_begin_ts}s"
)
async def load_idfm_lines() -> None:
data = await http_get(IDFM_LINES_URL)
if data is None:
return None
lines, pictos = [], []
picto_ids = set()
for line in json_lines_decoder.decode(data):
if (picto := line.fields.picto) is not None and picto.id_ not in picto_ids:
picto_ids.add(picto.id_)
pictos.append(picto)
lines.append(line)
if len(lines) == BATCH_SIZE:
await LinePicto.add(list(format_idfm_pictos(*pictos)))
await Line.add(list(format_idfm_lines(*lines)))
lines.clear()
pictos.clear()
if pictos:
await LinePicto.add(list(format_idfm_pictos(*pictos)))
if lines:
await Line.add(list(format_idfm_lines(*lines)))
async def load_ratp_pictos(batch_size: int = 5) -> None:
data = await http_get(RATP_PICTOS_URL)
if data is None:
return None
pictos = []
for picto in json_ratp_pictos_decoder.decode(data):
pictos.append(picto)
if len(pictos) == batch_size:
formatteds = format_ratp_pictos(*pictos)
await LinePicto.add([picto[1] for picto in formatteds])
await Line.add_pictos(formatteds)
pictos.clear()
if pictos:
formatteds = format_ratp_pictos(*pictos)
await LinePicto.add([picto[1] for picto in formatteds])
await Line.add_pictos(formatteds)
async def load_lines_stops_assos(batch_size: int = 5000) -> None:
data = await http_get(IDFM_STOPS_LINES_ASSOS_URL)
if data is None:
return None
total_assos_nb = total_found_nb = 0
assos = []
for asso in json_stops_lines_assos_decoder.decode(data):
fields = asso.fields
try:
stop_id = int(fields.stop_id.rsplit(":", 1)[-1])
except ValueError as err:
logger.error(err)
logger.error(f"{fields.stop_id = }")
continue
assos.append((fields.route_long_name, fields.operatorname, stop_id))
if len(assos) == batch_size:
total_assos_nb += batch_size
total_found_nb += await Line.add_stops(assos)
assos.clear()
if assos:
total_assos_nb += len(assos)
total_found_nb += await Line.add_stops(assos)
logger.info(f"{total_found_nb} line <-> stop ({total_assos_nb = } found)")
async def load_stop_assos(batch_size: int = 5000) -> None:
data = await http_get(IDFM_STOP_STOP_AREAS_ASSOS_URL)
if data is None:
return None
total_assos_nb = area_stop_assos_nb = conn_stop_assos_nb = 0
area_stop_assos = []
connection_stop_assos = []
for asso in json_stop_area_stop_asso_decoder.decode(data):
fields = asso.fields
stop_id = int(fields.arrid)
area_stop_assos.append((int(fields.zdaid), stop_id))
connection_stop_assos.append((int(fields.zdcid), stop_id))
if len(area_stop_assos) == batch_size:
total_assos_nb += batch_size
if (found_nb := await StopArea.add_stops(area_stop_assos)) is not None:
area_stop_assos_nb += found_nb
area_stop_assos.clear()
if (
found_nb := await ConnectionArea.add_stops(connection_stop_assos)
) is not None:
conn_stop_assos_nb += found_nb
connection_stop_assos.clear()
if area_stop_assos:
total_assos_nb += len(area_stop_assos)
if (found_nb := await StopArea.add_stops(area_stop_assos)) is not None:
area_stop_assos_nb += found_nb
if (
found_nb := await ConnectionArea.add_stops(connection_stop_assos)
) is not None:
conn_stop_assos_nb += found_nb
logger.info(f"{area_stop_assos_nb} stop area <-> stop ({total_assos_nb = } found)")
logger.info(f"{conn_stop_assos_nb} stop area <-> stop ({total_assos_nb = } found)")
async def prepare(db: Database) -> None:
await load_idfm_lines()
await gather(
*(
load_idfm_stops(),
load_idfm_stop_areas(),
load_idfm_connection_areas(),
load_ratp_pictos(),
)
)
await gather(
*(
load_idfm_stop_shapes(IDFM_STOP_SHAPES_URL),
load_idfm_stop_shapes(IDFM_STOP_AREA_SHAPES_URL),
load_lines_stops_assos(),
load_stop_assos(),
)
)
def load_settings(path: str) -> Settings:
with open(path, "r") as config_file:
config = safe_load(config_file)
return Settings(**config)
async def main() -> None:
settings = load_settings(CONFIG_PATH)
await db.connect(settings.db, True)
begin_ts = time()
await prepare(db)
logger.info(f"Elapsed time: {time() - begin_ts}s")
await db.disconnect()
if __name__ == "__main__":
run(main())