Skip to content

Commit 9844ff7

Browse files
authored
Merge pull request #196 from talmolab/divya/fix-sleap-io-v0.3
Update instance creation for sleap-io v0.3.0 compatibility
2 parents ee03f9e + ef43858 commit 9844ff7

File tree

2 files changed

+24
-24
lines changed

2 files changed

+24
-24
lines changed

sleap_nn/inference/predictors.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -869,10 +869,10 @@ def _make_labeled_frames_from_generator(
869869
pred_instances = pred_instances + bbox.squeeze(axis=0)[0, :]
870870
preds[(int(video_idx), int(frame_idx))].append(
871871
sio.PredictedInstance.from_numpy(
872-
points=pred_instances,
872+
points_data=pred_instances,
873873
skeleton=self.skeletons[skeleton_idx],
874874
point_scores=pred_values,
875-
instance_score=instance_score,
875+
score=instance_score,
876876
)
877877
)
878878
for key, inst in preds.items():
@@ -1205,9 +1205,9 @@ def _make_labeled_frames_from_generator(
12051205
):
12061206

12071207
inst = sio.PredictedInstance.from_numpy(
1208-
points=pred_instances,
1208+
points_data=pred_instances,
12091209
skeleton=self.skeletons[skeleton_idx],
1210-
instance_score=np.nansum(pred_values),
1210+
score=np.nansum(pred_values),
12111211
point_scores=pred_values,
12121212
)
12131213
predicted_frames.append(
@@ -1589,9 +1589,9 @@ def _make_labeled_frames_from_generator(
15891589

15901590
predicted_instances.append(
15911591
sio.PredictedInstance.from_numpy(
1592-
points=pts,
1592+
points_data=pts,
15931593
point_scores=confs,
1594-
instance_score=score,
1594+
score=score,
15951595
skeleton=self.skeletons[skeleton_idx],
15961596
)
15971597
)

tests/test_evaluation.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def create_labels_two_match_one_missed_inst(minimal_instance):
8989

9090
# Create user labelled instance.
9191
user_inst_1 = sio.Instance.from_numpy(
92-
points=np.array(
92+
points_data=np.array(
9393
[
9494
[11.4, 13.4],
9595
[13.6, 15.1],
@@ -101,7 +101,7 @@ def create_labels_two_match_one_missed_inst(minimal_instance):
101101

102102
# Create Predicted Instance.
103103
pred_inst_1 = sio.PredictedInstance.from_numpy(
104-
points=np.array(
104+
points_data=np.array(
105105
[
106106
[11.2, 17.4],
107107
[12.8, 15.1],
@@ -110,12 +110,12 @@ def create_labels_two_match_one_missed_inst(minimal_instance):
110110
),
111111
skeleton=skeleton,
112112
point_scores=np.array([0.7, 0.6, 0.8]),
113-
instance_score=0.7,
113+
score=0.7,
114114
)
115115

116116
# create second user instance
117117
user_inst_2 = sio.Instance.from_numpy(
118-
points=np.array(
118+
points_data=np.array(
119119
[
120120
[1.4, 2.9],
121121
[30.6, 9.5],
@@ -126,7 +126,7 @@ def create_labels_two_match_one_missed_inst(minimal_instance):
126126
)
127127

128128
pred_inst_2 = sio.PredictedInstance.from_numpy(
129-
points=np.array(
129+
points_data=np.array(
130130
[
131131
[2.3, 2.2],
132132
[25.6, 10.0],
@@ -135,12 +135,12 @@ def create_labels_two_match_one_missed_inst(minimal_instance):
135135
),
136136
skeleton=skeleton,
137137
point_scores=np.array([0.7, 0.6, 0.6]),
138-
instance_score=0.6,
138+
score=0.6,
139139
)
140140

141141
# create a user instance which shouldn't be matched with other predicted instances
142142
user_inst_3 = sio.Instance.from_numpy(
143-
points=np.array(
143+
points_data=np.array(
144144
[
145145
[55.6, 30.2],
146146
[10.1, 18.5],
@@ -236,7 +236,7 @@ def create_labels_no_match_frame_pairs(minimal_instance):
236236

237237
# Create user labelled instance.
238238
user_inst_1 = sio.Instance.from_numpy(
239-
points=np.array(
239+
points_data=np.array(
240240
[
241241
[11.4, 13.4],
242242
[13.6, 15.1],
@@ -248,7 +248,7 @@ def create_labels_no_match_frame_pairs(minimal_instance):
248248

249249
# Create Predicted Instance.
250250
pred_inst_1 = sio.PredictedInstance.from_numpy(
251-
points=np.array(
251+
points_data=np.array(
252252
[
253253
[11.2, 17.4],
254254
[12.8, 15.1],
@@ -257,7 +257,7 @@ def create_labels_no_match_frame_pairs(minimal_instance):
257257
),
258258
skeleton=skeleton,
259259
point_scores=np.array([0.7, 0.6, 0.8]),
260-
instance_score=0.7,
260+
score=0.7,
261261
)
262262

263263
user_lf = sio.LabeledFrame(
@@ -310,7 +310,7 @@ def create_labels_more_predicted_instances(minimal_instance):
310310

311311
# Create user labelled instance.
312312
user_inst_1 = sio.Instance.from_numpy(
313-
points=np.array(
313+
points_data=np.array(
314314
[
315315
[11.4, 13.4],
316316
[13.6, 15.1],
@@ -322,7 +322,7 @@ def create_labels_more_predicted_instances(minimal_instance):
322322

323323
# create predicted instance
324324
pred_inst_1 = sio.PredictedInstance.from_numpy(
325-
points=np.array(
325+
points_data=np.array(
326326
[
327327
[11.2, 17.4],
328328
[12.8, 13.1],
@@ -331,12 +331,12 @@ def create_labels_more_predicted_instances(minimal_instance):
331331
),
332332
skeleton=skeleton,
333333
point_scores=np.array([0.7, 0.6, 0.8]),
334-
instance_score=0.8,
334+
score=0.8,
335335
)
336336

337337
# create second user instance
338338
user_inst_2 = sio.Instance.from_numpy(
339-
points=np.array(
339+
points_data=np.array(
340340
[
341341
[1.4, 2.9],
342342
[30.6, 9.5],
@@ -348,7 +348,7 @@ def create_labels_more_predicted_instances(minimal_instance):
348348

349349
# create second predicted instance
350350
pred_inst_2 = sio.PredictedInstance.from_numpy(
351-
points=np.array(
351+
points_data=np.array(
352352
[
353353
[1.3, 2.9],
354354
[29.6, 9.2],
@@ -357,12 +357,12 @@ def create_labels_more_predicted_instances(minimal_instance):
357357
),
358358
skeleton=skeleton,
359359
point_scores=np.array([0.7, 0.6, 0.6]),
360-
instance_score=0.7,
360+
score=0.7,
361361
)
362362

363363
# create a predicted instance with nan values
364364
pred_inst_3 = sio.PredictedInstance.from_numpy(
365-
points=np.array(
365+
points_data=np.array(
366366
[
367367
[np.nan, np.nan],
368368
[np.nan, np.nan],
@@ -371,7 +371,7 @@ def create_labels_more_predicted_instances(minimal_instance):
371371
),
372372
skeleton=skeleton,
373373
point_scores=np.array([0.7, 0.6, 0.6]),
374-
instance_score=0.7,
374+
score=0.7,
375375
)
376376

377377
# create labeled frame with the instances

0 commit comments

Comments
 (0)