From 9c44ab88428731676b612c1a23f0e5875cc2255b Mon Sep 17 00:00:00 2001 From: Cristobal Valenzuela Date: Sat, 2 Jun 2018 20:06:50 -0400 Subject: [PATCH] fix diff between image and video in ImageClassifier --- src/ImageAndVideo.js | 2 ++ src/ImageClassifier/index.js | 15 ++++++--------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/src/ImageAndVideo.js b/src/ImageAndVideo.js index 5f99e1957..31a636dc3 100644 --- a/src/ImageAndVideo.js +++ b/src/ImageAndVideo.js @@ -23,6 +23,8 @@ class ImageAndVideo { } else if (typeof video === 'object' && video.elt instanceof HTMLVideoElement) { // Handle p5.js video element this.video = processVideo(video.elt, this.imageSize, this.onVideoReady); + } else { + this.videoReady = true; } } } diff --git a/src/ImageClassifier/index.js b/src/ImageClassifier/index.js index 8a099968c..04a2164ad 100644 --- a/src/ImageClassifier/index.js +++ b/src/ImageClassifier/index.js @@ -65,7 +65,7 @@ class ImageClassifier extends ImageAndVideo { this.mobilenet = await tf.loadModel(this.modelPath); const layer = this.mobilenet.getLayer('conv_pw_13_relu'); - if (this.videoReady) { + if (this.videoReady && this.video) { tf.tidy(() => this.mobilenet.predict(imgToTensor(this.video))); // Warm up } @@ -169,15 +169,13 @@ class ImageClassifier extends ImageAndVideo { /* eslint consistent-return: 0 */ async predict(inputNumOrCallback, numOrCallback = null, cb = null) { - let imgToPredict; + let imgToPredict = this.video; let numberOfClasses = 10; let callback; if (typeof inputNumOrCallback === 'function') { - if (this.video) { - imgToPredict = this.video; - callback = inputNumOrCallback; - } + imgToPredict = this.video; + callback = inputNumOrCallback; } else if (inputNumOrCallback instanceof HTMLImageElement) { imgToPredict = inputNumOrCallback; } else if (inputNumOrCallback instanceof HTMLVideoElement) { @@ -186,9 +184,8 @@ class ImageClassifier extends ImageAndVideo { } imgToPredict = this.video; } else if (typeof numOrCallback === 'number') { - if (this.video) { - numberOfClasses = inputNumOrCallback; - } + imgToPredict = this.video; + numberOfClasses = inputNumOrCallback; } if (typeof numOrCallback === 'function') {