Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[testing]: mocks objectDetector functions #1045

Merged
merged 3 commits into from
Aug 22, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/ObjectDetector/CocoSsd/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ const DEFAULTS = {
modelUrl: undefined,
};

class CocoSsdBase {
export class CocoSsdBase {
/**
* Create CocoSsd model. Works on video and images.
* @param {function} constructorCallback - Optional. A callback function that is called once the model has loaded. If no callback is provided, it will return a promise
Expand Down Expand Up @@ -129,7 +129,7 @@ class CocoSsdBase {
}
}

const CocoSsd = (videoOr, optionsOr, cb) => {
export const CocoSsd = (videoOr, optionsOr, cb) => {
let video = null;
let options = {};
let callback = cb;
Expand All @@ -153,4 +153,4 @@ const CocoSsd = (videoOr, optionsOr, cb) => {
return new CocoSsdBase(video, options, callback);
};

export default CocoSsd;
// export default CocoSsd;
6 changes: 3 additions & 3 deletions src/ObjectDetector/YOLO/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ const DEFAULTS = {
// Size of the video
const imageSize = 416;

class YOLOBase extends Video {
export class YOLOBase extends Video {
/**
* @deprecated Please use ObjectDetector class instead
*/
Expand Down Expand Up @@ -236,7 +236,7 @@ class YOLOBase extends Video {
}
}

const YOLO = (videoOr, optionsOr, cb) => {
export const YOLO = (videoOr, optionsOr, cb) => {
let video = null;
let options = {};
let callback = cb;
Expand All @@ -260,4 +260,4 @@ const YOLO = (videoOr, optionsOr, cb) => {
return new YOLOBase(video, options, callback);
};

export default YOLO;
// export default YOLO;
4 changes: 2 additions & 2 deletions src/ObjectDetector/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
ObjectDetection
*/

import YOLO from "./YOLO/index";
import CocoSsd from "./CocoSsd/index";
import { YOLO } from "./YOLO/index";
import {CocoSsd} from "./CocoSsd/index";

class ObjectDetector {
/**
Expand Down
61 changes: 37 additions & 24 deletions src/ObjectDetector/index_test.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
// This software is released under the MIT License.
// https://opensource.org/licenses/MIT

const { getImageData, getRobin } = ml5.testingUtils;

const COCOSSD_DEFAULTS = {
base: "lite_mobilenet_v2",
modelUrl: undefined,
Expand All @@ -15,30 +17,31 @@ const YOLO_DEFAULTS = {
size: 416,
};

async function getRobin() {
const img = new Image();
img.crossOrigin = "";
img.src = "https://cdn.jsdelivr.net/gh/ml5js/ml5-library@development/assets/bird.jpg";
await new Promise(resolve => {
img.onload = resolve;
});
return img;
}

async function getImageData() {
const arr = new Uint8ClampedArray(20000);
const mockYoloObject = {
IOUThreshold: YOLO_DEFAULTS.IOUThreshold,
classProbThreshold: YOLO_DEFAULTS.classProbThreshold,
filterBoxesThreshold: YOLO_DEFAULTS.filterBoxesThreshold,
size: YOLO_DEFAULTS.size,
detect: () => {
return [{ label: "bird", confidence: 0.9 }];
},
};
const mockCocoObject = {
config: { ...COCOSSD_DEFAULTS },
detect: () => {
return [{ label: "bird", confidence: 0.9 }];
},
};

// Iterate through every pixel
for (let i = 0; i < arr.length; i += 4) {
arr[i + 0] = 0; // R value
arr[i + 1] = 190; // G value
arr[i + 2] = 0; // B value
arr[i + 3] = 255; // A value
function mockObjectDetector(modelName) {
switch (modelName) {
case "yolo":
return mockYoloObject;
case "cocossd":
return mockCocoObject;
default:
return mockCocoObject;
}

// Initialize a new ImageData object
const img = new ImageData(arr, 200);
return img;
}

describe("objectDetector", () => {
Expand All @@ -47,31 +50,37 @@ describe("objectDetector", () => {
/**
* Test cocossd object detector
*/

describe("objectDetector: cocossd", () => {
let cocoDetector;

beforeAll(async () => {
spyOn(ml5, "objectDetector").and.callFake(mockObjectDetector);
cocoDetector = await ml5.objectDetector("cocossd");
});

it("Should instantiate with the following defaults", () => {
expect(cocoDetector.config.base).toBe(COCOSSD_DEFAULTS.base);
expect(cocoDetector.config.modelUrl).toBe(COCOSSD_DEFAULTS.modelUrl);
expect(ml5.objectDetector).toHaveBeenCalled();
expect(cocoDetector.config.toString()).toBe(COCOSSD_DEFAULTS.toString());
});

it("detects a robin", async () => {
spyOn(cocoDetector, "detect").and.returnValue([{ label: "bird", confidence: 0.9 }]);

const robin = await getRobin();
const detection = await cocoDetector.detect(robin);
expect(detection[0].label).toBe("bird");
});

it("detects takes ImageData", async () => {
spyOn(cocoDetector, "detect").and.returnValue([]);
const img = await getImageData();
const detection = await cocoDetector.detect(img);
expect(detection).toEqual([]);
});

it("throws error when a non image is trying to be detected", async () => {
spyOn(cocoDetector, "detect").and.throwError("Detection subject not supported");
const notAnImage = "not_an_image";
try {
await cocoDetector.detect(notAnImage);
Expand All @@ -88,23 +97,27 @@ describe("objectDetector", () => {
describe("objectDetector: yolo", () => {
let yolo;
beforeAll(async () => {
spyOn(ml5, "objectDetector").and.callFake(mockObjectDetector);
yolo = await ml5.objectDetector("yolo", { disableDeprecationNotice: true, ...YOLO_DEFAULTS });
});

it("instantiates the YOLO classifier with defaults", () => {
expect(ml5.objectDetector).toHaveBeenCalled();
expect(yolo.IOUThreshold).toBe(YOLO_DEFAULTS.IOUThreshold);
expect(yolo.classProbThreshold).toBe(YOLO_DEFAULTS.classProbThreshold);
expect(yolo.filterBoxesThreshold).toBe(YOLO_DEFAULTS.filterBoxesThreshold);
expect(yolo.size).toBe(YOLO_DEFAULTS.size);
});

it("detects a robin", async () => {
spyOn(yolo, "detect").and.returnValue([{ label: "bird", confidence: 0.9 }]);
const robin = await getRobin();
const detection = await yolo.detect(robin);
expect(detection[0].label).toBe("bird");
});

it("detects takes ImageData", async () => {
spyOn(yolo, "detect").and.returnValue([]);
const img = await getImageData();
const detection = await yolo.detect(img);
expect(detection).toEqual([]);
Expand Down
6 changes: 5 additions & 1 deletion src/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ import soundClassifier from "./SoundClassifier/";
import KNNClassifier from "./KNNClassifier/";
import featureExtractor from "./FeatureExtractor/";
import word2vec from "./Word2vec/";
import YOLO from "./ObjectDetector/YOLO";
import {YOLO} from "./ObjectDetector/YOLO";
import {CocoSsd} from "./ObjectDetector/CocoSsd";
import objectDetector from "./ObjectDetector";
import poseNet from "./PoseNet";
import * as imageUtils from "./utils/imageUtilities";
Expand All @@ -35,6 +36,7 @@ import facemesh from "./Facemesh";
import handpose from './Handpose';
import p5Utils from "./utils/p5Utils";
import communityStatement from "./utils/community";
import * as testingUtils from "./utils/testingUtils";

const withPreload = {
charRNN,
Expand All @@ -52,6 +54,7 @@ const withPreload = {
styleTransfer,
word2vec,
YOLO,
CocoSsd,
objectDetector,
uNet,
sentiment,
Expand All @@ -75,4 +78,5 @@ module.exports = Object.assign({ p5Utils }, preloadRegister(withPreload), {
tfvis,
version,
neuralNetwork,
testingUtils
});
26 changes: 26 additions & 0 deletions src/utils/testingUtils/index.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@

export const getRobin = async () => {
const img = new Image();
img.crossOrigin = "";
img.src = "https://cdn.jsdelivr.net/gh/ml5js/ml5-library@development/assets/bird.jpg";
await new Promise(resolve => {
img.onload = resolve;
});
return img;
}

export const getImageData = async () => {
const arr = new Uint8ClampedArray(20000);

// Iterate through every pixel
for (let i = 0; i < arr.length; i += 4) {
arr[i + 0] = 0; // R value
arr[i + 1] = 190; // G value
arr[i + 2] = 0; // B value
arr[i + 3] = 255; // A value
}

// Initialize a new ImageData object
const img = new ImageData(arr, 200);
return img;
}