Skip to content

Commit

Permalink
Merge pull request #1045 from ml5js/joeyklee.637-mock-functions-in-tests
Browse files Browse the repository at this point in the history
[testing]: mocks objectDetector functions
  • Loading branch information
joeyklee authored Aug 22, 2020
2 parents 4dd7168 + e61631f commit 752b23f
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 33 deletions.
6 changes: 3 additions & 3 deletions src/ObjectDetector/CocoSsd/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ const DEFAULTS = {
modelUrl: undefined,
};

class CocoSsdBase {
export class CocoSsdBase {
/**
* Create CocoSsd model. Works on video and images.
* @param {function} constructorCallback - Optional. A callback function that is called once the model has loaded. If no callback is provided, it will return a promise
Expand Down Expand Up @@ -129,7 +129,7 @@ class CocoSsdBase {
}
}

const CocoSsd = (videoOr, optionsOr, cb) => {
export const CocoSsd = (videoOr, optionsOr, cb) => {
let video = null;
let options = {};
let callback = cb;
Expand All @@ -153,4 +153,4 @@ const CocoSsd = (videoOr, optionsOr, cb) => {
return new CocoSsdBase(video, options, callback);
};

export default CocoSsd;
// export default CocoSsd;
6 changes: 3 additions & 3 deletions src/ObjectDetector/YOLO/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ const DEFAULTS = {
// Size of the video
const imageSize = 416;

class YOLOBase extends Video {
export class YOLOBase extends Video {
/**
* @deprecated Please use ObjectDetector class instead
*/
Expand Down Expand Up @@ -236,7 +236,7 @@ class YOLOBase extends Video {
}
}

const YOLO = (videoOr, optionsOr, cb) => {
export const YOLO = (videoOr, optionsOr, cb) => {
let video = null;
let options = {};
let callback = cb;
Expand All @@ -260,4 +260,4 @@ const YOLO = (videoOr, optionsOr, cb) => {
return new YOLOBase(video, options, callback);
};

export default YOLO;
// export default YOLO;
4 changes: 2 additions & 2 deletions src/ObjectDetector/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
ObjectDetection
*/

import YOLO from "./YOLO/index";
import CocoSsd from "./CocoSsd/index";
import { YOLO } from "./YOLO/index";
import {CocoSsd} from "./CocoSsd/index";

class ObjectDetector {
/**
Expand Down
61 changes: 37 additions & 24 deletions src/ObjectDetector/index_test.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
// This software is released under the MIT License.
// https://opensource.org/licenses/MIT

const { getImageData, getRobin } = ml5.testingUtils;

const COCOSSD_DEFAULTS = {
base: "lite_mobilenet_v2",
modelUrl: undefined,
Expand All @@ -15,30 +17,31 @@ const YOLO_DEFAULTS = {
size: 416,
};

async function getRobin() {
const img = new Image();
img.crossOrigin = "";
img.src = "https://cdn.jsdelivr.net/gh/ml5js/ml5-library@development/assets/bird.jpg";
await new Promise(resolve => {
img.onload = resolve;
});
return img;
}

async function getImageData() {
const arr = new Uint8ClampedArray(20000);
const mockYoloObject = {
IOUThreshold: YOLO_DEFAULTS.IOUThreshold,
classProbThreshold: YOLO_DEFAULTS.classProbThreshold,
filterBoxesThreshold: YOLO_DEFAULTS.filterBoxesThreshold,
size: YOLO_DEFAULTS.size,
detect: () => {
return [{ label: "bird", confidence: 0.9 }];
},
};
const mockCocoObject = {
config: { ...COCOSSD_DEFAULTS },
detect: () => {
return [{ label: "bird", confidence: 0.9 }];
},
};

// Iterate through every pixel
for (let i = 0; i < arr.length; i += 4) {
arr[i + 0] = 0; // R value
arr[i + 1] = 190; // G value
arr[i + 2] = 0; // B value
arr[i + 3] = 255; // A value
function mockObjectDetector(modelName) {
switch (modelName) {
case "yolo":
return mockYoloObject;
case "cocossd":
return mockCocoObject;
default:
return mockCocoObject;
}

// Initialize a new ImageData object
const img = new ImageData(arr, 200);
return img;
}

describe("objectDetector", () => {
Expand All @@ -47,31 +50,37 @@ describe("objectDetector", () => {
/**
* Test cocossd object detector
*/

describe("objectDetector: cocossd", () => {
let cocoDetector;

beforeAll(async () => {
spyOn(ml5, "objectDetector").and.callFake(mockObjectDetector);
cocoDetector = await ml5.objectDetector("cocossd");
});

it("Should instantiate with the following defaults", () => {
expect(cocoDetector.config.base).toBe(COCOSSD_DEFAULTS.base);
expect(cocoDetector.config.modelUrl).toBe(COCOSSD_DEFAULTS.modelUrl);
expect(ml5.objectDetector).toHaveBeenCalled();
expect(cocoDetector.config.toString()).toBe(COCOSSD_DEFAULTS.toString());
});

it("detects a robin", async () => {
spyOn(cocoDetector, "detect").and.returnValue([{ label: "bird", confidence: 0.9 }]);

const robin = await getRobin();
const detection = await cocoDetector.detect(robin);
expect(detection[0].label).toBe("bird");
});

it("detects takes ImageData", async () => {
spyOn(cocoDetector, "detect").and.returnValue([]);
const img = await getImageData();
const detection = await cocoDetector.detect(img);
expect(detection).toEqual([]);
});

it("throws error when a non image is trying to be detected", async () => {
spyOn(cocoDetector, "detect").and.throwError("Detection subject not supported");
const notAnImage = "not_an_image";
try {
await cocoDetector.detect(notAnImage);
Expand All @@ -88,23 +97,27 @@ describe("objectDetector", () => {
describe("objectDetector: yolo", () => {
let yolo;
beforeAll(async () => {
spyOn(ml5, "objectDetector").and.callFake(mockObjectDetector);
yolo = await ml5.objectDetector("yolo", { disableDeprecationNotice: true, ...YOLO_DEFAULTS });
});

it("instantiates the YOLO classifier with defaults", () => {
expect(ml5.objectDetector).toHaveBeenCalled();
expect(yolo.IOUThreshold).toBe(YOLO_DEFAULTS.IOUThreshold);
expect(yolo.classProbThreshold).toBe(YOLO_DEFAULTS.classProbThreshold);
expect(yolo.filterBoxesThreshold).toBe(YOLO_DEFAULTS.filterBoxesThreshold);
expect(yolo.size).toBe(YOLO_DEFAULTS.size);
});

it("detects a robin", async () => {
spyOn(yolo, "detect").and.returnValue([{ label: "bird", confidence: 0.9 }]);
const robin = await getRobin();
const detection = await yolo.detect(robin);
expect(detection[0].label).toBe("bird");
});

it("detects takes ImageData", async () => {
spyOn(yolo, "detect").and.returnValue([]);
const img = await getImageData();
const detection = await yolo.detect(img);
expect(detection).toEqual([]);
Expand Down
6 changes: 5 additions & 1 deletion src/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ import soundClassifier from "./SoundClassifier/";
import KNNClassifier from "./KNNClassifier/";
import featureExtractor from "./FeatureExtractor/";
import word2vec from "./Word2vec/";
import YOLO from "./ObjectDetector/YOLO";
import {YOLO} from "./ObjectDetector/YOLO";
import {CocoSsd} from "./ObjectDetector/CocoSsd";
import objectDetector from "./ObjectDetector";
import poseNet from "./PoseNet";
import * as imageUtils from "./utils/imageUtilities";
Expand All @@ -35,6 +36,7 @@ import facemesh from "./Facemesh";
import handpose from './Handpose';
import p5Utils from "./utils/p5Utils";
import communityStatement from "./utils/community";
import * as testingUtils from "./utils/testingUtils";

const withPreload = {
charRNN,
Expand All @@ -52,6 +54,7 @@ const withPreload = {
styleTransfer,
word2vec,
YOLO,
CocoSsd,
objectDetector,
uNet,
sentiment,
Expand All @@ -75,4 +78,5 @@ module.exports = Object.assign({ p5Utils }, preloadRegister(withPreload), {
tfvis,
version,
neuralNetwork,
testingUtils
});
26 changes: 26 additions & 0 deletions src/utils/testingUtils/index.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@

export const getRobin = async () => {
const img = new Image();
img.crossOrigin = "";
img.src = "https://cdn.jsdelivr.net/gh/ml5js/ml5-library@development/assets/bird.jpg";
await new Promise(resolve => {
img.onload = resolve;
});
return img;
}

export const getImageData = async () => {
const arr = new Uint8ClampedArray(20000);

// Iterate through every pixel
for (let i = 0; i < arr.length; i += 4) {
arr[i + 0] = 0; // R value
arr[i + 1] = 190; // G value
arr[i + 2] = 0; // B value
arr[i + 3] = 255; // A value
}

// Initialize a new ImageData object
const img = new ImageData(arr, 200);
return img;
}

0 comments on commit 752b23f

Please sign in to comment.