diff --git a/src/BodyPix/index.js b/src/BodyPix/index.js index 28fef3997..091585893 100644 --- a/src/BodyPix/index.js +++ b/src/BodyPix/index.js @@ -19,6 +19,7 @@ import generatedImageResult from '../utils/generatedImageResult'; import handleArguments from '../utils/handleArguments'; import p5Utils from '../utils/p5Utils'; import BODYPIX_PALETTE from './BODYPIX_PALETTE'; +import { mediaReady } from '../utils/imageUtilities'; /** * @typedef {Record} BodyPixPalette @@ -135,13 +136,7 @@ class BodyPix { async segmentWithPartsInternal(imgToSegment, segmentationOptions) { // estimatePartSegmentation await this.ready; - await tf.nextFrame(); - - if (this.video && this.video.readyState === 0) { - await new Promise(resolve => { - this.video.onloadeddata = () => resolve(); - }); - } + await mediaReady(imgToSegment, true); this.config.palette = segmentationOptions.palette || this.config.palette; this.config.outputStride = segmentationOptions.outputStride || this.config.outputStride; @@ -253,13 +248,7 @@ class BodyPix { async segmentInternal(imgToSegment, segmentationOptions) { await this.ready; - await tf.nextFrame(); - - if (this.video && this.video.readyState === 0) { - await new Promise(resolve => { - this.video.onloadeddata = () => resolve(); - }); - } + await mediaReady(imgToSegment, true); this.config.outputStride = segmentationOptions.outputStride || this.config.outputStride; this.config.segmentationThreshold = segmentationOptions.segmentationThreshold || this.config.segmentationThreshold; diff --git a/src/FaceApi/index.js b/src/FaceApi/index.js index 2f1e67711..f6a69e373 100644 --- a/src/FaceApi/index.js +++ b/src/FaceApi/index.js @@ -12,10 +12,10 @@ * Ported and integrated from all the hard work by: https://github.com/justadudewhohacks/face-api.js?files=1 */ -import * as tf from "@tensorflow/tfjs"; import * as faceapi from "face-api.js"; import callCallback from "../utils/callcallback"; import handleArguments from "../utils/handleArguments"; +import { mediaReady } from "../utils/imageUtilities"; import { getModelPath } from "../utils/modelLoader"; const DEFAULTS = { @@ -158,13 +158,7 @@ class FaceApiBase { */ async detectInternal(imgToClassify, faceApiOptions) { await this.ready; - await tf.nextFrame(); - - if (this.video && this.video.readyState === 0) { - await new Promise(resolve => { - this.video.onloadeddata = () => resolve(); - }); - } + await mediaReady(imgToClassify, true); // sets the return options if any are passed in during .detect() or .detectSingle() this.config = this.setReturnOptions(faceApiOptions); @@ -223,13 +217,7 @@ class FaceApiBase { */ async detectSingleInternal(imgToClassify, faceApiOptions) { await this.ready; - await tf.nextFrame(); - - if (this.video && this.video.readyState === 0) { - await new Promise(resolve => { - this.video.onloadeddata = () => resolve(); - }); - } + await mediaReady(imgToClassify, true); // sets the return options if any are passed in during .detect() or .detectSingle() this.config = this.setReturnOptions(faceApiOptions); diff --git a/src/Facemesh/index.js b/src/Facemesh/index.js index ae27a3803..bf7db5538 100644 --- a/src/Facemesh/index.js +++ b/src/Facemesh/index.js @@ -13,6 +13,7 @@ import * as facemeshCore from "@tensorflow-models/facemesh"; import { EventEmitter } from "events"; import callCallback from "../utils/callcallback"; import handleArguments from "../utils/handleArguments"; +import { mediaReady } from '../utils/imageUtilities'; class Facemesh extends EventEmitter { /** @@ -43,13 +44,6 @@ class Facemesh extends EventEmitter { this.model = await facemeshCore.load(this.config); this.modelReady = true; - if (this.video && this.video.readyState === 0) { - await new Promise(resolve => { - this.video.onloadeddata = () => { - resolve(); - }; - }); - } if (this.video) { this.predict(); } @@ -65,6 +59,7 @@ class Facemesh extends EventEmitter { if (!image) { throw new Error("No input image found."); } + await mediaReady(image, false); const { flipHorizontal } = this.config; const predictions = await this.model.estimateFaces(image, flipHorizontal); const result = predictions; diff --git a/src/Handpose/index.js b/src/Handpose/index.js index 6af5c0bf1..b52c3f92d 100644 --- a/src/Handpose/index.js +++ b/src/Handpose/index.js @@ -13,6 +13,7 @@ import * as handposeCore from "@tensorflow-models/handpose"; import { EventEmitter } from "events"; import callCallback from "../utils/callcallback"; import handleArguments from "../utils/handleArguments"; +import { mediaReady } from '../utils/imageUtilities'; class Handpose extends EventEmitter { /** @@ -43,14 +44,6 @@ class Handpose extends EventEmitter { this.model = await handposeCore.load(this.config); this.modelReady = true; - if (this.video && this.video.readyState === 0) { - await new Promise(resolve => { - this.video.onloadeddata = () => { - resolve(); - }; - }); - } - if (this.video) { this.predict(); } @@ -66,6 +59,7 @@ class Handpose extends EventEmitter { if (!image) { throw new Error("No input image found."); } + await mediaReady(image, false); const { flipHorizontal } = this.config; const predictions = await this.model.estimateHands(image, flipHorizontal); const result = predictions; diff --git a/src/ImageClassifier/index.js b/src/ImageClassifier/index.js index bbbd17a6b..7dca25298 100644 --- a/src/ImageClassifier/index.js +++ b/src/ImageClassifier/index.js @@ -15,7 +15,7 @@ import handleArguments from "../utils/handleArguments"; import * as darknet from "./darknet"; import * as doodlenet from "./doodlenet"; import callCallback from "../utils/callcallback"; -import { imgToTensor } from "../utils/imageUtilities"; +import { imgToTensor, mediaReady } from "../utils/imageUtilities"; const DEFAULTS = { mobilenet: { @@ -134,21 +134,7 @@ class ImageClassifier { async classifyInternal(imgToPredict, numberOfClasses) { // Wait for the model to be ready await this.ready; - await tf.nextFrame(); - - if (imgToPredict instanceof HTMLVideoElement && imgToPredict.readyState === 0) { - const video = imgToPredict; - // Wait for the video to be ready - await new Promise(resolve => { - video.onloadeddata = () => resolve(); - }); - } - - if (this.video && this.video.readyState === 0) { - await new Promise(resolve => { - this.video.onloadeddata = () => resolve(); - }); - } + await mediaReady(imgToPredict, true); // Process the images const imageResize = [IMAGE_SIZE, IMAGE_SIZE]; diff --git a/src/UNET/index.js b/src/UNET/index.js index 484f5100a..4cb392ec1 100644 --- a/src/UNET/index.js +++ b/src/UNET/index.js @@ -11,6 +11,7 @@ import * as tf from '@tensorflow/tfjs'; import callCallback from '../utils/callcallback'; import generatedImageResult from '../utils/generatedImageResult'; import handleArguments from "../utils/handleArguments"; +import { mediaReady } from '../utils/imageUtilities'; const DEFAULTS = { modelPath: 'https://raw.githubusercontent.com/zaidalyafeai/HostedModels/master/unet-128/model.json', @@ -33,8 +34,8 @@ class UNET { modelPath: typeof options.modelPath !== 'undefined' ? options.modelPath : DEFAULTS.modelPath, imageSize: typeof options.imageSize !== 'undefined' ? options.imageSize : DEFAULTS.imageSize, returnTensors: typeof options.returnTensors !== 'undefined' ? options.returnTensors : DEFAULTS.returnTensors, - }; + this.video = video; this.ready = callCallback(this.loadModel(), callback); } @@ -46,17 +47,13 @@ class UNET { async segment(inputOrCallback, cb) { const { image, callback } = handleArguments(this.video, inputOrCallback, cb); - await this.ready; return callCallback(this.segmentInternal(image), callback); } async segmentInternal(imgToPredict) { - // Wait for the model to be ready + // Wait for the model to be ready and video input to be loaded await this.ready; - // skip asking for next frame if it's not video - if (imgToPredict instanceof HTMLVideoElement) { - await tf.nextFrame(); - } + await mediaReady(imgToPredict, true); this.isPredicting = true; const { diff --git a/src/utils/handleArguments.js b/src/utils/handleArguments.js index 412eaa937..b0e417a78 100644 --- a/src/utils/handleArguments.js +++ b/src/utils/handleArguments.js @@ -27,6 +27,16 @@ export const isVideo = (img) => { img instanceof HTMLVideoElement; } +/** + * Check if a variable is an HTMLAudioElement. + * @param {any} img + * @returns {img is HTMLAudioElement} + */ +export const isAudio = (img) => { + return typeof (HTMLAudioElement) !== 'undefined' && + img instanceof HTMLAudioElement; +} + /** * Check if a variable is an HTMLCanvasElement. * @param {any} img @@ -203,7 +213,10 @@ class ArgHelper { }); } } - // TODO: handle audio elements and p5.sound + // TODO: handle p5.sound + if (isAudio(arg)) { + this.set({ audio: arg }); + } // Check for arrays else if (Array.isArray(arg)) { this.set({ array: arg }); diff --git a/src/utils/imageUtilities.js b/src/utils/imageUtilities.js index fa1dfa555..a7a26fe87 100644 --- a/src/utils/imageUtilities.js +++ b/src/utils/imageUtilities.js @@ -4,7 +4,16 @@ // https://opensource.org/licenses/MIT import * as tf from '@tensorflow/tfjs'; -import { getImageElement, isCanvas, isImageData, isImageElement, isP5Image } from "./handleArguments"; +import { + getImageElement, + isAudio, + isCanvas, + isImageData, + isImageElement, + isImg, + isP5Image, + isVideo +} from "./handleArguments"; import p5Utils from './p5Utils'; // Resize video elements @@ -162,6 +171,36 @@ function imgToPixelArray(img) { return Array.from(imgData.data); } +/** + * Extract common logic from models accepting video input. + * Makes sure that the video/audio/image data has loaded. + * Optionally can wait for the next frame every time the function is called. + * Will resolve immediately if the input is undefined or a different element type. + * @param {InputImage | undefined} input + * @param {boolean} nextFrame + * @returns {Promise} + */ +async function mediaReady(input, nextFrame) { + if (input && (isVideo(input) || isAudio(input))) { + if (nextFrame) { + await tf.nextFrame(); + } + if (input.readyState === 0) { + await new Promise((resolve, reject) => { + input.addEventListener('error', () => reject(input.error)); + input.addEventListener('loadeddata', resolve); + }); + } + } else if (input && isImg(input)) { + if (!input.complete) { + await new Promise((resolve, reject) => { + input.addEventListener('error', reject); + input.addEventListener('load', resolve); + }); + } + } +} + export { array3DToImage, processVideo, @@ -169,5 +208,6 @@ export { imgToTensor, isInstanceOfSupportedElement, flipImage, - imgToPixelArray + imgToPixelArray, + mediaReady };