-
Notifications
You must be signed in to change notification settings - Fork 1
RoPE, Fourier RoPE, Fourier learnable embeddings #96
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
16d7bc4
6f2c7ad
b82c4d4
a07ea57
6d135fb
4714aea
9c64789
67bf6e4
fa61af0
55f5f25
9cec3a2
f02a173
287c475
a7e3a56
5d4bf5e
f23ef5c
785df8f
3d3f2ca
6af9e17
6928078
c4b1124
62f2c03
9292bbc
3751de0
3d1a35e
c4abac2
5a7e86b
9eddead
d5993a9
bcb661a
fd77ded
41454f7
64c970b
b63f24f
c320eea
21035fb
4d27914
be5e630
dba9f08
0dd6a60
e492909
4140524
a1ca23e
b5fa58d
c721e90
6711697
20fd4a7
9ac41a8
c43ee75
fe1eeca
2da8c09
65a4ae0
7c38ad4
8b552ef
8fdfba1
03df33f
1d2f5a5
5a5f75f
3ff1ab0
fe2c88e
de2ace9
9b29171
3bc9fef
1998f6f
511161d
054147d
28d5f4f
38f7798
1c1a340
e703bdf
77a9437
acc4c13
15fec60
137c96e
711028a
cf9453e
a2aeea9
7783d6b
632481c
8e40ba9
0b70f6e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -26,31 +26,42 @@ def run(cfg: DictConfig) -> dict[int, sio.Labels]: | |||||||||||||||||||||||
""" | ||||||||||||||||||||||||
eval_cfg = Config(cfg) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
if "checkpoints" in cfg.keys(): | ||||||||||||||||||||||||
# update with parameters for batch train job | ||||||||||||||||||||||||
if "batch_config" in cfg.keys(): | ||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Simplify dictionary membership test by removing In Python, you can check if a key exists in a dictionary directly without calling Apply this diff to fix the issue: -if "batch_config" in cfg.keys():
+if "batch_config" in cfg: 📝 Committable suggestion
Suggested change
🧰 Tools🪛 Ruff30-30: Use Remove (SIM118) |
||||||||||||||||||||||||
try: | ||||||||||||||||||||||||
index = int(os.environ["POD_INDEX"]) | ||||||||||||||||||||||||
# For testing without deploying a job on runai | ||||||||||||||||||||||||
except KeyError: | ||||||||||||||||||||||||
index = input("Pod Index Not found! Please choose a pod index: ") | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
logger.info(f"Pod Index: {index}") | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
checkpoints = pd.read_csv(cfg.checkpoints) | ||||||||||||||||||||||||
checkpoint = checkpoints.iloc[index] | ||||||||||||||||||||||||
except KeyError as e: | ||||||||||||||||||||||||
index = int( | ||||||||||||||||||||||||
input(f"{e}. Assuming single run!\nPlease input task index to run:") | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
Comment on lines
+34
to
+36
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add input validation to handle non-integer user input When prompting the user for the task index, if a non-integer value is entered, a Apply this diff to implement input validation: except KeyError as e:
- index = int(
- input(f"{e}. Assuming single run!\nPlease input task index to run:")
- )
+ while True:
+ try:
+ index = int(
+ input(f"{e}. Assuming single run!\nPlease input task index to run:")
+ )
+ break
+ except ValueError:
+ print("Invalid input. Please enter an integer.") 📝 Committable suggestion
Suggested change
|
||||||||||||||||||||||||
|
||||||||||||||||||||||||
hparams_df = pd.read_csv(cfg.batch_config) | ||||||||||||||||||||||||
hparams = hparams_df.iloc[index].to_dict() | ||||||||||||||||||||||||
_ = hparams.pop("Unnamed: 0", None) | ||||||||||||||||||||||||
Comment on lines
+38
to
+40
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add CSV file validation and error handling. The code assumes the CSV file exists and contains the expected columns. Consider adding validation:
+ if not os.path.exists(cfg.batch_config):
+ raise FileNotFoundError(f"Batch config file {cfg.batch_config} not found")
hparams_df = pd.read_csv(cfg.batch_config)
+ if len(hparams_df) <= index:
+ raise IndexError(f"Task index {index} exceeds number of configurations {len(hparams_df)}")
hparams = hparams_df.iloc[index].to_dict()
_ = hparams.pop("Unnamed: 0", None) 📝 Committable suggestion
Suggested change
|
||||||||||||||||||||||||
|
||||||||||||||||||||||||
if eval_cfg.set_hparams(hparams): | ||||||||||||||||||||||||
logger.info("Updated the following hparams to the following values") | ||||||||||||||||||||||||
logger.info(hparams) | ||||||||||||||||||||||||
else: | ||||||||||||||||||||||||
checkpoint = eval_cfg.cfg.ckpt_path | ||||||||||||||||||||||||
hparams = {} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
checkpoint = eval_cfg.cfg.ckpt_path | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
logger.info(f"Testing model saved at {checkpoint}") | ||||||||||||||||||||||||
model = GTRRunner.load_from_checkpoint(checkpoint) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
model.tracker_cfg = eval_cfg.cfg.tracker | ||||||||||||||||||||||||
model.tracker = Tracker(**model.tracker_cfg) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
logger.info(f"Using the following tracker:") | ||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove unnecessary The log message does not contain any placeholders, so the Apply this diff to fix the issue: -logger.info(f"Using the following tracker:")
+logger.info("Using the following tracker:") 📝 Committable suggestion
Suggested change
🧰 Tools🪛 Ruff56-56: f-string without any placeholders Remove extraneous (F541) |
||||||||||||||||||||||||
|
||||||||||||||||||||||||
print(model.tracker) | ||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Replace Using Apply this diff to replace -print(model.tracker)
+logger.info(model.tracker)
|
||||||||||||||||||||||||
model.metrics["test"] = eval_cfg.cfg.runner.metrics.test | ||||||||||||||||||||||||
model.persistent_tracking["test"] = eval_cfg.cfg.tracker.get( | ||||||||||||||||||||||||
"persistent_tracking", False | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
logger.info(f"Computing the following metrics:") | ||||||||||||||||||||||||
logger.info(model.metrics.test) | ||||||||||||||||||||||||
logger.info(model.metrics['test']) | ||||||||||||||||||||||||
model.test_results["save_path"] = eval_cfg.cfg.runner.save_path | ||||||||||||||||||||||||
logger.info(f"Saving results to {model.test_results['save_path']}") | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -126,6 +126,8 @@ def filter_max_center_dist( | |
k_boxes: torch.Tensor | None = None, | ||
nonk_boxes: torch.Tensor | None = None, | ||
id_inds: torch.Tensor | None = None, | ||
h: int = None, | ||
w: int = None, | ||
Comment on lines
+129
to
+130
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Implement the usage of newly added parameters. The Consider implementing the scaling by applying this change: - norm_dist = dist.mean(axis=-1) # n_k x Np
+ # Scale distances from fractions to pixels
+ if h is not None and w is not None:
+ scale = torch.tensor([w, h], device=dist.device)
+ dist = dist * (scale ** 2).mean() # Use mean of w^2 and h^2 for scaling
+ norm_dist = dist.mean(axis=-1) # n_k x Np
|
||
) -> torch.Tensor: | ||
"""Filter trajectory score by distances between objects across frames. | ||
|
||
|
@@ -135,6 +137,8 @@ def filter_max_center_dist( | |
k_boxes: The bounding boxes in the current frame | ||
nonk_boxes: the boxes not in the current frame | ||
id_inds: track ids | ||
h: height of image | ||
w: width of image | ||
|
||
Returns: | ||
An N_t x N association matrix | ||
|
@@ -147,13 +151,15 @@ def filter_max_center_dist( | |
k_s = ((k_boxes[:, :, 2:] - k_boxes[:, :, :2]) ** 2).sum(dim=2) # n_k | ||
|
||
nonk_ct = (nonk_boxes[:, :, :2] + nonk_boxes[:, :, 2:]) / 2 | ||
|
||
# TODO: nonk_boxes should be only from previous frame rather than entire window | ||
dist = ((k_ct[:, None, :, :] - nonk_ct[None, :, :, :]) ** 2).sum( | ||
dim=-1 | ||
) # n_k x Np | ||
|
||
norm_dist = dist / (k_s[:, None, :] + 1e-8) | ||
# TODO: note that dist is in units of fraction of the height and width of the image; | ||
# TODO: need to scale it by the original image size so that its in units of pixels | ||
# norm_dist = dist / (k_s[:, None, :] + 1e-8) | ||
norm_dist = dist.mean(axis=-1) # n_k x Np | ||
# norm_dist = | ||
|
||
valid = norm_dist < max_center_dist # n_k x Np | ||
valid_assn = ( | ||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -96,25 +96,35 @@ def run(cfg: DictConfig) -> dict[int, sio.Labels]: | |||||||||||||||
""" | ||||||||||||||||
pred_cfg = Config(cfg) | ||||||||||||||||
|
||||||||||||||||
if "checkpoints" in cfg.keys(): | ||||||||||||||||
# update with parameters for batch train job | ||||||||||||||||
if "batch_config" in cfg.keys(): | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Simplify dictionary key check by removing In Python, checking for a key in a dictionary does not require the Apply this diff to fix the issue: -if "batch_config" in cfg.keys():
+if "batch_config" in cfg: 📝 Committable suggestion
Suggested change
🧰 Tools🪛 Ruff100-100: Use Remove (SIM118) |
||||||||||||||||
try: | ||||||||||||||||
index = int(os.environ["POD_INDEX"]) | ||||||||||||||||
# For testing without deploying a job on runai | ||||||||||||||||
except KeyError: | ||||||||||||||||
index = input("Pod Index Not found! Please choose a pod index: ") | ||||||||||||||||
|
||||||||||||||||
logger.info(f"Pod Index: {index}") | ||||||||||||||||
|
||||||||||||||||
checkpoints = pd.read_csv(cfg.checkpoints) | ||||||||||||||||
checkpoint = checkpoints.iloc[index] | ||||||||||||||||
except KeyError as e: | ||||||||||||||||
index = int( | ||||||||||||||||
input(f"{e}. Assuming single run!\nPlease input task index to run:") | ||||||||||||||||
) | ||||||||||||||||
Comment on lines
+103
to
+106
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Avoid using Using Modify the exception handling to set a default index and log a warning: -except KeyError as e:
- index = int(
- input(f"{e}. Assuming single run!\nPlease input task index to run:")
- )
+except KeyError:
+ logger.warning("Environment variable 'POD_INDEX' not found. Using default index 0.")
+ index = 0 📝 Committable suggestion
Suggested change
|
||||||||||||||||
|
||||||||||||||||
hparams_df = pd.read_csv(cfg.batch_config) | ||||||||||||||||
hparams = hparams_df.iloc[index].to_dict() | ||||||||||||||||
_ = hparams.pop("Unnamed: 0", None) | ||||||||||||||||
|
||||||||||||||||
if pred_cfg.set_hparams(hparams): | ||||||||||||||||
logger.info("Updated the following hparams to the following values") | ||||||||||||||||
logger.info(hparams) | ||||||||||||||||
else: | ||||||||||||||||
checkpoint = pred_cfg.cfg.ckpt_path | ||||||||||||||||
hparams = {} | ||||||||||||||||
|
||||||||||||||||
checkpoint = pred_cfg.cfg.ckpt_path | ||||||||||||||||
|
||||||||||||||||
logger.info(f"Running inference with model from {checkpoint}") | ||||||||||||||||
model = GTRRunner.load_from_checkpoint(checkpoint) | ||||||||||||||||
|
||||||||||||||||
tracker_cfg = pred_cfg.get_tracker_cfg() | ||||||||||||||||
logger.info("Updating tracker hparams") | ||||||||||||||||
|
||||||||||||||||
model.tracker_cfg = tracker_cfg | ||||||||||||||||
model.tracker = Tracker(**model.tracker_cfg) | ||||||||||||||||
|
||||||||||||||||
logger.info(f"Using the following tracker:") | ||||||||||||||||
logger.info(model.tracker) | ||||||||||||||||
|
||||||||||||||||
|
@@ -124,12 +134,14 @@ def run(cfg: DictConfig) -> dict[int, sio.Labels]: | |||||||||||||||
os.makedirs(outdir, exist_ok=True) | ||||||||||||||||
|
||||||||||||||||
for label_file, vid_file in zip(labels_files, vid_files): | ||||||||||||||||
logger.info(f"Tracking {label_file} - {vid_file}...") | ||||||||||||||||
dataset = pred_cfg.get_dataset( | ||||||||||||||||
label_files=[label_file], vid_files=[vid_file], mode="test" | ||||||||||||||||
) | ||||||||||||||||
dataloader = pred_cfg.get_dataloader(dataset, mode="test") | ||||||||||||||||
preds = track(model, trainer, dataloader) | ||||||||||||||||
outpath = os.path.join(outdir, f"{Path(label_file).stem}.dreem_inference.slp") | ||||||||||||||||
logger.info(f"Saving results to {outpath}...") | ||||||||||||||||
preds.save(outpath) | ||||||||||||||||
|
||||||||||||||||
return preds | ||||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -138,8 +138,10 @@ def track( | |||||||||||||||||||||||||||||||||||||
# asso_preds, pred_boxes, pred_time, embeddings = self.model( | ||||||||||||||||||||||||||||||||||||||
# instances, reid_features | ||||||||||||||||||||||||||||||||||||||
# ) | ||||||||||||||||||||||||||||||||||||||
# get reference and query instances from TrackQueue and calls _run_global_tracker() | ||||||||||||||||||||||||||||||||||||||
instances_pred = self.sliding_inference(model, frames) | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
# e.g. during train/val, don't track across batches so persistent_tracking is switched off | ||||||||||||||||||||||||||||||||||||||
if not self.persistent_tracking: | ||||||||||||||||||||||||||||||||||||||
logger.debug(f"Clearing Queue after tracking") | ||||||||||||||||||||||||||||||||||||||
self.track_queue.end_tracks() | ||||||||||||||||||||||||||||||||||||||
|
@@ -164,7 +166,9 @@ def sliding_inference( | |||||||||||||||||||||||||||||||||||||
# H: height. | ||||||||||||||||||||||||||||||||||||||
# W: width. | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
# frames is untracked clip for inference | ||||||||||||||||||||||||||||||||||||||
for batch_idx, frame_to_track in enumerate(frames): | ||||||||||||||||||||||||||||||||||||||
# tracked_frames is a list of reference frames that have been tracked (associated) | ||||||||||||||||||||||||||||||||||||||
tracked_frames = self.track_queue.collate_tracks( | ||||||||||||||||||||||||||||||||||||||
device=frame_to_track.frame_id.device | ||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||
|
@@ -188,19 +192,21 @@ def sliding_inference( | |||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
curr_track_id = 0 | ||||||||||||||||||||||||||||||||||||||
# if track ids exist from another tracking program i.e. sleap, init with those | ||||||||||||||||||||||||||||||||||||||
for i, instance in enumerate(frames[batch_idx].instances): | ||||||||||||||||||||||||||||||||||||||
instance.pred_track_id = instance.gt_track_id | ||||||||||||||||||||||||||||||||||||||
curr_track_id = max(curr_track_id, instance.pred_track_id) | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
# if no track ids, then assign new ones | ||||||||||||||||||||||||||||||||||||||
for i, instance in enumerate(frames[batch_idx].instances): | ||||||||||||||||||||||||||||||||||||||
if instance.pred_track_id == -1: | ||||||||||||||||||||||||||||||||||||||
curr_track += 1 | ||||||||||||||||||||||||||||||||||||||
curr_track_id += 1 | ||||||||||||||||||||||||||||||||||||||
Comment on lines
+195
to
+202
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unused loop variable In both loops on lines 196 and 200, the loop variable Apply this diff to rename - for i, instance in enumerate(frames[batch_idx].instances):
+ for _, instance in enumerate(frames[batch_idx].instances):
instance.pred_track_id = instance.gt_track_id
curr_track_id = max(curr_track_id, instance.pred_track_id)
- for i, instance in enumerate(frames[batch_idx].instances):
+ for _, instance in enumerate(frames[batch_idx].instances):
if instance.pred_track_id == -1:
curr_track_id += 1
instance.pred_track_id = curr_track_id 📝 Committable suggestion
Suggested change
🧰 Tools🪛 Ruff196-196: Loop control variable Rename unused (B007) 200-200: Loop control variable Rename unused (B007) |
||||||||||||||||||||||||||||||||||||||
instance.pred_track_id = curr_track_id | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||
if ( | ||||||||||||||||||||||||||||||||||||||
frame_to_track.has_instances() | ||||||||||||||||||||||||||||||||||||||
): # Check if there are detections. If there are skip and increment gap count | ||||||||||||||||||||||||||||||||||||||
# combine the tracked frames with the latest frame; inference pipeline uses latest frame as pred | ||||||||||||||||||||||||||||||||||||||
frames_to_track = tracked_frames + [ | ||||||||||||||||||||||||||||||||||||||
frame_to_track | ||||||||||||||||||||||||||||||||||||||
] # better var name? | ||||||||||||||||||||||||||||||||||||||
|
@@ -217,7 +223,7 @@ def sliding_inference( | |||||||||||||||||||||||||||||||||||||
self.track_queue.add_frame(frame_to_track) | ||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||
self.track_queue.increment_gaps([]) | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
# update the frame object from the input inference untracked clip | ||||||||||||||||||||||||||||||||||||||
frames[batch_idx] = frame_to_track | ||||||||||||||||||||||||||||||||||||||
return frames | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
|
@@ -252,7 +258,7 @@ def _run_global_tracker( | |||||||||||||||||||||||||||||||||||||
# E.g.: instances_per_frame: [4, 5, 6, 7]; window of length 4 with 4 detected instances in the first frame of the window. | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
_ = model.eval() | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
# get the last frame in the clip to perform inference on | ||||||||||||||||||||||||||||||||||||||
query_frame = frames[query_ind] | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
query_instances = query_frame.instances | ||||||||||||||||||||||||||||||||||||||
|
@@ -279,8 +285,10 @@ def _run_global_tracker( | |||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
# (L=1, n_query, total_instances) | ||||||||||||||||||||||||||||||||||||||
with torch.no_grad(): | ||||||||||||||||||||||||||||||||||||||
# GTR knows this is for inference since query_instances is not None | ||||||||||||||||||||||||||||||||||||||
asso_matrix = model(all_instances, query_instances) | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
# GTR output is n_query x n_instances - split this into per-frame to softmax each frame separately | ||||||||||||||||||||||||||||||||||||||
asso_output = asso_matrix[-1].matrix.split( | ||||||||||||||||||||||||||||||||||||||
instances_per_frame, dim=1 | ||||||||||||||||||||||||||||||||||||||
) # (window_size, n_query, N_i) | ||||||||||||||||||||||||||||||||||||||
|
@@ -296,7 +304,7 @@ def _run_global_tracker( | |||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
asso_output_df.index.name = "Instances" | ||||||||||||||||||||||||||||||||||||||
asso_output_df.columns.name = "Instances" | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
# save the association matrix to the Frame object | ||||||||||||||||||||||||||||||||||||||
query_frame.add_traj_score("asso_output", asso_output_df) | ||||||||||||||||||||||||||||||||||||||
query_frame.asso_output = asso_matrix[-1] | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
|
@@ -343,6 +351,8 @@ def _run_global_tracker( | |||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
query_frame.add_traj_score("asso_nonquery", asso_nonquery_df) | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
# need frame height and width to scale boxes during post-processing | ||||||||||||||||||||||||||||||||||||||
_, h, w = query_frame.img_shape.flatten() | ||||||||||||||||||||||||||||||||||||||
pred_boxes = model_utils.get_boxes(all_instances) | ||||||||||||||||||||||||||||||||||||||
query_boxes = pred_boxes[query_inds] # n_k x 4 | ||||||||||||||||||||||||||||||||||||||
nonquery_boxes = pred_boxes[nonquery_inds] # n_nonquery x 4 | ||||||||||||||||||||||||||||||||||||||
|
@@ -374,7 +384,7 @@ def _run_global_tracker( | |||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
query_frame.add_traj_score("decay_time", decay_time_traj_score) | ||||||||||||||||||||||||||||||||||||||
################################################################################ | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
# reduce association matrix - aggregating reference instance association scores by tracks | ||||||||||||||||||||||||||||||||||||||
# (n_query x n_nonquery) x (n_nonquery x n_traj) --> n_query x n_traj | ||||||||||||||||||||||||||||||||||||||
traj_score = torch.mm(traj_score, id_inds.cpu()) # (n_query, n_traj) | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
|
@@ -387,6 +397,7 @@ def _run_global_tracker( | |||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
query_frame.add_traj_score("traj_score", traj_score_df) | ||||||||||||||||||||||||||||||||||||||
################################################################################ | ||||||||||||||||||||||||||||||||||||||
# IOU-based post-processing; add a weighted IOU across successive frames to association scores | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
# with iou -> combining with location in tracker, they set to True | ||||||||||||||||||||||||||||||||||||||
# todo -> should also work without pos_embed | ||||||||||||||||||||||||||||||||||||||
|
@@ -421,11 +432,12 @@ def _run_global_tracker( | |||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
query_frame.add_traj_score("weight_iou", iou_traj_score) | ||||||||||||||||||||||||||||||||||||||
################################################################################ | ||||||||||||||||||||||||||||||||||||||
# filters association matrix such that instances too far from each other get scores=0 | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
# threshold for continuing a tracking or starting a new track -> they use 1.0 | ||||||||||||||||||||||||||||||||||||||
# todo -> should also work without pos_embed | ||||||||||||||||||||||||||||||||||||||
traj_score = post_processing.filter_max_center_dist( | ||||||||||||||||||||||||||||||||||||||
traj_score, self.max_center_dist, query_boxes, nonquery_boxes, id_inds | ||||||||||||||||||||||||||||||||||||||
traj_score, self.max_center_dist, query_boxes, nonquery_boxes, id_inds, h, w | ||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
if self.max_center_dist is not None and self.max_center_dist > 0: | ||||||||||||||||||||||||||||||||||||||
|
@@ -439,6 +451,7 @@ def _run_global_tracker( | |||||||||||||||||||||||||||||||||||||
query_frame.add_traj_score("max_center_dist", max_center_dist_traj_score) | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
################################################################################ | ||||||||||||||||||||||||||||||||||||||
# softmax along tracks for each instance, for interpretability | ||||||||||||||||||||||||||||||||||||||
scaled_traj_score = torch.softmax(traj_score, dim=1) | ||||||||||||||||||||||||||||||||||||||
scaled_traj_score_df = pd.DataFrame( | ||||||||||||||||||||||||||||||||||||||
scaled_traj_score.numpy(), columns=unique_ids.cpu().numpy() | ||||||||||||||||||||||||||||||||||||||
|
@@ -449,6 +462,7 @@ def _run_global_tracker( | |||||||||||||||||||||||||||||||||||||
query_frame.add_traj_score("scaled", scaled_traj_score_df) | ||||||||||||||||||||||||||||||||||||||
################################################################################ | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
# hungarian matching | ||||||||||||||||||||||||||||||||||||||
match_i, match_j = linear_sum_assignment((-traj_score)) | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
track_ids = instance_ids.new_full((n_query,), -1) | ||||||||||||||||||||||||||||||||||||||
|
@@ -462,6 +476,7 @@ def _run_global_tracker( | |||||||||||||||||||||||||||||||||||||
thresh = ( | ||||||||||||||||||||||||||||||||||||||
overlap_thresh * id_inds[:, j].sum() if mult_thresh else overlap_thresh | ||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||
# if the association score for a query instance is lower than the threshold, create a new track for it | ||||||||||||||||||||||||||||||||||||||
if n_traj >= self.max_tracks or traj_score[i, j] > thresh: | ||||||||||||||||||||||||||||||||||||||
logger.debug( | ||||||||||||||||||||||||||||||||||||||
f"Assigning instance {i} to track {j} with id {unique_ids[j]}" | ||||||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix video reader cleanup in del
The current cleanup implementation has several issues:
Apply this diff to improve the cleanup:
📝 Committable suggestion