22
22
from multiprocessing import Pool
23
23
import pickle
24
24
import pandas as pd
25
+ from cuda_functions .nms_3D .pth_nms import nms_gpu as nms_3D
25
26
26
27
27
28
class Predictor :
@@ -272,7 +273,22 @@ def load_saved_predictions(self, apply_wbc=True):
272
273
pool .close ()
273
274
pool .join ()
274
275
else :
275
- list_of_results_per_patient = list_of_results_per_patient
276
+ apply_nms = True
277
+ if apply_nms :
278
+ self .logger .info ('applying NMS to test set predictions with iou = {} and n_ens = {}.' .format (
279
+ self .cf .wcs_iou , n_ens ))
280
+
281
+ mp_inputs = [[ii [0 ], ii [1 ], self .cf .class_dict , self .cf .wcs_iou , n_ens ] for ii in list_of_results_per_patient ]
282
+ # pool = Pool(processes=6)
283
+ # list_of_results_per_patient = pool.map(apply_nms_to_patient, mp_inputs, chunksize=1)
284
+ # pool.close()
285
+ # pool.join()
286
+ list_of_results_per_patient = []
287
+ for mp_input in mp_inputs :
288
+ list_of_results_per_patient .append (apply_nms_to_patient (mp_input ))
289
+ # list_of_results_per_patient = apply_nms_to_patient(mp_inputs[0])
290
+ else :
291
+ list_of_results_per_patient = list_of_results_per_patient #returns all the predictions without NMS or WBC
276
292
277
293
# merge 2D box predictions to 3D cubes (if model predicts 2D but evaluation is run in 3D)
278
294
if self .cf .merge_2D_to_3D_preds :
@@ -501,7 +517,7 @@ def batch_tiling_forward(self, batch):
501
517
if self .mode == 'val' :
502
518
chunk_dicts += [self .net .train_forward (b , is_validation = True )]
503
519
else :
504
- # print('check net.test_forward')
520
+ # print('check net.test_forward')
505
521
chunk_dicts += [self .net .test_forward (b , return_masks = self .cf .return_masks_in_test )]
506
522
507
523
@@ -520,7 +536,43 @@ def batch_tiling_forward(self, batch):
520
536
521
537
return results_dict
522
538
539
+ def apply_nms_to_patient (inputs ):
540
+ '''
541
+ Apply 3D NMS instead of WBC
542
+ '''
543
+ in_patient_results_list , pid , class_dict , wcs_iou , n_ens = inputs
544
+ out_patient_results_list = [[] for _ in range (len (in_patient_results_list ))]
545
+ for bix , b in enumerate (in_patient_results_list ):
546
+
547
+ for cl in list (class_dict .keys ()):
548
+
549
+ boxes = [(ix , box ) for ix , box in enumerate (b ) if (box ['box_type' ] == 'det' and box ['box_pred_class_id' ] == cl )]
550
+ box_coords = np .array ([b [1 ]['box_coords' ] for b in boxes ])
551
+ box_scores = np .array ([b [1 ]['box_score' ] for b in boxes ])
552
+ box_center_factor = np .array ([b [1 ]['box_patch_center_factor' ] for b in boxes ])
553
+ box_n_overlaps = np .array ([b [1 ]['box_n_overlaps' ] for b in boxes ])
554
+ box_patch_id = np .array ([b [1 ]['patch_id' ] for b in boxes ])
555
+
556
+ # print(box_coords.shape, box_scores.shape)
557
+ # print(box_coords.dtype, box_scores.dtype)
558
+ box_coords = torch .from_numpy (box_coords ).type (torch .float32 ).cuda ()
559
+ box_scores = torch .from_numpy (box_scores ).type (torch .float32 ).cuda ()
560
+ # print(box_coords.dtype, box_scores.dtype)
561
+ if 0 not in box_scores .shape :
562
+ keep_scores , keep_coords = non_max_suppression_3D (
563
+ torch .cat ((box_coords , box_scores [:, None ]), dim = 1 ), box_patch_id , wcs_iou , n_ens )
564
+ # keep_scores, keep_coords = non_max_suppression_3D(
565
+ # np.concatenate((box_coords, box_scores[:, None], box_center_factor[:, None],
566
+ # box_n_overlaps[:, None]), axis=1), box_patch_id, wcs_iou, n_ens)
567
+
568
+ for boxix in range (len (keep_scores )):
569
+ out_patient_results_list [bix ].append ({'box_type' : 'det' , 'box_coords' : keep_coords [boxix ],
570
+ 'box_score' : keep_scores [boxix ], 'box_pred_class_id' : cl })
571
+
572
+ # add gt boxes back to new output list.
573
+ out_patient_results_list [bix ].extend ([box for box in b if box ['box_type' ] == 'gt' ])
523
574
575
+ return [out_patient_results_list , pid ]
524
576
525
577
def apply_wbc_to_patient (inputs ):
526
578
"""
@@ -603,7 +655,20 @@ def merge_2D_to_3D_preds_per_patient(inputs):
603
655
604
656
return [out_patient_results_list , pid ]
605
657
658
+ def non_max_suppression_3D (dets , box_patch_id , thresh , n_ens ):
659
+ '''
660
+ Non max suppression
661
+ '''
662
+ # print(dets.shape, box_patch_id.shape)
663
+ keep = nms_3D (dets , thresh )
664
+ # print(keep.shape)
665
+ dets_keep = dets [keep ]
666
+ # print(dets_keep.shape)
667
+ keep_coords = dets_keep [:,:6 ].cpu ().detach ().numpy ()
668
+ keep_scores = dets_keep [:,- 1 ].cpu ().detach ().numpy ()
669
+ # print(keep_scores.shape, keep_coords.shape)
606
670
671
+ return keep_scores , keep_coords
607
672
608
673
def weighted_box_clustering (dets , box_patch_id , thresh , n_ens ):
609
674
"""
0 commit comments