Skip to content

Commit 919aa9f

Browse files
NNotepad: fix load(), add zeros(), improve tensor output display (#247)
- The load() helper was broken when the scripts were converted to modules, and also when array parsing was updated. Oops! Fix it. - Add zeros() helper that creates a zero-initized tensor of the given shape. Useful for testing shape display. - Ensure tensor output spaces n-dimensions tensors the same as numpy - more spaces between each layer. - Ensure tensor output handles 0-sized dimensions.
1 parent 1e0ad50 commit 919aa9f

File tree

4 files changed

+25
-9
lines changed

4 files changed

+25
-9
lines changed

nnotepad/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ The default [data type](https://webmachinelearning.github.io/webnn/#enumdef-mlop
4949
In addition to WebNN [`MLGraphBuilder`](https://webmachinelearning.github.io/webnn/#mlgraphbuilder) methods, you can use these helpers:
5050

5151
* **load(_url_, _shape_, _dataType_)** - fetch a tensor resource. Must be served with appropriate [CORS](https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS) headers. Example: `load('https://www.random.org/cgi-bin/randbyte?nbytes=256', [16, 16], 'uint8')`
52+
* **zeros(_shape_, _dataType_)** - constant zero-filled tensor of the given shape. Example: `zeros([2,2,2,2], 'int8')`
5253

5354

5455
# Details & Gotchas

nnotepad/js/index.js

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ function explain(outputs) {
102102

103103
const width = [...buffer]
104104
.map((n) => String(n).length)
105-
.reduce((a, b) => Math.max(a, b));
105+
.reduce((a, b) => Math.max(a, b), 0);
106106

107107
const out = [];
108108
let bufferIndex = 0;
@@ -118,10 +118,7 @@ function explain(outputs) {
118118
if (i !== shape[dim] - 1) {
119119
out.push(', ');
120120
if (dim + 1 !== shape.length) {
121-
if (dim + 2 !== shape.length) {
122-
out.push('\n');
123-
}
124-
out.push('\n');
121+
out.push('\n'.repeat(shape.length - dim - 1));
125122
out.push(' '.repeat(indent + dim + 1));
126123
}
127124
}

nnotepad/js/nnotepad.js

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ export class NNotepad {
368368
.join('');
369369

370370
const AsyncFunction = async function() {}.constructor;
371-
return [new AsyncFunction(['_'], src), src];
371+
return [new AsyncFunction(['_', 'Util'], src), src];
372372

373373
function serializeLine(line, last) {
374374
const expr = serializeExpr(line.expr);
@@ -476,18 +476,35 @@ export class NNotepad {
476476
if (url.type !== 'string') {
477477
throw new TypeError('load(): expected string');
478478
}
479-
if (shape.type !== 'tensor') {
479+
if (shape.type !== 'array') {
480480
throw new TypeError('load(): expected array');
481481
}
482482
if (dataType.type !== 'string') {
483483
throw new TypeError('load(): expected string');
484484
}
485+
const dims = shape.value.map((expr) => expr.value);
485486
const ctor = WebNNUtil.dataTypeToBufferType(dataType.value);
486487
return `_.constant({dataType: "${dataType.value}", dimensions: ${
487-
Util.stringify(shape.value)}}, new ${
488+
Util.stringify(dims)}}, new ${
488489
ctor.name}(await Util.loadBuffer(${Util.stringify(url.value)})))`;
489490
}
490491

492+
if (name === 'zeros') {
493+
const [shape, dataType] = args;
494+
if (shape.type !== 'array') {
495+
throw new TypeError('zeros(): expected array');
496+
}
497+
if (dataType.type !== 'string') {
498+
throw new TypeError('zeros(): expected string');
499+
}
500+
const dims = shape.value.map((expr) => expr.value);
501+
const ctor = WebNNUtil.dataTypeToBufferType(dataType.value);
502+
const len = dims.reduce((a, b) => a * b, 1);
503+
return `_.constant({dataType: "${dataType.value}", dimensions: ${
504+
Util.stringify(dims)}}, new ${
505+
ctor.name}(${len}))`;
506+
}
507+
491508
return `_.${name}(${
492509
args.map(
493510
(arg, index) => serializeExpr(
@@ -509,7 +526,7 @@ export class NNotepad {
509526
const builder = new self.MLGraphBuilder(context);
510527

511528
const outputOperands = [];
512-
let output = await builderFunc(builder);
529+
let output = await builderFunc(builder, Util);
513530
if (output instanceof self.MLOperand) {
514531
// TODO: remove try/catch once all back-ends support `identity()`.
515532
try {

nnotepad/res/docs.html

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ <h1>Helpers</h1>
7777
<p>In addition to WebNN <a href="https://webmachinelearning.github.io/webnn/#mlgraphbuilder"><code>MLGraphBuilder</code></a> methods, you can use these helpers:</p>
7878
<ul>
7979
<li><strong>load(<em>url</em>, <em>shape</em>, <em>dataType</em>)</strong> - fetch a tensor resource. Must be served with appropriate <a href="https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS">CORS</a> headers. Example: <code>load('https://www.random.org/cgi-bin/randbyte?nbytes=256', [16, 16], 'uint8')</code></li>
80+
<li><strong>zeros(<em>shape</em>, <em>dataType</em>)</strong> - constant zero-filled tensor of the given shape. Example: <code>zeros([2,2,2,2], 'int8')</code></li>
8081
</ul>
8182
<h1>Details &amp; Gotchas</h1>
8283
<ul>

0 commit comments

Comments
 (0)