2020from sleap_nn .config .utils import get_backbone_type_from_cfg , get_model_type_from_cfg
2121from sleap_nn .data .identity import generate_class_maps , make_class_vectors
2222from sleap_nn .data .instance_centroids import generate_centroids
23- from sleap_nn .data .instance_cropping import generate_crops
23+ from sleap_nn .data .instance_cropping import generate_crops , find_instance_crop_size
24+
25+
2426from sleap_nn .data .normalization import (
2527 apply_normalization ,
2628 convert_to_grayscale ,
3436)
3537from sleap_nn .data .confidence_maps import generate_confmaps , generate_multiconfmaps
3638from sleap_nn .data .edge_maps import generate_pafs
37- from sleap_nn .data .instance_cropping import make_centered_bboxes
39+ from sleap_nn .data .instance_cropping import make_centered_bboxes , get_cropped_img
3840from sleap_nn .training .utils import is_distributed_initialized
3941from sleap_nn .config .get_config import get_aug_config
4042
@@ -738,6 +740,7 @@ def __init__(
738740 self .confmap_head_config = confmap_head_config
739741 self .instance_idx_list = self ._get_instance_idx_list (labels )
740742 self .cache_lf = [None , None ]
743+ # self.max_crop_size = find_instance_crop_size(self.labels, maximum_stride=self.max_stride)
741744
742745 def _get_instance_idx_list (self , labels : List [sio .Labels ]) -> List [Tuple [int ]]:
743746 """Return list of tuples with indices of labelled frames and instances."""
@@ -840,24 +843,41 @@ def __getitem__(self, index) -> Dict:
840843 scale = self .scale ,
841844 )
842845
843- # get the centroids based on the anchor idx
844- centroids = generate_centroids (instances , anchor_ind = self .anchor_ind )
846+ instance = instances [0 ]
845847
846- instance , centroid = instances [ 0 ], centroids [ 0 ] # (n_samples=1)
848+ sample = {}
847849
848- crop_size = np .array ([self .crop_size , self .crop_size ]) * np .sqrt (
849- 2
850- ) # crop extra for rotation augmentation
851- crop_size = crop_size .astype (np .int32 ).tolist ()
850+ # Get the head index
851+ head_idx = self .anchor_ind
852852
853- sample = generate_crops (image , instance , centroid , crop_size )
853+ # Determine if the instance has enough valid points
854+ valid_points = instance [~ torch .isnan (instance ).any (dim = 1 )]
855+ if valid_points .shape [0 ] < 3 :
856+ return self .__getitem__ ((index + 1 ) % len (self )) # safely retry next sample
857+
858+ # crop image
859+ sample_image , sample_instance , src_pts , dst_pts , rotated = get_cropped_img (
860+ image [0 ], instance , head_idx
861+ )
862+ sample_image , sample_instance = sample_image .unsqueeze (
863+ 0
864+ ), sample_instance .unsqueeze (0 )
865+
866+ sample ["instance_image" ] = sample_image
867+ sample ["instance" ] = sample_instance
868+ sample ["src_pts" ] = src_pts .unsqueeze (0 )
869+ sample ["dst_pts" ] = dst_pts .unsqueeze (0 )
870+ sample ["rotated" ] = torch .tensor ([rotated ], dtype = torch .bool )
854871
855872 sample ["frame_idx" ] = torch .tensor (lf_frame_idx , dtype = torch .int32 )
856873 sample ["video_idx" ] = torch .tensor (video_idx , dtype = torch .int32 )
857874 sample ["num_instances" ] = num_instances
858875 sample ["orig_size" ] = torch .Tensor ([orig_img_height , orig_img_width ]).unsqueeze (
859876 0
860877 )
878+ height , width = sample_image .shape [- 2 :]
879+ sample ["height" ] = [height ]
880+ sample ["width" ] = [width ]
861881 sample ["eff_scale" ] = torch .tensor (eff_scale , dtype = torch .float32 )
862882
863883 # apply augmentation
@@ -883,27 +903,32 @@ def __getitem__(self, index) -> Dict:
883903 )
884904
885905 # re-crop to original crop size
886- sample ["instance_bbox" ] = torch .unsqueeze (
887- make_centered_bboxes (sample ["centroid" ][0 ], self .crop_size , self .crop_size ),
888- 0 ,
889- ) # (n_samples=1, 4, 2)
890-
891- sample ["instance_image" ] = crop_and_resize (
906+ # sample["instance_bbox"] = torch.unsqueeze(
907+ # make_centered_bboxes(sample["centroid"][0], self.crop_size, self.crop_size),
908+ # 0,
909+ # ) # (n_samples=1, 4, 2)
910+
911+ # sample["instance_image"] = crop_and_resize(
912+ # sample["instance_image"],
913+ # boxes=sample["instance_bbox"],
914+ # size=(self.crop_size, self.crop_size),
915+ # )
916+ # size matcher
917+ sample_image , eff_scale = apply_sizematcher (
892918 sample ["instance_image" ],
893- boxes = sample [ "instance_bbox" ] ,
894- size = ( self .crop_size , self . crop_size ) ,
919+ max_height = self . crop_size ,
920+ max_width = self .crop_size ,
895921 )
896- point = sample ["instance_bbox" ][0 ][0 ]
897- center_instance = sample ["instance" ] - point
898- centered_centroid = sample ["centroid" ] - point
899-
900- sample ["instance" ] = center_instance # (n_samples=1, n_nodes, 2)
901- sample ["centroid" ] = centered_centroid # (n_samples=1, 2)
922+ # point = sample["instance_bbox"][0][0]
923+ # center_instance = sample["instance"] - point
924+ # centered_centroid = sample["centroid"] - point
902925
903- # Pad the image (if needed) according max stride
904- sample ["instance_image" ] = apply_pad_to_stride (
905- sample ["instance_image" ], max_stride = self .max_stride
906- )
926+ # sample["instance"] = center_instance # (n_samples=1, n_nodes, 2)
927+ # sample["centroid"] = centered_centroid # (n_samples=1, 2)
928+ sample_instance = sample ["instance" ] * eff_scale
929+ sample ["instance" ] = sample_instance
930+ sample ["instance_image" ] = sample_image
931+ sample ["scale" ] = torch .tensor (eff_scale , dtype = torch .float32 ).unsqueeze (dim = 0 )
907932
908933 img_hw = sample ["instance_image" ].shape [- 2 :]
909934
@@ -1831,7 +1856,12 @@ def get_train_val_datasets(
18311856 ),
18321857 scale = config .data_config .preprocessing .scale ,
18331858 apply_aug = config .data_config .use_augmentations_train ,
1834- crop_size = config .data_config .preprocessing .crop_size ,
1859+ crop_size = find_instance_crop_size (
1860+ train_labels ,
1861+ maximum_stride = config .model_config .backbone_config [f"{ backbone_type } " ][
1862+ "max_stride"
1863+ ],
1864+ ),
18351865 max_hw = (
18361866 config .data_config .preprocessing .max_height ,
18371867 config .data_config .preprocessing .max_width ,
@@ -1855,7 +1885,7 @@ def get_train_val_datasets(
18551885 geometric_aug = None ,
18561886 scale = config .data_config .preprocessing .scale ,
18571887 apply_aug = False ,
1858- crop_size = config . data_config . preprocessing .crop_size ,
1888+ crop_size = train_dataset .crop_size ,
18591889 max_hw = (
18601890 config .data_config .preprocessing .max_height ,
18611891 config .data_config .preprocessing .max_width ,
0 commit comments