@@ -31,6 +31,10 @@ export class ComputeError extends Error {
31
31
// General WebNN Utilities
32
32
// ============================================================
33
33
34
+ const kArgTypeOperandList = 1 ;
35
+ const kArgTypeNonOperand = 2 ;
36
+ const kArgTypeOperand = 3 ;
37
+
34
38
class WebNNUtil {
35
39
static bufferForOperand ( operand ) {
36
40
const size = [ ...operand . shape ( ) ] . reduce ( ( a , b ) => a * b , 1 ) ;
@@ -60,21 +64,22 @@ class WebNNUtil {
60
64
throw new Error ( `Unsupported dataType ${ type } ` ) ;
61
65
}
62
66
63
- static isNonOperandArg ( name , index ) {
67
+ static argumentType ( name , index ) {
64
68
return ( {
65
- concat : [ 0 , 1 ] ,
66
- expand : [ 1 ] ,
67
- gru : [ 3 , 4 ] ,
68
- gruCell : [ 4 ] ,
69
- lstm : [ 3 , 4 ] ,
70
- lstmCell : [ 5 ] ,
71
- pad : [ 1 , 2 ] ,
72
- reshape : [ 1 ] ,
73
- slice : [ 1 , 2 ] ,
74
- softmax : [ 1 ] , // TODO: Distinguish overloads
75
- split : [ 1 ] ,
69
+ concat : { 0 : kArgTypeOperandList , 1 : kArgTypeNonOperand } ,
70
+ expand : { 1 : kArgTypeNonOperand } ,
71
+ gru : { 3 : kArgTypeNonOperand , 4 : kArgTypeNonOperand } ,
72
+ gruCell : { 4 : kArgTypeNonOperand } ,
73
+ lstm : { 3 : kArgTypeNonOperand , 4 : kArgTypeNonOperand } ,
74
+ lstmCell : { 5 : kArgTypeNonOperand } ,
75
+ pad : { 1 : kArgTypeNonOperand , 2 : kArgTypeNonOperand } ,
76
+ reshape : { 1 : kArgTypeNonOperand } ,
77
+ slice : { 1 : kArgTypeNonOperand , 2 : kArgTypeNonOperand } ,
78
+ softmax : { 1 : kArgTypeNonOperand } ,
79
+ split : { 1 : kArgTypeNonOperand } ,
76
80
} ) [ name ]
77
- ?. includes ( index ) ;
81
+ ?. [ index ] ||
82
+ kArgTypeOperand ;
78
83
}
79
84
}
80
85
@@ -379,7 +384,7 @@ export class NNotepad {
379
384
}
380
385
throw new Error ( `unexpected line type: ${ line . type } ` ) ;
381
386
}
382
- function serializeExpr ( expr , nonOperand = false ) {
387
+ function serializeExpr ( expr , argumentType = kArgTypeOperand ) {
383
388
if ( expr . op ) {
384
389
if ( expr . lhs ) {
385
390
return `_.${ kBinaryOperators [ expr . op ] } (${ serializeExpr ( expr . lhs ) } , ${
@@ -394,11 +399,21 @@ export class NNotepad {
394
399
case 'boolean' :
395
400
return String ( expr . value ) ;
396
401
case 'number' :
397
- return nonOperand ? Util . stringify ( expr . value ) :
398
- serializeScalar ( expr . value , expr . dataType ) ;
402
+ switch ( argumentType ) {
403
+ case kArgTypeNonOperand :
404
+ return Util . stringify ( expr . value ) ;
405
+ default :
406
+ return serializeScalar ( expr . value , expr . dataType ) ;
407
+ }
399
408
case 'array' :
400
- return nonOperand ? serializeArray ( expr . value ) :
401
- serializeTensor ( expr . value , expr . dataType ) ;
409
+ switch ( argumentType ) {
410
+ case kArgTypeNonOperand :
411
+ return serializeArray ( expr . value , kArgTypeNonOperand ) ;
412
+ case kArgTypeOperandList :
413
+ return serializeArray ( expr . value , kArgTypeOperand ) ;
414
+ default :
415
+ return serializeTensor ( expr . value , expr . dataType ) ;
416
+ }
402
417
case 'dict' :
403
418
return serializeDict ( expr . dict ) ;
404
419
case 'identifier' :
@@ -414,7 +429,7 @@ export class NNotepad {
414
429
. map ( ( k ) => {
415
430
const v = dict [ k ] ;
416
431
k = Util . stringify ( k ) ;
417
- return `${ k } : ${ serializeExpr ( v , true ) } ` ;
432
+ return `${ k } : ${ serializeExpr ( v , kArgTypeNonOperand ) } ` ;
418
433
} )
419
434
. join ( ', ' ) +
420
435
'}' ;
@@ -465,8 +480,10 @@ export class NNotepad {
465
480
elements . map ( ( n ) => Util . stringifyNumber ( n , dataType ) ) . join ( ',' ) } ]))`;
466
481
}
467
482
468
- function serializeArray ( array ) {
469
- return '[' + array . map ( ( expr ) => serializeExpr ( expr ) ) . join ( ', ' ) + ']' ;
483
+ function serializeArray ( array , argumentType ) {
484
+ return '[' +
485
+ array . map ( ( expr ) => serializeExpr ( expr , argumentType ) ) . join ( ', ' ) +
486
+ ']' ;
470
487
}
471
488
472
489
function serializeCall ( name , args ) {
@@ -506,8 +523,8 @@ export class NNotepad {
506
523
507
524
return `_.${ name } (${
508
525
args . map (
509
- ( arg , index ) => serializeExpr (
510
- arg , WebNNUtil . isNonOperandArg ( name , index ) ) )
526
+ ( arg , index ) =>
527
+ serializeExpr ( arg , WebNNUtil . argumentType ( name , index ) ) )
511
528
. join ( ', ' ) } )`;
512
529
}
513
530
}
0 commit comments