Skip to content

Commit

Permalink
Merge pull request #302 from inexorabletash/nnotepad-dicttypes
Browse files Browse the repository at this point in the history
NNotepad: Infer argument type for dictionary members
  • Loading branch information
inexorabletash authored Jan 16, 2025
2 parents 5db65aa + f128575 commit fae1916
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 26 deletions.
8 changes: 4 additions & 4 deletions nnotepad/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ Functions and operators are turned into [`MLGraphBuilder`](https://webmachinelea

Array literals (`[...]`) and number literals (`12.34`) are interpreted contextually:

* In assignments, they are intepreted as tensor/scalar constant [`MLOperand`](https://webmachinelearning.github.io/webnn/#mloperand)s, e.g. `alpha = 12.34` or `T = [1,2,3,4]`.
* In most function calls, they are interpreted as tensor/scalar constant [`MLOperand`](https://webmachinelearning.github.io/webnn/#mloperand)s, e.g. `neg(123)` or `neg([1,2,3])`.
* In some function calls, they are interpreted as arrays/numbers for some positional parameters, e.g. `concat([A,B,C],0)`. This includes: [`concat()`](https://webmachinelearning.github.io/webnn/#dom-mlgraphbuilder-concat), [`expand()`](https://webmachinelearning.github.io/webnn/#dom-mlgraphbuilder-expand), [`pad()`](https://webmachinelearning.github.io/webnn/#dom-mlgraphbuilder-pad), [`reshape()`](https://webmachinelearning.github.io/webnn/#dom-mlgraphbuilder-reshape), [`slice()`](https://webmachinelearning.github.io/webnn/#dom-mlgraphbuilder-slice), [`split()`](https://webmachinelearning.github.io/webnn/#dom-mlgraphbuilder-split).
* In dictionaries, they are interpreted as arrays/numbers, e.g. `linear(123, {alpha: 456, beta: 789})` or `transpose(T, {permutation: [0,2,1]})`. To pass a tensor/scalar constant in a dictionary, use a variable or wrap it in [`identity()`](https://webmachinelearning.github.io/webnn/#dom-mlgraphbuilder-identity) e.g. `gemm(A, B, {c:identity([4])})` or `gemm(A, B, {c:identity(4)})`.
* In assignments, they are intepreted as tensor/scalar constants [`MLOperand`](https://webmachinelearning.github.io/webnn/#mloperand)s, e.g. `alpha = 12.34` (scalar) or `T = [1,2,3,4]` (tensor).
* As arguments in function calls, they are interpreted depending on the argument definition, e.g. `neg(123)` (scalar), `neg([1,2,3])` (tensor), `concat([A,B,C],0)` (number).
* In options dictionaries inside function calls, they are interpreted depending on the dictionary definition. e.g. `linear(123, {alpha: 456, beta: 789})` (numbers), `transpose(T, {permutation: [0,2,1]})` (array of numbers), `gemm(A, B, {c: 123})` (scalar), `gemm(A, B, {c: [123]})` (tensor).
* In dictionaries outside of function calls, they are interpreted as arrays/numbers, e.g. `options = {alpha: 456, beta: 789})`. To pass a tensor/scalar constant in a dictionary, use a variable or wrap it in [`identity()`](https://webmachinelearning.github.io/webnn/#dom-mlgraphbuilder-identity) e.g. `options = {c:identity(4)} gemm(A, B, options)`.

The default [data type](https://webmachinelearning.github.io/webnn/#enumdef-mloperanddatatype) for scalars and tensors is [`float32`](https://webmachinelearning.github.io/webnn/#dom-mloperanddatatype-float32). To specify a different data type, suffix with one of `i8`, `u8`, `i32`, `u32`, `i64`, `u64`, `f16`, `f32`, e.g. `123i8` or `[1,2,3]u32`.

Expand Down
151 changes: 129 additions & 22 deletions nnotepad/js/nnotepad.js
Original file line number Diff line number Diff line change
Expand Up @@ -82,22 +82,114 @@ class WebNNUtil {
throw new Error(`Unsupported dataType ${type}`);
}

static argumentType(name, index) {
return ({
concat: {0: kArgTypeOperandList, 1: kArgTypeNonOperand},
expand: {1: kArgTypeNonOperand},
gru: {3: kArgTypeNonOperand, 4: kArgTypeNonOperand},
gruCell: {4: kArgTypeNonOperand},
lstm: {3: kArgTypeNonOperand, 4: kArgTypeNonOperand},
lstmCell: {5: kArgTypeNonOperand},
pad: {1: kArgTypeNonOperand, 2: kArgTypeNonOperand},
reshape: {1: kArgTypeNonOperand},
slice: {1: kArgTypeNonOperand, 2: kArgTypeNonOperand},
softmax: {1: kArgTypeNonOperand},
split: {1: kArgTypeNonOperand},
})[name]
?.[index] ||
kArgTypeOperand;
// Called to determine the type of an argument. `name` is the name of the
// `MLGraphBuilder` method. `index` is the argument index. If `key` is
// provided, this is serializing a member of an options dictionary. Returns
// one of the `kArgTypeXYZ` values.
static argumentType(name, index, key) {
const kDefaultDictMemberType = kArgTypeNonOperand;
const kDefaultArgType = kArgTypeOperand;

// TODO: Auto-generate this from the WebIDL API definition.
const argType = ({
batchNormalization: {
3: {
scale: kArgTypeOperand,
bias: kArgTypeOperand,
},
},
concat: {
0: kArgTypeOperandList,
1: kArgTypeNonOperand},
conv2d: {
2: {
bias: kArgTypeOperand,
},
},
convTranspose2d: {
2: {
bias: kArgTypeOperand,
},
},
expand: {
1: kArgTypeNonOperand,
},
gemm: {
2: {
c: kArgTypeOperand,
},
},
gru: {
3: kArgTypeNonOperand,
4: kArgTypeNonOperand,
5: {
bias: kArgTypeOperand,
recurrentBias: kArgTypeOperand,
initialHiddenState: kArgTypeOperand,
},
},
gruCell: {
4: kArgTypeNonOperand,
5: {
bias: kArgTypeOperand,
recurrentBias: kArgTypeOperand,
},
},
instanceNormalization: {
1: {
scale: kArgTypeOperand,
bias: kArgTypeOperand,
},
},
layerNormalization: {
1: {
scale: kArgTypeOperand,
bias: kArgTypeOperand,
},
},
lstm: {
3: kArgTypeNonOperand,
4: kArgTypeNonOperand,
5: {
bias: kArgTypeOperand,
recurrentBias: kArgTypeOperand,
peepholeWeight: kArgTypeOperand,
initialHiddenState: kArgTypeOperand,
initialCellState: kArgTypeOperand,
},
},
lstmCell: {
5: kArgTypeNonOperand,
6: {
bias: kArgTypeOperand,
recurrentBias: kArgTypeOperand,
peepholeWeight: kArgTypeOperand,
},
},
pad: {
1: kArgTypeNonOperand,
2: kArgTypeNonOperand,
},
reshape: {
1: kArgTypeNonOperand,
},
slice: {
1: kArgTypeNonOperand,
2: kArgTypeNonOperand,
},
softmax: {
1: kArgTypeNonOperand,
},
split: {
1: kArgTypeNonOperand,
},
})[name]?.[index];

if (key) {
return argType?.[key] ?? kDefaultDictMemberType;
}

return argType ?? kDefaultArgType;
}
}

Expand Down Expand Up @@ -401,7 +493,18 @@ export class NNotepad {
}
throw new Error(`unexpected line type: ${line.type}`);
}
function serializeExpr(expr, argumentType = kArgTypeOperand) {

// Serialize an expression. If `callContext` is provided, it can either be
// an object with `name` and `index` properties which identify a method call
// and argument position, used to determine the argument type, or an
// `kArgTypeXYZ` value to explicitly specify the type. This is needed for
// numbers, arrays, and dictionary members, which are serialized
// contextually.
function serializeExpr(expr, callContext) {
const argumentType = typeof callContext === 'object' ?
WebNNUtil.argumentType(callContext.name, callContext.index) :
typeof callContext === 'number' ? callContext :
kArgTypeOperand;
if (expr.op) {
if (expr.lhs) {
return `_.${kBinaryOperators[expr.op]}(${serializeExpr(expr.lhs)}, ${
Expand Down Expand Up @@ -432,21 +535,25 @@ export class NNotepad {
return serializeTensor(expr.value, expr.dataType);
}
case 'dict':
return serializeDict(expr.dict);
return serializeDict(expr.dict, callContext);
case 'identifier':
return expr.value;
case 'call':
return serializeCall(expr.identifier, expr.args);
}
throw new Error(`unexpected expr type: ${expr.type}`);
}
function serializeDict(dict) {
function serializeDict(dict, callContext) {
return '{' +
Object.keys(dict)
.map((k) => {
const v = dict[k];
k = Util.stringify(k);
return `${k}: ${serializeExpr(v, kArgTypeNonOperand)}`;
const argumentType = typeof callContext === 'object' ?
WebNNUtil.argumentType(
callContext.name, callContext.index, k) :
kArgTypeNonOperand;
return `${Util.stringify(k)}: ${
serializeExpr(v, argumentType)}`;
})
.join(', ') +
'}';
Expand Down Expand Up @@ -545,7 +652,7 @@ export class NNotepad {
return `_.${name}(${
args.map(
(arg, index) =>
serializeExpr(arg, WebNNUtil.argumentType(name, index)))
serializeExpr(arg, {name, index}))
.join(', ')})`;
}
}
Expand Down
8 changes: 8 additions & 0 deletions nnotepad/js/tests.js
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,14 @@ document.addEventListener('DOMContentLoaded', async (e) => {
`softmax([1], 0)`,
{dataType: 'float32', shape: [1], buffer: [1]});

Harness.section('Optional operand arguments');
await test(
'A = [[1,2], [3,4]] B = [[5,6], [7,8]] gemm(A, B, {c: 123})',
{dataType: 'float32', shape: [2, 2], buffer: [142, 145, 166, 173]});
await test(
'instanceNormalization([[[[1]]]], {scale: [123], bias: [456]})',
{dataType: 'float32', shape: [1, 1, 1, 1], buffer: [456]});

Harness.section('Regression tests');
await test(
`concat([[1,2],[3,4]], 0)`,
Expand Down

0 comments on commit fae1916

Please sign in to comment.