Skip to content

Commit

Permalink
feat: can compress image data before sending it to over the websocket
Browse files Browse the repository at this point in the history
This is useful when the widget is being used over a non-local network.
This can reduce the network traffic by a factor of 80 (for smooth, easy
to compress images). Pure noise image (random pixels) will not compress
well but will still see a factor of 7 reduction in size,
due to using uint8 instead of float64.
  • Loading branch information
maartenbreddels committed Oct 15, 2024
1 parent ba16715 commit 57c7230
Show file tree
Hide file tree
Showing 13 changed files with 299 additions and 47 deletions.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,17 @@ Used for https://github.com/glue-viz/glue-jupyter

(currently requires latest developer version of bqplot)

## Usage

### ImageGL

See https://py.cafe/maartenbreddels/bqplot-image-gl-demo for a demo of the ImageGL widget.

Preview image:
![preview image](https://py.cafe/preview/maartenbreddels/bqplot-image-gl-demo)



# Installation

To install use pip:
Expand Down
9 changes: 8 additions & 1 deletion bqplot_image_gl/imagegl.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import ipywidgets as widgets
import bqplot
from traittypes import Array
Expand All @@ -6,10 +7,15 @@
from bqplot.marks import shape
from bqplot.traits import array_to_json, array_from_json
from bqplot_image_gl._version import __version__
from .serialize import image_data_serialization

__all__ = ['ImageGL', 'Contour']


# can be 'png', 'webp' or 'none'
DEFAULT_IMAGE_DATA_COMPRESSION = os.environ.get("BQPLOT_IMAGE_GL_IMAGE_DATA_COMPRESSION", "none")


@widgets.register
class ImageGL(bqplot.Mark):
"""An example widget."""
Expand All @@ -24,7 +30,8 @@ class ImageGL(bqplot.Mark):
scaled=True,
rtype='Color',
atype='bqplot.ColorAxis',
**array_serialization)
**image_data_serialization)
compression = Unicode(DEFAULT_IMAGE_DATA_COMPRESSION, allow_none=True).tag(sync=True)
interpolation = Unicode('nearest', allow_none=True).tag(sync=True)
opacity = Float(1.0).tag(sync=True)
x = Array(default_value=(0, 1)).tag(sync=True, scaled=True,
Expand Down
98 changes: 98 additions & 0 deletions bqplot_image_gl/serialize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from PIL import Image
import numpy as np
import io

from bqplot.traits import array_serialization


def array_to_image_or_array(array, widget):
if widget.compression in ["png", "webp"]:
return array_to_image(array, widget.compression)
else:
return array_serialization["to_json"](array, widget)


def not_implemented(image):
# the widget never sends the image data back to the kernel
raise NotImplementedError("deserializing is not implemented yet")


def array_to_image(array, image_format):
# convert the array to a png image with intensity values only
array = np.array(array, copy=False)
min, max = None, None
use_colormap = False
if array.ndim == 2:
use_colormap = True
min = np.nanmin(array)
max = np.nanmax(array)

array = (array - min) / (max - min)
# only convert to uint8 if the array is float
if array.dtype.kind == "f":
array_bytes = (array * 255).astype(np.uint8)
else:
array_bytes = array
intensity_image = Image.fromarray(array_bytes, mode="L")

# create a mask image with 0 for NaN values and 255 for valid values
isnan = ~np.isnan(array)
mask = (isnan * 255).astype(np.uint8)
mask_image = Image.fromarray(mask, mode="L")

# merge the intensity and mask image into a single image
image = Image.merge("LA", (intensity_image, mask_image))
else:
# if floats, convert to uint8
if array.dtype.kind == "f":
array_bytes = (array * 255).astype(np.uint8)
elif array.dtype == np.uint8:
array_bytes = array
else:
raise ValueError(
"Only float arrays or uint8 arrays are supported, your array has dtype"
"{array.dtype}"
)
if array.shape[2] == 3:
image = Image.fromarray(array_bytes, mode="RGB")
elif array.shape[2] == 4:
image = Image.fromarray(array_bytes, mode="RGBA")
else:
raise ValueError(
"Only 2D arrays or 3D arrays with 3 or 4 channels are supported, "
f"your array has shape {array.shape}"
)

# and serialize it to a PNG
png_data = io.BytesIO()
image.save(png_data, format=image_format, lossless=True)
png_bytes = png_data.getvalue()
original_byte_length = array.nbytes
uint8_byte_length = array_bytes.nbytes
compressed_byte_length = len(png_bytes)
return {
"type": "image",
"format": image_format,
"use_colormap": use_colormap,
"min": min,
"max": max,
"data": png_bytes,
# this metadata is only useful/needed for debugging
"shape": array.shape,
"info": {
"original_byte_length": original_byte_length,
"uint8_byte_length": uint8_byte_length,
"compressed_byte_length": compressed_byte_length,
"compression_ratio": original_byte_length / compressed_byte_length,
"MB": {
"original": original_byte_length / 1024 / 1024,
"uint8": uint8_byte_length / 1024 / 1024,
"compressed": compressed_byte_length / 1024 / 1024,
},
},
}


image_data_serialization = dict(
to_json=array_to_image_or_array, from_json=not_implemented
)
29 changes: 25 additions & 4 deletions js/lib/contour.js
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,40 @@ class ContourModel extends bqplot.MarkModel {
this.update_data();
}

update_data() {
async update_data() {
const image_widget = this.get('image');
const level = this.get('level')
// we support a single level or multiple
this.thresholds = Array.isArray(level) ? level : [level];
if(image_widget) {
const image = image_widget.get('image')
this.width = image.shape[1];
this.height = image.shape[0];
let data = null;
if(image.image) {
const imageNode = image.image;
this.width = imageNode.width;
this.height = imageNode.height;
// conver the image to a typed array using canvas
const canvas = document.createElement('canvas');
canvas.width = this.width
canvas.height = this.height
const ctx = canvas.getContext('2d');
ctx.drawImage(imageNode, 0, 0);
const imageData = ctx.getImageData(0, 0, imageNode.width, imageNode.height);
const {min, max} = image;
// use the r channel as the data, and scale to the range
data = new Float32Array(imageData.data.length / 4);
for(var i = 0; i < data.length; i++) {
data[i] = (imageData.data[i*4] / 255) * (max - min) + min;
}
} else {
this.width = image.shape[1];
this.height = image.shape[0];
data = image.data;
}
this.contours = this.thresholds.map((threshold) => d3contour
.contours()
.size([this.width, this.height])
.contour(image.data, [threshold])
.contour(data, [threshold])
)
} else {
this.width = 1; // precomputed contour_lines will have to be in normalized
Expand Down
113 changes: 82 additions & 31 deletions js/lib/imagegl.js
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,24 @@ class ImageGLModel extends bqplot.MarkModel {
super.initialize(attributes, options);
this.on_some_change(['x', 'y'], this.update_data, this);
this.on_some_change(["preserve_domain"], this.update_domains, this);
this.listenTo(this, "change:image", () => {
const previous = this.previous("image");
if(previous.image && previous.image.src) {
URL.revokeObjectURL(previous.image.src);
}
}, this);

this.update_data();
}

close(comm_closed) {
const image = this.get("image");
if(image.image && image.image.src) {
URL.revokeObjectURL(previous.image.src);
}
return super.close(comm_closed);
}

update_data() {
this.mark_data = {
x: this.get("x"), y: this.get("y")
Expand Down Expand Up @@ -79,9 +94,24 @@ ImageGLModel.serializers = Object.assign({}, bqplot.MarkModel.serializers,
{ x: serialize.array_or_json,
y: serialize.array_or_json,
image: {
deserialize: (obj, manager) => {
let state = {buffer: obj.value, dtype: obj.dtype, shape: obj.shape};
return jupyter_dataserializers.JSONToArray(state);
deserialize: async (obj, manager) => {
if(obj.type == "image") {
// the data is encoded in an image with LA format
// luminance for the intensity, alpha for the mask
let image = new Image();
const blob = new Blob([obj.data], {type: `image/${obj.format}`});
const url = URL.createObjectURL(blob);
image.src = url;
await new Promise((resolve, reject) => {
image.onload = resolve;
image.onerror = reject;
} );
return {image, min: obj.min, max: obj.max, use_colormap: obj.use_colormap};
} else {
// otherwise just a 'normal' ndarray
let state = {buffer: obj.value, dtype: obj.dtype, shape: obj.shape};
return jupyter_dataserializers.JSONToArray(state);
}
},
serialize: (ar) => {
const {buffer, dtype, shape} = jupyter_dataserializers.arrayToJSON(ar);
Expand Down Expand Up @@ -114,6 +144,10 @@ class ImageGLView extends bqplot.Mark {
// basically the corners of the image
image_domain_x : { type: "2f", value: [0.0, 1.0] },
image_domain_y : { type: "2f", value: [0.0, 1.0] },
// in the case we use an image for the values, the image is normalized, and we need to scale
// it back to a particular image range
// This needs to be set to [0, 1] for array data (which is not normalized)
range_image : { type: "2f", value: [0.0, 1.0] },
// extra opacity value
opacity: {type: 'f', value: 1.0}
},
Expand Down Expand Up @@ -280,39 +314,56 @@ class ImageGLView extends bqplot.Mark {
update_image(skip_render) {
var image = this.model.get("image");
var type = null;
var data = image.data;
if(data instanceof Uint8Array) {
type = THREE.UnsignedByteType;
} else if(data instanceof Float64Array) {
console.warn('ImageGLView.data is a Float64Array which WebGL does not support, will convert to a Float32Array (consider sending float32 data for better performance).');
data = Float32Array.from(data);
type = THREE.FloatType;
} else if(data instanceof Float32Array) {
type = THREE.FloatType;
} else {
console.error('only types uint8 and float32 are supported');
return;
}
if(this.scales.image.model.get('scheme') && image.shape.length == 2) {
if(this.texture)
if(image.image) {
// the data is encoded in an image with LA format
if(this.texture) {
this.texture.dispose();
this.texture = new THREE.DataTexture(data, image.shape[1], image.shape[0], THREE.LuminanceFormat, type);
}
this.texture = new THREE.Texture(image.image);
this.texture.needsUpdate = true;
this.texture.flipY = false;
this.image_material.uniforms.image.value = this.texture;
this.image_material.defines.USE_COLORMAP = true;
this.image_material.defines.USE_COLORMAP = image.use_colormap;
this.image_material.needsUpdate = true;
} else if(image.shape.length == 3) {
this.image_material.defines.USE_COLORMAP = false;
if(this.texture)
this.texture.dispose();
if(image.shape[2] == 3)
this.texture = new THREE.DataTexture(data, image.shape[1], image.shape[0], THREE.RGBFormat, type);
if(image.shape[2] == 4)
this.texture = new THREE.DataTexture(data, image.shape[1], image.shape[0], THREE.RGBAFormat, type);
this.texture.needsUpdate = true;
this.image_material.uniforms.image.value = this.texture;
this.image_material.uniforms.range_image.value = [image.min, image.max];
} else {
console.error('image data not understood');
// we are not dealing with an image, but with an array
// which is not normalized, so we can reset the range_image
this.image_material.uniforms.range_image.value = [0, 1];
var data = image.data;
if(data instanceof Uint8Array) {
type = THREE.UnsignedByteType;
} else if(data instanceof Float64Array) {
console.warn('ImageGLView.data is a Float64Array which WebGL does not support, will convert to a Float32Array (consider sending float32 data for better performance).');
data = Float32Array.from(data);
type = THREE.FloatType;
} else if(data instanceof Float32Array) {
type = THREE.FloatType;
} else {
console.error('only types uint8 and float32 are supported');
return;
}
if(this.scales.image.model.get('scheme') && image.shape.length == 2) {
if(this.texture)
this.texture.dispose();
this.texture = new THREE.DataTexture(data, image.shape[1], image.shape[0], THREE.LuminanceFormat, type);
this.texture.needsUpdate = true;
this.image_material.uniforms.image.value = this.texture;
this.image_material.defines.USE_COLORMAP = true;
this.image_material.needsUpdate = true;
} else if(image.shape.length == 3) {
this.image_material.defines.USE_COLORMAP = false;
if(this.texture)
this.texture.dispose();
if(image.shape[2] == 3)
this.texture = new THREE.DataTexture(data, image.shape[1], image.shape[0], THREE.RGBFormat, type);
if(image.shape[2] == 4)
this.texture = new THREE.DataTexture(data, image.shape[1], image.shape[0], THREE.RGBAFormat, type);
this.texture.needsUpdate = true;
this.image_material.uniforms.image.value = this.texture;
} else {
console.error('image data not understood');
}
}
this.texture.magFilter = interpolations[this.model.get('interpolation')];
this.texture.minFilter = interpolations[this.model.get('interpolation')];
Expand Down
11 changes: 9 additions & 2 deletions js/shaders/image-fragment.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ uniform vec2 domain_y;
uniform vec2 image_domain_x;
uniform vec2 image_domain_y;

uniform vec2 range_image;


bool isnan(float val)
{
return (val < 0.0 || 0.0 < val || val == 0.0) ? false : true;
Expand All @@ -32,7 +35,10 @@ void main(void) {
float y_normalized = scale_transform_linear(y_domain_value, vec2(0., 1.), image_domain_y);
vec2 tex_uv = vec2(x_normalized, y_normalized);
#ifdef USE_COLORMAP
float raw_value = texture2D(image, tex_uv).r;
// r (or g or b) is used for the value, alpha for the mask (is 0 if a nan is found)
vec2 pixel_value = texture2D(image, tex_uv).ra;
float raw_value = pixel_value[0] * (range_image[1] - range_image[0]) + range_image[0];
float opacity_image = pixel_value[1];
float value = (raw_value - color_min) / (color_max - color_min);
vec4 color;
if(isnan(value)) // nan's are interpreted as missing values, and 'not shown'
Expand All @@ -41,8 +47,9 @@ void main(void) {
color = texture2D(colormap, vec2(value, 0.5));
#else
vec4 color = texture2D(image, tex_uv);
float opacity_image = 1.0;
#endif
// since we're working with pre multiplied colors (regarding blending)
// we also need to multiply rgb by opacity
gl_FragColor = color * opacity;
gl_FragColor = color * opacity * opacity_image;
}
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@
include_package_data=True,
install_requires=[
'ipywidgets>=7.0.0',
'bqplot>=0.12'
'bqplot>=0.12',
'pillow',
],
packages=find_packages(),
zip_safe=False,
Expand Down
Loading

0 comments on commit 57c7230

Please sign in to comment.