diff --git a/README.md b/README.md index c270a689..1c32ba67 100644 --- a/README.md +++ b/README.md @@ -61,7 +61,7 @@ Key = highway Value = crossing ZoomLevel = 19 Compare = yes -Orthofoto = other +Orthofoto = wms FollowStreets = yes StepWidth = 0.66 @@ -86,16 +86,8 @@ Some hints to the config file: ### Own Orthofotos -To use your own Orthofotos you have to do the following steps: - -1. Add a new directory to `src/data/orthofoto` -2. Add a new module to the directory with the name: <your_new_directory>_api.py -3. Create a class in the module with the name: <Your_new_directory>Api (First letter needs to be uppercase) -4. Implement the function `def get_image(self, bbox):` and returns a pillow image of the bbox -5. After that you can use your api with the parameter --orthofots <your_new_directory> - -If you have problems with the implementation have a look at the wms or other example. +Provide it as a WMS from a MapProxy server. ## Dataset During this work, we have collected our own dataset with swiss crosswalks and non-crosswalks. The pictures have a size of 50x50 pixels and are available by request. diff --git a/mapproxy.yml b/mapproxy.yml new file mode 100644 index 00000000..e41712c4 --- /dev/null +++ b/mapproxy.yml @@ -0,0 +1,43 @@ +layers: + - name: osm + title: osm + sources: [cache_osm] + - name: bing + title: bing + sources: [cache_bing] + - name: ign + title: ign + sources: [cache_ign] + +caches: + cache_osm: + grids: [webmercator] + sources: [source_osm] + cache_bing: + grids: [webmercator] + sources: [source_bing] + cache_ign: + grids: [webmercator] + sources: [source_ign] + +sources: + source_osm: + type: tile + grid: GLOBAL_WEBMERCATOR + url: http://a.tile.openstreetmap.org/%(z)s/%(x)s/%(y)s.png + source_bing: + type: tile + grid: GLOBAL_WEBMERCATOR + url: https://t2.ssl.ak.tiles.virtualearth.net/tiles/a%(quadkey)s.jpeg?g=4401&n=z + source_ign: + type: tile + grid: GLOBAL_WEBMERCATOR + url: https://proxy-ign.openstreetmap.fr/94GjiyqD/bdortho/%(z)s/%(x)s/%(y)s.jpg + +grids: + webmercator: + base: GLOBAL_WEBMERCATOR + +services: + demo: + wms: diff --git a/requires.dev.txt b/requires.dev.txt index 5189fef9..cbfb1859 100644 --- a/requires.dev.txt +++ b/requires.dev.txt @@ -1,7 +1,7 @@ click==6.6 geopy==1.11.0 httplib2==0.9.2 -overpass==0.4.0 +overpass==0.6.0 Pillow==2.6.1 pytest==2.9.2 redis==2.10.5 diff --git a/src/base/configuration.py b/src/base/configuration.py index 3b6fd885..e6d1b414 100644 --- a/src/base/configuration.py +++ b/src/base/configuration.py @@ -17,7 +17,7 @@ def __init__(self, config_file_path=''): {'option': 'value', 'fallback': 'crossing'}, {'option': 'zoomlevel', 'fallback': '19'}, {'option': 'compare', 'fallback': 'yes'}, - {'option': 'orthophoto', 'fallback': 'other'}, + {'option': 'orthophoto', 'fallback': 'wms'}, {'option': 'stepwidth', 'fallback': '0.66'}, {'option': 'followstreets', 'fallback': 'yes'}]}, {'section': 'JOB', 'options': [{'option': 'bboxsize', 'fallback': '2000'}, diff --git a/src/base/tile.py b/src/base/tile.py index e5e779a1..ac4984fd 100755 --- a/src/base/tile.py +++ b/src/base/tile.py @@ -3,7 +3,8 @@ class Tile: - def __init__(self, image=None, bbox=None): + def __init__(self, image_api=None, image=None, bbox=None): + self.image_api = image_api self.image = image self.bbox = bbox @@ -14,15 +15,15 @@ def get_pixel(self, node): x = node.longitude - self.bbox.left y = node.latitude - self.bbox.bottom - pixel_x = int(self.image.size[0] * (x / image_width)) - pixel_y = self.image.size[1] - int(self.image.size[1] * (y / image_height)) + image_size = self.image_api.get_image_size(self.bbox) + pixel_x = int(image_size[0] * (x / image_width)) + pixel_y = image_size[1] - int(image_size[1] * (y / image_height)) return pixel_x, pixel_y def get_node(self, pixel=(0, 0)): x = pixel[0] y = pixel[1] - image_size_x = self.image.size[0] - image_size_y = self.image.size[1] + image_size_x, image_size_y = self.image_api.get_image_size(self.bbox) y_part = 0 x_part = 0 if image_size_x > 0 and image_size_y > 0: @@ -45,12 +46,12 @@ def get_tile_by_node(self, centre_node, side_length): y2 = centre_pixel[1] + side_length // 2 crop_box = (x1, y1, x2, y2) - img = self.image.crop(crop_box) left_down = self.get_node((x1, y1)) right_up = self.get_node((x2, y2)) bbox = Bbox.from_nodes(node_left_down=left_down, node_right_up=right_up) + img = self.image_api.get_image(bbox) - return Tile(img, bbox) + return Tile(image=img, bbox=bbox) def get_centre_node(self): diff_lat = self.bbox.node_right_up().latitude - self.bbox.node_left_down().latitude diff --git a/src/data/orthofoto/other/__init__.py b/src/data/orthofoto/other/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/data/orthofoto/other/fitting_bbox.py b/src/data/orthofoto/other/fitting_bbox.py deleted file mode 100644 index 3d3d6706..00000000 --- a/src/data/orthofoto/other/fitting_bbox.py +++ /dev/null @@ -1,29 +0,0 @@ -from src.base.bbox import Bbox -from src.base.globalmaptiles import GlobalMercator - - -class FittingBbox: - def __init__(self, zoom_level=19): - self._mercator = GlobalMercator() - self._zoom_level = zoom_level - - def get(self, bbox): - t_minx, t_miny, t_maxx, t_maxy = self.bbox_to_tiles(bbox) - bbox = self._bbox_from(t_minx, t_miny, t_maxx, t_maxy) - return bbox - - def bbox_to_tiles(self, bbox): - m_minx, m_miny = self._mercator.LatLonToMeters(bbox.bottom, bbox.left) - m_maxx, m_maxy = self._mercator.LatLonToMeters(bbox.top, bbox.right) - t_maxx, t_maxy = self._mercator.MetersToTile(m_maxx, m_maxy, self._zoom_level) - t_minx, t_miny = self._mercator.MetersToTile(m_minx, m_miny, self._zoom_level) - return t_minx, t_miny, t_maxx, t_maxy - - def generate_bbox(self, tx, ty): - bottom, left, top, right = self._mercator.TileLatLonBounds(tx, ty, self._zoom_level) - return Bbox(left=left, bottom=bottom, right=right, top=top) - - def _bbox_from(self, t_minx, t_miny, t_maxx, t_maxy): - bottom, left, _, _ = self._mercator.TileLatLonBounds(t_minx, t_miny, self._zoom_level) - _, _, top, right = self._mercator.TileLatLonBounds(t_maxx, t_maxy, self._zoom_level) - return Bbox(left=left, bottom=bottom, right=right, top=top) diff --git a/src/data/orthofoto/other/multi_loader.py b/src/data/orthofoto/other/multi_loader.py deleted file mode 100755 index 97d44440..00000000 --- a/src/data/orthofoto/other/multi_loader.py +++ /dev/null @@ -1,60 +0,0 @@ -import logging -import time -from io import BytesIO -from multiprocessing.dummy import Pool as ThreadPool - -import requests -from PIL import Image - -from src.data.orthofoto.other.user_agent import UserAgent - - -class MultiLoader: - def __init__(self, urls, auth=None): - self.urls = urls - self.results = [] - self.nb_threads = 10 - self.nb_tile_per_trial = 40 - self.auth = tuple() if auth is None else auth - self.user_agent = UserAgent() - self.logger = logging.getLogger(__name__) - - def download(self): - results = [] - nb_urls = len(self.urls) - for i in range(int(nb_urls / self.nb_tile_per_trial) + 1): - start = i * self.nb_tile_per_trial - end = start + self.nb_tile_per_trial - if end >= nb_urls: - end = nb_urls - url_part = self.urls[start:end] - - result = self._try_download(url_part) - results += result - - self.results = results - - def _try_download(self, urls): - for i in range(4): - try: - results = self._download_async(urls) - return results - except Exception as e: - print("Tile download failed " + str(i) + " wait " + str(i * 10) + " " + str(e)) - time.sleep(i * 10) - error_message = "Download of tiles have failed 4 times" - self.logger.error(error_message) - raise Exception(error_message) - - def _download_async(self, urls): - pool = ThreadPool(self.nb_threads) - results = pool.map(self._download_image, urls) - pool.close() - pool.join() - return results - - def _download_image(self, url): - response = requests.get(url, headers={'User-Agent': self.user_agent.random}, auth=self.auth) - img = Image.open(BytesIO(response.content)) - img.filename = url - return img diff --git a/src/data/orthofoto/other/other_api.py b/src/data/orthofoto/other/other_api.py deleted file mode 100644 index fe1b0d7d..00000000 --- a/src/data/orthofoto/other/other_api.py +++ /dev/null @@ -1,78 +0,0 @@ -from PIL import Image - -from src.base.bbox import Bbox -from src.base.tile import Tile -from src.data.orthofoto.other.multi_loader import MultiLoader -from src.data.orthofoto.other.url_builder import UrlBuilder -from src.base.globalmaptiles import GlobalMercator - - -class OtherApi: - def __init__(self, zoom_level=19): - self._mercator = GlobalMercator() - self._zoom_level = zoom_level - self.tile = None - - def get_image(self, bbox): - t_minx, t_miny, t_maxx, t_maxy = self._bbox_to_tile_indexes(bbox) - images = self._download_images(t_minx, t_miny, t_maxx, t_maxy) - image_matrix = self._to_image_matrix(images, t_minx, t_miny, t_maxx, t_maxy) - image = self._to_image(image_matrix) - big_bbox = self._generate_bbox(t_minx, t_miny, t_maxx, t_maxy) - self.tile = Tile(image, big_bbox) - return self._crop(self.tile, bbox) - - @staticmethod - def _to_image_matrix(images, t_minx, t_miny, t_maxx, t_maxy): - image_matrix = [] - row = 0 - url_number = 0 - for ty in range(t_miny, t_maxy + 1): - image_matrix.append([]) - for tx in range(t_minx, t_maxx + 1): - image = images[url_number] - image_matrix[row].append(image) - url_number += 1 - row += 1 - return image_matrix - - def _download_images(self, t_minx, t_miny, t_maxx, t_maxy): - url_builder = UrlBuilder(self._zoom_level) - urls = url_builder.get_urls_by_tiles(t_minx, t_miny, t_maxx, t_maxy) - loader = MultiLoader(urls) - loader.download() - return loader.results - - @staticmethod - def _to_image(image_matrix): - num_rows = len(image_matrix) - num_cols = len(image_matrix[0]) - width, height = image_matrix[0][0].size - - result = Image.new("RGB", (num_cols * width, num_rows * height)) - - for y in range(0, num_rows): - for x in range(0, num_cols): - result.paste(image_matrix[y][x], (x * width, (num_rows - 1 - y) * height)) - return result - - def _bbox_to_tile_indexes(self, bbox): - m_minx, m_miny = self._mercator.LatLonToMeters(bbox.bottom, bbox.left) - m_maxx, m_maxy = self._mercator.LatLonToMeters(bbox.top, bbox.right) - t_maxx, t_maxy = self._mercator.MetersToTile(m_maxx, m_maxy, self._zoom_level) - t_minx, t_miny = self._mercator.MetersToTile(m_minx, m_miny, self._zoom_level) - return t_minx, t_miny, t_maxx, t_maxy - - def _generate_bbox(self, t_minx, t_miny, t_maxx, t_maxy): - bottom, left, _, _ = self._mercator.TileLatLonBounds(t_minx, t_miny, self._zoom_level) - _, _, top, right = self._mercator.TileLatLonBounds(t_maxx, t_maxy, self._zoom_level) - return Bbox(left=left, bottom=bottom, right=right, top=top) - - @staticmethod - def _crop(tile, bbox): - left, bottom = tile.get_pixel(bbox.node_left_down()) - right, top = tile.get_pixel(bbox.node_right_up()) - box = (left, top, right, bottom) - cropped_image = tile.image.crop(box) - image = Image.frombytes(mode='RGB', data=cropped_image.tobytes(), size=cropped_image.size) - return image diff --git a/src/data/orthofoto/other/url_builder.py b/src/data/orthofoto/other/url_builder.py deleted file mode 100644 index 041a8e3a..00000000 --- a/src/data/orthofoto/other/url_builder.py +++ /dev/null @@ -1,26 +0,0 @@ -import random - -from src.base.globalmaptiles import GlobalMercator - - -class UrlBuilder: - def __init__(self, zoom_level=19): - self._url_first_part = 'https://t' - self._url_second_part = '.ssl.ak.tiles.virtualearth.net/tiles/a' - self._url_last_part = '.jpeg?g=4401&n=z' - self._zoom_level = zoom_level - self._mercator = GlobalMercator() - - def get_urls_by_tiles(self, t_minx, t_miny, t_maxx, t_maxy): - urls = [] - for ty in range(t_miny, t_maxy + 1): - for tx in range(t_minx, t_maxx + 1): - quad_tree = self._mercator.QuadTree(tx, ty, self._zoom_level) - url = self._build_url(quad_tree) - urls.append(url) - return urls - - def _build_url(self, quadtree): - server = random.randint(0, 7) - return self._url_first_part + str(server) + self._url_second_part + str( - quadtree) + self._url_last_part diff --git a/src/data/orthofoto/other/user_agent.py b/src/data/orthofoto/other/user_agent.py deleted file mode 100644 index eb616332..00000000 --- a/src/data/orthofoto/other/user_agent.py +++ /dev/null @@ -1,23 +0,0 @@ -from random import choice - - -class UserAgent: - def __init__(self): - self.user_agents = [ - 'Mozilla/5.0 (Windows; U; Windows NT 5.1; it; rv:1.8.1.11) Gecko/20071127 Firefox/2.0.0.11', - 'Opera/9.25 (Windows NT 5.1; U; en)', - 'Mozilla/4.0 (compatible; MSIE 6.0; Windows NT 5.1; SV1; .NET CLR 1.1.4322; .NET CLR 2.0.50727)', - 'Mozilla/5.0 (compatible; Konqueror/3.5; Linux) KHTML/3.5.5 (like Gecko) (Kubuntu)', - 'Mozilla/5.0 (X11; U; Linux i686; en-US; rv:1.8.0.12) Gecko/20070731 Ubuntu/dapper-security Firefox/1.5.0.12', - 'Lynx/2.8.5rel.1 libwww-FM/2.14 SSL-MM/1.4.1 GNUTLS/1.2.9' - 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_9_3) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/35.0.1916.47 Safari/537.36' - 'Opera/9.80 (X11; Linux i686; U; ru) Presto/2.8.131 Version/11.11' - 'Dalvik/2.1.0 (Linux; U; Android 6.0.1; Nexus Player Build/MMB29T)' - 'Mozilla/3.0(DDIPOCKET;JRC/AH-J3001V,AH-J3002V/1.0/0100/c50)CNF/2.0' - 'Mozilla/1.22 (compatible; MSIE 5.01; PalmOS 3.0) EudoraWeb 2.1' - 'Mozilla/4.0 (compatible; MSIE 4.01; Windows CE; PPC; 240x320)' - ] - - @property - def random(self): - return choice(self.user_agents) diff --git a/src/data/orthofoto/tile_loader.py b/src/data/orthofoto/tile_loader.py deleted file mode 100755 index c96ac779..00000000 --- a/src/data/orthofoto/tile_loader.py +++ /dev/null @@ -1,13 +0,0 @@ -from src.base.tile import Tile - - -class TileLoader: - def __init__(self, bbox=None, image_api=None): - self.bbox = bbox - self.image_api = image_api - self.tile = None - - def load_tile(self): - image = self.image_api.get_image(self.bbox) - self.tile = Tile(image, self.bbox) - return self.tile diff --git a/src/data/orthofoto/wms/wms.ini b/src/data/orthofoto/wms/wms.ini index b408a9bc..2b3298dc 100644 --- a/src/data/orthofoto/wms/wms.ini +++ b/src/data/orthofoto/wms/wms.ini @@ -3,9 +3,7 @@ NtlmUser = hans NtlmPassword = test [WMS] -Url = http://test.com -Srs = EPSG:4326 -Version = 1.0.1 -Layer = upwms - - +Url = http://localhost:8080/service +Srs = EPSG:3857 +Version = 1.1.1 +Layer = bing diff --git a/src/data/orthofoto/wms/wms_api.py b/src/data/orthofoto/wms/wms_api.py index 1fbbac79..666127a8 100644 --- a/src/data/orthofoto/wms/wms_api.py +++ b/src/data/orthofoto/wms/wms_api.py @@ -5,6 +5,7 @@ from PIL import Image from src.base import geo_helper +from src.base.globalmaptiles import GlobalMercator from src.data.orthofoto.wms.auth_monkey_patch import AuthMonkeyPatch from requests_ntlm import HttpNtlmAuth @@ -17,6 +18,7 @@ def __init__(self, zoom_level=19): self.auth = self.set_auth() self.zoom_level = zoom_level self._auth_monkey_patch(self.auth) + self.mercator = GlobalMercator() from owslib.wms import WebMapService self.wms = WebMapService(url=self.config.get(section='WMS', option='Url'), @@ -42,15 +44,20 @@ def _auth_monkey_patch(auth): AuthMonkeyPatch(auth) def get_image(self, bbox): + bbox_ = self._box(bbox) + bbox_ = self.mercator.LatLonToMeters(bbox_[3], bbox_[0]) + self.mercator.LatLonToMeters(bbox_[1], bbox_[2]) size = self._calculate_image_size(bbox, self.zoom_level) image = self._get(layers=[self.config.get(section='WMS', option='Layer')], srs=self.config.get(section='WMS', option='Srs'), - bbox=self._box(bbox), + bbox=bbox_, size=size, format='image/jpeg', ) return image + def get_image_size(self, bbox): + return self._calculate_image_size(bbox, self.zoom_level) + @staticmethod def _calculate_image_size(bbox, zoom_level): meters_per_pixel = geo_helper.meters_per_pixel(zoom_level, bbox.bottom) diff --git a/src/data/osm/overpass_api.py b/src/data/osm/overpass_api.py index 79b98a26..8e912f10 100644 --- a/src/data/osm/overpass_api.py +++ b/src/data/osm/overpass_api.py @@ -8,28 +8,31 @@ def __init__(self): self.overpass = overpass.API(timeout=60) self.logger = logging.getLogger(__name__) - def get(self, bbox, tags): - query = self._get_query(bbox, tags) - return self._try_overpass_download(query) + def get(self, bbox, tags, nodes=True, ways=True, relations=True, responseformat='geojson'): + query = self._get_query(bbox, tags, nodes, ways, relations) + return self._try_overpass_download(query, responseformat) @staticmethod - def _get_query(bbox, tags): + def _get_query(bbox, tags, nodes=True, ways=True, relations=True): bbox_string = '(' + str(bbox) + ');' query = '(' for tag in tags: - node = 'node["' + tag.key + '"="' + tag.value + '"]' + bbox_string - way = 'way["' + tag.key + '"="' + tag.value + '"]' + bbox_string - relation = 'relation["' + tag.key + '"="' + tag.value + '"]' + bbox_string - query += node + way + relation + if nodes: + query += 'node["' + tag.key + '"="' + tag.value + '"]' + bbox_string + if ways: + query += 'way["' + tag.key + '"="' + tag.value + '"]' + bbox_string + if relations: + query += 'relation["' + tag.key + '"="' + tag.value + '"]' + bbox_string query += ');' return query - def _try_overpass_download(self, query): + def _try_overpass_download(self, query, responseformat='geojson'): for i in range(4): try: - json_data = self.overpass.Get(query) + json_data = self.overpass.get(query, responseformat=responseformat) return json_data except Exception as e: + self.logger.warning(e) self.logger.warning("Download from overpass failed " + str(i) + " wait " + str(i * 10) + ". " + str(e)) time.sleep(i * 10) error_message = "Download from overpass failed 4 times." diff --git a/src/data/osm/street_loader.py b/src/data/osm/street_loader.py index 33438f1f..7a6c20b5 100644 --- a/src/data/osm/street_loader.py +++ b/src/data/osm/street_loader.py @@ -9,8 +9,8 @@ class StreetLoader: street_categories = [ - 'road', - 'trunk', +# 'road', +# 'trunk', 'primary', 'secondary', 'tertiary', @@ -29,7 +29,7 @@ def __init__(self, categories=None): self.tags = self._generate_tags() def load_data(self, bbox): - data = self.api.get(bbox, self.tags) + data = self.api.get(bbox, self.tags, nodes=True, ways=True, relations=False) return self._generate_street(data) def _add(self, categories): diff --git a/src/detection/box_walker.py b/src/detection/box_walker.py index 6ef9e7e7..967561e6 100644 --- a/src/detection/box_walker.py +++ b/src/detection/box_walker.py @@ -1,36 +1,75 @@ import datetime import logging +import threading +import queue from importlib import import_module +from itertools import chain from src.base.globalmaptiles import GlobalMercator from src.base.configuration import Configuration from src.base.tag import Tag -from src.data.orthofoto.tile_loader import TileLoader +from src.base.tile import Tile from src.data.osm.node_merger import NodeMerger from src.data.osm.osm_comparator import OsmComparator from src.data.osm.street_loader import StreetLoader from src.detection.street_walker import StreetWalker from src.detection.tensor.detector import Detector -from src.data.orthofoto.other.other_api import OtherApi from src.detection.tile_walker import TileWalker logger = logging.getLogger(__name__) +class BackgroundGenerator(threading.Thread): + def __init__(self, generator, queue): + threading.Thread.__init__(self) + self.queue = queue + self.generator = generator + self.daemon = True + #self.start() + + def run(self): + for item in self.generator: + print("put: " + str(self.queue.qsize())) + self.queue.put(item) + self.queue.put(None) + + def __iter__(self): + return self + + def __next__(self): + print("next: " + str(self.queue.qsize())) + next_item = self.queue.get() + if next_item is None: + raise StopIteration + return next_item + + +class BackgroundGeneratorPool: + def __init__(self, generator, n=8): + q = queue.Queue(100) + self.pool = list(map(lambda _: BackgroundGenerator(generator, q), range(n))) + for b in self.pool: + b.start() + + def __iter__(self): + return self.pool[0].__iter__() + + def __next__(self): + return self.pool[0].__next__() + + class BoxWalker: def __init__(self, bbox, configuration=None): self.configuration = Configuration() if configuration is None else configuration self.bbox = bbox - self.tile = None self.streets = [] self.convnet = None - self.square_image_length = 50 + self.square_image_length = 100 self.max_distance = self._calculate_max_distance(self.configuration.DETECTION.zoomlevel, self.square_image_length) - self.image_api = OtherApi( - self.configuration.DETECTION.zoomlevel) if self.configuration.DETECTION.orthophoto is 'other' else self._get_image_api( - self.configuration.DETECTION.orthophoto) + self.image_api = self._get_image_api(self.configuration.DETECTION.orthophoto) + self.tile = Tile(image_api=self.image_api, bbox=self.bbox) @staticmethod def _get_image_api(image_api): @@ -47,13 +86,6 @@ def load_convnet(self): logger.error(error_message) raise Exception(error_message) - def load_tiles(self): - self._printer("Start image loading.") - loader = TileLoader(bbox=self.bbox, image_api=self.image_api) - loader.load_tile() - self.tile = loader.tile - self._printer("Stop image loading.") - def load_streets(self): self._printer("Start street loading.") street_loader = StreetLoader() @@ -62,7 +94,7 @@ def load_streets(self): self._printer("Stop street loading.") def walk(self): - ready_for_walk = (not self.tile is None) and (not self.convnet is None) + ready_for_walk = not self.convnet is None if not ready_for_walk: error_message = "Not ready for walk. Load tiles and convnet first" logger.error(error_message) @@ -77,15 +109,14 @@ def walk(self): else: tiles = self._get_tiles_of_box(self.tile) - self._printer("{0} images to analyze.".format(str(len(tiles)))) + #self._printer("{0} images to analyze.".format(str(len(tiles)))) - images = [tile.image for tile in tiles] - predictions = self.convnet.detect(images) + tiles = BackgroundGeneratorPool(tiles, n=32) # Fetch next images while running detect + predictions = self.convnet.detect(tiles) detected_nodes = [] - for i, _ in enumerate(tiles): - prediction = predictions[i] + for prediction in predictions: if self.hit(prediction): - detected_nodes.append(tiles[i].get_centre_node()) + detected_nodes.append(prediction['tile'].get_centre_node()) self._printer("Stop detection.") merged_nodes = self._merge_near_nodes(detected_nodes) @@ -106,11 +137,9 @@ def _get_tiles_of_box_with_streets(self, streets, tile): street_walker = StreetWalker(tile=tile, square_image_length=self.square_image_length, zoom_level=self.configuration.DETECTION.zoomlevel, step_width=self.configuration.DETECTION.stepwidth) - tiles = [] - for street in streets: - street_tiles = street_walker.get_tiles(street) - tiles += street_tiles - return tiles + + return chain.from_iterable(map(street_walker.get_tiles, streets)) + def _merge_near_nodes(self, node_list): merger = NodeMerger(node_list, self.max_distance) diff --git a/src/detection/tensor/detector.py b/src/detection/tensor/detector.py index 66026800..cc5a4e76 100644 --- a/src/detection/tensor/detector.py +++ b/src/detection/tensor/detector.py @@ -24,9 +24,7 @@ def _load_graph(self): graph_def.ParseFromString(f.read()) return graph_def - def detect(self, images): - np_images = [self._pil_to_np(image) for image in images] - + def detect(self, tiles): pool = ThreadPool() with tf.Graph().as_default() as imported_graph: tf.import_graph_def(self.graph_def, name='') @@ -34,14 +32,13 @@ def detect(self, images): with tf.Session(graph=imported_graph) as sess: with tf.device("/gpu:0"): softmax_tensor = sess.graph.get_tensor_by_name('final_result:0') - threads = [pool.apply_async(self.operation, - args=(sess, softmax_tensor, np_images[image_number], image_number,)) for - image_number in range(len(np_images))] + + prediction_image = pool.map(lambda image_number_tile: self.operation(sess, softmax_tensor, self._pil_to_np(image_number_tile[1].image), image_number_tile[1], image_number_tile[0]), enumerate(tiles), 16) + answers = [] - for thread in threads: - prediction, image_number = thread.get() + for prediction, tile in prediction_image: prediction = np.squeeze(prediction) - answer = {'image_number': image_number} + answer = {'tile': tile} for node_id, _ in enumerate(prediction): answer[self.labels[node_id]] = prediction[node_id] answers.append(answer) @@ -52,6 +49,8 @@ def _pil_to_np(image): return np.array(image)[:, :, 0:3] @staticmethod - def operation(sess, softmax, image, image_number): + def operation(sess, softmax, image, tile, image_number): + if image_number % 50 == 0: + print("operation {0}".format(image_number)) prediction = sess.run(softmax, {'DecodeJpeg:0': image}) - return prediction, image_number + return prediction, tile diff --git a/src/detection/walker.py b/src/detection/walker.py index 2689cc8c..8cc7b3eb 100644 --- a/src/detection/walker.py +++ b/src/detection/walker.py @@ -15,9 +15,4 @@ def _calculate_step_distance(self, zoom_level): return resolution * (self._square_image_length * self._step_width) def _get_squared_tiles(self, nodes): - square_tiles = [] - for node in nodes: - tile = self.tile.get_tile_by_node(node, self._square_image_length) - if self.tile.bbox.in_bbox(node): - square_tiles.append(tile) - return square_tiles + return map(lambda node: self.tile.get_tile_by_node(node, self._square_image_length), nodes) diff --git a/src/role/config.ini b/src/role/config.ini index e12d844d..9f55077f 100644 --- a/src/role/config.ini +++ b/src/role/config.ini @@ -12,7 +12,7 @@ Key = highway Value = crossing Zoomlevel = 19 Compare = no -Orthophoto = other +Orthophoto = wms FollowStreets = yes Stepwidth = 0.66 diff --git a/src/role/worker_functions.py b/src/role/worker_functions.py index 9c37734d..df8a8692 100755 --- a/src/role/worker_functions.py +++ b/src/role/worker_functions.py @@ -23,7 +23,6 @@ def get_nodes(bbox, configuration): if len(walker.streets) > 0 or not follow_streets: walker.load_convnet() - walker.load_tiles() crosswalk_nodes = walker.walk() return crosswalk_nodes diff --git a/src/train/coord_walker.py b/src/train/coord_walker.py new file mode 100644 index 00000000..1151ff8a --- /dev/null +++ b/src/train/coord_walker.py @@ -0,0 +1,11 @@ +from src.detection.walker import Walker + + +class CoordWalker(Walker): + def __init__(self, tile, nodes, square_image_length=50, zoom_level=19, step_width=0.66): + super(CoordWalker, self).__init__(tile, square_image_length, zoom_level, step_width) + self.nodes = nodes + + def get_tiles(self): + squared_tiles = self._get_squared_tiles(self.nodes) + return squared_tiles diff --git a/src/train/fetch.py b/src/train/fetch.py new file mode 100644 index 00000000..ca3d807d --- /dev/null +++ b/src/train/fetch.py @@ -0,0 +1,53 @@ +from src.base.tile import Tile +from src.base.bbox import Bbox +from src.base.node import Node +from src.base.tag import Tag + +from src.data.orthofoto.wms.wms_api import WmsApi + +from src.train.coord_walker import CoordWalker +from src.train.osm_object_walker import OsmObjectWalker + +import argparse + +def main(args): + coords = [args.coord[i:i + 2] for i in range(0, len(args.coord), 2)] + coords = list(map(lambda c: Node(*c), coords)) + bbox = Bbox.from_nodes(coords[0], coords[1]) + if args.tags: + tags = map(lambda kv: Tag(key=kv[0], value=kv[1]), map(lambda kv: kv.split('=', 1), args.tags.split(','))) + walker = OsmObjectWalker(Tile(image_api=WmsApi(), bbox=bbox), tags, square_image_length=100) + else: + walker = CoordWalker(Tile(image_api=WmsApi(), bbox=bbox), coords, square_image_length=100) + + tiles = walker.get_tiles() + for n, t in enumerate(tiles): + centre_node = t.get_centre_node() + name = "fetch/{0:02.8}_{1:02.8}.png".format(centre_node.latitude, centre_node.longitude) + t.image.save(name, "PNG") + print(name) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + '--tags', + type=str, + default=None, + help='Tag to fetch from OSM: highway=crossing.' + ) + + parser.add_argument( + 'coord', + type=float, + action='store', + nargs='+', + help='List of lon lat coord in WGS84, if --tags bbox left,bottom right,top, else list of coords to fetch.') + + args = parser.parse_args() + main(args) + +# mapproxy-util serve-develop mapproxy.yml +# python fetch.py --tags public_transport=platform 2.3681867122650146 48.87587197694874 2.3733580112457275 48.8794564363519 +# montage *.png -geometry 100x100+1+1 out.png +# python retrain.py --image_dir retrain-data --print_misclassified_test_images diff --git a/src/train/osm_object_walker.py b/src/train/osm_object_walker.py new file mode 100644 index 00000000..79de2035 --- /dev/null +++ b/src/train/osm_object_walker.py @@ -0,0 +1,25 @@ +from src.base.node import Node +from src.detection.walker import Walker +from src.data.osm.overpass_api import OverpassApi + + +class OsmObjectWalker(Walker): + def __init__(self, tile, tags, square_image_length=50, zoom_level=19, step_width=0.66): + super(OsmObjectWalker, self).__init__(tile, square_image_length, zoom_level, step_width) + self.tags = tags + + def get_tiles(self): + nodes = self._calculate_tile_centres() + squared_tiles = self._get_squared_tiles(nodes) + return squared_tiles + + def _calculate_tile_centres(self): + centers = [] + + # [out:csv(::lat,::lon)][timeout:25];node["public_transport"="platform"]({{bbox}});out; + self.api = OverpassApi() + data = self.api.get(self.tile.bbox, self.tags, nodes=True, ways=False, relations=False, responseformat='csv(::lat,::lon)') + data = list(map(lambda cc: Node(float(cc[0]), float(cc[1])), data[1:])) + print(data) + + return data diff --git a/tests/data/orthofoto/other/__init__.py b/tests/data/orthofoto/other/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/data/orthofoto/other/test_muti_loader.py b/tests/data/orthofoto/other/test_muti_loader.py deleted file mode 100644 index 33746f9f..00000000 --- a/tests/data/orthofoto/other/test_muti_loader.py +++ /dev/null @@ -1,15 +0,0 @@ -import pytest - -from src.data.orthofoto.other.multi_loader import MultiLoader - - -@pytest.fixture(scope='module') -def urls(): - return ['https://t5.ssl.ak.tiles.virtualearth.net/tiles/a1202211220212011302.jpeg?g=4401&n=z', - 'https://t2.ssl.ak.tiles.virtualearth.net/tiles/a1202211220212011303.jpeg?g=4401&n=z'] - - -def test_multi_loader(urls): - multi_loader = MultiLoader(urls=urls) - multi_loader.download() - assert 2 == len(multi_loader.results) diff --git a/tests/data/orthofoto/other/test_url_builder.py b/tests/data/orthofoto/other/test_url_builder.py deleted file mode 100644 index 4c3b58ce..00000000 --- a/tests/data/orthofoto/other/test_url_builder.py +++ /dev/null @@ -1,25 +0,0 @@ -import pytest -import requests - -from src.data.orthofoto.other.fitting_bbox import FittingBbox -from src.data.orthofoto.other.url_builder import UrlBuilder - - -@pytest.fixture(scope='module') -def urls(small_bbox): - url_builder = UrlBuilder() - fitting_box = FittingBbox() - t_minx, t_miny, t_maxx, t_maxy = fitting_box.bbox_to_tiles(small_bbox) - return url_builder.get_urls_by_tiles(t_minx, t_miny, t_maxx, t_maxy) - - -def test_url_from_node(urls): - assert 'ssl.ak.tiles.virtualearth.net/tiles/a' in urls[0] - - -def test_url_reachable(urls): - try: - response = requests.get(urls[0]) - except Exception: - assert False - assert response.content is not None diff --git a/tests/data/orthofoto/other/test_user_agent.py b/tests/data/orthofoto/other/test_user_agent.py deleted file mode 100644 index 7ab97979..00000000 --- a/tests/data/orthofoto/other/test_user_agent.py +++ /dev/null @@ -1,6 +0,0 @@ -from src.data.orthofoto.other.user_agent import UserAgent - - -def test_user_agent_random(): - user_agent = UserAgent() - assert len(user_agent.random) > 0 diff --git a/tests/data/orthofoto/test_tile_loader.py b/tests/data/orthofoto/test_tile_loader.py deleted file mode 100755 index f5259aca..00000000 --- a/tests/data/orthofoto/test_tile_loader.py +++ /dev/null @@ -1,26 +0,0 @@ -import pytest - -from src.data.orthofoto.tile_loader import TileLoader -from src.data.orthofoto.other.other_api import OtherApi - - -@pytest.fixture(scope='module') -def image_api(): - return OtherApi() - - -def test_satellite_image_download(zurich_bellevue, image_api): - bbox = zurich_bellevue - tl = TileLoader(bbox, image_api=image_api) - tile = tl.load_tile() - img = tile.image - assert img.size[0] > 0 - assert img.size[1] > 0 - - -def test_new_bbox(small_bbox, image_api): - tile_loader = TileLoader(small_bbox, image_api) - tile_loader.load_tile() - tile = tile_loader.tile - tile_bbox = tile.bbox - assert tile_bbox == small_bbox diff --git a/tests/test_config.ini b/tests/test_config.ini index a7a20510..38002e92 100644 --- a/tests/test_config.ini +++ b/tests/test_config.ini @@ -12,7 +12,7 @@ Key = highway Value = crossing Zoomlevel = 19 Compare = yes -Orthophoto = other +Orthophoto = wms FollowStreets = yes Stepwidth = 0.66