Skip to content

Commit 57c7230

Browse files
feat: can compress image data before sending it to over the websocket
This is useful when the widget is being used over a non-local network. This can reduce the network traffic by a factor of 80 (for smooth, easy to compress images). Pure noise image (random pixels) will not compress well but will still see a factor of 7 reduction in size, due to using uint8 instead of float64.
1 parent ba16715 commit 57c7230

13 files changed

+299
-47
lines changed

README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,17 @@ Used for https://github.com/glue-viz/glue-jupyter
55

66
(currently requires latest developer version of bqplot)
77

8+
## Usage
9+
10+
### ImageGL
11+
12+
See https://py.cafe/maartenbreddels/bqplot-image-gl-demo for a demo of the ImageGL widget.
13+
14+
Preview image:
15+
![preview image](https://py.cafe/preview/maartenbreddels/bqplot-image-gl-demo)
16+
17+
18+
819
# Installation
920

1021
To install use pip:

bqplot_image_gl/imagegl.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import ipywidgets as widgets
23
import bqplot
34
from traittypes import Array
@@ -6,10 +7,15 @@
67
from bqplot.marks import shape
78
from bqplot.traits import array_to_json, array_from_json
89
from bqplot_image_gl._version import __version__
10+
from .serialize import image_data_serialization
911

1012
__all__ = ['ImageGL', 'Contour']
1113

1214

15+
# can be 'png', 'webp' or 'none'
16+
DEFAULT_IMAGE_DATA_COMPRESSION = os.environ.get("BQPLOT_IMAGE_GL_IMAGE_DATA_COMPRESSION", "none")
17+
18+
1319
@widgets.register
1420
class ImageGL(bqplot.Mark):
1521
"""An example widget."""
@@ -24,7 +30,8 @@ class ImageGL(bqplot.Mark):
2430
scaled=True,
2531
rtype='Color',
2632
atype='bqplot.ColorAxis',
27-
**array_serialization)
33+
**image_data_serialization)
34+
compression = Unicode(DEFAULT_IMAGE_DATA_COMPRESSION, allow_none=True).tag(sync=True)
2835
interpolation = Unicode('nearest', allow_none=True).tag(sync=True)
2936
opacity = Float(1.0).tag(sync=True)
3037
x = Array(default_value=(0, 1)).tag(sync=True, scaled=True,

bqplot_image_gl/serialize.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
from PIL import Image
2+
import numpy as np
3+
import io
4+
5+
from bqplot.traits import array_serialization
6+
7+
8+
def array_to_image_or_array(array, widget):
9+
if widget.compression in ["png", "webp"]:
10+
return array_to_image(array, widget.compression)
11+
else:
12+
return array_serialization["to_json"](array, widget)
13+
14+
15+
def not_implemented(image):
16+
# the widget never sends the image data back to the kernel
17+
raise NotImplementedError("deserializing is not implemented yet")
18+
19+
20+
def array_to_image(array, image_format):
21+
# convert the array to a png image with intensity values only
22+
array = np.array(array, copy=False)
23+
min, max = None, None
24+
use_colormap = False
25+
if array.ndim == 2:
26+
use_colormap = True
27+
min = np.nanmin(array)
28+
max = np.nanmax(array)
29+
30+
array = (array - min) / (max - min)
31+
# only convert to uint8 if the array is float
32+
if array.dtype.kind == "f":
33+
array_bytes = (array * 255).astype(np.uint8)
34+
else:
35+
array_bytes = array
36+
intensity_image = Image.fromarray(array_bytes, mode="L")
37+
38+
# create a mask image with 0 for NaN values and 255 for valid values
39+
isnan = ~np.isnan(array)
40+
mask = (isnan * 255).astype(np.uint8)
41+
mask_image = Image.fromarray(mask, mode="L")
42+
43+
# merge the intensity and mask image into a single image
44+
image = Image.merge("LA", (intensity_image, mask_image))
45+
else:
46+
# if floats, convert to uint8
47+
if array.dtype.kind == "f":
48+
array_bytes = (array * 255).astype(np.uint8)
49+
elif array.dtype == np.uint8:
50+
array_bytes = array
51+
else:
52+
raise ValueError(
53+
"Only float arrays or uint8 arrays are supported, your array has dtype"
54+
"{array.dtype}"
55+
)
56+
if array.shape[2] == 3:
57+
image = Image.fromarray(array_bytes, mode="RGB")
58+
elif array.shape[2] == 4:
59+
image = Image.fromarray(array_bytes, mode="RGBA")
60+
else:
61+
raise ValueError(
62+
"Only 2D arrays or 3D arrays with 3 or 4 channels are supported, "
63+
f"your array has shape {array.shape}"
64+
)
65+
66+
# and serialize it to a PNG
67+
png_data = io.BytesIO()
68+
image.save(png_data, format=image_format, lossless=True)
69+
png_bytes = png_data.getvalue()
70+
original_byte_length = array.nbytes
71+
uint8_byte_length = array_bytes.nbytes
72+
compressed_byte_length = len(png_bytes)
73+
return {
74+
"type": "image",
75+
"format": image_format,
76+
"use_colormap": use_colormap,
77+
"min": min,
78+
"max": max,
79+
"data": png_bytes,
80+
# this metadata is only useful/needed for debugging
81+
"shape": array.shape,
82+
"info": {
83+
"original_byte_length": original_byte_length,
84+
"uint8_byte_length": uint8_byte_length,
85+
"compressed_byte_length": compressed_byte_length,
86+
"compression_ratio": original_byte_length / compressed_byte_length,
87+
"MB": {
88+
"original": original_byte_length / 1024 / 1024,
89+
"uint8": uint8_byte_length / 1024 / 1024,
90+
"compressed": compressed_byte_length / 1024 / 1024,
91+
},
92+
},
93+
}
94+
95+
96+
image_data_serialization = dict(
97+
to_json=array_to_image_or_array, from_json=not_implemented
98+
)

js/lib/contour.js

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,40 @@ class ContourModel extends bqplot.MarkModel {
3434
this.update_data();
3535
}
3636

37-
update_data() {
37+
async update_data() {
3838
const image_widget = this.get('image');
3939
const level = this.get('level')
4040
// we support a single level or multiple
4141
this.thresholds = Array.isArray(level) ? level : [level];
4242
if(image_widget) {
4343
const image = image_widget.get('image')
44-
this.width = image.shape[1];
45-
this.height = image.shape[0];
44+
let data = null;
45+
if(image.image) {
46+
const imageNode = image.image;
47+
this.width = imageNode.width;
48+
this.height = imageNode.height;
49+
// conver the image to a typed array using canvas
50+
const canvas = document.createElement('canvas');
51+
canvas.width = this.width
52+
canvas.height = this.height
53+
const ctx = canvas.getContext('2d');
54+
ctx.drawImage(imageNode, 0, 0);
55+
const imageData = ctx.getImageData(0, 0, imageNode.width, imageNode.height);
56+
const {min, max} = image;
57+
// use the r channel as the data, and scale to the range
58+
data = new Float32Array(imageData.data.length / 4);
59+
for(var i = 0; i < data.length; i++) {
60+
data[i] = (imageData.data[i*4] / 255) * (max - min) + min;
61+
}
62+
} else {
63+
this.width = image.shape[1];
64+
this.height = image.shape[0];
65+
data = image.data;
66+
}
4667
this.contours = this.thresholds.map((threshold) => d3contour
4768
.contours()
4869
.size([this.width, this.height])
49-
.contour(image.data, [threshold])
70+
.contour(data, [threshold])
5071
)
5172
} else {
5273
this.width = 1; // precomputed contour_lines will have to be in normalized

js/lib/imagegl.js

Lines changed: 82 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,24 @@ class ImageGLModel extends bqplot.MarkModel {
3838
super.initialize(attributes, options);
3939
this.on_some_change(['x', 'y'], this.update_data, this);
4040
this.on_some_change(["preserve_domain"], this.update_domains, this);
41+
this.listenTo(this, "change:image", () => {
42+
const previous = this.previous("image");
43+
if(previous.image && previous.image.src) {
44+
URL.revokeObjectURL(previous.image.src);
45+
}
46+
}, this);
47+
4148
this.update_data();
4249
}
4350

51+
close(comm_closed) {
52+
const image = this.get("image");
53+
if(image.image && image.image.src) {
54+
URL.revokeObjectURL(previous.image.src);
55+
}
56+
return super.close(comm_closed);
57+
}
58+
4459
update_data() {
4560
this.mark_data = {
4661
x: this.get("x"), y: this.get("y")
@@ -79,9 +94,24 @@ ImageGLModel.serializers = Object.assign({}, bqplot.MarkModel.serializers,
7994
{ x: serialize.array_or_json,
8095
y: serialize.array_or_json,
8196
image: {
82-
deserialize: (obj, manager) => {
83-
let state = {buffer: obj.value, dtype: obj.dtype, shape: obj.shape};
84-
return jupyter_dataserializers.JSONToArray(state);
97+
deserialize: async (obj, manager) => {
98+
if(obj.type == "image") {
99+
// the data is encoded in an image with LA format
100+
// luminance for the intensity, alpha for the mask
101+
let image = new Image();
102+
const blob = new Blob([obj.data], {type: `image/${obj.format}`});
103+
const url = URL.createObjectURL(blob);
104+
image.src = url;
105+
await new Promise((resolve, reject) => {
106+
image.onload = resolve;
107+
image.onerror = reject;
108+
} );
109+
return {image, min: obj.min, max: obj.max, use_colormap: obj.use_colormap};
110+
} else {
111+
// otherwise just a 'normal' ndarray
112+
let state = {buffer: obj.value, dtype: obj.dtype, shape: obj.shape};
113+
return jupyter_dataserializers.JSONToArray(state);
114+
}
85115
},
86116
serialize: (ar) => {
87117
const {buffer, dtype, shape} = jupyter_dataserializers.arrayToJSON(ar);
@@ -114,6 +144,10 @@ class ImageGLView extends bqplot.Mark {
114144
// basically the corners of the image
115145
image_domain_x : { type: "2f", value: [0.0, 1.0] },
116146
image_domain_y : { type: "2f", value: [0.0, 1.0] },
147+
// in the case we use an image for the values, the image is normalized, and we need to scale
148+
// it back to a particular image range
149+
// This needs to be set to [0, 1] for array data (which is not normalized)
150+
range_image : { type: "2f", value: [0.0, 1.0] },
117151
// extra opacity value
118152
opacity: {type: 'f', value: 1.0}
119153
},
@@ -280,39 +314,56 @@ class ImageGLView extends bqplot.Mark {
280314
update_image(skip_render) {
281315
var image = this.model.get("image");
282316
var type = null;
283-
var data = image.data;
284-
if(data instanceof Uint8Array) {
285-
type = THREE.UnsignedByteType;
286-
} else if(data instanceof Float64Array) {
287-
console.warn('ImageGLView.data is a Float64Array which WebGL does not support, will convert to a Float32Array (consider sending float32 data for better performance).');
288-
data = Float32Array.from(data);
289-
type = THREE.FloatType;
290-
} else if(data instanceof Float32Array) {
291-
type = THREE.FloatType;
292-
} else {
293-
console.error('only types uint8 and float32 are supported');
294-
return;
295-
}
296-
if(this.scales.image.model.get('scheme') && image.shape.length == 2) {
297-
if(this.texture)
317+
if(image.image) {
318+
// the data is encoded in an image with LA format
319+
if(this.texture) {
298320
this.texture.dispose();
299-
this.texture = new THREE.DataTexture(data, image.shape[1], image.shape[0], THREE.LuminanceFormat, type);
321+
}
322+
this.texture = new THREE.Texture(image.image);
300323
this.texture.needsUpdate = true;
324+
this.texture.flipY = false;
301325
this.image_material.uniforms.image.value = this.texture;
302-
this.image_material.defines.USE_COLORMAP = true;
326+
this.image_material.defines.USE_COLORMAP = image.use_colormap;
303327
this.image_material.needsUpdate = true;
304-
} else if(image.shape.length == 3) {
305-
this.image_material.defines.USE_COLORMAP = false;
306-
if(this.texture)
307-
this.texture.dispose();
308-
if(image.shape[2] == 3)
309-
this.texture = new THREE.DataTexture(data, image.shape[1], image.shape[0], THREE.RGBFormat, type);
310-
if(image.shape[2] == 4)
311-
this.texture = new THREE.DataTexture(data, image.shape[1], image.shape[0], THREE.RGBAFormat, type);
312-
this.texture.needsUpdate = true;
313-
this.image_material.uniforms.image.value = this.texture;
328+
this.image_material.uniforms.range_image.value = [image.min, image.max];
314329
} else {
315-
console.error('image data not understood');
330+
// we are not dealing with an image, but with an array
331+
// which is not normalized, so we can reset the range_image
332+
this.image_material.uniforms.range_image.value = [0, 1];
333+
var data = image.data;
334+
if(data instanceof Uint8Array) {
335+
type = THREE.UnsignedByteType;
336+
} else if(data instanceof Float64Array) {
337+
console.warn('ImageGLView.data is a Float64Array which WebGL does not support, will convert to a Float32Array (consider sending float32 data for better performance).');
338+
data = Float32Array.from(data);
339+
type = THREE.FloatType;
340+
} else if(data instanceof Float32Array) {
341+
type = THREE.FloatType;
342+
} else {
343+
console.error('only types uint8 and float32 are supported');
344+
return;
345+
}
346+
if(this.scales.image.model.get('scheme') && image.shape.length == 2) {
347+
if(this.texture)
348+
this.texture.dispose();
349+
this.texture = new THREE.DataTexture(data, image.shape[1], image.shape[0], THREE.LuminanceFormat, type);
350+
this.texture.needsUpdate = true;
351+
this.image_material.uniforms.image.value = this.texture;
352+
this.image_material.defines.USE_COLORMAP = true;
353+
this.image_material.needsUpdate = true;
354+
} else if(image.shape.length == 3) {
355+
this.image_material.defines.USE_COLORMAP = false;
356+
if(this.texture)
357+
this.texture.dispose();
358+
if(image.shape[2] == 3)
359+
this.texture = new THREE.DataTexture(data, image.shape[1], image.shape[0], THREE.RGBFormat, type);
360+
if(image.shape[2] == 4)
361+
this.texture = new THREE.DataTexture(data, image.shape[1], image.shape[0], THREE.RGBAFormat, type);
362+
this.texture.needsUpdate = true;
363+
this.image_material.uniforms.image.value = this.texture;
364+
} else {
365+
console.error('image data not understood');
366+
}
316367
}
317368
this.texture.magFilter = interpolations[this.model.get('interpolation')];
318369
this.texture.minFilter = interpolations[this.model.get('interpolation')];

js/shaders/image-fragment.glsl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ uniform vec2 domain_y;
1616
uniform vec2 image_domain_x;
1717
uniform vec2 image_domain_y;
1818

19+
uniform vec2 range_image;
20+
21+
1922
bool isnan(float val)
2023
{
2124
return (val < 0.0 || 0.0 < val || val == 0.0) ? false : true;
@@ -32,7 +35,10 @@ void main(void) {
3235
float y_normalized = scale_transform_linear(y_domain_value, vec2(0., 1.), image_domain_y);
3336
vec2 tex_uv = vec2(x_normalized, y_normalized);
3437
#ifdef USE_COLORMAP
35-
float raw_value = texture2D(image, tex_uv).r;
38+
// r (or g or b) is used for the value, alpha for the mask (is 0 if a nan is found)
39+
vec2 pixel_value = texture2D(image, tex_uv).ra;
40+
float raw_value = pixel_value[0] * (range_image[1] - range_image[0]) + range_image[0];
41+
float opacity_image = pixel_value[1];
3642
float value = (raw_value - color_min) / (color_max - color_min);
3743
vec4 color;
3844
if(isnan(value)) // nan's are interpreted as missing values, and 'not shown'
@@ -41,8 +47,9 @@ void main(void) {
4147
color = texture2D(colormap, vec2(value, 0.5));
4248
#else
4349
vec4 color = texture2D(image, tex_uv);
50+
float opacity_image = 1.0;
4451
#endif
4552
// since we're working with pre multiplied colors (regarding blending)
4653
// we also need to multiply rgb by opacity
47-
gl_FragColor = color * opacity;
54+
gl_FragColor = color * opacity * opacity_image;
4855
}

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@
5151
include_package_data=True,
5252
install_requires=[
5353
'ipywidgets>=7.0.0',
54-
'bqplot>=0.12'
54+
'bqplot>=0.12',
55+
'pillow',
5556
],
5657
packages=find_packages(),
5758
zip_safe=False,

0 commit comments

Comments
 (0)