Skip to content

Commit 82c7ec2

Browse files
parikshit14sfmig
authored andcommitted
converted predicted path to movement's ValidPosesDataset
1 parent 394f763 commit 82c7ec2

File tree

2 files changed

+38
-8
lines changed

2 files changed

+38
-8
lines changed

ethology/tap_models/track_any_point.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
Visualizer,
1010
read_video_from_path,
1111
)
12+
from movement.io.load_poses import from_numpy
1213

1314
LIST_OF_SUPPORTED_TAP_MODELS = ["cotracker"]
1415

@@ -50,6 +51,32 @@ def __init__(
5051
"facebookresearch/co-tracker", "cotracker3_offline"
5152
)
5253

54+
def convert_to_movement_dataset(self, pred_tracks):
55+
"""Convert the predicted tracks to movement dataset.
56+
57+
Parameters
58+
----------
59+
pred_tracks: torch.Tensor
60+
Tensor containing tracks of each query
61+
point across the frames of size
62+
[batch, frame, query_points, X, Y].
63+
64+
Returns
65+
-------
66+
momentum.ValidPosesDataset
67+
Dataset containing the predicted tracks.
68+
69+
"""
70+
# pred_tracks.shape = [batch, frames, querypoints, 2d points]
71+
# goes into positional array with shape
72+
# [frames, 2d points, 1 keypoint per query point, no. of query points]
73+
pred_tracks = pred_tracks.cpu().numpy().squeeze(0)
74+
ds = from_numpy(
75+
position_array=pred_tracks.transpose(0, 2, 1)[:, :, None, :],
76+
source_software="cotracker",
77+
)
78+
return ds
79+
5380
def track(
5481
self,
5582
video_path: str,
@@ -66,7 +93,7 @@ def track(
6693
query_points: list
6794
2D List of shape (Q,3) where Q is the
6895
number of query points and each Q containing
69-
X, Y, Frame-Number to start tracking from
96+
Frame-Number to start tracking from, X, Y
7097
save_dir: str
7198
Directory path to save the processed video
7299
save_results: bool
@@ -87,8 +114,6 @@ def track(
87114
raise FileNotFoundError(
88115
f"Video source file {video_path} not found"
89116
)
90-
else:
91-
print(f"Video source file {video_path} found")
92117

93118
# load video
94119
video = read_video_from_path(video_path)
@@ -110,7 +135,7 @@ def track(
110135
self.model = self.model.to(device)
111136
video = video.to(device)
112137
query_tensor = query_tensor.to(device)
113-
print("video shape", video.shape)
138+
114139
pred_tracks, pred_visibility = self.model(
115140
video, queries=query_tensor[None]
116141
)
@@ -134,5 +159,5 @@ def track(
134159
visibility=pred_visibility,
135160
filename=file_name,
136161
)
137-
138-
return pred_tracks
162+
ds = self.convert_to_movement_dataset(pred_tracks)
163+
return ds

tests/test_unit/test_tap_models/test_tap_models.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,12 @@ def test_track_with_valid_parameters():
6060
),
6161
patch("torch.from_numpy", return_value=torch.from_numpy(mock_video)),
6262
):
63-
path = tracker.track(
63+
ds = tracker.track(
6464
video_path="fake_video_path.mp4", query_points=query_points
6565
)
66-
assert list(path.shape) == [1, 50, 4, 2]
66+
assert ds.position.shape == (50, 2, 1, 4)
67+
assert ds.source_software == "cotracker"
68+
assert len(ds.individuals) == 4
69+
assert len(ds.keypoints) == 1
70+
assert len(ds.time) == 50
71+
assert len(ds.space) == 2

0 commit comments

Comments
 (0)