Skip to content

Commit 5d9f04a

Browse files
authored
Fix instance color (#2434)
* Fix instance color * Fix cli cmd * Minor bug fix
1 parent b2f782e commit 5d9f04a

File tree

5 files changed

+38
-18
lines changed

5 files changed

+38
-18
lines changed

sleap/gui/color.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from sleap.util import get_config_file
2020
from sleap_io.model.instance import Instance, Track
2121
from sleap_io.model.skeleton import Node, Edge
22-
from sleap_io import Labels
22+
from sleap_io import Labels, LabeledFrame
2323
from sleap_io.model.skeleton import Skeleton
2424
from sleap.prefs import prefs
2525
from sleap.sleap_io_adaptors.skeleton_utils import node_to_index, edge_to_index
@@ -248,6 +248,7 @@ def get_item_color(
248248
item: Any,
249249
parent_instance: Optional[Instance] = None,
250250
parent_skeleton: Optional[Skeleton] = None,
251+
frame: Optional[LabeledFrame] = None,
251252
) -> ColorTupleType:
252253
"""Gets (r, g, b) tuple of color to use for drawing item."""
253254

@@ -276,7 +277,7 @@ def get_item_color(
276277

277278
if track is None and parent_instance:
278279
# Get an index for items without track
279-
track = self.get_pseudo_track_index(parent_instance)
280+
track = self.get_pseudo_track_index(parent_instance, frame=frame)
280281

281282
return self.get_track_color(track=track)
282283

sleap/gui/learning/runners.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -500,11 +500,13 @@ def write_pipeline_files(
500500
# Add a line to the script for training this model
501501
train_script += (
502502
f"sleap-nn-train --config-name {new_cfg_filename} "
503-
f"--config-dir {''} "
503+
f"--config-dir . "
504504
f"trainer_config.ckpt_dir={Path(ckpt_path).parent.as_posix()} "
505-
f"trainer_config.run_name={Path(ckpt_path).name}"
506-
f"trainer_config.zmq.controller_port={cfg_info.config.trainer_config.zmq.controller_port}"
507-
f"trainer_config.zmq.publish_port={cfg_info.config.trainer_config.zmq.publish_port}"
505+
f"trainer_config.run_name={Path(ckpt_path).name} "
506+
f"trainer_config.zmq.controller_port="
507+
f"{cfg_info.config.trainer_config.zmq.controller_port} "
508+
f"trainer_config.zmq.publish_port="
509+
f"{cfg_info.config.trainer_config.zmq.publish_port} "
508510
"\n"
509511
)
510512

sleap/gui/overlays/instance.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def add_to_scene(self, video, frame_idx):
4545
for instance in instances:
4646
self.player.addInstance(
4747
instance=instance,
48+
frame=lf,
4849
markerRadius=self.state.get("marker size", 4),
4950
nodeLabelSize=self.state.get("node label size", 12),
5051
show_non_visible=self.state.get("show non-visible nodes", default=True),

sleap/gui/widgets/video.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
from sleap_io.model.instance import Instance, PredictedInstance
6161
from sleap.sleap_io_adaptors.instance_utils import fill_missing, node_points
6262
from sleap.sleap_io_adaptors.video_utils import get_last_frame_idx
63-
from sleap_io import Video
63+
from sleap_io import Video, LabeledFrame
6464
from sleap.prefs import prefs
6565
from sleap_io import Node
6666

@@ -484,7 +484,7 @@ def scene(self):
484484
"""Returns `QGraphicsScene` for viewer."""
485485
return self.view.scene
486486

487-
def addInstance(self, instance, **kwargs):
487+
def addInstance(self, instance, frame: Optional[LabeledFrame] = None, **kwargs):
488488
"""Add a skeleton instance to the video.
489489
490490
Args:
@@ -494,7 +494,7 @@ def addInstance(self, instance, **kwargs):
494494
"""
495495
# Check if instance is an Instance (or subclass of Instance)
496496
if issubclass(type(instance), Instance):
497-
instance = QtInstance(instance=instance, player=self, **kwargs)
497+
instance = QtInstance(instance=instance, frame=frame, player=self, **kwargs)
498498
if type(instance) != QtInstance:
499499
return
500500
if instance.instance.n_visible > 0 or not isinstance(
@@ -1415,6 +1415,7 @@ def __init__(
14151415
predicted=False,
14161416
show_non_visible=True,
14171417
callbacks=None,
1418+
frame: Optional[LabeledFrame] = None,
14181419
*args,
14191420
**kwargs,
14201421
):
@@ -1425,7 +1426,7 @@ def __init__(
14251426
self.radius = radius
14261427
self.color_manager = self.player.color_manager
14271428
self.color = self.color_manager.get_item_color(
1428-
self.node, self._parent_instance.instance
1429+
self.node, self._parent_instance.instance, frame=frame
14291430
)
14301431
self.edges = []
14311432
self.name = node.name
@@ -1676,6 +1677,7 @@ def __init__(
16761677
src: QtNode,
16771678
dst: QtNode,
16781679
show_non_visible: bool = True,
1680+
frame: Optional[LabeledFrame] = None,
16791681
*args,
16801682
**kwargs,
16811683
):
@@ -1702,7 +1704,9 @@ def __init__(
17021704
)
17031705

17041706
edge_pair = (src.node, dst.node)
1705-
color = player.color_manager.get_item_color(edge_pair, parent.instance)
1707+
color = player.color_manager.get_item_color(
1708+
edge_pair, parent.instance, frame=frame
1709+
)
17061710
pen_width = player.color_manager.get_item_pen_width(edge_pair, parent.instance)
17071711
pen = QPen(QColor(*color), pen_width)
17081712
pen.setCosmetic(True)
@@ -1834,6 +1838,7 @@ def __init__(
18341838
markerRadius=4,
18351839
nodeLabelSize=12,
18361840
show_non_visible=True,
1841+
frame: Optional[LabeledFrame] = None,
18371842
*args,
18381843
**kwargs,
18391844
):
@@ -1844,7 +1849,7 @@ def __init__(
18441849
self.predicted = hasattr(instance, "score")
18451850

18461851
color_manager = self.player.color_manager
1847-
color = color_manager.get_item_color(self.instance)
1852+
color = color_manager.get_item_color(self.instance, frame=frame)
18481853

18491854
self.show_non_visible = show_non_visible
18501855
self.selectable = not self.predicted or color_manager.color_predicted
@@ -1875,7 +1880,9 @@ def __init__(
18751880
if self.predicted:
18761881
self.box = QGraphicsRectItem(parent=self)
18771882
else:
1878-
self.box = VisibleBoundingBox(rect=self._bounding_rect, parent=self)
1883+
self.box = VisibleBoundingBox(
1884+
rect=self._bounding_rect, parent=self, frame=frame
1885+
)
18791886
box_pen_width = color_manager.get_item_pen_width(self.instance)
18801887
box_pen = QPen(QColor(*color), box_pen_width)
18811888
box_pen.setStyle(Qt.DashLine)
@@ -1924,6 +1931,7 @@ def __init__(
19241931
predicted=self.predicted,
19251932
radius=self.markerRadius,
19261933
show_non_visible=self.show_non_visible,
1934+
frame=frame,
19271935
)
19281936

19291937
self.nodes[node.name] = node_item
@@ -1938,6 +1946,7 @@ def __init__(
19381946
src=self.nodes[src],
19391947
dst=self.nodes[dst],
19401948
show_non_visible=self.show_non_visible,
1949+
frame=frame,
19411950
)
19421951
self.nodes[src].edges.append(edge_item)
19431952
self.nodes[dst].edges.append(edge_item)
@@ -2254,11 +2263,12 @@ def __init__(
22542263
parent: QtInstance,
22552264
opacity: float = 0.8,
22562265
scaling_padding: float = 10.0,
2266+
frame: Optional[LabeledFrame] = None,
22572267
):
22582268
super().__init__(rect, parent)
22592269
self.box_width = parent.markerRadius
22602270
color_manager = parent.player.color_manager
2261-
int_color = color_manager.get_item_color(parent.instance)
2271+
int_color = color_manager.get_item_color(parent.instance, frame=frame)
22622272
self.int_color = QColor(*int_color)
22632273
self.corner_opacity = opacity
22642274
self.scaling_padding = scaling_padding
@@ -2512,6 +2522,7 @@ def plot_instances(scene, frame_idx, labels, video=None, fixed=True):
25122522
color=color_manager.get_track_color(pseudo_track),
25132523
predicted=fixed,
25142524
color_predicted=True,
2525+
frame=labeled_frame,
25152526
show_non_visible=False,
25162527
)
25172528
inst.showLabels(False)

sleap/io/visuals.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from sleap.gui.color import ColorManager
1919
from sleap_io.model.instance import Instance
20-
from sleap_io import Video, Labels
20+
from sleap_io import Video, Labels, LabeledFrame
2121
from sleap.sleap_io_adaptors.video_utils import _sentinel
2222
from sleap.sleap_io_adaptors.lf_labels_utils import (
2323
load_labels_video_search,
@@ -204,7 +204,7 @@ def _plot_instances_cv(
204204
img, offset = self._crop_frame(img, instances)
205205

206206
for instance in instances:
207-
self._plot_instance_cv(img, instance, offset)
207+
self._plot_instance_cv(img, instance, offset, frame=lfs[0])
208208

209209
return img
210210

@@ -270,6 +270,7 @@ def _plot_instance_cv(
270270
instance: "Instance",
271271
offset: Optional[Tuple[int, int]] = None,
272272
fill: bool = True,
273+
frame: Optional[LabeledFrame] = None,
273274
):
274275
"""
275276
Add visual annotations for single instance.
@@ -299,7 +300,9 @@ def _plot_instance_cv(
299300

300301
for node_idx, (x, y) in enumerate(points_array):
301302
node = nodes[node_idx]
302-
node_color_bgr = self.color_manager.get_item_color(node, instance)[::-1]
303+
node_color_bgr = self.color_manager.get_item_color(
304+
node, instance, frame=frame
305+
)[::-1]
303306

304307
# Make sure this is a valid and visible point
305308
if not has_nans(x, y):
@@ -323,7 +326,9 @@ def _plot_instance_cv(
323326
dst_x, dst_y = points_array[dst]
324327

325328
edge = (nodes[src], nodes[dst])
326-
edge_color_bgr = self.color_manager.get_item_color(edge, instance)[::-1]
329+
edge_color_bgr = self.color_manager.get_item_color(
330+
edge, instance, frame=frame
331+
)[::-1]
327332

328333
# Make sure that both nodes are present in this instance before
329334
# drawing edge

0 commit comments

Comments
 (0)