576 lines
18 KiB
Python
Executable File
576 lines
18 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
|
|
from asyncio import run, gather
|
|
from logging import getLogger, INFO, Handler as LoggingHandler, NOTSET
|
|
from itertools import islice
|
|
from time import time
|
|
from os import environ
|
|
from typing import Callable, Iterable, List, Type
|
|
|
|
from aiofiles.tempfile import NamedTemporaryFile
|
|
from aiohttp import ClientSession
|
|
from msgspec import ValidationError
|
|
from msgspec.json import Decoder
|
|
from pyproj import Transformer
|
|
from shapefile import Reader as ShapeFileReader, ShapeRecord # type: ignore
|
|
from tqdm import tqdm
|
|
from yaml import safe_load
|
|
|
|
from api.db import Base, db, Database
|
|
from api.models import ConnectionArea, Line, LinePicto, Stop, StopArea, StopShape
|
|
from api.idfm_interface.idfm_types import (
|
|
ConnectionArea as IdfmConnectionArea,
|
|
IdfmLineState,
|
|
Line as IdfmLine,
|
|
LinePicto as IdfmPicto,
|
|
IdfmState,
|
|
Stop as IdfmStop,
|
|
StopArea as IdfmStopArea,
|
|
StopAreaStopAssociation,
|
|
StopAreaType,
|
|
StopLineAsso as IdfmStopLineAsso,
|
|
TransportMode,
|
|
)
|
|
from api.idfm_interface.ratp_types import Picto as RatpPicto
|
|
from api.settings import Settings
|
|
|
|
|
|
CONFIG_PATH = environ.get("CONFIG_PATH", "./config.sample.yaml")
|
|
BATCH_SIZE = 1000
|
|
|
|
|
|
IDFM_ROOT_URL = "https://data.iledefrance-mobilites.fr/explore/dataset"
|
|
IDFM_CONNECTION_AREAS_URL = (
|
|
f"{IDFM_ROOT_URL}/zones-de-correspondance/download/?format=json"
|
|
)
|
|
IDFM_LINES_URL = f"{IDFM_ROOT_URL}/referentiel-des-lignes/download/?format=json"
|
|
IDFM_PICTO_URL = f"{IDFM_ROOT_URL}/referentiel-des-lignes/files"
|
|
IDFM_STOP_AREAS_URL = f"{IDFM_ROOT_URL}/zones-d-arrets/download/?format=json"
|
|
IDFM_STOP_SHAPES_URL = "https://eu.ftp.opendatasoft.com/stif/Reflex/REF_ArR.zip"
|
|
IDFM_STOP_AREA_SHAPES_URL = "https://eu.ftp.opendatasoft.com/stif/Reflex/REF_ZdA.zip"
|
|
|
|
IDFM_STOP_STOP_AREAS_ASSOS_URL = f"{IDFM_ROOT_URL}/relations/download/?format=json"
|
|
IDFM_STOPS_LINES_ASSOS_URL = f"{IDFM_ROOT_URL}/arrets-lignes/download/?format=json"
|
|
IDFM_STOPS_URL = f"{IDFM_ROOT_URL}/arrets/download/?format=json"
|
|
|
|
RATP_ROOT_URL = "https://data.ratp.fr/api/explore/v2.1/catalog/datasets"
|
|
RATP_PICTOS_URL = (
|
|
f"{RATP_ROOT_URL}"
|
|
"/pictogrammes-des-lignes-de-metro-rer-tramway-bus-et-noctilien/exports/json?lang=fr"
|
|
)
|
|
|
|
|
|
# From https://stackoverflow.com/a/38739634
|
|
class TqdmLoggingHandler(LoggingHandler):
|
|
def __init__(self, level=NOTSET):
|
|
super().__init__(level)
|
|
|
|
def emit(self, record):
|
|
try:
|
|
msg = self.format(record)
|
|
tqdm.write(msg)
|
|
self.flush()
|
|
except Exception:
|
|
self.handleError(record)
|
|
|
|
|
|
logger = getLogger(__name__)
|
|
logger.setLevel(INFO)
|
|
logger.addHandler(TqdmLoggingHandler())
|
|
|
|
|
|
epsg2154_epsg3857_transformer = Transformer.from_crs(2154, 3857)
|
|
|
|
json_stops_decoder = Decoder(type=List[IdfmStop])
|
|
json_stop_areas_decoder = Decoder(type=List[IdfmStopArea])
|
|
json_connection_areas_decoder = Decoder(type=List[IdfmConnectionArea])
|
|
json_lines_decoder = Decoder(type=List[IdfmLine])
|
|
json_stops_lines_assos_decoder = Decoder(type=List[IdfmStopLineAsso])
|
|
json_ratp_pictos_decoder = Decoder(type=List[RatpPicto])
|
|
json_stop_area_stop_asso_decoder = Decoder(type=List[StopAreaStopAssociation])
|
|
|
|
|
|
def format_idfm_pictos(*pictos: IdfmPicto) -> Iterable[LinePicto]:
|
|
ret = []
|
|
|
|
for picto in pictos:
|
|
ret.append(
|
|
LinePicto(
|
|
id=picto.id_,
|
|
mime_type=picto.mimetype,
|
|
height_px=picto.height,
|
|
width_px=picto.width,
|
|
filename=picto.filename,
|
|
url=f"{IDFM_PICTO_URL}/{picto.id_}/download",
|
|
thumbnail=picto.thumbnail,
|
|
format=picto.format,
|
|
)
|
|
)
|
|
|
|
return ret
|
|
|
|
|
|
def format_ratp_pictos(*pictos: RatpPicto) -> Iterable[tuple[str, LinePicto]]:
|
|
ret = []
|
|
|
|
for picto in pictos:
|
|
if (fields := picto.noms_des_fichiers) is not None:
|
|
ret.append(
|
|
(
|
|
picto.indices_commerciaux,
|
|
LinePicto(
|
|
id=fields.id_,
|
|
mime_type=f"image/{fields.format.lower()}",
|
|
height_px=fields.height,
|
|
width_px=fields.width,
|
|
filename=fields.filename,
|
|
url=f"{RATP_PICTOS_URL}/{fields.id_}/download",
|
|
thumbnail=fields.thumbnail,
|
|
format=fields.format,
|
|
),
|
|
)
|
|
)
|
|
|
|
return ret
|
|
|
|
|
|
def format_idfm_lines(*lines: IdfmLine) -> Iterable[Line]:
|
|
ret = []
|
|
|
|
optional_value = IdfmLine.optional_value
|
|
|
|
for line in lines:
|
|
fields = line.fields
|
|
|
|
line_id = fields.id_line
|
|
try:
|
|
formatted_line_id = int(line_id[1:] if line_id[0] == "C" else line_id)
|
|
except ValueError:
|
|
logger.warning("Unable to format %s line id.", line_id)
|
|
continue
|
|
|
|
try:
|
|
operator_id = int(fields.operatorref) # type: ignore
|
|
except (ValueError, TypeError):
|
|
logger.warning("Unable to format %s operator id.", fields.operatorref)
|
|
operator_id = 0
|
|
|
|
ret.append(
|
|
Line(
|
|
id=formatted_line_id,
|
|
short_name=fields.shortname_line,
|
|
name=fields.name_line,
|
|
status=IdfmLineState(fields.status.value),
|
|
transport_mode=TransportMode(fields.transportmode.value),
|
|
transport_submode=optional_value(fields.transportsubmode),
|
|
network_name=optional_value(fields.networkname),
|
|
group_of_lines_id=optional_value(fields.id_groupoflines),
|
|
group_of_lines_shortname=optional_value(fields.shortname_groupoflines),
|
|
colour_web_hexa=fields.colourweb_hexa,
|
|
text_colour_hexa=fields.textcolourprint_hexa,
|
|
operator_id=operator_id,
|
|
operator_name=optional_value(fields.operatorname),
|
|
accessibility=IdfmState(fields.accessibility.value),
|
|
visual_signs_available=IdfmState(fields.visualsigns_available.value),
|
|
audible_signs_available=IdfmState(fields.audiblesigns_available.value),
|
|
picto_id=fields.picto.id_ if fields.picto is not None else None,
|
|
record_id=line.recordid,
|
|
record_ts=int(line.record_timestamp.timestamp()),
|
|
)
|
|
)
|
|
|
|
return ret
|
|
|
|
|
|
def format_idfm_stops(*stops: IdfmStop) -> Iterable[Stop]:
|
|
for stop in stops:
|
|
fields = stop.fields
|
|
|
|
try:
|
|
created_ts = int(fields.arrcreated.timestamp()) # type: ignore
|
|
except AttributeError:
|
|
created_ts = None
|
|
|
|
epsg3857_point = epsg2154_epsg3857_transformer.transform(
|
|
fields.arrxepsg2154, fields.arryepsg2154
|
|
)
|
|
|
|
try:
|
|
postal_region = int(fields.arrpostalregion)
|
|
except ValueError:
|
|
logger.warning("Unable to format %s postal region.", fields.arrpostalregion)
|
|
continue
|
|
|
|
yield Stop(
|
|
id=int(fields.arrid),
|
|
name=fields.arrname,
|
|
epsg3857_x=epsg3857_point[0],
|
|
epsg3857_y=epsg3857_point[1],
|
|
town_name=fields.arrtown,
|
|
postal_region=postal_region,
|
|
transport_mode=TransportMode(fields.arrtype.value),
|
|
version=fields.arrversion,
|
|
created_ts=created_ts,
|
|
changed_ts=int(fields.arrchanged.timestamp()),
|
|
accessibility=IdfmState(fields.arraccessibility.value),
|
|
visual_signs_available=IdfmState(fields.arrvisualsigns.value),
|
|
audible_signs_available=IdfmState(fields.arraudiblesignals.value),
|
|
record_id=stop.recordid,
|
|
record_ts=int(stop.record_timestamp.timestamp()),
|
|
)
|
|
|
|
|
|
def format_idfm_stop_areas(*stop_areas: IdfmStopArea) -> Iterable[StopArea]:
|
|
for stop_area in stop_areas:
|
|
fields = stop_area.fields
|
|
|
|
try:
|
|
created_ts = int(fields.zdacreated.timestamp()) # type: ignore
|
|
except AttributeError:
|
|
created_ts = None
|
|
|
|
epsg3857_point = epsg2154_epsg3857_transformer.transform(
|
|
fields.zdaxepsg2154, fields.zdayepsg2154
|
|
)
|
|
|
|
yield StopArea(
|
|
id=int(fields.zdaid),
|
|
name=fields.zdaname,
|
|
town_name=fields.zdatown,
|
|
postal_region=fields.zdapostalregion,
|
|
epsg3857_x=epsg3857_point[0],
|
|
epsg3857_y=epsg3857_point[1],
|
|
type=StopAreaType(fields.zdatype.value),
|
|
version=fields.zdaversion,
|
|
created_ts=created_ts,
|
|
changed_ts=int(fields.zdachanged.timestamp()),
|
|
)
|
|
|
|
|
|
def format_idfm_connection_areas(
|
|
*connection_areas: IdfmConnectionArea,
|
|
) -> Iterable[ConnectionArea]:
|
|
for connection_area in connection_areas:
|
|
fields = connection_area.fields
|
|
|
|
epsg3857_point = epsg2154_epsg3857_transformer.transform(
|
|
fields.zdcxepsg2154, fields.zdcyepsg2154
|
|
)
|
|
|
|
yield ConnectionArea(
|
|
id=int(fields.zdcid),
|
|
name=fields.zdcname,
|
|
town_name=fields.zdctown,
|
|
postal_region=fields.zdcpostalregion,
|
|
epsg3857_x=epsg3857_point[0],
|
|
epsg3857_y=epsg3857_point[1],
|
|
transport_mode=StopAreaType(fields.zdctype.value),
|
|
version=fields.zdcversion,
|
|
created_ts=int(fields.zdccreated.timestamp()),
|
|
changed_ts=int(fields.zdcchanged.timestamp()),
|
|
)
|
|
|
|
|
|
def format_idfm_stop_shapes(*shape_records: ShapeRecord) -> Iterable[StopShape]:
|
|
for shape_record in shape_records:
|
|
|
|
epsg3857_points = [
|
|
epsg2154_epsg3857_transformer.transform(*point)
|
|
for point in shape_record.shape.points
|
|
]
|
|
|
|
try:
|
|
bbox_it = iter(shape_record.shape.bbox)
|
|
epsg3857_bbox = [
|
|
epsg2154_epsg3857_transformer.transform(*point)
|
|
for point in zip(bbox_it, bbox_it)
|
|
]
|
|
except AttributeError:
|
|
# Handle stop shapes for which no bbox is provided
|
|
epsg3857_bbox = []
|
|
|
|
yield StopShape(
|
|
id=shape_record.record[1],
|
|
type=shape_record.shape.shapeType,
|
|
epsg3857_bbox=epsg3857_bbox,
|
|
epsg3857_points=epsg3857_points,
|
|
)
|
|
|
|
|
|
async def http_get(url: str) -> str | None:
|
|
chunks = []
|
|
|
|
headers = {"Accept": "application/json"}
|
|
async with ClientSession(headers=headers) as session:
|
|
|
|
async with session.get(url) as response:
|
|
size = int(response.headers.get("content-length", 0)) or None
|
|
|
|
progress_bar = tqdm(desc=f"Downloading {url}", total=size)
|
|
|
|
if response.status == 200:
|
|
async for chunk in response.content.iter_chunked(1024 * 1024):
|
|
chunks.append(chunk.decode())
|
|
progress_bar.update(len(chunk))
|
|
|
|
else:
|
|
return None
|
|
|
|
return "".join(chunks)
|
|
|
|
|
|
async def http_request(
|
|
url: str, decode: Callable, format_method: Callable, model: Type[Base]
|
|
) -> bool:
|
|
elements = []
|
|
|
|
data = await http_get(url)
|
|
if data is None:
|
|
return False
|
|
|
|
try:
|
|
for element in decode(data):
|
|
elements.append(element)
|
|
|
|
if len(elements) == BATCH_SIZE:
|
|
await model.add(format_method(*elements))
|
|
elements.clear()
|
|
|
|
if elements:
|
|
await model.add(format_method(*elements))
|
|
|
|
except ValidationError as err:
|
|
logger.warning(err)
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
async def load_idfm_stops() -> bool:
|
|
return await http_request(
|
|
IDFM_STOPS_URL, json_stops_decoder.decode, format_idfm_stops, Stop
|
|
)
|
|
|
|
|
|
async def load_idfm_stop_areas() -> bool:
|
|
return await http_request(
|
|
IDFM_STOP_AREAS_URL,
|
|
json_stop_areas_decoder.decode,
|
|
format_idfm_stop_areas,
|
|
StopArea,
|
|
)
|
|
|
|
|
|
async def load_idfm_connection_areas() -> bool:
|
|
return await http_request(
|
|
IDFM_CONNECTION_AREAS_URL,
|
|
json_connection_areas_decoder.decode,
|
|
format_idfm_connection_areas,
|
|
ConnectionArea,
|
|
)
|
|
|
|
|
|
async def load_idfm_stop_shapes(url: str) -> None:
|
|
async with ClientSession(headers={"Accept": "application/zip"}) as session:
|
|
async with session.get(url) as response:
|
|
size = int(response.headers.get("content-length", 0)) or None
|
|
dl_progress_bar = tqdm(desc=f"Downloading {url}", total=size)
|
|
|
|
if response.status == 200:
|
|
async with NamedTemporaryFile(suffix=".zip") as tmp_file:
|
|
async for chunk in response.content.iter_chunked(1024 * 1024):
|
|
await tmp_file.write(chunk)
|
|
dl_progress_bar.update(len(chunk))
|
|
|
|
with ShapeFileReader(tmp_file.name) as reader:
|
|
step_begin_ts = time()
|
|
|
|
shapes = reader.shapeRecords()
|
|
shapes_len = len(shapes)
|
|
|
|
db_progress_bar = tqdm(
|
|
desc=f"Filling db with {shapes_len} StopShapes",
|
|
total=shapes_len,
|
|
)
|
|
|
|
begin, end, finished = 0, BATCH_SIZE, False
|
|
while not finished:
|
|
elements = islice(shapes, begin, end)
|
|
formatteds = list(format_idfm_stop_shapes(*elements))
|
|
await StopShape.add(formatteds)
|
|
|
|
begin = end
|
|
end = begin + BATCH_SIZE
|
|
finished = begin > len(shapes)
|
|
db_progress_bar.update(BATCH_SIZE)
|
|
|
|
logger.info(
|
|
f"Add {StopShape.__name__}s: {time() - step_begin_ts}s"
|
|
)
|
|
|
|
|
|
async def load_idfm_lines() -> None:
|
|
data = await http_get(IDFM_LINES_URL)
|
|
if data is None:
|
|
return None
|
|
|
|
lines, pictos = [], []
|
|
picto_ids = set()
|
|
|
|
for line in json_lines_decoder.decode(data):
|
|
|
|
if (picto := line.fields.picto) is not None and picto.id_ not in picto_ids:
|
|
picto_ids.add(picto.id_)
|
|
pictos.append(picto)
|
|
|
|
lines.append(line)
|
|
|
|
if len(lines) == BATCH_SIZE:
|
|
await LinePicto.add(list(format_idfm_pictos(*pictos)))
|
|
await Line.add(list(format_idfm_lines(*lines)))
|
|
lines.clear()
|
|
pictos.clear()
|
|
|
|
if pictos:
|
|
await LinePicto.add(list(format_idfm_pictos(*pictos)))
|
|
if lines:
|
|
await Line.add(list(format_idfm_lines(*lines)))
|
|
|
|
|
|
async def load_ratp_pictos(batch_size: int = 5) -> None:
|
|
data = await http_get(RATP_PICTOS_URL)
|
|
if data is None:
|
|
return None
|
|
|
|
pictos = []
|
|
for picto in json_ratp_pictos_decoder.decode(data):
|
|
pictos.append(picto)
|
|
|
|
if len(pictos) == batch_size:
|
|
formatteds = format_ratp_pictos(*pictos)
|
|
await LinePicto.add([picto[1] for picto in formatteds])
|
|
await Line.add_pictos(formatteds)
|
|
pictos.clear()
|
|
|
|
if pictos:
|
|
formatteds = format_ratp_pictos(*pictos)
|
|
await LinePicto.add([picto[1] for picto in formatteds])
|
|
await Line.add_pictos(formatteds)
|
|
|
|
|
|
async def load_lines_stops_assos(batch_size: int = 5000) -> None:
|
|
data = await http_get(IDFM_STOPS_LINES_ASSOS_URL)
|
|
if data is None:
|
|
return None
|
|
|
|
total_assos_nb = total_found_nb = 0
|
|
assos = []
|
|
for asso in json_stops_lines_assos_decoder.decode(data):
|
|
fields = asso.fields
|
|
try:
|
|
stop_id = int(fields.stop_id.rsplit(":", 1)[-1])
|
|
except ValueError as err:
|
|
logger.error(err)
|
|
logger.error(f"{fields.stop_id = }")
|
|
continue
|
|
|
|
assos.append((fields.route_long_name, fields.operatorname, stop_id))
|
|
if len(assos) == batch_size:
|
|
total_assos_nb += batch_size
|
|
total_found_nb += await Line.add_stops(assos)
|
|
assos.clear()
|
|
|
|
if assos:
|
|
total_assos_nb += len(assos)
|
|
total_found_nb += await Line.add_stops(assos)
|
|
|
|
logger.info(f"{total_found_nb} line <-> stop ({total_assos_nb = } found)")
|
|
|
|
|
|
async def load_stop_assos(batch_size: int = 5000) -> None:
|
|
data = await http_get(IDFM_STOP_STOP_AREAS_ASSOS_URL)
|
|
if data is None:
|
|
return None
|
|
|
|
total_assos_nb = area_stop_assos_nb = conn_stop_assos_nb = 0
|
|
area_stop_assos = []
|
|
connection_stop_assos = []
|
|
|
|
for asso in json_stop_area_stop_asso_decoder.decode(data):
|
|
fields = asso.fields
|
|
|
|
stop_id = int(fields.arrid)
|
|
|
|
area_stop_assos.append((int(fields.zdaid), stop_id))
|
|
connection_stop_assos.append((int(fields.zdcid), stop_id))
|
|
|
|
if len(area_stop_assos) == batch_size:
|
|
total_assos_nb += batch_size
|
|
|
|
if (found_nb := await StopArea.add_stops(area_stop_assos)) is not None:
|
|
area_stop_assos_nb += found_nb
|
|
area_stop_assos.clear()
|
|
|
|
if (
|
|
found_nb := await ConnectionArea.add_stops(connection_stop_assos)
|
|
) is not None:
|
|
conn_stop_assos_nb += found_nb
|
|
connection_stop_assos.clear()
|
|
|
|
if area_stop_assos:
|
|
total_assos_nb += len(area_stop_assos)
|
|
if (found_nb := await StopArea.add_stops(area_stop_assos)) is not None:
|
|
area_stop_assos_nb += found_nb
|
|
if (
|
|
found_nb := await ConnectionArea.add_stops(connection_stop_assos)
|
|
) is not None:
|
|
conn_stop_assos_nb += found_nb
|
|
|
|
logger.info(f"{area_stop_assos_nb} stop area <-> stop ({total_assos_nb = } found)")
|
|
logger.info(f"{conn_stop_assos_nb} stop area <-> stop ({total_assos_nb = } found)")
|
|
|
|
|
|
async def prepare(db: Database) -> None:
|
|
await load_idfm_lines()
|
|
|
|
await gather(
|
|
*(
|
|
load_idfm_stops(),
|
|
load_idfm_stop_areas(),
|
|
load_idfm_connection_areas(),
|
|
load_ratp_pictos(),
|
|
)
|
|
)
|
|
|
|
await gather(
|
|
*(
|
|
load_idfm_stop_shapes(IDFM_STOP_SHAPES_URL),
|
|
load_idfm_stop_shapes(IDFM_STOP_AREA_SHAPES_URL),
|
|
load_lines_stops_assos(),
|
|
load_stop_assos(),
|
|
)
|
|
)
|
|
|
|
|
|
def load_settings(path: str) -> Settings:
|
|
with open(path, "r") as config_file:
|
|
config = safe_load(config_file)
|
|
|
|
return Settings(**config)
|
|
|
|
|
|
async def main() -> None:
|
|
settings = load_settings(CONFIG_PATH)
|
|
|
|
await db.connect(settings.db, True)
|
|
|
|
begin_ts = time()
|
|
await prepare(db)
|
|
logger.info(f"Elapsed time: {time() - begin_ts}s")
|
|
|
|
await db.disconnect()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run(main())
|