Skip to content

Commit 2a88bc3

Browse files
committed
Add functionality to get just NMS results (without WBC).
1 parent 90bb0af commit 2a88bc3

File tree

1 file changed

+67
-2
lines changed

1 file changed

+67
-2
lines changed

Diff for: predictor.py

+67-2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from multiprocessing import Pool
2323
import pickle
2424
import pandas as pd
25+
from cuda_functions.nms_3D.pth_nms import nms_gpu as nms_3D
2526

2627

2728
class Predictor:
@@ -272,7 +273,22 @@ def load_saved_predictions(self, apply_wbc=True):
272273
pool.close()
273274
pool.join()
274275
else:
275-
list_of_results_per_patient = list_of_results_per_patient
276+
apply_nms = True
277+
if apply_nms:
278+
self.logger.info('applying NMS to test set predictions with iou = {} and n_ens = {}.'.format(
279+
self.cf.wcs_iou, n_ens))
280+
281+
mp_inputs = [[ii[0], ii[1], self.cf.class_dict, self.cf.wcs_iou, n_ens] for ii in list_of_results_per_patient]
282+
# pool = Pool(processes=6)
283+
# list_of_results_per_patient = pool.map(apply_nms_to_patient, mp_inputs, chunksize=1)
284+
# pool.close()
285+
# pool.join()
286+
list_of_results_per_patient = []
287+
for mp_input in mp_inputs:
288+
list_of_results_per_patient.append(apply_nms_to_patient(mp_input))
289+
# list_of_results_per_patient = apply_nms_to_patient(mp_inputs[0])
290+
else:
291+
list_of_results_per_patient = list_of_results_per_patient #returns all the predictions without NMS or WBC
276292

277293
# merge 2D box predictions to 3D cubes (if model predicts 2D but evaluation is run in 3D)
278294
if self.cf.merge_2D_to_3D_preds:
@@ -501,7 +517,7 @@ def batch_tiling_forward(self, batch):
501517
if self.mode == 'val':
502518
chunk_dicts += [self.net.train_forward(b, is_validation=True)]
503519
else:
504-
# print('check net.test_forward')
520+
# print('check net.test_forward')
505521
chunk_dicts += [self.net.test_forward(b, return_masks=self.cf.return_masks_in_test)]
506522

507523

@@ -520,7 +536,43 @@ def batch_tiling_forward(self, batch):
520536

521537
return results_dict
522538

539+
def apply_nms_to_patient(inputs):
540+
'''
541+
Apply 3D NMS instead of WBC
542+
'''
543+
in_patient_results_list, pid, class_dict, wcs_iou, n_ens = inputs
544+
out_patient_results_list = [[] for _ in range(len(in_patient_results_list))]
545+
for bix, b in enumerate(in_patient_results_list):
546+
547+
for cl in list(class_dict.keys()):
548+
549+
boxes = [(ix, box) for ix, box in enumerate(b) if (box['box_type'] == 'det' and box['box_pred_class_id'] == cl)]
550+
box_coords = np.array([b[1]['box_coords'] for b in boxes])
551+
box_scores = np.array([b[1]['box_score'] for b in boxes])
552+
box_center_factor = np.array([b[1]['box_patch_center_factor'] for b in boxes])
553+
box_n_overlaps = np.array([b[1]['box_n_overlaps'] for b in boxes])
554+
box_patch_id = np.array([b[1]['patch_id'] for b in boxes])
555+
556+
# print(box_coords.shape, box_scores.shape)
557+
# print(box_coords.dtype, box_scores.dtype)
558+
box_coords = torch.from_numpy(box_coords).type(torch.float32).cuda()
559+
box_scores = torch.from_numpy(box_scores).type(torch.float32).cuda()
560+
# print(box_coords.dtype, box_scores.dtype)
561+
if 0 not in box_scores.shape:
562+
keep_scores, keep_coords = non_max_suppression_3D(
563+
torch.cat((box_coords, box_scores[:, None]), dim=1), box_patch_id, wcs_iou, n_ens)
564+
# keep_scores, keep_coords = non_max_suppression_3D(
565+
# np.concatenate((box_coords, box_scores[:, None], box_center_factor[:, None],
566+
# box_n_overlaps[:, None]), axis=1), box_patch_id, wcs_iou, n_ens)
567+
568+
for boxix in range(len(keep_scores)):
569+
out_patient_results_list[bix].append({'box_type': 'det', 'box_coords': keep_coords[boxix],
570+
'box_score': keep_scores[boxix], 'box_pred_class_id': cl})
571+
572+
# add gt boxes back to new output list.
573+
out_patient_results_list[bix].extend([box for box in b if box['box_type'] == 'gt'])
523574

575+
return [out_patient_results_list, pid]
524576

525577
def apply_wbc_to_patient(inputs):
526578
"""
@@ -603,7 +655,20 @@ def merge_2D_to_3D_preds_per_patient(inputs):
603655

604656
return [out_patient_results_list, pid]
605657

658+
def non_max_suppression_3D(dets, box_patch_id, thresh, n_ens):
659+
'''
660+
Non max suppression
661+
'''
662+
# print(dets.shape, box_patch_id.shape)
663+
keep = nms_3D(dets, thresh)
664+
# print(keep.shape)
665+
dets_keep = dets[keep]
666+
# print(dets_keep.shape)
667+
keep_coords = dets_keep[:,:6].cpu().detach().numpy()
668+
keep_scores = dets_keep[:,-1].cpu().detach().numpy()
669+
# print(keep_scores.shape, keep_coords.shape)
606670

671+
return keep_scores, keep_coords
607672

608673
def weighted_box_clustering(dets, box_patch_id, thresh, n_ens):
609674
"""

0 commit comments

Comments
 (0)