Skip to content

Commit

Permalink
feat: can compress image data before sending it to over the websocket
Browse files Browse the repository at this point in the history
This is useful when the widget is being used over a non-local network.
This can reduce the network traffic by a factor of 80 (for smooth, easy
to compress images). Pure noise image (random pixels) will not compress
well but will still see a factor of 7 reduction in size,
due to using uint8 instead of float64.
  • Loading branch information
maartenbreddels committed Oct 11, 2024
1 parent ba16715 commit 57383a8
Show file tree
Hide file tree
Showing 13 changed files with 283 additions and 46 deletions.
9 changes: 8 additions & 1 deletion bqplot_image_gl/imagegl.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import ipywidgets as widgets
import bqplot
from traittypes import Array
Expand All @@ -6,10 +7,15 @@
from bqplot.marks import shape
from bqplot.traits import array_to_json, array_from_json
from bqplot_image_gl._version import __version__
from .serialize import image_data_serialization

__all__ = ['ImageGL', 'Contour']


# can be 'png', 'webp' or 'none'
DEFAULT_IMAGE_DATA_COMPRESSION = os.environ.get("BQPLOT_IMAGE_GL_IMAGE_DATA_COMPRESSION", "png")


@widgets.register
class ImageGL(bqplot.Mark):
"""An example widget."""
Expand All @@ -24,7 +30,8 @@ class ImageGL(bqplot.Mark):
scaled=True,
rtype='Color',
atype='bqplot.ColorAxis',
**array_serialization)
**image_data_serialization)
compression = Unicode(DEFAULT_IMAGE_DATA_COMPRESSION, allow_none=True).tag(sync=True)
interpolation = Unicode('nearest', allow_none=True).tag(sync=True)
opacity = Float(1.0).tag(sync=True)
x = Array(default_value=(0, 1)).tag(sync=True, scaled=True,
Expand Down
94 changes: 94 additions & 0 deletions bqplot_image_gl/serialize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from PIL import Image
import numpy as np
import io

from bqplot.traits import array_serialization


def array_to_image_or_array(array, widget):
if widget.compression in ["png", "webp"]:
return array_to_image(array, widget.compression)
else:
return array_serialization["to_json"](array, widget)

Check warning on line 12 in bqplot_image_gl/serialize.py

View check run for this annotation

Codecov / codecov/patch

bqplot_image_gl/serialize.py#L12

Added line #L12 was not covered by tests


def not_implemented(image):
# the widget never sends the image data back to the kernel
raise NotImplementedError("deserializing is not implemented yet")

Check warning on line 17 in bqplot_image_gl/serialize.py

View check run for this annotation

Codecov / codecov/patch

bqplot_image_gl/serialize.py#L17

Added line #L17 was not covered by tests


def array_to_image(array, image_format):
# convert the array to a png image with intensity values only
# array = np.array(array)
min, max = None, None
use_colormap = False
if array.ndim == 2:
use_colormap = True
min = np.nanmin(array)
max = np.nanmax(array)

array = (array - min) / (max - min)
array_bytes = (array * 255).astype(np.uint8)
intensity_image = Image.fromarray(array_bytes, mode="L")

# create a mask image with 0 for NaN values and 255 for valid values
isnan = ~np.isnan(array)
mask = (isnan * 255).astype(np.uint8)
mask_image = Image.fromarray(mask, mode="L")

# merge the intensity and mask image into a single image
image = Image.merge("LA", (intensity_image, mask_image))
else:
# if floats, convert to uint8
if array.dtype.kind == "f":
array_bytes = (array * 255).astype(np.uint8)
elif array.dtype == np.uint8:
array_bytes = array

Check warning on line 46 in bqplot_image_gl/serialize.py

View check run for this annotation

Codecov / codecov/patch

bqplot_image_gl/serialize.py#L43-L46

Added lines #L43 - L46 were not covered by tests
else:
raise ValueError(

Check warning on line 48 in bqplot_image_gl/serialize.py

View check run for this annotation

Codecov / codecov/patch

bqplot_image_gl/serialize.py#L48

Added line #L48 was not covered by tests
"Only float arrays or uint8 arrays are supported, your array has dtype"
"{array.dtype}"
)
if array.shape[2] == 3:
image = Image.fromarray(array_bytes, mode="RGB")
elif array.shape[2] == 4:
image = Image.fromarray(array_bytes, mode="RGBA")

Check warning on line 55 in bqplot_image_gl/serialize.py

View check run for this annotation

Codecov / codecov/patch

bqplot_image_gl/serialize.py#L52-L55

Added lines #L52 - L55 were not covered by tests
else:
raise ValueError(

Check warning on line 57 in bqplot_image_gl/serialize.py

View check run for this annotation

Codecov / codecov/patch

bqplot_image_gl/serialize.py#L57

Added line #L57 was not covered by tests
"Only 2D arrays or 3D arrays with 3 or 4 channels are supported, "
f"your array has shape {array.shape}"
)

# and serialize it to a PNG
png_data = io.BytesIO()
image.save(png_data, format=image_format, lossless=True)
png_bytes = png_data.getvalue()
original_byte_length = array.nbytes
uint8_byte_length = array_bytes.nbytes
compressed_byte_length = len(png_bytes)
return {
"type": "image",
"format": image_format,
"use_colormap": use_colormap,
"min": min,
"max": max,
"data": png_bytes,
# this metadata is only useful/needed for debugging
"shape": array.shape,
"info": {
"original_byte_length": original_byte_length,
"uint8_byte_length": uint8_byte_length,
"compressed_byte_length": compressed_byte_length,
"compression_ratio": original_byte_length / compressed_byte_length,
"MB": {
"original": original_byte_length / 1024 / 1024,
"uint8": uint8_byte_length / 1024 / 1024,
"compressed": compressed_byte_length / 1024 / 1024,
},
},
}


image_data_serialization = dict(
to_json=array_to_image_or_array, from_json=not_implemented
)
29 changes: 25 additions & 4 deletions js/lib/contour.js
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,40 @@ class ContourModel extends bqplot.MarkModel {
this.update_data();
}

update_data() {
async update_data() {
const image_widget = this.get('image');
const level = this.get('level')
// we support a single level or multiple
this.thresholds = Array.isArray(level) ? level : [level];
if(image_widget) {
const image = image_widget.get('image')
this.width = image.shape[1];
this.height = image.shape[0];
let data = null;
if(image.image) {
const imageNode = image.image;
this.width = imageNode.width;
this.height = imageNode.height;
// conver the image to a typed array using canvas
const canvas = document.createElement('canvas');
canvas.width = this.width
canvas.height = this.height
const ctx = canvas.getContext('2d');
ctx.drawImage(imageNode, 0, 0);
const imageData = ctx.getImageData(0, 0, imageNode.width, imageNode.height);
const {min, max} = image;
// use the r channel as the data, and scale to the range
data = new Float32Array(imageData.data.length / 4);
for(var i = 0; i < data.length; i++) {
data[i] = (imageData.data[i*4] / 255) * (max - min) + min;
}
} else {
this.width = image.shape[1];
this.height = image.shape[0];
data = image.data;
}
this.contours = this.thresholds.map((threshold) => d3contour
.contours()
.size([this.width, this.height])
.contour(image.data, [threshold])
.contour(data, [threshold])
)
} else {
this.width = 1; // precomputed contour_lines will have to be in normalized
Expand Down
113 changes: 82 additions & 31 deletions js/lib/imagegl.js
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,24 @@ class ImageGLModel extends bqplot.MarkModel {
super.initialize(attributes, options);
this.on_some_change(['x', 'y'], this.update_data, this);
this.on_some_change(["preserve_domain"], this.update_domains, this);
this.listenTo(this, "change:image", () => {
const previous = this.previous("image");
if(previous.image && previous.image.src) {
URL.revokeObjectURL(previous.image.src);
}
}, this);

this.update_data();
}

close(comm_closed) {
const image = this.get("image");
if(image.image && image.image.src) {
URL.revokeObjectURL(previous.image.src);
}
return super.close(comm_closed);
}

update_data() {
this.mark_data = {
x: this.get("x"), y: this.get("y")
Expand Down Expand Up @@ -79,9 +94,24 @@ ImageGLModel.serializers = Object.assign({}, bqplot.MarkModel.serializers,
{ x: serialize.array_or_json,
y: serialize.array_or_json,
image: {
deserialize: (obj, manager) => {
let state = {buffer: obj.value, dtype: obj.dtype, shape: obj.shape};
return jupyter_dataserializers.JSONToArray(state);
deserialize: async (obj, manager) => {
if(obj.type == "image") {
// the data is encoded in an image with LA format
// luminance for the intensity, alpha for the mask
let image = new Image();
const blob = new Blob([obj.data], {type: `image/${obj.format}`});
const url = URL.createObjectURL(blob);
image.src = url;
await new Promise((resolve, reject) => {
image.onload = resolve;
image.onerror = reject;
} );
return {image, min: obj.min, max: obj.max, use_colormap: obj.use_colormap};
} else {
// otherwise just a 'normal' ndarray
let state = {buffer: obj.value, dtype: obj.dtype, shape: obj.shape};
return jupyter_dataserializers.JSONToArray(state);
}
},
serialize: (ar) => {
const {buffer, dtype, shape} = jupyter_dataserializers.arrayToJSON(ar);
Expand Down Expand Up @@ -114,6 +144,10 @@ class ImageGLView extends bqplot.Mark {
// basically the corners of the image
image_domain_x : { type: "2f", value: [0.0, 1.0] },
image_domain_y : { type: "2f", value: [0.0, 1.0] },
// in the case we use an image for the values, the image is normalized, and we need to scale
// it back to a particular image range
// This needs to be set to [0, 1] for array data (which is not normalized)
range_image : { type: "2f", value: [0.0, 1.0] },
// extra opacity value
opacity: {type: 'f', value: 1.0}
},
Expand Down Expand Up @@ -280,39 +314,56 @@ class ImageGLView extends bqplot.Mark {
update_image(skip_render) {
var image = this.model.get("image");
var type = null;
var data = image.data;
if(data instanceof Uint8Array) {
type = THREE.UnsignedByteType;
} else if(data instanceof Float64Array) {
console.warn('ImageGLView.data is a Float64Array which WebGL does not support, will convert to a Float32Array (consider sending float32 data for better performance).');
data = Float32Array.from(data);
type = THREE.FloatType;
} else if(data instanceof Float32Array) {
type = THREE.FloatType;
} else {
console.error('only types uint8 and float32 are supported');
return;
}
if(this.scales.image.model.get('scheme') && image.shape.length == 2) {
if(this.texture)
if(image.image) {
// the data is encoded in an image with LA format
if(this.texture) {
this.texture.dispose();
this.texture = new THREE.DataTexture(data, image.shape[1], image.shape[0], THREE.LuminanceFormat, type);
}
this.texture = new THREE.Texture(image.image);
this.texture.needsUpdate = true;
this.texture.flipY = false;
this.image_material.uniforms.image.value = this.texture;
this.image_material.defines.USE_COLORMAP = true;
this.image_material.defines.USE_COLORMAP = image.use_colormap;
this.image_material.needsUpdate = true;
} else if(image.shape.length == 3) {
this.image_material.defines.USE_COLORMAP = false;
if(this.texture)
this.texture.dispose();
if(image.shape[2] == 3)
this.texture = new THREE.DataTexture(data, image.shape[1], image.shape[0], THREE.RGBFormat, type);
if(image.shape[2] == 4)
this.texture = new THREE.DataTexture(data, image.shape[1], image.shape[0], THREE.RGBAFormat, type);
this.texture.needsUpdate = true;
this.image_material.uniforms.image.value = this.texture;
this.image_material.uniforms.range_image.value = [image.min, image.max];
} else {
console.error('image data not understood');
// we are not dealing with an image, but with an array
// which is not normalized, so we can reset the range_image
this.image_material.uniforms.range_image.value = [0, 1];
var data = image.data;
if(data instanceof Uint8Array) {
type = THREE.UnsignedByteType;
} else if(data instanceof Float64Array) {
console.warn('ImageGLView.data is a Float64Array which WebGL does not support, will convert to a Float32Array (consider sending float32 data for better performance).');
data = Float32Array.from(data);
type = THREE.FloatType;
} else if(data instanceof Float32Array) {
type = THREE.FloatType;
} else {
console.error('only types uint8 and float32 are supported');
return;
}
if(this.scales.image.model.get('scheme') && image.shape.length == 2) {
if(this.texture)
this.texture.dispose();
this.texture = new THREE.DataTexture(data, image.shape[1], image.shape[0], THREE.LuminanceFormat, type);
this.texture.needsUpdate = true;
this.image_material.uniforms.image.value = this.texture;
this.image_material.defines.USE_COLORMAP = true;
this.image_material.needsUpdate = true;
} else if(image.shape.length == 3) {
this.image_material.defines.USE_COLORMAP = false;
if(this.texture)
this.texture.dispose();
if(image.shape[2] == 3)
this.texture = new THREE.DataTexture(data, image.shape[1], image.shape[0], THREE.RGBFormat, type);
if(image.shape[2] == 4)
this.texture = new THREE.DataTexture(data, image.shape[1], image.shape[0], THREE.RGBAFormat, type);
this.texture.needsUpdate = true;
this.image_material.uniforms.image.value = this.texture;
} else {
console.error('image data not understood');
}
}
this.texture.magFilter = interpolations[this.model.get('interpolation')];
this.texture.minFilter = interpolations[this.model.get('interpolation')];
Expand Down
11 changes: 9 additions & 2 deletions js/shaders/image-fragment.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ uniform vec2 domain_y;
uniform vec2 image_domain_x;
uniform vec2 image_domain_y;

uniform vec2 range_image;


bool isnan(float val)
{
return (val < 0.0 || 0.0 < val || val == 0.0) ? false : true;
Expand All @@ -32,7 +35,10 @@ void main(void) {
float y_normalized = scale_transform_linear(y_domain_value, vec2(0., 1.), image_domain_y);
vec2 tex_uv = vec2(x_normalized, y_normalized);
#ifdef USE_COLORMAP
float raw_value = texture2D(image, tex_uv).r;
// r (or g or b) is used for the value, alpha for the mask (is 0 if a nan is found)
vec2 pixel_value = texture2D(image, tex_uv).ra;
float raw_value = pixel_value[0] * (range_image[1] - range_image[0]) + range_image[0];
float opacity_image = pixel_value[1];
float value = (raw_value - color_min) / (color_max - color_min);
vec4 color;
if(isnan(value)) // nan's are interpreted as missing values, and 'not shown'
Expand All @@ -41,8 +47,9 @@ void main(void) {
color = texture2D(colormap, vec2(value, 0.5));
#else
vec4 color = texture2D(image, tex_uv);
float opacity_image = 1.0;
#endif
// since we're working with pre multiplied colors (regarding blending)
// we also need to multiply rgb by opacity
gl_FragColor = color * opacity;
gl_FragColor = color * opacity * opacity_image;
}
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@
include_package_data=True,
install_requires=[
'ipywidgets>=7.0.0',
'bqplot>=0.12'
'bqplot>=0.12',
'pillow',
],
packages=find_packages(),
zip_safe=False,
Expand Down
10 changes: 8 additions & 2 deletions tests/ui/contour_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from pathlib import Path
import pytest
import ipywidgets as widgets
import playwright.sync_api
from IPython.display import display
Expand All @@ -6,7 +8,8 @@
from bqplot_image_gl import ImageGL, Contour


def test_widget_image(solara_test, page_session: playwright.sync_api.Page, assert_solara_snapshot):
@pytest.mark.parametrize("compression", ["png", "none"])
def test_widget_image(solara_test, page_session: playwright.sync_api.Page, assert_solara_snapshot, compression, request):

scale_x = LinearScale(min=0, max=1)
scale_y = LinearScale(min=0, max=1)
Expand Down Expand Up @@ -34,5 +37,8 @@ def test_widget_image(solara_test, page_session: playwright.sync_api.Page, asser

svg = page_session.locator(".bqplot")
svg.wait_for()
# page_session.wait_for_timeout(1000)
page_session.wait_for_timeout(100)
# although the contour is almost the same, due to precision issues, the image is slightly different
# therefore unlike the image_test, we use a different testname/image name based on the fixture value
# for compression
assert_solara_snapshot(svg.screenshot())
Loading

0 comments on commit 57383a8

Please sign in to comment.