Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NNotepad: Infer argument type for dictionary members #302

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions nnotepad/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ Functions and operators are turned into [`MLGraphBuilder`](https://webmachinelea

Array literals (`[...]`) and number literals (`12.34`) are interpreted contextually:

* In assignments, they are intepreted as tensor/scalar constant [`MLOperand`](https://webmachinelearning.github.io/webnn/#mloperand)s, e.g. `alpha = 12.34` or `T = [1,2,3,4]`.
* In most function calls, they are interpreted as tensor/scalar constant [`MLOperand`](https://webmachinelearning.github.io/webnn/#mloperand)s, e.g. `neg(123)` or `neg([1,2,3])`.
* In some function calls, they are interpreted as arrays/numbers for some positional parameters, e.g. `concat([A,B,C],0)`. This includes: [`concat()`](https://webmachinelearning.github.io/webnn/#dom-mlgraphbuilder-concat), [`expand()`](https://webmachinelearning.github.io/webnn/#dom-mlgraphbuilder-expand), [`pad()`](https://webmachinelearning.github.io/webnn/#dom-mlgraphbuilder-pad), [`reshape()`](https://webmachinelearning.github.io/webnn/#dom-mlgraphbuilder-reshape), [`slice()`](https://webmachinelearning.github.io/webnn/#dom-mlgraphbuilder-slice), [`split()`](https://webmachinelearning.github.io/webnn/#dom-mlgraphbuilder-split).
* In dictionaries, they are interpreted as arrays/numbers, e.g. `linear(123, {alpha: 456, beta: 789})` or `transpose(T, {permutation: [0,2,1]})`. To pass a tensor/scalar constant in a dictionary, use a variable or wrap it in [`identity()`](https://webmachinelearning.github.io/webnn/#dom-mlgraphbuilder-identity) e.g. `gemm(A, B, {c:identity([4])})` or `gemm(A, B, {c:identity(4)})`.
* In assignments, they are intepreted as tensor/scalar constants [`MLOperand`](https://webmachinelearning.github.io/webnn/#mloperand)s, e.g. `alpha = 12.34` (scalar) or `T = [1,2,3,4]` (tensor).
* As arguments in function calls, they are interpreted depending on the argument definition, e.g. `neg(123)` (scalar), `neg([1,2,3])` (tensor), `concat([A,B,C],0)` (number).
* In options dictionaries inside function calls, they are interpreted depending on the dictionary definition. e.g. `linear(123, {alpha: 456, beta: 789})` (numbers), `transpose(T, {permutation: [0,2,1]})` (array of numbers), `gemm(A, B, {c: 123})` (scalar), `gemm(A, B, {c: [123]})` (tensor).
* In dictionaries outside of function calls, they are interpreted as arrays/numbers, e.g. `options = {alpha: 456, beta: 789})`. To pass a tensor/scalar constant in a dictionary, use a variable or wrap it in [`identity()`](https://webmachinelearning.github.io/webnn/#dom-mlgraphbuilder-identity) e.g. `options = {c:identity(4)} gemm(A, B, options)`.

The default [data type](https://webmachinelearning.github.io/webnn/#enumdef-mloperanddatatype) for scalars and tensors is [`float32`](https://webmachinelearning.github.io/webnn/#dom-mloperanddatatype-float32). To specify a different data type, suffix with one of `i8`, `u8`, `i32`, `u32`, `i64`, `u64`, `f16`, `f32`, e.g. `123i8` or `[1,2,3]u32`.

Expand Down
151 changes: 129 additions & 22 deletions nnotepad/js/nnotepad.js
Original file line number Diff line number Diff line change
Expand Up @@ -82,22 +82,114 @@ class WebNNUtil {
throw new Error(`Unsupported dataType ${type}`);
}

static argumentType(name, index) {
return ({
concat: {0: kArgTypeOperandList, 1: kArgTypeNonOperand},
expand: {1: kArgTypeNonOperand},
gru: {3: kArgTypeNonOperand, 4: kArgTypeNonOperand},
gruCell: {4: kArgTypeNonOperand},
lstm: {3: kArgTypeNonOperand, 4: kArgTypeNonOperand},
lstmCell: {5: kArgTypeNonOperand},
pad: {1: kArgTypeNonOperand, 2: kArgTypeNonOperand},
reshape: {1: kArgTypeNonOperand},
slice: {1: kArgTypeNonOperand, 2: kArgTypeNonOperand},
softmax: {1: kArgTypeNonOperand},
split: {1: kArgTypeNonOperand},
})[name]
?.[index] ||
kArgTypeOperand;
// Called to determine the type of an argument. `name` is the name of the
// `MLGraphBuilder` method. `index` is the argument index. If `key` is
// provided, this is serializing a member of an options dictionary. Returns
// one of the `kArgTypeXYZ` values.
static argumentType(name, index, key) {
const kDefaultDictMemberType = kArgTypeNonOperand;
const kDefaultArgType = kArgTypeOperand;

// TODO: Auto-generate this from the WebIDL API definition.
const argType = ({
batchNormalization: {
3: {
scale: kArgTypeOperand,
bias: kArgTypeOperand,
},
},
concat: {
0: kArgTypeOperandList,
1: kArgTypeNonOperand},
conv2d: {
2: {
bias: kArgTypeOperand,
},
},
convTranspose2d: {
2: {
bias: kArgTypeOperand,
},
},
expand: {
1: kArgTypeNonOperand,
},
gemm: {
2: {
c: kArgTypeOperand,
},
},
gru: {
3: kArgTypeNonOperand,
4: kArgTypeNonOperand,
5: {
bias: kArgTypeOperand,
recurrentBias: kArgTypeOperand,
initialHiddenState: kArgTypeOperand,
},
},
gruCell: {
4: kArgTypeNonOperand,
5: {
bias: kArgTypeOperand,
recurrentBias: kArgTypeOperand,
},
},
instanceNormalization: {
1: {
scale: kArgTypeOperand,
bias: kArgTypeOperand,
},
},
layerNormalization: {
1: {
scale: kArgTypeOperand,
bias: kArgTypeOperand,
},
},
lstm: {
3: kArgTypeNonOperand,
4: kArgTypeNonOperand,
5: {
bias: kArgTypeOperand,
recurrentBias: kArgTypeOperand,
peepholeWeight: kArgTypeOperand,
initialHiddenState: kArgTypeOperand,
initialCellState: kArgTypeOperand,
},
},
lstmCell: {
5: kArgTypeNonOperand,
6: {
bias: kArgTypeOperand,
recurrentBias: kArgTypeOperand,
peepholeWeight: kArgTypeOperand,
},
},
pad: {
1: kArgTypeNonOperand,
2: kArgTypeNonOperand,
},
reshape: {
1: kArgTypeNonOperand,
},
slice: {
1: kArgTypeNonOperand,
2: kArgTypeNonOperand,
},
softmax: {
1: kArgTypeNonOperand,
},
split: {
1: kArgTypeNonOperand,
},
})[name]?.[index];

if (key) {
return argType?.[key] ?? kDefaultDictMemberType;
}

return argType ?? kDefaultArgType;
}
}

Expand Down Expand Up @@ -401,7 +493,18 @@ export class NNotepad {
}
throw new Error(`unexpected line type: ${line.type}`);
}
function serializeExpr(expr, argumentType = kArgTypeOperand) {

// Serialize an expression. If `callContext` is provided, it can either be
// an object with `name` and `index` properties which identify a method call
// and argument position, used to determine the argument type, or an
// `kArgTypeXYZ` value to explicitly specify the type. This is needed for
// numbers, arrays, and dictionary members, which are serialized
// contextually.
function serializeExpr(expr, callContext) {
const argumentType = typeof callContext === 'object' ?
WebNNUtil.argumentType(callContext.name, callContext.index) :
typeof callContext === 'number' ? callContext :
kArgTypeOperand;
if (expr.op) {
if (expr.lhs) {
return `_.${kBinaryOperators[expr.op]}(${serializeExpr(expr.lhs)}, ${
Expand Down Expand Up @@ -432,21 +535,25 @@ export class NNotepad {
return serializeTensor(expr.value, expr.dataType);
}
case 'dict':
return serializeDict(expr.dict);
return serializeDict(expr.dict, callContext);
case 'identifier':
return expr.value;
case 'call':
return serializeCall(expr.identifier, expr.args);
}
throw new Error(`unexpected expr type: ${expr.type}`);
}
function serializeDict(dict) {
function serializeDict(dict, callContext) {
return '{' +
Object.keys(dict)
.map((k) => {
const v = dict[k];
k = Util.stringify(k);
return `${k}: ${serializeExpr(v, kArgTypeNonOperand)}`;
const argumentType = typeof callContext === 'object' ?
WebNNUtil.argumentType(
callContext.name, callContext.index, k) :
kArgTypeNonOperand;
return `${Util.stringify(k)}: ${
serializeExpr(v, argumentType)}`;
})
.join(', ') +
'}';
Expand Down Expand Up @@ -545,7 +652,7 @@ export class NNotepad {
return `_.${name}(${
args.map(
(arg, index) =>
serializeExpr(arg, WebNNUtil.argumentType(name, index)))
serializeExpr(arg, {name, index}))
.join(', ')})`;
}
}
Expand Down
8 changes: 8 additions & 0 deletions nnotepad/js/tests.js
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,14 @@ document.addEventListener('DOMContentLoaded', async (e) => {
`softmax([1], 0)`,
{dataType: 'float32', shape: [1], buffer: [1]});

Harness.section('Optional operand arguments');
await test(
'A = [[1,2], [3,4]] B = [[5,6], [7,8]] gemm(A, B, {c: 123})',
{dataType: 'float32', shape: [2, 2], buffer: [142, 145, 166, 173]});
await test(
'instanceNormalization([[[[1]]]], {scale: [123], bias: [456]})',
{dataType: 'float32', shape: [1, 1, 1, 1], buffer: [456]});

Harness.section('Regression tests');
await test(
`concat([[1,2],[3,4]], 0)`,
Expand Down
Loading