Skip to content

Commit fae1916

Browse files
Merge pull request #302 from inexorabletash/nnotepad-dicttypes
NNotepad: Infer argument type for dictionary members
2 parents 5db65aa + f128575 commit fae1916

File tree

3 files changed

+141
-26
lines changed

3 files changed

+141
-26
lines changed

nnotepad/README.md

+4-4
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,10 @@ Functions and operators are turned into [`MLGraphBuilder`](https://webmachinelea
3636

3737
Array literals (`[...]`) and number literals (`12.34`) are interpreted contextually:
3838

39-
* In assignments, they are intepreted as tensor/scalar constant [`MLOperand`](https://webmachinelearning.github.io/webnn/#mloperand)s, e.g. `alpha = 12.34` or `T = [1,2,3,4]`.
40-
* In most function calls, they are interpreted as tensor/scalar constant [`MLOperand`](https://webmachinelearning.github.io/webnn/#mloperand)s, e.g. `neg(123)` or `neg([1,2,3])`.
41-
* In some function calls, they are interpreted as arrays/numbers for some positional parameters, e.g. `concat([A,B,C],0)`. This includes: [`concat()`](https://webmachinelearning.github.io/webnn/#dom-mlgraphbuilder-concat), [`expand()`](https://webmachinelearning.github.io/webnn/#dom-mlgraphbuilder-expand), [`pad()`](https://webmachinelearning.github.io/webnn/#dom-mlgraphbuilder-pad), [`reshape()`](https://webmachinelearning.github.io/webnn/#dom-mlgraphbuilder-reshape), [`slice()`](https://webmachinelearning.github.io/webnn/#dom-mlgraphbuilder-slice), [`split()`](https://webmachinelearning.github.io/webnn/#dom-mlgraphbuilder-split).
42-
* In dictionaries, they are interpreted as arrays/numbers, e.g. `linear(123, {alpha: 456, beta: 789})` or `transpose(T, {permutation: [0,2,1]})`. To pass a tensor/scalar constant in a dictionary, use a variable or wrap it in [`identity()`](https://webmachinelearning.github.io/webnn/#dom-mlgraphbuilder-identity) e.g. `gemm(A, B, {c:identity([4])})` or `gemm(A, B, {c:identity(4)})`.
39+
* In assignments, they are intepreted as tensor/scalar constants [`MLOperand`](https://webmachinelearning.github.io/webnn/#mloperand)s, e.g. `alpha = 12.34` (scalar) or `T = [1,2,3,4]` (tensor).
40+
* As arguments in function calls, they are interpreted depending on the argument definition, e.g. `neg(123)` (scalar), `neg([1,2,3])` (tensor), `concat([A,B,C],0)` (number).
41+
* In options dictionaries inside function calls, they are interpreted depending on the dictionary definition. e.g. `linear(123, {alpha: 456, beta: 789})` (numbers), `transpose(T, {permutation: [0,2,1]})` (array of numbers), `gemm(A, B, {c: 123})` (scalar), `gemm(A, B, {c: [123]})` (tensor).
42+
* In dictionaries outside of function calls, they are interpreted as arrays/numbers, e.g. `options = {alpha: 456, beta: 789})`. To pass a tensor/scalar constant in a dictionary, use a variable or wrap it in [`identity()`](https://webmachinelearning.github.io/webnn/#dom-mlgraphbuilder-identity) e.g. `options = {c:identity(4)} gemm(A, B, options)`.
4343

4444
The default [data type](https://webmachinelearning.github.io/webnn/#enumdef-mloperanddatatype) for scalars and tensors is [`float32`](https://webmachinelearning.github.io/webnn/#dom-mloperanddatatype-float32). To specify a different data type, suffix with one of `i8`, `u8`, `i32`, `u32`, `i64`, `u64`, `f16`, `f32`, e.g. `123i8` or `[1,2,3]u32`.
4545

nnotepad/js/nnotepad.js

+129-22
Original file line numberDiff line numberDiff line change
@@ -82,22 +82,114 @@ class WebNNUtil {
8282
throw new Error(`Unsupported dataType ${type}`);
8383
}
8484

85-
static argumentType(name, index) {
86-
return ({
87-
concat: {0: kArgTypeOperandList, 1: kArgTypeNonOperand},
88-
expand: {1: kArgTypeNonOperand},
89-
gru: {3: kArgTypeNonOperand, 4: kArgTypeNonOperand},
90-
gruCell: {4: kArgTypeNonOperand},
91-
lstm: {3: kArgTypeNonOperand, 4: kArgTypeNonOperand},
92-
lstmCell: {5: kArgTypeNonOperand},
93-
pad: {1: kArgTypeNonOperand, 2: kArgTypeNonOperand},
94-
reshape: {1: kArgTypeNonOperand},
95-
slice: {1: kArgTypeNonOperand, 2: kArgTypeNonOperand},
96-
softmax: {1: kArgTypeNonOperand},
97-
split: {1: kArgTypeNonOperand},
98-
})[name]
99-
?.[index] ||
100-
kArgTypeOperand;
85+
// Called to determine the type of an argument. `name` is the name of the
86+
// `MLGraphBuilder` method. `index` is the argument index. If `key` is
87+
// provided, this is serializing a member of an options dictionary. Returns
88+
// one of the `kArgTypeXYZ` values.
89+
static argumentType(name, index, key) {
90+
const kDefaultDictMemberType = kArgTypeNonOperand;
91+
const kDefaultArgType = kArgTypeOperand;
92+
93+
// TODO: Auto-generate this from the WebIDL API definition.
94+
const argType = ({
95+
batchNormalization: {
96+
3: {
97+
scale: kArgTypeOperand,
98+
bias: kArgTypeOperand,
99+
},
100+
},
101+
concat: {
102+
0: kArgTypeOperandList,
103+
1: kArgTypeNonOperand},
104+
conv2d: {
105+
2: {
106+
bias: kArgTypeOperand,
107+
},
108+
},
109+
convTranspose2d: {
110+
2: {
111+
bias: kArgTypeOperand,
112+
},
113+
},
114+
expand: {
115+
1: kArgTypeNonOperand,
116+
},
117+
gemm: {
118+
2: {
119+
c: kArgTypeOperand,
120+
},
121+
},
122+
gru: {
123+
3: kArgTypeNonOperand,
124+
4: kArgTypeNonOperand,
125+
5: {
126+
bias: kArgTypeOperand,
127+
recurrentBias: kArgTypeOperand,
128+
initialHiddenState: kArgTypeOperand,
129+
},
130+
},
131+
gruCell: {
132+
4: kArgTypeNonOperand,
133+
5: {
134+
bias: kArgTypeOperand,
135+
recurrentBias: kArgTypeOperand,
136+
},
137+
},
138+
instanceNormalization: {
139+
1: {
140+
scale: kArgTypeOperand,
141+
bias: kArgTypeOperand,
142+
},
143+
},
144+
layerNormalization: {
145+
1: {
146+
scale: kArgTypeOperand,
147+
bias: kArgTypeOperand,
148+
},
149+
},
150+
lstm: {
151+
3: kArgTypeNonOperand,
152+
4: kArgTypeNonOperand,
153+
5: {
154+
bias: kArgTypeOperand,
155+
recurrentBias: kArgTypeOperand,
156+
peepholeWeight: kArgTypeOperand,
157+
initialHiddenState: kArgTypeOperand,
158+
initialCellState: kArgTypeOperand,
159+
},
160+
},
161+
lstmCell: {
162+
5: kArgTypeNonOperand,
163+
6: {
164+
bias: kArgTypeOperand,
165+
recurrentBias: kArgTypeOperand,
166+
peepholeWeight: kArgTypeOperand,
167+
},
168+
},
169+
pad: {
170+
1: kArgTypeNonOperand,
171+
2: kArgTypeNonOperand,
172+
},
173+
reshape: {
174+
1: kArgTypeNonOperand,
175+
},
176+
slice: {
177+
1: kArgTypeNonOperand,
178+
2: kArgTypeNonOperand,
179+
},
180+
softmax: {
181+
1: kArgTypeNonOperand,
182+
},
183+
split: {
184+
1: kArgTypeNonOperand,
185+
},
186+
})[name]?.[index];
187+
188+
if (key) {
189+
return argType?.[key] ?? kDefaultDictMemberType;
190+
}
191+
192+
return argType ?? kDefaultArgType;
101193
}
102194
}
103195

@@ -401,7 +493,18 @@ export class NNotepad {
401493
}
402494
throw new Error(`unexpected line type: ${line.type}`);
403495
}
404-
function serializeExpr(expr, argumentType = kArgTypeOperand) {
496+
497+
// Serialize an expression. If `callContext` is provided, it can either be
498+
// an object with `name` and `index` properties which identify a method call
499+
// and argument position, used to determine the argument type, or an
500+
// `kArgTypeXYZ` value to explicitly specify the type. This is needed for
501+
// numbers, arrays, and dictionary members, which are serialized
502+
// contextually.
503+
function serializeExpr(expr, callContext) {
504+
const argumentType = typeof callContext === 'object' ?
505+
WebNNUtil.argumentType(callContext.name, callContext.index) :
506+
typeof callContext === 'number' ? callContext :
507+
kArgTypeOperand;
405508
if (expr.op) {
406509
if (expr.lhs) {
407510
return `_.${kBinaryOperators[expr.op]}(${serializeExpr(expr.lhs)}, ${
@@ -432,21 +535,25 @@ export class NNotepad {
432535
return serializeTensor(expr.value, expr.dataType);
433536
}
434537
case 'dict':
435-
return serializeDict(expr.dict);
538+
return serializeDict(expr.dict, callContext);
436539
case 'identifier':
437540
return expr.value;
438541
case 'call':
439542
return serializeCall(expr.identifier, expr.args);
440543
}
441544
throw new Error(`unexpected expr type: ${expr.type}`);
442545
}
443-
function serializeDict(dict) {
546+
function serializeDict(dict, callContext) {
444547
return '{' +
445548
Object.keys(dict)
446549
.map((k) => {
447550
const v = dict[k];
448-
k = Util.stringify(k);
449-
return `${k}: ${serializeExpr(v, kArgTypeNonOperand)}`;
551+
const argumentType = typeof callContext === 'object' ?
552+
WebNNUtil.argumentType(
553+
callContext.name, callContext.index, k) :
554+
kArgTypeNonOperand;
555+
return `${Util.stringify(k)}: ${
556+
serializeExpr(v, argumentType)}`;
450557
})
451558
.join(', ') +
452559
'}';
@@ -545,7 +652,7 @@ export class NNotepad {
545652
return `_.${name}(${
546653
args.map(
547654
(arg, index) =>
548-
serializeExpr(arg, WebNNUtil.argumentType(name, index)))
655+
serializeExpr(arg, {name, index}))
549656
.join(', ')})`;
550657
}
551658
}

nnotepad/js/tests.js

+8
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,14 @@ document.addEventListener('DOMContentLoaded', async (e) => {
178178
`softmax([1], 0)`,
179179
{dataType: 'float32', shape: [1], buffer: [1]});
180180

181+
Harness.section('Optional operand arguments');
182+
await test(
183+
'A = [[1,2], [3,4]] B = [[5,6], [7,8]] gemm(A, B, {c: 123})',
184+
{dataType: 'float32', shape: [2, 2], buffer: [142, 145, 166, 173]});
185+
await test(
186+
'instanceNormalization([[[[1]]]], {scale: [123], bias: [456]})',
187+
{dataType: 'float32', shape: [1, 1, 1, 1], buffer: [456]});
188+
181189
Harness.section('Regression tests');
182190
await test(
183191
`concat([[1,2],[3,4]], 0)`,

0 commit comments

Comments
 (0)