Skip to content

Commit e771701

Browse files
authored
Fix empty instance handling (#364)
This PR improves empty-instance handling in instance cropping method and the `CenteredInstanceDataset` class. Previously, if an instance contained only NaN keypoints, `find_instance_crop_size` would trigger a “NaN values encountered” warning when attempting to compute its bounding box. We now fixthis by only computing crop sizes for non-empty instances (`instance.is_empty` is False). We also remove redundant empty-instance filtering in `__getitem__()` for centered-instance models, since this is already enforced when creating `instance_idx_list`.
1 parent d946ec3 commit e771701

File tree

2 files changed

+16
-13
lines changed

2 files changed

+16
-13
lines changed

sleap_nn/data/custom_datasets.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -799,8 +799,9 @@ def __getitem__(self, index) -> Dict:
799799

800800
instances = []
801801
for inst in instances_list:
802-
if not inst.is_empty:
803-
instances.append(inst.numpy())
802+
instances.append(
803+
inst.numpy()
804+
) # no need to filter empty instances; handled while creating instance_idx_list
804805
instances = np.stack(instances, axis=0)
805806

806807
# Add singleton time dimension for single frames.
@@ -1046,8 +1047,9 @@ def __getitem__(self, index) -> Dict:
10461047

10471048
instances = []
10481049
for inst in instances_list:
1049-
if not inst.is_empty:
1050-
instances.append(inst.numpy())
1050+
instances.append(
1051+
inst.numpy()
1052+
) # no need to filter empty instance (handled while creating instance_idx)
10511053
instances = np.stack(instances, axis=0)
10521054

10531055
# Add singleton time dimension for single frames.

sleap_nn/data/instance_cropping.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,15 +44,16 @@ def find_instance_crop_size(
4444
max_length = 0.0
4545
for lf in labels:
4646
for inst in lf.instances:
47-
pts = inst.numpy()
48-
pts *= input_scaling
49-
diff_x = np.nanmax(pts[:, 0]) - np.nanmin(pts[:, 0])
50-
diff_x = 0 if np.isnan(diff_x) else diff_x
51-
max_length = np.maximum(max_length, diff_x)
52-
diff_y = np.nanmax(pts[:, 1]) - np.nanmin(pts[:, 1])
53-
diff_y = 0 if np.isnan(diff_y) else diff_y
54-
max_length = np.maximum(max_length, diff_y)
55-
max_length = np.maximum(max_length, min_crop_size_no_pad)
47+
if not inst.is_empty: # only if at least one point is not nan
48+
pts = inst.numpy()
49+
pts *= input_scaling
50+
diff_x = np.nanmax(pts[:, 0]) - np.nanmin(pts[:, 0])
51+
diff_x = 0 if np.isnan(diff_x) else diff_x
52+
max_length = np.maximum(max_length, diff_x)
53+
diff_y = np.nanmax(pts[:, 1]) - np.nanmin(pts[:, 1])
54+
diff_y = 0 if np.isnan(diff_y) else diff_y
55+
max_length = np.maximum(max_length, diff_y)
56+
max_length = np.maximum(max_length, min_crop_size_no_pad)
5657

5758
max_length += float(padding)
5859
crop_size = math.ceil(max_length / float(maximum_stride)) * maximum_stride

0 commit comments

Comments
 (0)