Skip to content

Commit 76a2428

Browse files
committed
Deploying to gh-pages from @ ed1ece2 🚀
1 parent 43ac827 commit 76a2428

File tree

2 files changed

+54
-24
lines changed

2 files changed

+54
-24
lines changed

nnotepad/js/nnotepad.js

+40-23
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ export class ComputeError extends Error {
3131
// General WebNN Utilities
3232
// ============================================================
3333

34+
const kArgTypeOperandList = 1;
35+
const kArgTypeNonOperand = 2;
36+
const kArgTypeOperand = 3;
37+
3438
class WebNNUtil {
3539
static bufferForOperand(operand) {
3640
const size = [...operand.shape()].reduce((a, b) => a * b, 1);
@@ -60,21 +64,22 @@ class WebNNUtil {
6064
throw new Error(`Unsupported dataType ${type}`);
6165
}
6266

63-
static isNonOperandArg(name, index) {
67+
static argumentType(name, index) {
6468
return ({
65-
concat: [0, 1],
66-
expand: [1],
67-
gru: [3, 4],
68-
gruCell: [4],
69-
lstm: [3, 4],
70-
lstmCell: [5],
71-
pad: [1, 2],
72-
reshape: [1],
73-
slice: [1, 2],
74-
softmax: [1], // TODO: Distinguish overloads
75-
split: [1],
69+
concat: {0: kArgTypeOperandList, 1: kArgTypeNonOperand},
70+
expand: {1: kArgTypeNonOperand},
71+
gru: {3: kArgTypeNonOperand, 4: kArgTypeNonOperand},
72+
gruCell: {4: kArgTypeNonOperand},
73+
lstm: {3: kArgTypeNonOperand, 4: kArgTypeNonOperand},
74+
lstmCell: {5: kArgTypeNonOperand},
75+
pad: {1: kArgTypeNonOperand, 2: kArgTypeNonOperand},
76+
reshape: {1: kArgTypeNonOperand},
77+
slice: {1: kArgTypeNonOperand, 2: kArgTypeNonOperand},
78+
softmax: {1: kArgTypeNonOperand},
79+
split: {1: kArgTypeNonOperand},
7680
})[name]
77-
?.includes(index);
81+
?.[index] ||
82+
kArgTypeOperand;
7883
}
7984
}
8085

@@ -379,7 +384,7 @@ export class NNotepad {
379384
}
380385
throw new Error(`unexpected line type: ${line.type}`);
381386
}
382-
function serializeExpr(expr, nonOperand = false) {
387+
function serializeExpr(expr, argumentType = kArgTypeOperand) {
383388
if (expr.op) {
384389
if (expr.lhs) {
385390
return `_.${kBinaryOperators[expr.op]}(${serializeExpr(expr.lhs)}, ${
@@ -394,11 +399,21 @@ export class NNotepad {
394399
case 'boolean':
395400
return String(expr.value);
396401
case 'number':
397-
return nonOperand ? Util.stringify(expr.value) :
398-
serializeScalar(expr.value, expr.dataType);
402+
switch (argumentType) {
403+
case kArgTypeNonOperand:
404+
return Util.stringify(expr.value);
405+
default:
406+
return serializeScalar(expr.value, expr.dataType);
407+
}
399408
case 'array':
400-
return nonOperand ? serializeArray(expr.value) :
401-
serializeTensor(expr.value, expr.dataType);
409+
switch (argumentType) {
410+
case kArgTypeNonOperand:
411+
return serializeArray(expr.value, kArgTypeNonOperand);
412+
case kArgTypeOperandList:
413+
return serializeArray(expr.value, kArgTypeOperand);
414+
default:
415+
return serializeTensor(expr.value, expr.dataType);
416+
}
402417
case 'dict':
403418
return serializeDict(expr.dict);
404419
case 'identifier':
@@ -414,7 +429,7 @@ export class NNotepad {
414429
.map((k) => {
415430
const v = dict[k];
416431
k = Util.stringify(k);
417-
return `${k}: ${serializeExpr(v, true)}`;
432+
return `${k}: ${serializeExpr(v, kArgTypeNonOperand)}`;
418433
})
419434
.join(', ') +
420435
'}';
@@ -465,8 +480,10 @@ export class NNotepad {
465480
elements.map((n) => Util.stringifyNumber(n, dataType)).join(',')}]))`;
466481
}
467482

468-
function serializeArray(array) {
469-
return '[' + array.map((expr) => serializeExpr(expr)).join(', ') + ']';
483+
function serializeArray(array, argumentType) {
484+
return '[' +
485+
array.map((expr) => serializeExpr(expr, argumentType)).join(', ') +
486+
']';
470487
}
471488

472489
function serializeCall(name, args) {
@@ -506,8 +523,8 @@ export class NNotepad {
506523

507524
return `_.${name}(${
508525
args.map(
509-
(arg, index) => serializeExpr(
510-
arg, WebNNUtil.isNonOperandArg(name, index)))
526+
(arg, index) =>
527+
serializeExpr(arg, WebNNUtil.argumentType(name, index)))
511528
.join(', ')})`;
512529
}
513530
}

nnotepad/js/tests.js

+14-1
Original file line numberDiff line numberDiff line change
@@ -157,14 +157,27 @@ document.addEventListener('DOMContentLoaded', async (e) => {
157157
{dataType: 'float32', shape: [2], buffer: [3, 4]},
158158
]);
159159

160-
Harness.section('Multiple input tensors');
160+
Harness.section('Non-operand arguments: array of operands');
161161
await test(
162162
`A = [1,2] B = [3,4] concat([A,B], 0)`,
163163
{dataType: 'float32', shape: [4], buffer: [1, 2, 3, 4]});
164164
await test(
165165
`concat([identity([1,2]),identity([3,4])], 0)`,
166166
{dataType: 'float32', shape: [4], buffer: [1, 2, 3, 4]});
167167

168+
Harness.section('Non-operand arguments: array of numbers');
169+
await test(
170+
`T = [[1,2,3],[4,5,6]] reshape(T, [1, 3, 2, 1])`,
171+
{dataType: 'float32', shape: [1, 3, 2, 1], buffer: [1, 2, 3, 4, 5, 6]});
172+
await test(
173+
`expand([1], [2, 2])`,
174+
{dataType: 'float32', shape: [2, 2], buffer: [1, 1, 1, 1]});
175+
176+
Harness.section('Non-operand arguments: simple numbers');
177+
await test(
178+
`softmax([1], 0)`,
179+
{dataType: 'float32', shape: [1], buffer: [1]});
180+
168181
Harness.section('Regression tests');
169182
await test(
170183
`concat([[1,2],[3,4]], 0)`,

0 commit comments

Comments
 (0)