From f128575824a2d8812bae10254d0de0904ca036ad Mon Sep 17 00:00:00 2001 From: Joshua Bell Date: Tue, 14 Jan 2025 16:59:24 -0800 Subject: [PATCH] NNotepad: Infer argument type for dictionary members Extend the WebNNUtil.argumentType() helper to handle option dictionary members, to improve the user experience. Previously, number and array literals were serialized as JS numbers and arrays of numbers. This was great for `linear(1, {alpha: 2, beta: 3})` and `transpose(T, {permutation: [0,2,1]})`. But it meant MLOperand dict members required passing via variables or using identity(), e.g. `gemm(A, B, {c: identity(123)})` Now you can just write `gemm(A, B, {c: 123})`, but `linear()` and `transpose()` still work as before. --- nnotepad/README.md | 8 +-- nnotepad/js/nnotepad.js | 151 ++++++++++++++++++++++++++++++++++------ nnotepad/js/tests.js | 8 +++ 3 files changed, 141 insertions(+), 26 deletions(-) diff --git a/nnotepad/README.md b/nnotepad/README.md index cb8a8b44..dc768fb9 100644 --- a/nnotepad/README.md +++ b/nnotepad/README.md @@ -36,10 +36,10 @@ Functions and operators are turned into [`MLGraphBuilder`](https://webmachinelea Array literals (`[...]`) and number literals (`12.34`) are interpreted contextually: -* In assignments, they are intepreted as tensor/scalar constant [`MLOperand`](https://webmachinelearning.github.io/webnn/#mloperand)s, e.g. `alpha = 12.34` or `T = [1,2,3,4]`. -* In most function calls, they are interpreted as tensor/scalar constant [`MLOperand`](https://webmachinelearning.github.io/webnn/#mloperand)s, e.g. `neg(123)` or `neg([1,2,3])`. -* In some function calls, they are interpreted as arrays/numbers for some positional parameters, e.g. `concat([A,B,C],0)`. This includes: [`concat()`](https://webmachinelearning.github.io/webnn/#dom-mlgraphbuilder-concat), [`expand()`](https://webmachinelearning.github.io/webnn/#dom-mlgraphbuilder-expand), [`pad()`](https://webmachinelearning.github.io/webnn/#dom-mlgraphbuilder-pad), [`reshape()`](https://webmachinelearning.github.io/webnn/#dom-mlgraphbuilder-reshape), [`slice()`](https://webmachinelearning.github.io/webnn/#dom-mlgraphbuilder-slice), [`split()`](https://webmachinelearning.github.io/webnn/#dom-mlgraphbuilder-split). -* In dictionaries, they are interpreted as arrays/numbers, e.g. `linear(123, {alpha: 456, beta: 789})` or `transpose(T, {permutation: [0,2,1]})`. To pass a tensor/scalar constant in a dictionary, use a variable or wrap it in [`identity()`](https://webmachinelearning.github.io/webnn/#dom-mlgraphbuilder-identity) e.g. `gemm(A, B, {c:identity([4])})` or `gemm(A, B, {c:identity(4)})`. +* In assignments, they are intepreted as tensor/scalar constants [`MLOperand`](https://webmachinelearning.github.io/webnn/#mloperand)s, e.g. `alpha = 12.34` (scalar) or `T = [1,2,3,4]` (tensor). +* As arguments in function calls, they are interpreted depending on the argument definition, e.g. `neg(123)` (scalar), `neg([1,2,3])` (tensor), `concat([A,B,C],0)` (number). +* In options dictionaries inside function calls, they are interpreted depending on the dictionary definition. e.g. `linear(123, {alpha: 456, beta: 789})` (numbers), `transpose(T, {permutation: [0,2,1]})` (array of numbers), `gemm(A, B, {c: 123})` (scalar), `gemm(A, B, {c: [123]})` (tensor). +* In dictionaries outside of function calls, they are interpreted as arrays/numbers, e.g. `options = {alpha: 456, beta: 789})`. To pass a tensor/scalar constant in a dictionary, use a variable or wrap it in [`identity()`](https://webmachinelearning.github.io/webnn/#dom-mlgraphbuilder-identity) e.g. `options = {c:identity(4)} gemm(A, B, options)`. The default [data type](https://webmachinelearning.github.io/webnn/#enumdef-mloperanddatatype) for scalars and tensors is [`float32`](https://webmachinelearning.github.io/webnn/#dom-mloperanddatatype-float32). To specify a different data type, suffix with one of `i8`, `u8`, `i32`, `u32`, `i64`, `u64`, `f16`, `f32`, e.g. `123i8` or `[1,2,3]u32`. diff --git a/nnotepad/js/nnotepad.js b/nnotepad/js/nnotepad.js index 21906a63..ea603537 100644 --- a/nnotepad/js/nnotepad.js +++ b/nnotepad/js/nnotepad.js @@ -82,22 +82,114 @@ class WebNNUtil { throw new Error(`Unsupported dataType ${type}`); } - static argumentType(name, index) { - return ({ - concat: {0: kArgTypeOperandList, 1: kArgTypeNonOperand}, - expand: {1: kArgTypeNonOperand}, - gru: {3: kArgTypeNonOperand, 4: kArgTypeNonOperand}, - gruCell: {4: kArgTypeNonOperand}, - lstm: {3: kArgTypeNonOperand, 4: kArgTypeNonOperand}, - lstmCell: {5: kArgTypeNonOperand}, - pad: {1: kArgTypeNonOperand, 2: kArgTypeNonOperand}, - reshape: {1: kArgTypeNonOperand}, - slice: {1: kArgTypeNonOperand, 2: kArgTypeNonOperand}, - softmax: {1: kArgTypeNonOperand}, - split: {1: kArgTypeNonOperand}, - })[name] - ?.[index] || - kArgTypeOperand; + // Called to determine the type of an argument. `name` is the name of the + // `MLGraphBuilder` method. `index` is the argument index. If `key` is + // provided, this is serializing a member of an options dictionary. Returns + // one of the `kArgTypeXYZ` values. + static argumentType(name, index, key) { + const kDefaultDictMemberType = kArgTypeNonOperand; + const kDefaultArgType = kArgTypeOperand; + + // TODO: Auto-generate this from the WebIDL API definition. + const argType = ({ + batchNormalization: { + 3: { + scale: kArgTypeOperand, + bias: kArgTypeOperand, + }, + }, + concat: { + 0: kArgTypeOperandList, + 1: kArgTypeNonOperand}, + conv2d: { + 2: { + bias: kArgTypeOperand, + }, + }, + convTranspose2d: { + 2: { + bias: kArgTypeOperand, + }, + }, + expand: { + 1: kArgTypeNonOperand, + }, + gemm: { + 2: { + c: kArgTypeOperand, + }, + }, + gru: { + 3: kArgTypeNonOperand, + 4: kArgTypeNonOperand, + 5: { + bias: kArgTypeOperand, + recurrentBias: kArgTypeOperand, + initialHiddenState: kArgTypeOperand, + }, + }, + gruCell: { + 4: kArgTypeNonOperand, + 5: { + bias: kArgTypeOperand, + recurrentBias: kArgTypeOperand, + }, + }, + instanceNormalization: { + 1: { + scale: kArgTypeOperand, + bias: kArgTypeOperand, + }, + }, + layerNormalization: { + 1: { + scale: kArgTypeOperand, + bias: kArgTypeOperand, + }, + }, + lstm: { + 3: kArgTypeNonOperand, + 4: kArgTypeNonOperand, + 5: { + bias: kArgTypeOperand, + recurrentBias: kArgTypeOperand, + peepholeWeight: kArgTypeOperand, + initialHiddenState: kArgTypeOperand, + initialCellState: kArgTypeOperand, + }, + }, + lstmCell: { + 5: kArgTypeNonOperand, + 6: { + bias: kArgTypeOperand, + recurrentBias: kArgTypeOperand, + peepholeWeight: kArgTypeOperand, + }, + }, + pad: { + 1: kArgTypeNonOperand, + 2: kArgTypeNonOperand, + }, + reshape: { + 1: kArgTypeNonOperand, + }, + slice: { + 1: kArgTypeNonOperand, + 2: kArgTypeNonOperand, + }, + softmax: { + 1: kArgTypeNonOperand, + }, + split: { + 1: kArgTypeNonOperand, + }, + })[name]?.[index]; + + if (key) { + return argType?.[key] ?? kDefaultDictMemberType; + } + + return argType ?? kDefaultArgType; } } @@ -401,7 +493,18 @@ export class NNotepad { } throw new Error(`unexpected line type: ${line.type}`); } - function serializeExpr(expr, argumentType = kArgTypeOperand) { + + // Serialize an expression. If `callContext` is provided, it can either be + // an object with `name` and `index` properties which identify a method call + // and argument position, used to determine the argument type, or an + // `kArgTypeXYZ` value to explicitly specify the type. This is needed for + // numbers, arrays, and dictionary members, which are serialized + // contextually. + function serializeExpr(expr, callContext) { + const argumentType = typeof callContext === 'object' ? + WebNNUtil.argumentType(callContext.name, callContext.index) : + typeof callContext === 'number' ? callContext : + kArgTypeOperand; if (expr.op) { if (expr.lhs) { return `_.${kBinaryOperators[expr.op]}(${serializeExpr(expr.lhs)}, ${ @@ -432,7 +535,7 @@ export class NNotepad { return serializeTensor(expr.value, expr.dataType); } case 'dict': - return serializeDict(expr.dict); + return serializeDict(expr.dict, callContext); case 'identifier': return expr.value; case 'call': @@ -440,13 +543,17 @@ export class NNotepad { } throw new Error(`unexpected expr type: ${expr.type}`); } - function serializeDict(dict) { + function serializeDict(dict, callContext) { return '{' + Object.keys(dict) .map((k) => { const v = dict[k]; - k = Util.stringify(k); - return `${k}: ${serializeExpr(v, kArgTypeNonOperand)}`; + const argumentType = typeof callContext === 'object' ? + WebNNUtil.argumentType( + callContext.name, callContext.index, k) : + kArgTypeNonOperand; + return `${Util.stringify(k)}: ${ + serializeExpr(v, argumentType)}`; }) .join(', ') + '}'; @@ -545,7 +652,7 @@ export class NNotepad { return `_.${name}(${ args.map( (arg, index) => - serializeExpr(arg, WebNNUtil.argumentType(name, index))) + serializeExpr(arg, {name, index})) .join(', ')})`; } } diff --git a/nnotepad/js/tests.js b/nnotepad/js/tests.js index c789dd20..33767c42 100644 --- a/nnotepad/js/tests.js +++ b/nnotepad/js/tests.js @@ -178,6 +178,14 @@ document.addEventListener('DOMContentLoaded', async (e) => { `softmax([1], 0)`, {dataType: 'float32', shape: [1], buffer: [1]}); + Harness.section('Optional operand arguments'); + await test( + 'A = [[1,2], [3,4]] B = [[5,6], [7,8]] gemm(A, B, {c: 123})', + {dataType: 'float32', shape: [2, 2], buffer: [142, 145, 166, 173]}); + await test( + 'instanceNormalization([[[[1]]]], {scale: [123], bias: [456]})', + {dataType: 'float32', shape: [1, 1, 1, 1], buffer: [456]}); + Harness.section('Regression tests'); await test( `concat([[1,2],[3,4]], 0)`,