diff --git a/birdy/client/converters.py b/birdy/client/converters.py index 22497ce..dce5baa 100644 --- a/birdy/client/converters.py +++ b/birdy/client/converters.py @@ -7,6 +7,8 @@ from typing import Sequence, Union from birdy.utils import is_opendap_url from owslib.wps import Output +from functools import partial +from boltons.funcutils import update_wrapper from . import notebook as nb @@ -22,7 +24,7 @@ def __init__(self, output=None, path=None, verify=True): Parameters ---------- - output: owslib.wps.Output + output: owslib.wps.Output | Path | str Output object to be converted. """ self.path = path or tempfile.mkdtemp() @@ -38,6 +40,9 @@ def __init__(self, output=None, path=None, verify=True): else: raise NotImplementedError + # Create load method for converter + self.load = self._load_func() + @property def file(self): """Return output Path object. Download from server if not found.""" @@ -77,10 +82,18 @@ def convert(self): """To be subclassed.""" raise NotImplementedError + def _load_func(self): + """Return function that can open file.""" + raise NotImplementedError + class GenericConverter(BaseConverter): # noqa: D101 priority = 0 + def _load_func(self): + """Return function that can open file.""" + return lambda self: self.data + def convert(self): """Return raw bytes memory representation.""" return self.data @@ -91,6 +104,9 @@ class TextConverter(BaseConverter): # noqa: D101 extensions = ["txt", "csv", "md", "rst"] priority = 1 + def _load_func(self): + return self.file.read_text + def convert(self): """Return text content.""" return self.file.read_text(encoding="utf8") @@ -117,6 +133,12 @@ class JSONConverter(BaseConverter): # noqa: D101 extensions = ["json"] priority = 1 + def _load_func(self): + import json + + func = json.loads + return update_wrapper(partial(func, s=self.data), func, injected=["s"]) + def convert(self): # noqa: D102 import json @@ -132,6 +154,13 @@ class GeoJSONConverter(BaseConverter): # noqa: D101 def check_dependencies(self): # noqa: D102 self._check_import("geojson") + def _load_func(self): + import geojson + + func = geojson.loads + return update_wrapper(partial(func, s=self.data), func, injected=["s"]) + + def convert(self): # noqa: D102 import geojson @@ -151,6 +180,14 @@ class MetalinkConverter(BaseConverter): # noqa: D101 def check_dependencies(self): # noqa: D102 self._check_import("metalink.download") + def _load_func(self): + from metalink import download as md + + func = md.get + return update_wrapper(partial(func, src=self.url, path=self.path, segmented=False), + func, + injected=["src", "path", "segmented"]) + def convert(self): # noqa: D102 from metalink import download as md @@ -171,6 +208,13 @@ def check_dependencies(self): # noqa: D102 if version < StrictVersion("4.5"): raise ImportError("netCDF4 library must be at least version 4.5") + def _load_func(self): + import netCDF4 + + link = self.url if is_opendap_url(self.url) else self.file + func = netCDF4.Dataset.__call__ + return update_wrapper(partial(func, filename=link), func, injected=["filename"]) + def convert(self): # noqa: D102 import netCDF4 @@ -191,6 +235,13 @@ def check_dependencies(self): # noqa: D102 Netcdf4Converter.check_dependencies(self) self._check_import("xarray") + def _load_func(self): + import xarray as xr + + link = self.url if is_opendap_url(self.url) else self.file + func = xr.open_dataset + return update_wrapper(partial(func, filename_or_obj=link), func, injected=["filename_or_obj"]) + def convert(self): # noqa: D102 import xarray as xr @@ -211,11 +262,18 @@ def check_dependencies(self): # noqa: D102 ShpOgrConverter.check_dependencies(self) self._check_import("fiona") + def _load_func(self): + import io # isort: skip + import fiona # isort: skip + + func = fiona.open + return update_wrapper(partial(func, fp=self.file), func, injected=["fp"]) + def convert(self): # noqa: D102 import io # isort: skip import fiona # isort: skip - return lambda x: fiona.open(io.BytesIO(x)) + return fiona.open(fp=self.file) # TODO: Add test for this. @@ -227,10 +285,15 @@ class ShpOgrConverter(BaseConverter): # noqa: D101 def check_dependencies(self): # noqa: D102 self._check_import("ogr", package="osgeo") + def _load_func(self): + from osgeo import ogr + func = ogr.Open + return update_wrapper(partial(func, utf8_path=self.file), func, injected=["utf8_path"]) + def convert(self): # noqa: D102 from osgeo import ogr - return ogr.Open + return ogr.Open(self.file) # TODO: Add test for this. Probably can be applied to jpeg/jpg/gif but needs notebook testing @@ -242,6 +305,11 @@ class ImageConverter(BaseConverter): # noqa: D101 def check_dependencies(self): # noqa: D102 return nb.is_notebook() + def _load_func(self): + from birdy.dependencies import IPython + func = IPython.display.Image + return update_wrapper(partial(func, data=self.url), func, injected=["data"]) + def convert(self): # noqa: D102 from birdy.dependencies import IPython @@ -257,6 +325,13 @@ class GeotiffRioxarrayConverter(BaseConverter): # noqa: D101 def check_dependencies(self): # noqa: D102 self._check_import("rioxarray") + def _load_func(self): + import xarray # isort: skip + import rioxarray # noqa + + func = xarray.open_rasterio + return update_wrapper(partial(func, filename=self.file), injected=["filename"]) + def convert(self): # noqa: D102 import xarray # isort: skip import rioxarray # noqa @@ -273,6 +348,12 @@ class GeotiffRasterioConverter(BaseConverter): # noqa: D101 def check_dependencies(self): # noqa: D102 self._check_import("rasterio") + def _load_func(self): + import rasterio # isort: skip + + ds = rasterio.open(self.file) + return ds.read + def convert(self): # noqa: D102 import rasterio # isort: skip @@ -288,11 +369,16 @@ class GeotiffGdalConverter(BaseConverter): # noqa: D101 def check_dependencies(self): # noqa: D102 self._check_import("gdal", package="osgeo") + def _load_func(self): + from osgeo import gdal # isort: skip + + func = gdal.Open + return update_wrapper(partial(func, utf8_path=self.file), func, injected=["utf8_path"]) + def convert(self): # noqa: D102 - import io # isort: skip from osgeo import gdal # isort: skip - return lambda x: gdal.Open(io.BytesIO(x)) + return lambda x: gdal.Open(self.file) class ZipConverter(BaseConverter): # noqa: D101 @@ -301,7 +387,11 @@ class ZipConverter(BaseConverter): # noqa: D101 nested = True priority = 1 + def _load_func(self): + return self.convert + def convert(self): # noqa: D102 + """Return list of files in archive.""" import zipfile with zipfile.ZipFile(self.file) as z: @@ -312,6 +402,7 @@ def convert(self): # noqa: D102 def _find_converter(mimetype=None, extension=None, converters=()): """Return a list of compatible converters ordered by priority.""" select = [GenericConverter] + for obj in converters: if (mimetype in obj.mimetypes) or (extension in obj.extensions): select.append(obj) @@ -320,8 +411,13 @@ def _find_converter(mimetype=None, extension=None, converters=()): return select -def find_converter(obj, converters): +def find_converter(obj, converters=None): """Find converters for a WPS output or a file on disk.""" + + # Get all converters + if converters is None: + converters = all_subclasses(BaseConverter) + if isinstance(obj, Output): mimetype = obj.mimeType extension = Path(obj.fileName or "").suffix[1:] @@ -358,15 +454,13 @@ def convert( objs Python object or file's content as bytes. """ - # Get all converters - if converters is None: - converters = all_subclasses(BaseConverter) - # Find converters matching mime type or extension. + # Find converters matching mime type or extension. convs = find_converter(output, converters) # Try converters in order of priority for cls in convs: + print(cls) try: converter = cls(output, path=path, verify=verify) out = converter.convert() diff --git a/birdy/client/outputs.py b/birdy/client/outputs.py index 8eea5a7..28ee2fa 100644 --- a/birdy/client/outputs.py +++ b/birdy/client/outputs.py @@ -3,12 +3,35 @@ import tempfile from collections import namedtuple -from owslib.wps import WPSExecution +import owslib +from owslib.wps import WPSExecution, Output from birdy.client import utils from birdy.client.converters import convert from birdy.exceptions import ProcessFailed, ProcessIsNotComplete from birdy.utils import delist, sanitize +from .converters import find_converter + + +class BirdyOutput(Output): + """An owslib WPS output with user-friendly interface, including conversion methods.""" + + def __init__(self, output, path=None, converters=None): + # Copy owslib.wps.Output attributes + for key in ["abstract", "title", "identifier", "reference", "dataType"]: + setattr(self, key, getattr(output, key)) + self.path = path + + # List of converters + self.converters = find_converter(output, converters) + + if len(self.converters) > 0: + # Primary converter instance + self.converter = self.converters[0](output, path=path, verify=False) + + # Copy converter attributes, including `load` method + for key in ["data", "file", "path", "load"]: + setattr(self, key, getattr(self.converter, key)) class WPSResult(WPSExecution): # noqa: D101 @@ -25,6 +48,33 @@ def attach(self, wps_outputs, converters=None): self._converters = converters self._path = tempfile.mkdtemp() + def _output_namedtuple(self): + """Return namedtuple for outputs.""" + Output = namedtuple( + sanitize(self.process.identifier) + "Response", + [sanitize(o.identifier) for o in self.processOutputs], + ) + Output.__repr__ = utils.pretty_repr + return Output + + def _create_birdy_outputs(self): + Output = self._output_namedtuple() + return Output( + *[BirdyOutput(o) for o in self.processOutputs] + ) + + def load(self): + """Return BirdyOutput instances. + + TODO: Decide on function name. + """ + if not self.isComplete(): + raise ProcessIsNotComplete("Please wait ...") + if not self.isSucceded(): + # TODO: add reason for failure + raise ProcessFailed("Sorry, process failed.") + return self._create_birdy_outputs() + def get(self, asobj=False): """Return the process response outputs. @@ -41,11 +91,7 @@ def get(self, asobj=False): return self._make_output(asobj) def _make_output(self, convert_objects=False): - Output = namedtuple( - sanitize(self.process.identifier) + "Response", - [sanitize(o.identifier) for o in self.processOutputs], - ) - Output.__repr__ = utils.pretty_repr + Output = self._output_namedtuple() return Output( *[self._process_output(o, convert_objects) for o in self.processOutputs] ) diff --git a/tests/test_client.py b/tests/test_client.py index 621c75e..317435d 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -162,6 +162,11 @@ def test_wps_client_multiple_outputs(wps): # noqa: D103 assert len(files) == 2 assert len(files4) == 2 + # As augmented outputs + [files, files4] = resp.load() + len(files.load()) == 2 + len(files4.load()) == 2 + @pytest.mark.online def test_process_subset_only_one(): # noqa: D103 diff --git a/tests/test_converters.py b/tests/test_converters.py index 4540ea3..1393f88 100644 --- a/tests/test_converters.py +++ b/tests/test_converters.py @@ -9,56 +9,80 @@ from birdy.client import converters -def test_all_subclasses(): # noqa: D103 - c = converters.all_subclasses(converters.BaseConverter) - assert converters.MetalinkConverter in c +@pytest.fixture +def nc_ex(tmp_path): + """Test netCDF dataset.""" + import netCDF4 as nc + + fn = tmp_path / "a.nc" + ds = nc.Dataset(fn, "w") + ds.createDimension("time", 10) + ds.createVariable("time", "f8", ("time",)) + ds.close() + return fn, ds + + +@pytest.fixture(params=[True, False]) +def json_ex(request, tmp_path): + binary = request.param + fn = tmp_path / "a.json" + d = {"a": 1} + val = json.dumps(d) + mode = "wb" if binary else "w" + val = bytes(val, "utf8") if binary else val + + with open(fn, mode) as f: + f.write(val) + + return fn, d -def test_jsonconverter(): # noqa: D103 - d = {"a": 1} - s = json.dumps(d) - b = bytes(s, "utf8") - fs = tempfile.NamedTemporaryFile(mode="w") - fs.write(s) - fs.file.seek(0) +@pytest.fixture +def txt_ex(tmp_path): + fn = tmp_path / "a.txt" + text = "coucou" - fb = tempfile.NamedTemporaryFile(mode="w+b") - fb.write(b) - fb.file.seek(0) + with open(fn, "w") as f: + f.write(text) + return fn, text - j = converters.JSONConverter(fs.name) - assert j.convert() == d - j = converters.JSONConverter(fb.name) - assert j.convert() == d +def test_all_subclasses(): # noqa: D103 + c = converters.all_subclasses(converters.BaseConverter) + assert converters.MetalinkConverter in c - fs.close() - fb.close() +def test_jsonconverter(json_ex): # noqa: D103 + fn, d = json_ex -def test_geojsonconverter(): # noqa: D103 - pytest.importorskip("geojson") - d = {"a": 1} - s = json.dumps(d) - b = bytes(s, "utf8") + c = converters.JSONConverter(fn) + assert c.convert() == d + assert c.load() == d + + +def test_geojsonconverter(json_ex): # noqa: D103 + fn, d = json_ex - fs = tempfile.NamedTemporaryFile(mode="w") - fs.write(s) - fs.file.seek(0) + c = converters.GeoJSONConverter(fn) + assert c.convert() == d + assert c.load() == d - fb = tempfile.NamedTemporaryFile(mode="w+b") - fb.write(b) - fb.file.seek(0) - j = converters.GeoJSONConverter(fs.name) - assert j.convert() == d +def test_textconverter(txt_ex): + fn, text = txt_ex + t = converters.TextConverter(fn) + assert t.convert() == text - j = converters.GeoJSONConverter(fb.name) - assert j.convert() == d + assert t.load() == text - fs.close() - fb.close() + # As class method + class A: + def __init__(self): + self.load = t._load_func() + + a = A() + assert a.load(encoding="utf8") == text def test_zipconverter(): # noqa: D103 @@ -79,11 +103,25 @@ def test_zipconverter(): # noqa: D103 zf.write(b.name, arcname=os.path.split(b.name)[1]) zf.close() + z = converters.ZipConverter(f) + files = z.convert() + assert len(files) == 2 + files = z.load() + assert len(files) == 2 + [oa, ob] = [converters.convert(f, path="/tmp") for f in files] + assert oa == {"a": 1} + assert len(ob.splitlines()) == 2 + [oa, ob] = converters.convert(f, path="/tmp", converters=[converters.ZipConverter]) assert oa == {"a": 1} assert len(ob.splitlines()) == 2 +def test_geotiff_converter(tmp_path): + c = converters.GeotiffRasterioConverter("resources/Olympus.tif") + assert c.load().shape == (1, 99, 133) + + def test_jpeg_imageconverter(): # noqa: D103 # Note: Since the format is not supported, bytes will be returned fn = tempfile.mktemp(suffix=".jpeg") @@ -94,3 +132,23 @@ def test_jpeg_imageconverter(): # noqa: D103 b = converters.convert(fn, path="/tmp") assert isinstance(b, bytes) + + +def test_netcdf_converter(nc_ex): + pytest.importorskip("netCDF4") + + fn, ds = nc_ex + + c = converters.Netcdf4Converter(fn) + ds = c.convert() + assert "time" in ds.variables + + +def test_xarray_converter(nc_ex): + pytest.importorskip("xarray") + + fn, ds = nc_ex + + c = converters.XarrayConverter(fn) + ds = c.convert() + assert "time" in ds.variables