Skip to content

Commit b3971a8

Browse files
committed
[Android] class for ml-service
Add new class and interface for ml-service api. Signed-off-by: Jaeyun Jung <[email protected]>
1 parent 7f8530c commit b3971a8

File tree

9 files changed

+1088
-0
lines changed

9 files changed

+1088
-0
lines changed

java/android/nnstreamer/src/androidTest/assets/README.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ If you want to do unit tests, put the following files in this directory.
33
$ tree .
44
.
55
└── nnstreamer
6+
├── config
7+
│   ├── config_pipeline_imgclf.conf
8+
│   └── config_single_imgclf.conf
69
├── pytorch_data
710
│   ├── mobilenetv2-quant_core-nnapi.pt
811
│   └── orange_float.raw

java/android/nnstreamer/src/androidTest/java/org/nnsuite/nnstreamer/APITestCommon.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,20 @@ public static Context getContext() {
119119
return InstrumentationRegistry.getTargetContext();
120120
}
121121

122+
/**
123+
* Gets the path string of configurations.
124+
*/
125+
public static String getConfigPath() {
126+
String root = getRootDirectory();
127+
File config = new File(root + "/nnstreamer/config");
128+
129+
if (!config.exists()) {
130+
fail();
131+
}
132+
133+
return config.getAbsolutePath();
134+
}
135+
122136
/**
123137
* Gets the File object of tensorflow-lite model.
124138
* Note that, to invoke model in the storage, the permission READ_EXTERNAL_STORAGE is required.
Lines changed: 346 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,346 @@
1+
package org.nnsuite.nnstreamer;
2+
3+
import android.support.test.runner.AndroidJUnit4;
4+
5+
import org.junit.Before;
6+
import org.junit.Test;
7+
import org.junit.runner.RunWith;
8+
9+
import java.nio.ByteBuffer;
10+
11+
import static org.junit.Assert.*;
12+
13+
/**
14+
* Testcases for MLService.
15+
*/
16+
@RunWith(AndroidJUnit4.class)
17+
public class APITestMLService {
18+
private int mReceived = 0;
19+
private boolean mInvalidState = false;
20+
private boolean mIsPipeline = false;
21+
22+
/**
23+
* The event callback for image classification model.
24+
*/
25+
private MLService.NewEventCallback mEventCb = new MLService.NewEventCallback() {
26+
@Override
27+
public void onNewDataReceived(String name, TensorsData data) {
28+
if (mIsPipeline) {
29+
if (name == null || !name.equals("result_clf")) {
30+
mInvalidState = true;
31+
return;
32+
}
33+
}
34+
35+
if (data == null || data.getTensorsCount() != 1) {
36+
mInvalidState = true;
37+
return;
38+
}
39+
40+
ByteBuffer buffer = data.getTensorData(0);
41+
int labelIndex = APITestCommon.getMaxScore(buffer);
42+
43+
/* check label index (orange) */
44+
if (labelIndex != 951) {
45+
mInvalidState = true;
46+
}
47+
48+
mReceived++;
49+
}
50+
};
51+
52+
@Before
53+
public void setUp() {
54+
APITestCommon.initNNStreamer();
55+
}
56+
57+
@Test
58+
public void testNullConfig_n() {
59+
try {
60+
new MLService(null, mEventCb);
61+
fail();
62+
} catch (Exception e) {
63+
/* expected */
64+
}
65+
}
66+
67+
@Test
68+
public void testEmptyConfig_n() {
69+
try {
70+
new MLService("", mEventCb);
71+
fail();
72+
} catch (Exception e) {
73+
/* expected */
74+
}
75+
}
76+
77+
@Test
78+
public void testInvalidConfig_n() {
79+
try {
80+
String config = APITestCommon.getConfigPath() + "/config_invalid.conf";
81+
82+
new MLService(config, mEventCb);
83+
fail();
84+
} catch (Exception e) {
85+
/* expected */
86+
}
87+
}
88+
89+
@Test
90+
public void testNullCallback_n() {
91+
try {
92+
String config = APITestCommon.getConfigPath() + "/config_single_imgclf.conf";
93+
94+
new MLService(config, null);
95+
fail();
96+
} catch (Exception e) {
97+
/* expected */
98+
}
99+
}
100+
101+
@Test
102+
public void testInputNullData_n() {
103+
try {
104+
String config = APITestCommon.getConfigPath() + "/config_pipeline_imgclf.conf";
105+
MLService service = new MLService(config, mEventCb);
106+
107+
service.start();
108+
109+
service.inputData("input_img", null);
110+
fail();
111+
} catch (Exception e) {
112+
/* expected */
113+
}
114+
}
115+
116+
@Test
117+
public void testInputInvalidNode_n() {
118+
try {
119+
String config = APITestCommon.getConfigPath() + "/config_pipeline_imgclf.conf";
120+
MLService service = new MLService(config, mEventCb);
121+
122+
service.start();
123+
124+
service.inputData("invalid_node", APITestCommon.readRawImageData());
125+
fail();
126+
} catch (Exception e) {
127+
/* expected */
128+
}
129+
}
130+
131+
@Test
132+
public void testGetInputInfo() {
133+
try {
134+
String config = APITestCommon.getConfigPath() + "/config_pipeline_imgclf.conf";
135+
MLService service = new MLService(config, mEventCb);
136+
137+
TensorsInfo info = service.getInputInformation("input_img");
138+
139+
assertEquals(1, info.getTensorsCount());
140+
assertEquals(NNStreamer.TensorType.UINT8, info.getTensorType(0));
141+
assertArrayEquals(new int[]{3,224,224,1}, info.getTensorDimension(0));
142+
143+
service.close();
144+
} catch (Exception e) {
145+
fail();
146+
}
147+
}
148+
149+
@Test
150+
public void testGetInputInfoInvalidNode_n() {
151+
try {
152+
String config = APITestCommon.getConfigPath() + "/config_pipeline_imgclf.conf";
153+
MLService service = new MLService(config, mEventCb);
154+
155+
service.getInputInformation("invalid_node");
156+
fail();
157+
} catch (Exception e) {
158+
/* expected */
159+
}
160+
}
161+
162+
@Test
163+
public void testGetOutputInfo() {
164+
try {
165+
String config = APITestCommon.getConfigPath() + "/config_pipeline_imgclf.conf";
166+
MLService service = new MLService(config, mEventCb);
167+
168+
TensorsInfo info = service.getOutputInformation("result_clf");
169+
170+
assertEquals(1, info.getTensorsCount());
171+
assertEquals(NNStreamer.TensorType.UINT8, info.getTensorType(0));
172+
assertArrayEquals(new int[]{1001,1}, info.getTensorDimension(0));
173+
174+
service.close();
175+
} catch (Exception e) {
176+
fail();
177+
}
178+
}
179+
180+
@Test
181+
public void testGetOutputInfoInvalidNode_n() {
182+
try {
183+
String config = APITestCommon.getConfigPath() + "/config_pipeline_imgclf.conf";
184+
MLService service = new MLService(config, mEventCb);
185+
186+
service.getOutputInformation("invalid_node");
187+
fail();
188+
} catch (Exception e) {
189+
/* expected */
190+
}
191+
}
192+
193+
@Test
194+
public void testSetInfoNullName_n() {
195+
try {
196+
String config = APITestCommon.getConfigPath() + "/config_pipeline_imgclf.conf";
197+
MLService service = new MLService(config, mEventCb);
198+
199+
service.setInformation(null, "test_value");
200+
fail();
201+
} catch (Exception e) {
202+
/* expected */
203+
}
204+
}
205+
206+
@Test
207+
public void testSetInfoEmptyName_n() {
208+
try {
209+
String config = APITestCommon.getConfigPath() + "/config_pipeline_imgclf.conf";
210+
MLService service = new MLService(config, mEventCb);
211+
212+
service.setInformation("", "test_value");
213+
fail();
214+
} catch (Exception e) {
215+
/* expected */
216+
}
217+
}
218+
219+
@Test
220+
public void testSetInfoNullValue_n() {
221+
try {
222+
String config = APITestCommon.getConfigPath() + "/config_pipeline_imgclf.conf";
223+
MLService service = new MLService(config, mEventCb);
224+
225+
service.setInformation("test_info", null);
226+
fail();
227+
} catch (Exception e) {
228+
/* expected */
229+
}
230+
}
231+
232+
@Test
233+
public void testSetInfoEmptyValue_n() {
234+
try {
235+
String config = APITestCommon.getConfigPath() + "/config_pipeline_imgclf.conf";
236+
MLService service = new MLService(config, mEventCb);
237+
238+
service.setInformation("test_info", "");
239+
fail();
240+
} catch (Exception e) {
241+
/* expected */
242+
}
243+
}
244+
245+
@Test
246+
public void testGetInfoNullName_n() {
247+
try {
248+
String config = APITestCommon.getConfigPath() + "/config_pipeline_imgclf.conf";
249+
MLService service = new MLService(config, mEventCb);
250+
251+
service.getInformation(null);
252+
fail();
253+
} catch (Exception e) {
254+
/* expected */
255+
}
256+
}
257+
258+
@Test
259+
public void testGetInfoEmptyName_n() {
260+
try {
261+
String config = APITestCommon.getConfigPath() + "/config_pipeline_imgclf.conf";
262+
MLService service = new MLService(config, mEventCb);
263+
264+
service.getInformation("");
265+
fail();
266+
} catch (Exception e) {
267+
/* expected */
268+
}
269+
}
270+
271+
@Test
272+
public void testGetInfoInvalidName_n() {
273+
try {
274+
String config = APITestCommon.getConfigPath() + "/config_pipeline_imgclf.conf";
275+
MLService service = new MLService(config, mEventCb);
276+
277+
service.getInformation("invalid_name");
278+
fail();
279+
} catch (Exception e) {
280+
/* expected */
281+
}
282+
}
283+
284+
@Test
285+
public void testGetInfo() {
286+
try {
287+
String config = APITestCommon.getConfigPath() + "/config_single_imgclf.conf";
288+
MLService service = new MLService(config, mEventCb);
289+
290+
service.setInformation("test_info", "test_value");
291+
292+
assertEquals("0.5", service.getInformation("threshold"));
293+
assertEquals("test_value", service.getInformation("test_info"));
294+
295+
service.close();
296+
} catch (Exception e) {
297+
fail();
298+
}
299+
}
300+
301+
/**
302+
* Runs image classification with configuration.
303+
*/
304+
private void runImageClassification(String config, boolean isPipeline) {
305+
mIsPipeline = isPipeline;
306+
307+
try {
308+
MLService service = new MLService(config, mEventCb);
309+
310+
service.start();
311+
312+
/* push input buffer */
313+
TensorsData input = APITestCommon.readRawImageData();
314+
315+
for (int i = 0; i < 5; i++) {
316+
service.inputData("input_img", input);
317+
Thread.sleep(100);
318+
}
319+
320+
/* sleep 200 to invoke */
321+
Thread.sleep(200);
322+
323+
/* check received data from output node */
324+
assertFalse(mInvalidState);
325+
assertTrue(mReceived > 0);
326+
327+
service.close();
328+
} catch (Exception e) {
329+
fail();
330+
}
331+
}
332+
333+
@Test
334+
public void testRunPipeline() {
335+
String config = APITestCommon.getConfigPath() + "/config_pipeline_imgclf.conf";
336+
337+
runImageClassification(config, true);
338+
}
339+
340+
@Test
341+
public void testRunSingleShot() {
342+
String config = APITestCommon.getConfigPath() + "/config_single_imgclf.conf";
343+
344+
runImageClassification(config, false);
345+
}
346+
}

0 commit comments

Comments
 (0)