2424
2525public class MainActivity extends AppCompatActivity {
2626
27+ //Load the tensorflow inference library
2728 static {
2829 System .loadLibrary ("tensorflow_inference" );
2930 }
3031
32+ //PATH TO OUR MODEL FILE AND NAMES OF THE INPUT AND OUTPUT NODES
3133 private String MODEL_PATH = "file:///android_asset/squeezenet.pb" ;
3234 private String INPUT_NAME = "input_1" ;
3335 private String OUTPUT_NAME = "output_1" ;
3436 private TensorFlowInferenceInterface tf ;
3537
38+ //ARRAY TO HOLD THE PREDICTIONS AND FLOAT VALUES TO HOLD THE IMAGE DATA
3639 float [] PREDICTIONS = new float [1000 ];
3740 private float [] floatValues ;
3841 private int [] INPUT_SIZE = {224 ,224 ,3 };
@@ -46,12 +49,14 @@ protected void onCreate(Bundle savedInstanceState) {
4649 super .onCreate (savedInstanceState );
4750 setContentView (R .layout .activity_main );
4851
49- tf = new TensorFlowInferenceInterface (getAssets (),MODEL_PATH );
50-
5152
5253 Toolbar toolbar = (Toolbar ) findViewById (R .id .toolbar );
5354 setSupportActionBar (toolbar );
5455
56+
57+ //initialize tensorflow with the AssetManager and the Model
58+ tf = new TensorFlowInferenceInterface (getAssets (),MODEL_PATH );
59+
5560 imageView = (ImageView ) findViewById (R .id .imageview );
5661 resultView = (TextView ) findViewById (R .id .results );
5762
@@ -66,6 +71,7 @@ public void onClick(View view) {
6671
6772 try {
6873
74+ //READ THE IMAGE FROM ASSETS FOLDER
6975 InputStream imageStream = getAssets ().open ("testimage.jpg" );
7076
7177 Bitmap bitmap = BitmapFactory .decodeStream (imageStream );
@@ -84,6 +90,7 @@ public void onClick(View view) {
8490 });
8591 }
8692
93+ //FUNCTION TO COMPUTE THE MAXIMUM PREDICTION AND ITS CONFIDENCE
8794 public Object [] argmax (float [] array ){
8895
8996
@@ -110,20 +117,29 @@ public Object[] argmax(float[] array){
110117 public void predict (final Bitmap bitmap ){
111118
112119
120+ //Runs inference in background thread
113121 new AsyncTask <Integer ,Integer ,Integer >(){
114122
115123 @ Override
116124
117125 protected Integer doInBackground (Integer ...params ){
118126
127+ //Resize the image into 224 x 224
119128 Bitmap resized_image = ImageUtils .processBitmap (bitmap ,224 );
129+
130+ //Normalize the pixels
120131 floatValues = ImageUtils .normalizeBitmap (resized_image ,224 ,127.5f ,1.0f );
121132
133+ //Pass input into the tensorflow
122134 tf .feed (INPUT_NAME ,floatValues ,1 ,224 ,224 ,3 );
135+
136+ //compute predictions
123137 tf .run (new String []{OUTPUT_NAME });
124138
139+ //copy the output into the PREDICTIONS array
125140 tf .fetch (OUTPUT_NAME ,PREDICTIONS );
126141
142+ //Obtained highest prediction
127143 Object [] results = argmax (PREDICTIONS );
128144
129145
@@ -135,10 +151,12 @@ protected Integer doInBackground(Integer ...params){
135151
136152 final String conf = String .valueOf (confidence * 100 ).substring (0 ,5 );
137153
154+ //Convert predicted class index into actual label name
138155 final String label = ImageUtils .getLabel (getAssets ().open ("labels.json" ),class_index );
139156
140157
141158
159+ //Display result on UI
142160 runOnUiThread (new Runnable () {
143161 @ Override
144162 public void run () {
0 commit comments