diff --git a/digits/model/images/classification/views.py b/digits/model/images/classification/views.py index 9d5d12012..d843d8912 100644 --- a/digits/model/images/classification/views.py +++ b/digits/model/images/classification/views.py @@ -529,6 +529,11 @@ def classify_many(): 'Unable to classify any image from the file') scores = last_output_data + # force correct 2D shape squeezing scores + for i in reversed(range(2, len(scores.shape))): + if scores.shape[i] == 1: + scores = np.squeeze(scores, axis=(i,)) + # take top 5 indices = (-scores).argsort()[:, :5] @@ -665,6 +670,11 @@ def top_n(): if scores is None: raise RuntimeError('An error occurred while processing the images') + # force correct 2D shape squeezing scores + for i in reversed(range(2, len(scores.shape))): + if scores.shape[i] == 1: + scores = np.squeeze(scores, axis=(i,)) + labels = model_job.train_task().get_labels() images = inputs['data'] indices = (-scores).argsort(axis=0)[:top_n]