#!/usr/bin/env python3 from asyncio import run, gather from logging import getLogger, INFO, Handler as LoggingHandler, NOTSET from itertools import islice from time import time from os import environ from typing import Callable, Iterable, List, Type from aiofiles.tempfile import NamedTemporaryFile from aiohttp import ClientSession from msgspec import ValidationError from msgspec.json import Decoder from pyproj import Transformer from shapefile import Reader as ShapeFileReader, ShapeRecord # type: ignore from tqdm import tqdm from yaml import safe_load from backend.db import Base, db, Database from backend.models import ConnectionArea, Line, LinePicto, Stop, StopArea, StopShape from backend.idfm_interface.idfm_types import ( ConnectionArea as IdfmConnectionArea, IdfmLineState, Line as IdfmLine, LinePicto as IdfmPicto, IdfmState, Stop as IdfmStop, StopArea as IdfmStopArea, StopAreaStopAssociation, StopAreaType, StopLineAsso as IdfmStopLineAsso, TransportMode, ) from backend.idfm_interface.ratp_types import Picto as RatpPicto from backend.settings import Settings CONFIG_PATH = environ.get("CONFIG_PATH", "./config.sample.yaml") BATCH_SIZE = 1000 IDFM_ROOT_URL = "https://data.iledefrance-mobilites.fr/explore/dataset" IDFM_CONNECTION_AREAS_URL = ( f"{IDFM_ROOT_URL}/zones-de-correspondance/download/?format=json" ) IDFM_LINES_URL = f"{IDFM_ROOT_URL}/referentiel-des-lignes/download/?format=json" IDFM_PICTO_URL = f"{IDFM_ROOT_URL}/referentiel-des-lignes/files" IDFM_STOP_AREAS_URL = f"{IDFM_ROOT_URL}/zones-d-arrets/download/?format=json" IDFM_STOP_SHAPES_URL = "https://eu.ftp.opendatasoft.com/stif/Reflex/REF_ArR.zip" IDFM_STOP_AREA_SHAPES_URL = "https://eu.ftp.opendatasoft.com/stif/Reflex/REF_ZdA.zip" IDFM_STOP_STOP_AREAS_ASSOS_URL = f"{IDFM_ROOT_URL}/relations/download/?format=json" IDFM_STOPS_LINES_ASSOS_URL = f"{IDFM_ROOT_URL}/arrets-lignes/download/?format=json" IDFM_STOPS_URL = f"{IDFM_ROOT_URL}/arrets/download/?format=json" RATP_ROOT_URL = "https://data.ratp.fr/api/explore/v2.1/catalog/datasets" RATP_PICTOS_URL = ( f"{RATP_ROOT_URL}" "/pictogrammes-des-lignes-de-metro-rer-tramway-bus-et-noctilien/exports/json?lang=fr" ) # From https://stackoverflow.com/a/38739634 class TqdmLoggingHandler(LoggingHandler): def __init__(self, level=NOTSET): super().__init__(level) def emit(self, record): try: msg = self.format(record) tqdm.write(msg) self.flush() except Exception: self.handleError(record) logger = getLogger(__name__) logger.setLevel(INFO) logger.addHandler(TqdmLoggingHandler()) epsg2154_epsg3857_transformer = Transformer.from_crs(2154, 3857) json_stops_decoder = Decoder(type=List[IdfmStop]) json_stop_areas_decoder = Decoder(type=List[IdfmStopArea]) json_connection_areas_decoder = Decoder(type=List[IdfmConnectionArea]) json_lines_decoder = Decoder(type=List[IdfmLine]) json_stops_lines_assos_decoder = Decoder(type=List[IdfmStopLineAsso]) json_ratp_pictos_decoder = Decoder(type=List[RatpPicto]) json_stop_area_stop_asso_decoder = Decoder(type=List[StopAreaStopAssociation]) def format_idfm_pictos(*pictos: IdfmPicto) -> Iterable[LinePicto]: ret = [] for picto in pictos: ret.append( LinePicto( id=picto.id_, mime_type=picto.mimetype, height_px=picto.height, width_px=picto.width, filename=picto.filename, url=f"{IDFM_PICTO_URL}/{picto.id_}/download", thumbnail=picto.thumbnail, format=picto.format, ) ) return ret def format_ratp_pictos(*pictos: RatpPicto) -> Iterable[tuple[str, LinePicto]]: ret = [] for picto in pictos: if (fields := picto.noms_des_fichiers) is not None: ret.append( ( picto.indices_commerciaux, LinePicto( id=fields.id_, mime_type=f"image/{fields.format.lower()}", height_px=fields.height, width_px=fields.width, filename=fields.filename, url=f"{RATP_PICTOS_URL}/{fields.id_}/download", thumbnail=fields.thumbnail, format=fields.format, ), ) ) return ret def format_idfm_lines(*lines: IdfmLine) -> Iterable[Line]: ret = [] optional_value = IdfmLine.optional_value for line in lines: fields = line.fields line_id = fields.id_line try: formatted_line_id = int(line_id[1:] if line_id[0] == "C" else line_id) except ValueError: logger.warning("Unable to format %s line id.", line_id) continue try: operator_id = int(fields.operatorref) # type: ignore except (ValueError, TypeError): logger.warning("Unable to format %s operator id.", fields.operatorref) operator_id = 0 ret.append( Line( id=formatted_line_id, short_name=fields.shortname_line, name=fields.name_line, status=IdfmLineState(fields.status.value), transport_mode=TransportMode(fields.transportmode.value), transport_submode=optional_value(fields.transportsubmode), network_name=optional_value(fields.networkname), group_of_lines_id=optional_value(fields.id_groupoflines), group_of_lines_shortname=optional_value(fields.shortname_groupoflines), colour_web_hexa=fields.colourweb_hexa, text_colour_hexa=fields.textcolourprint_hexa, operator_id=operator_id, operator_name=optional_value(fields.operatorname), accessibility=IdfmState(fields.accessibility.value), visual_signs_available=IdfmState(fields.visualsigns_available.value), audible_signs_available=IdfmState(fields.audiblesigns_available.value), picto_id=fields.picto.id_ if fields.picto is not None else None, record_id=line.recordid, record_ts=int(line.record_timestamp.timestamp()), ) ) return ret def format_idfm_stops(*stops: IdfmStop) -> Iterable[Stop]: for stop in stops: fields = stop.fields try: created_ts = int(fields.arrcreated.timestamp()) # type: ignore except AttributeError: created_ts = None epsg3857_point = epsg2154_epsg3857_transformer.transform( fields.arrxepsg2154, fields.arryepsg2154 ) try: postal_region = int(fields.arrpostalregion) except ValueError: logger.warning("Unable to format %s postal region.", fields.arrpostalregion) continue yield Stop( id=int(fields.arrid), name=fields.arrname, epsg3857_x=epsg3857_point[0], epsg3857_y=epsg3857_point[1], town_name=fields.arrtown, postal_region=postal_region, transport_mode=TransportMode(fields.arrtype.value), version=fields.arrversion, created_ts=created_ts, changed_ts=int(fields.arrchanged.timestamp()), accessibility=IdfmState(fields.arraccessibility.value), visual_signs_available=IdfmState(fields.arrvisualsigns.value), audible_signs_available=IdfmState(fields.arraudiblesignals.value), record_id=stop.recordid, record_ts=int(stop.record_timestamp.timestamp()), ) def format_idfm_stop_areas(*stop_areas: IdfmStopArea) -> Iterable[StopArea]: for stop_area in stop_areas: fields = stop_area.fields try: created_ts = int(fields.zdacreated.timestamp()) # type: ignore except AttributeError: created_ts = None epsg3857_point = epsg2154_epsg3857_transformer.transform( fields.zdaxepsg2154, fields.zdayepsg2154 ) yield StopArea( id=int(fields.zdaid), name=fields.zdaname, town_name=fields.zdatown, postal_region=fields.zdapostalregion, epsg3857_x=epsg3857_point[0], epsg3857_y=epsg3857_point[1], type=StopAreaType(fields.zdatype.value), version=fields.zdaversion, created_ts=created_ts, changed_ts=int(fields.zdachanged.timestamp()), ) def format_idfm_connection_areas( *connection_areas: IdfmConnectionArea, ) -> Iterable[ConnectionArea]: for connection_area in connection_areas: fields = connection_area.fields epsg3857_point = epsg2154_epsg3857_transformer.transform( fields.zdcxepsg2154, fields.zdcyepsg2154 ) yield ConnectionArea( id=int(fields.zdcid), name=fields.zdcname, town_name=fields.zdctown, postal_region=fields.zdcpostalregion, epsg3857_x=epsg3857_point[0], epsg3857_y=epsg3857_point[1], transport_mode=StopAreaType(fields.zdctype.value), version=fields.zdcversion, created_ts=int(fields.zdccreated.timestamp()), changed_ts=int(fields.zdcchanged.timestamp()), ) def format_idfm_stop_shapes(*shape_records: ShapeRecord) -> Iterable[StopShape]: for shape_record in shape_records: epsg3857_points = [ epsg2154_epsg3857_transformer.transform(*point) for point in shape_record.shape.points ] try: bbox_it = iter(shape_record.shape.bbox) epsg3857_bbox = [ epsg2154_epsg3857_transformer.transform(*point) for point in zip(bbox_it, bbox_it) ] except AttributeError: # Handle stop shapes for which no bbox is provided epsg3857_bbox = [] yield StopShape( id=shape_record.record[1], type=shape_record.shape.shapeType, epsg3857_bbox=epsg3857_bbox, epsg3857_points=epsg3857_points, ) async def http_get(url: str) -> str | None: chunks = [] headers = {"Accept": "application/json"} async with ClientSession(headers=headers) as session: async with session.get(url) as response: size = int(response.headers.get("content-length", 0)) or None progress_bar = tqdm(desc=f"Downloading {url}", total=size) if response.status == 200: async for chunk in response.content.iter_chunked(1024 * 1024): chunks.append(chunk.decode()) progress_bar.update(len(chunk)) else: return None return "".join(chunks) async def http_request( url: str, decode: Callable, format_method: Callable, model: Type[Base] ) -> bool: elements = [] data = await http_get(url) if data is None: return False try: for element in decode(data): elements.append(element) if len(elements) == BATCH_SIZE: await model.add(format_method(*elements)) elements.clear() if elements: await model.add(format_method(*elements)) except ValidationError as err: logger.warning(err) return False return True async def load_idfm_stops() -> bool: return await http_request( IDFM_STOPS_URL, json_stops_decoder.decode, format_idfm_stops, Stop ) async def load_idfm_stop_areas() -> bool: return await http_request( IDFM_STOP_AREAS_URL, json_stop_areas_decoder.decode, format_idfm_stop_areas, StopArea, ) async def load_idfm_connection_areas() -> bool: return await http_request( IDFM_CONNECTION_AREAS_URL, json_connection_areas_decoder.decode, format_idfm_connection_areas, ConnectionArea, ) async def load_idfm_stop_shapes(url: str) -> None: async with ClientSession(headers={"Accept": "application/zip"}) as session: async with session.get(url) as response: size = int(response.headers.get("content-length", 0)) or None dl_progress_bar = tqdm(desc=f"Downloading {url}", total=size) if response.status == 200: async with NamedTemporaryFile(suffix=".zip") as tmp_file: async for chunk in response.content.iter_chunked(1024 * 1024): await tmp_file.write(chunk) dl_progress_bar.update(len(chunk)) with ShapeFileReader(tmp_file.name) as reader: step_begin_ts = time() shapes = reader.shapeRecords() shapes_len = len(shapes) db_progress_bar = tqdm( desc=f"Filling db with {shapes_len} StopShapes", total=shapes_len, ) begin, end, finished = 0, BATCH_SIZE, False while not finished: elements = islice(shapes, begin, end) formatteds = list(format_idfm_stop_shapes(*elements)) await StopShape.add(formatteds) begin = end end = begin + BATCH_SIZE finished = begin > len(shapes) db_progress_bar.update(BATCH_SIZE) logger.info( f"Add {StopShape.__name__}s: {time() - step_begin_ts}s" ) async def load_idfm_lines() -> None: data = await http_get(IDFM_LINES_URL) if data is None: return None lines, pictos = [], [] picto_ids = set() for line in json_lines_decoder.decode(data): if (picto := line.fields.picto) is not None and picto.id_ not in picto_ids: picto_ids.add(picto.id_) pictos.append(picto) lines.append(line) if len(lines) == BATCH_SIZE: await LinePicto.add(list(format_idfm_pictos(*pictos))) await Line.add(list(format_idfm_lines(*lines))) lines.clear() pictos.clear() if pictos: await LinePicto.add(list(format_idfm_pictos(*pictos))) if lines: await Line.add(list(format_idfm_lines(*lines))) async def load_ratp_pictos(batch_size: int = 5) -> None: data = await http_get(RATP_PICTOS_URL) if data is None: return None pictos = [] for picto in json_ratp_pictos_decoder.decode(data): pictos.append(picto) if len(pictos) == batch_size: formatteds = format_ratp_pictos(*pictos) await LinePicto.add([picto[1] for picto in formatteds]) await Line.add_pictos(formatteds) pictos.clear() if pictos: formatteds = format_ratp_pictos(*pictos) await LinePicto.add([picto[1] for picto in formatteds]) await Line.add_pictos(formatteds) async def load_lines_stops_assos(batch_size: int = 5000) -> None: data = await http_get(IDFM_STOPS_LINES_ASSOS_URL) if data is None: return None total_assos_nb = total_found_nb = 0 assos = [] for asso in json_stops_lines_assos_decoder.decode(data): fields = asso.fields try: stop_id = int(fields.stop_id.rsplit(":", 1)[-1]) except ValueError as err: logger.error(err) logger.error(f"{fields.stop_id = }") continue assos.append((fields.route_long_name, fields.operatorname, stop_id)) if len(assos) == batch_size: total_assos_nb += batch_size total_found_nb += await Line.add_stops(assos) assos.clear() if assos: total_assos_nb += len(assos) total_found_nb += await Line.add_stops(assos) logger.info(f"{total_found_nb} line <-> stop ({total_assos_nb = } found)") async def load_stop_assos(batch_size: int = 5000) -> None: data = await http_get(IDFM_STOP_STOP_AREAS_ASSOS_URL) if data is None: return None total_assos_nb = area_stop_assos_nb = conn_stop_assos_nb = 0 area_stop_assos = [] connection_stop_assos = [] for asso in json_stop_area_stop_asso_decoder.decode(data): fields = asso.fields stop_id = int(fields.arrid) area_stop_assos.append((int(fields.zdaid), stop_id)) connection_stop_assos.append((int(fields.zdcid), stop_id)) if len(area_stop_assos) == batch_size: total_assos_nb += batch_size if (found_nb := await StopArea.add_stops(area_stop_assos)) is not None: area_stop_assos_nb += found_nb area_stop_assos.clear() if ( found_nb := await ConnectionArea.add_stops(connection_stop_assos) ) is not None: conn_stop_assos_nb += found_nb connection_stop_assos.clear() if area_stop_assos: total_assos_nb += len(area_stop_assos) if (found_nb := await StopArea.add_stops(area_stop_assos)) is not None: area_stop_assos_nb += found_nb if ( found_nb := await ConnectionArea.add_stops(connection_stop_assos) ) is not None: conn_stop_assos_nb += found_nb logger.info(f"{area_stop_assos_nb} stop area <-> stop ({total_assos_nb = } found)") logger.info(f"{conn_stop_assos_nb} stop area <-> stop ({total_assos_nb = } found)") async def prepare(db: Database) -> None: await load_idfm_lines() await gather( *( load_idfm_stops(), load_idfm_stop_areas(), load_idfm_connection_areas(), load_ratp_pictos(), ) ) await gather( *( load_idfm_stop_shapes(IDFM_STOP_SHAPES_URL), load_idfm_stop_shapes(IDFM_STOP_AREA_SHAPES_URL), load_lines_stops_assos(), load_stop_assos(), ) ) def load_settings(path: str) -> Settings: with open(path, "r") as config_file: config = safe_load(config_file) return Settings(**config) async def main() -> None: settings = load_settings(CONFIG_PATH) await db.connect(settings.db, True) begin_ts = time() await prepare(db) logger.info(f"Elapsed time: {time() - begin_ts}s") await db.disconnect() if __name__ == "__main__": run(main())