From 06e9a6689a30d8a995a5d9706cc0ecd75bbb7382 Mon Sep 17 00:00:00 2001 From: Paul-Edouard Sarlin Date: Mon, 19 Feb 2024 22:54:29 +0100 Subject: [PATCH] Apply black and isort --- demo.ipynb | 10 +++++++--- maploc/__init__.py | 3 +-- maploc/data/image.py | 4 ++-- maploc/data/kitti/dataset.py | 4 ++-- maploc/data/kitti/prepare.py | 6 +++--- maploc/data/mapillary/dataset.py | 4 ++-- maploc/data/mapillary/download.py | 12 ++++++----- maploc/data/mapillary/prepare.py | 33 +++++++++++++++---------------- maploc/data/mapillary/utils.py | 2 +- maploc/data/torch.py | 6 +++--- maploc/demo.py | 14 ++++++------- maploc/evaluation/kitti.py | 3 +-- maploc/evaluation/mapillary.py | 3 +-- maploc/evaluation/run.py | 13 ++++++------ maploc/evaluation/viz.py | 10 +++++----- maploc/models/orienternet.py | 6 +++--- maploc/models/sequential.py | 2 +- maploc/osm/analysis.py | 2 +- maploc/osm/data.py | 3 +-- maploc/osm/download.py | 2 +- maploc/osm/reader.py | 2 +- maploc/osm/tiling.py | 2 +- maploc/osm/viz.py | 2 +- maploc/train.py | 4 ++-- maploc/utils/exif.py | 7 ++++--- maploc/utils/geo_opensfm.py | 4 +++- maploc/utils/io.py | 2 +- setup.py | 10 +++++----- 28 files changed, 89 insertions(+), 86 deletions(-) diff --git a/demo.ipynb b/demo.ipynb index 7ba7398..98d3ada 100644 --- a/demo.ipynb +++ b/demo.ipynb @@ -14,7 +14,7 @@ "# The highest accuracy is achieved with num_rotations=360\n", "# but num_rotations=64~128 is often sufficient.\n", "# To reduce the memory usage, we can reduce the tile size in the next cell.\n", - "demo = Demo(num_rotations=256, device='cpu')" + "demo = Demo(num_rotations=256, device=\"cpu\")" ] }, { @@ -135,6 +135,7 @@ "\n", "# Show the query area in an interactive map\n", "from maploc.osm.viz import GeoPlotter\n", + "\n", "plot = GeoPlotter(zoom=16)\n", "plot.points(prior_latlon[:2], \"red\", name=\"location prior\", size=10)\n", "plot.bbox(proj.unproject(bbox), \"blue\", name=\"map tile\")\n", @@ -142,12 +143,14 @@ "\n", "# Query OpenStreetMap for this area\n", "from maploc.osm.tiling import TileManager\n", + "\n", "tiler = TileManager.from_bbox(proj, bbox + 10, demo.config.data.pixel_per_meter)\n", "canvas = tiler.query(bbox)\n", "\n", "# Show the inputs to the model: image and raster map\n", "from maploc.osm.viz import Colormap, plot_nodes\n", "from maploc.utils.viz_2d import plot_images\n", + "\n", "map_viz = Colormap.apply(canvas.raster)\n", "plot_images([image, map_viz], titles=[\"input image\", \"OpenStreetMap raster\"])\n", "plot_nodes(1, canvas.raster[2], fontsize=6, size=10)" @@ -1186,7 +1189,8 @@ "\n", "# Run the inference\n", "uv, yaw, prob, neural_map, image_rectified = demo.localize(\n", - " image, camera, canvas, roll_pitch=gravity)\n", + " image, camera, canvas, roll_pitch=gravity\n", + ")\n", "\n", "# Visualize the predictions\n", "overlay = likelihood_overlay(prob.numpy().max(-1), map_viz.mean(-1, keepdims=True))\n", @@ -1194,7 +1198,7 @@ "plot_images([overlay, neural_map_rgb], titles=[\"prediction\", \"neural map\"])\n", "ax = plt.gcf().axes[0]\n", "ax.scatter(*canvas.to_uv(bbox.center), s=5, c=\"red\")\n", - "plot_dense_rotations(ax, prob, w=0.005, s=1/25)\n", + "plot_dense_rotations(ax, prob, w=0.005, s=1 / 25)\n", "add_circle_inset(ax, uv)\n", "plt.show(\"notebook\")\n", "\n", diff --git a/maploc/__init__.py b/maploc/__init__.py index 8d1fff3..8824223 100644 --- a/maploc/__init__.py +++ b/maploc/__init__.py @@ -1,11 +1,10 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. -from pathlib import Path import logging +from pathlib import Path import pytorch_lightning # noqa: F401 - formatter = logging.Formatter( fmt="[%(asctime)s %(name)s %(levelname)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S", diff --git a/maploc/data/image.py b/maploc/data/image.py index 4e66b26..7bc0bdc 100644 --- a/maploc/data/image.py +++ b/maploc/data/image.py @@ -1,11 +1,11 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. -from typing import Callable, Optional, Union, Sequence +import collections +from typing import Callable, Optional, Sequence, Union import numpy as np import torch import torchvision.transforms.functional as tvf -import collections from scipy.spatial.transform import Rotation from ..utils.geometry import from_homogeneous, to_homogeneous diff --git a/maploc/data/kitti/dataset.py b/maploc/data/kitti/dataset.py index 5b0153c..905057f 100644 --- a/maploc/data/kitti/dataset.py +++ b/maploc/data/kitti/dataset.py @@ -13,12 +13,12 @@ from omegaconf import OmegaConf from scipy.spatial.transform import Rotation -from ... import logger, DATASETS_PATH +from ... import DATASETS_PATH, logger from ...osm.tiling import TileManager from ..dataset import MapLocDataset from ..sequential import chunk_sequence from ..torch import collate, worker_init_fn -from .utils import parse_split_file, parse_gps_file, get_camera_calibration +from .utils import get_camera_calibration, parse_gps_file, parse_split_file class KittiDataModule(pl.LightningDataModule): diff --git a/maploc/data/kitti/prepare.py b/maploc/data/kitti/prepare.py index 5fb059d..a3f7978 100644 --- a/maploc/data/kitti/prepare.py +++ b/maploc/data/kitti/prepare.py @@ -1,9 +1,9 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. import argparse -from pathlib import Path import shutil import zipfile +from pathlib import Path import numpy as np from tqdm.auto import tqdm @@ -12,9 +12,9 @@ from ...osm.tiling import TileManager from ...osm.viz import GeoPlotter from ...utils.geo import BoundaryBox, Projection -from ...utils.io import download_file, DATA_URL -from .utils import parse_gps_file +from ...utils.io import DATA_URL, download_file from .dataset import KittiDataModule +from .utils import parse_gps_file split_files = ["test1_files.txt", "test2_files.txt", "train_files.txt"] diff --git a/maploc/data/mapillary/dataset.py b/maploc/data/mapillary/dataset.py index f15874f..7dc0f7f 100644 --- a/maploc/data/mapillary/dataset.py +++ b/maploc/data/mapillary/dataset.py @@ -1,10 +1,10 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. import json -from collections import defaultdict import os import shutil import tarfile +from collections import defaultdict from pathlib import Path from typing import Any, Dict, Optional @@ -14,7 +14,7 @@ import torch.utils.data as torchdata from omegaconf import DictConfig, OmegaConf -from ... import logger, DATASETS_PATH +from ... import DATASETS_PATH, logger from ...osm.tiling import TileManager from ..dataset import MapLocDataset from ..sequential import chunk_sequence diff --git a/maploc/data/mapillary/download.py b/maploc/data/mapillary/download.py index 4042237..83876c1 100644 --- a/maploc/data/mapillary/download.py +++ b/maploc/data/mapillary/download.py @@ -1,25 +1,24 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +import asyncio import json from pathlib import Path -import numpy as np import httpx -import asyncio -from aiolimiter import AsyncLimiter +import numpy as np import tqdm - +from aiolimiter import AsyncLimiter from opensfm.pygeometry import Camera, Pose from opensfm.pymap import Shot from ... import logger from ...utils.geo import Projection - semaphore = asyncio.Semaphore(100) # number of parallel threads. image_filename = "{image_id}.jpg" info_filename = "{image_id}.json" + def retry(times, exceptions): def decorator(func): async def wrapper(*args, **kwargs): @@ -30,9 +29,12 @@ async def wrapper(*args, **kwargs): except exceptions: attempt += 1 return await func(*args, **kwargs) + return wrapper + return decorator + class MapillaryDownloader: image_fields = ( "id", diff --git a/maploc/data/mapillary/prepare.py b/maploc/data/mapillary/prepare.py index 5c4a7ba..af8d96e 100644 --- a/maploc/data/mapillary/prepare.py +++ b/maploc/data/mapillary/prepare.py @@ -1,17 +1,15 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. -import asyncio import argparse -from collections import defaultdict +import asyncio import json import shutil +from collections import defaultdict from pathlib import Path from typing import List -import numpy as np import cv2 -from tqdm import tqdm -from tqdm.contrib.concurrent import thread_map +import numpy as np from omegaconf import DictConfig, OmegaConf from opensfm.pygeometry import Camera from opensfm.pymap import Shot @@ -19,30 +17,31 @@ perspective_camera_from_fisheye, perspective_camera_from_perspective, ) +from tqdm import tqdm +from tqdm.contrib.concurrent import thread_map from ... import logger from ...osm.tiling import TileManager from ...osm.viz import GeoPlotter from ...utils.geo import BoundaryBox, Projection -from ...utils.io import write_json, download_file, DATA_URL +from ...utils.io import DATA_URL, download_file, write_json from ..utils import decompose_rotmat +from .dataset import MapillaryDataModule +from .download import ( + MapillaryDownloader, + fetch_image_infos, + fetch_images_pixels, + image_filename, + opensfm_shot_from_info, +) from .utils import ( + CameraUndistorter, + PanoramaUndistorter, keyframe_selection, perspective_camera_from_pano, scale_camera, - CameraUndistorter, - PanoramaUndistorter, undistort_shot, ) -from .download import ( - MapillaryDownloader, - opensfm_shot_from_info, - image_filename, - fetch_image_infos, - fetch_images_pixels, -) -from .dataset import MapillaryDataModule - location_to_params = { "sanfrancisco_soma": { diff --git a/maploc/data/mapillary/utils.py b/maploc/data/mapillary/utils.py index b7a991e..ddfe181 100644 --- a/maploc/data/mapillary/utils.py +++ b/maploc/data/mapillary/utils.py @@ -6,7 +6,7 @@ import cv2 import numpy as np from opensfm import features -from opensfm.pygeometry import Camera, compute_camera_mapping, Pose +from opensfm.pygeometry import Camera, Pose, compute_camera_mapping from opensfm.pymap import Shot from scipy.spatial.transform import Rotation diff --git a/maploc/data/torch.py b/maploc/data/torch.py index 9547ca1..1be33e6 100644 --- a/maploc/data/torch.py +++ b/maploc/data/torch.py @@ -4,14 +4,14 @@ import os import torch +from lightning_fabric.utilities.apply_func import move_data_to_device +from lightning_fabric.utilities.seed import pl_worker_init_function +from lightning_utilities.core.apply_func import apply_to_collection from torch.utils.data import get_worker_info from torch.utils.data._utils.collate import ( default_collate_err_msg_format, np_str_obj_array_pattern, ) -from lightning_fabric.utilities.seed import pl_worker_init_function -from lightning_utilities.core.apply_func import apply_to_collection -from lightning_fabric.utilities.apply_func import move_data_to_device def collate(batch): diff --git a/maploc/demo.py b/maploc/demo.py index 7abbb41..985b6fa 100644 --- a/maploc/demo.py +++ b/maploc/demo.py @@ -2,19 +2,19 @@ from typing import Optional, Tuple -import torch import numpy as np +import torch from . import logger -from .evaluation.run import resolve_checkpoint_path, pretrained_models +from .data.image import pad_image, rectify_image, resize_image +from .evaluation.run import pretrained_models, resolve_checkpoint_path from .models.orienternet import OrienterNet -from .models.voting import fuse_gps, argmax_xyr -from .data.image import resize_image, pad_image, rectify_image +from .models.voting import argmax_xyr, fuse_gps from .osm.raster import Canvas -from .utils.wrappers import Camera -from .utils.io import read_image -from .utils.geo import BoundaryBox, Projection from .utils.exif import EXIF +from .utils.geo import BoundaryBox, Projection +from .utils.io import read_image +from .utils.wrappers import Camera try: from geopy.geocoders import Nominatim diff --git a/maploc/evaluation/kitti.py b/maploc/evaluation/kitti.py index e91da06..13b0668 100644 --- a/maploc/evaluation/kitti.py +++ b/maploc/evaluation/kitti.py @@ -4,13 +4,12 @@ from pathlib import Path from typing import Optional, Tuple -from omegaconf import OmegaConf, DictConfig +from omegaconf import DictConfig, OmegaConf from .. import logger from ..data import KittiDataModule from .run import evaluate - default_cfg_single = OmegaConf.create({}) # For the sequential evaluation, we need to center the map around the GT location, # since random offsets would accumulate and leave only the GT location with a valid mask. diff --git a/maploc/evaluation/mapillary.py b/maploc/evaluation/mapillary.py index 4a867d6..f9f117b 100644 --- a/maploc/evaluation/mapillary.py +++ b/maploc/evaluation/mapillary.py @@ -4,14 +4,13 @@ from pathlib import Path from typing import Optional, Tuple -from omegaconf import OmegaConf, DictConfig +from omegaconf import DictConfig, OmegaConf from .. import logger from ..conf import data as conf_data_dir from ..data import MapillaryDataModule from .run import evaluate - split_overrides = { "val": { "scenes": [ diff --git a/maploc/evaluation/run.py b/maploc/evaluation/run.py index a93b201..90ec2df 100644 --- a/maploc/evaluation/run.py +++ b/maploc/evaluation/run.py @@ -2,26 +2,25 @@ import functools from itertools import islice -from typing import Callable, Dict, Optional, Tuple from pathlib import Path +from typing import Callable, Dict, Optional, Tuple import numpy as np import torch from omegaconf import DictConfig, OmegaConf -from torchmetrics import MetricCollection from pytorch_lightning import seed_everything +from torchmetrics import MetricCollection from tqdm import tqdm -from .. import logger, EXPERIMENTS_PATH +from .. import EXPERIMENTS_PATH, logger from ..data.torch import collate, unbatch_to_device -from ..models.voting import argmax_xyr, fuse_gps from ..models.metrics import AngleError, LateralLongitudinalError, Location2DError from ..models.sequential import GPSAligner, RigidAligner +from ..models.voting import argmax_xyr, fuse_gps from ..module import GenericModule -from ..utils.io import download_file, DATA_URL -from .viz import plot_example_single, plot_example_sequential +from ..utils.io import DATA_URL, download_file from .utils import write_dump - +from .viz import plot_example_sequential, plot_example_single pretrained_models = dict( OrienterNet_MGL=("orienternet_mgl.ckpt", dict(num_rotations=256)), diff --git a/maploc/evaluation/viz.py b/maploc/evaluation/viz.py index 5519fa4..37abc2a 100644 --- a/maploc/evaluation/viz.py +++ b/maploc/evaluation/viz.py @@ -1,18 +1,18 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +import matplotlib.pyplot as plt import numpy as np import torch -import matplotlib.pyplot as plt +from ..osm.viz import Colormap, plot_nodes from ..utils.io import write_torch_image -from ..utils.viz_2d import plot_images, features_to_RGB, save_plot +from ..utils.viz_2d import features_to_RGB, plot_images, save_plot from ..utils.viz_localization import ( + add_circle_inset, likelihood_overlay, - plot_pose, plot_dense_rotations, - add_circle_inset, + plot_pose, ) -from ..osm.viz import Colormap, plot_nodes def plot_example_single( diff --git a/maploc/models/orienternet.py b/maploc/models/orienternet.py index 1e3c7fc..7be9f6f 100644 --- a/maploc/models/orienternet.py +++ b/maploc/models/orienternet.py @@ -8,7 +8,10 @@ from .base import BaseModel from .bev_net import BEVNet from .bev_projection import CartesianProjection, PolarProjectionDepth +from .map_encoder import MapEncoder +from .metrics import AngleError, AngleRecall, Location2DError, Location2DRecall from .voting import ( + TemplateSampler, argmax_xyr, conv2d_fft_batchwise, expectation_xyr, @@ -16,10 +19,7 @@ mask_yaw_prior, nll_loss_xyr, nll_loss_xyr_smoothed, - TemplateSampler, ) -from .map_encoder import MapEncoder -from .metrics import AngleError, AngleRecall, Location2DError, Location2DRecall class OrienterNet(BaseModel): diff --git a/maploc/models/sequential.py b/maploc/models/sequential.py index 606be3f..1de372c 100644 --- a/maploc/models/sequential.py +++ b/maploc/models/sequential.py @@ -3,8 +3,8 @@ import numpy as np import torch -from .voting import argmax_xyr, log_softmax_spatial, sample_xyr from .utils import deg2rad, make_grid, rotmat2d +from .voting import argmax_xyr, log_softmax_spatial, sample_xyr def log_gaussian(points, mean, sigma): diff --git a/maploc/osm/analysis.py b/maploc/osm/analysis.py index a667c21..669826b 100644 --- a/maploc/osm/analysis.py +++ b/maploc/osm/analysis.py @@ -8,6 +8,7 @@ import plotly.graph_objects as go from .parser import ( + Patterns, filter_area, filter_node, filter_way, @@ -15,7 +16,6 @@ parse_area, parse_node, parse_way, - Patterns, ) from .reader import OSMData diff --git a/maploc/osm/data.py b/maploc/osm/data.py index dafc568..066ac79 100644 --- a/maploc/osm/data.py +++ b/maploc/osm/data.py @@ -7,6 +7,7 @@ import numpy as np from .parser import ( + Patterns, filter_area, filter_node, filter_way, @@ -14,11 +15,9 @@ parse_area, parse_node, parse_way, - Patterns, ) from .reader import OSMData, OSMNode, OSMRelation, OSMWay - logger = logging.getLogger(__name__) diff --git a/maploc/osm/download.py b/maploc/osm/download.py index 307ec87..2c1f7a2 100644 --- a/maploc/osm/download.py +++ b/maploc/osm/download.py @@ -1,9 +1,9 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. import json +from http.client import responses from pathlib import Path from typing import Any, Dict, Optional -from http.client import responses import urllib3 diff --git a/maploc/osm/reader.py b/maploc/osm/reader.py index 78d44ad..1b0bd36 100644 --- a/maploc/osm/reader.py +++ b/maploc/osm/reader.py @@ -6,8 +6,8 @@ from pathlib import Path from typing import Any, Dict, List, Optional -from lxml import etree import numpy as np +from lxml import etree from ..utils.geo import BoundaryBox, Projection diff --git a/maploc/osm/tiling.py b/maploc/osm/tiling.py index 0363b8f..6ebb10d 100644 --- a/maploc/osm/tiling.py +++ b/maploc/osm/tiling.py @@ -6,8 +6,8 @@ from typing import Dict, List, Optional, Tuple import numpy as np -from PIL import Image import rtree +from PIL import Image from ..utils.geo import BoundaryBox, Projection from .data import MapData diff --git a/maploc/osm/viz.py b/maploc/osm/viz.py index 2925618..70c97eb 100644 --- a/maploc/osm/viz.py +++ b/maploc/osm/viz.py @@ -3,8 +3,8 @@ import matplotlib as mpl import matplotlib.pyplot as plt import numpy as np -import plotly.graph_objects as go import PIL.Image +import plotly.graph_objects as go from ..utils.viz_2d import add_text from .parser import Groups diff --git a/maploc/train.py b/maploc/train.py index ee2906c..2834ddc 100644 --- a/maploc/train.py +++ b/maploc/train.py @@ -1,8 +1,8 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. import os.path as osp -from typing import Optional from pathlib import Path +from typing import Optional import hydra import pytorch_lightning as pl @@ -10,7 +10,7 @@ from omegaconf import DictConfig, OmegaConf from pytorch_lightning.utilities import rank_zero_only -from . import logger, pl_logger, EXPERIMENTS_PATH +from . import EXPERIMENTS_PATH, logger, pl_logger from .data import modules as data_modules from .module import GenericModule diff --git a/maploc/utils/exif.py b/maploc/utils/exif.py index 4c419ea..a061c39 100644 --- a/maploc/utils/exif.py +++ b/maploc/utils/exif.py @@ -1,9 +1,10 @@ """Copied from opensfm.exif to minimize hard dependencies.""" -from pathlib import Path -import json + import datetime +import json import logging -from codecs import encode, decode +from codecs import decode, encode +from pathlib import Path from typing import Any, Dict, Optional, Tuple import exifread diff --git a/maploc/utils/geo_opensfm.py b/maploc/utils/geo_opensfm.py index d421452..ad08ff2 100644 --- a/maploc/utils/geo_opensfm.py +++ b/maploc/utils/geo_opensfm.py @@ -1,7 +1,9 @@ """Copied from opensfm.geo to minimize hard dependencies.""" + +from typing import Tuple + import numpy as np from numpy import ndarray -from typing import Tuple WGS84_a = 6378137.0 WGS84_b = 6356752.314245 diff --git a/maploc/utils/io.py b/maploc/utils/io.py index bae91c5..dbe30a3 100644 --- a/maploc/utils/io.py +++ b/maploc/utils/io.py @@ -1,12 +1,12 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. import json -import requests import shutil from pathlib import Path import cv2 import numpy as np +import requests import torch from tqdm.auto import tqdm diff --git a/setup.py b/setup.py index 43deeb0..fb80132 100644 --- a/setup.py +++ b/setup.py @@ -1,11 +1,11 @@ -from setuptools import setup, find_packages +from setuptools import find_packages, setup setup( - name='maploc', - version='0.0.0', + name="maploc", + version="0.0.0", packages=find_packages(), - python_requires='>=3.8', - author='Paul-Edouard Sarlin', + python_requires=">=3.8", + author="Paul-Edouard Sarlin", long_description_content_type="text/markdown", classifiers=[ "Programming Language :: Python :: 3",