99 Visualizer ,
1010 read_video_from_path ,
1111)
12+ from movement .io .load_poses import from_numpy
1213
1314LIST_OF_SUPPORTED_TAP_MODELS = ["cotracker" ]
1415
@@ -50,6 +51,32 @@ def __init__(
5051 "facebookresearch/co-tracker" , "cotracker3_offline"
5152 )
5253
54+ def convert_to_movement_dataset (self , pred_tracks ):
55+ """Convert the predicted tracks to movement dataset.
56+
57+ Parameters
58+ ----------
59+ pred_tracks: torch.Tensor
60+ Tensor containing tracks of each query
61+ point across the frames of size
62+ [batch, frame, query_points, X, Y].
63+
64+ Returns
65+ -------
66+ momentum.ValidPosesDataset
67+ Dataset containing the predicted tracks.
68+
69+ """
70+ # pred_tracks.shape = [batch, frames, querypoints, 2d points]
71+ # goes into positional array with shape
72+ # [frames, 2d points, 1 keypoint per query point, no. of query points]
73+ pred_tracks = pred_tracks .cpu ().numpy ().squeeze (0 )
74+ ds = from_numpy (
75+ position_array = pred_tracks .transpose (0 , 2 , 1 )[:, :, None , :],
76+ source_software = "cotracker" ,
77+ )
78+ return ds
79+
5380 def track (
5481 self ,
5582 video_path : str ,
@@ -66,7 +93,7 @@ def track(
6693 query_points: list
6794 2D List of shape (Q,3) where Q is the
6895 number of query points and each Q containing
69- X, Y, Frame-Number to start tracking from
96+ Frame-Number to start tracking from, X, Y
7097 save_dir: str
7198 Directory path to save the processed video
7299 save_results: bool
@@ -87,8 +114,6 @@ def track(
87114 raise FileNotFoundError (
88115 f"Video source file { video_path } not found"
89116 )
90- else :
91- print (f"Video source file { video_path } found" )
92117
93118 # load video
94119 video = read_video_from_path (video_path )
@@ -110,7 +135,7 @@ def track(
110135 self .model = self .model .to (device )
111136 video = video .to (device )
112137 query_tensor = query_tensor .to (device )
113- print ( "video shape" , video . shape )
138+
114139 pred_tracks , pred_visibility = self .model (
115140 video , queries = query_tensor [None ]
116141 )
@@ -134,5 +159,5 @@ def track(
134159 visibility = pred_visibility ,
135160 filename = file_name ,
136161 )
137-
138- return pred_tracks
162+ ds = self . convert_to_movement_dataset ( pred_tracks )
163+ return ds
0 commit comments