Skip to content

RoPE, Fourier RoPE, Fourier learnable embeddings #96

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 81 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
81 commits
Select commit Hold shift + click to select a range
16d7bc4
create notebook for dev
shaikh58 Jul 31, 2024
6f2c7ad
test update of notebook
shaikh58 Jul 31, 2024
b82c4d4
implement rope embedding
shaikh58 Aug 2, 2024
a07ea57
minor changes - add batch job file to repo
shaikh58 Aug 5, 2024
6d135fb
add local train run script, minor changes
shaikh58 Aug 5, 2024
4714aea
Update rope.ipynb
shaikh58 Aug 5, 2024
9c64789
refactor transformer encoder
shaikh58 Aug 6, 2024
67bf6e4
further changes for rope
shaikh58 Aug 6, 2024
fa61af0
complete encoder section of rope
shaikh58 Aug 6, 2024
55f5f25
setup batch training
shaikh58 Aug 7, 2024
9cec3a2
remove batch run commands from repo
shaikh58 Aug 7, 2024
f02a173
remove batch training script
shaikh58 Aug 7, 2024
287c475
Update base.yaml
shaikh58 Aug 7, 2024
a7e3a56
Merge branch 'mustafa-rope' of https://github.com/talmolab/dreem into…
shaikh58 Aug 7, 2024
5d4bf5e
Update run_trainer.py
shaikh58 Aug 7, 2024
f23ef5c
Update .gitignore
shaikh58 Aug 7, 2024
785df8f
comments for tracker.py
shaikh58 Aug 7, 2024
3d3f2ca
embedding bug fixes for encoder
shaikh58 Aug 8, 2024
6af9e17
implement rope for decoder
shaikh58 Aug 9, 2024
6928078
final attn head supports stack embeddings
shaikh58 Aug 9, 2024
c4b1124
Update tests, add new unit tests for rope
shaikh58 Aug 10, 2024
62f2c03
rope bug fixes
shaikh58 Aug 12, 2024
9292bbc
minor update to previous commit
shaikh58 Aug 12, 2024
3751de0
fix device mismatch in mlp module
shaikh58 Aug 15, 2024
3d1a35e
support for adding embedding to instance
shaikh58 Aug 15, 2024
c4abac2
bug fixes to pass unit tests
shaikh58 Aug 16, 2024
5a7e86b
minor updates from PR review
shaikh58 Aug 16, 2024
9eddead
allow batch eval/inference flexibility rather than just different mod…
aaprasad Aug 16, 2024
d5993a9
linting
shaikh58 Aug 19, 2024
bcb661a
add cross attn for rope-stack before final asso matrix output
shaikh58 Aug 26, 2024
fd77ded
minor bug fix in rope embedding for single instance clips
shaikh58 Aug 27, 2024
41454f7
use `sleap-io` as video backend instead of imageio
aaprasad Aug 30, 2024
64c970b
lint
aaprasad Aug 30, 2024
b63f24f
create notebook for dev
shaikh58 Jul 31, 2024
c320eea
test update of notebook
shaikh58 Jul 31, 2024
21035fb
implement rope embedding
shaikh58 Aug 2, 2024
4d27914
minor changes - add batch job file to repo
shaikh58 Aug 5, 2024
be5e630
add local train run script, minor changes
shaikh58 Aug 5, 2024
dba9f08
Update rope.ipynb
shaikh58 Aug 5, 2024
0dd6a60
refactor transformer encoder
shaikh58 Aug 6, 2024
e492909
further changes for rope
shaikh58 Aug 6, 2024
4140524
complete encoder section of rope
shaikh58 Aug 6, 2024
a1ca23e
setup batch training
shaikh58 Aug 7, 2024
b5fa58d
remove batch run commands from repo
shaikh58 Aug 7, 2024
c721e90
Update base.yaml
shaikh58 Aug 7, 2024
6711697
remove batch training script
shaikh58 Aug 7, 2024
20fd4a7
Update run_trainer.py
shaikh58 Aug 7, 2024
9ac41a8
Update .gitignore
shaikh58 Aug 7, 2024
c43ee75
comments for tracker.py
shaikh58 Aug 7, 2024
fe1eeca
embedding bug fixes for encoder
shaikh58 Aug 8, 2024
2da8c09
implement rope for decoder
shaikh58 Aug 9, 2024
65a4ae0
final attn head supports stack embeddings
shaikh58 Aug 9, 2024
7c38ad4
Update tests, add new unit tests for rope
shaikh58 Aug 10, 2024
8b552ef
rope bug fixes
shaikh58 Aug 12, 2024
8fdfba1
minor update to previous commit
shaikh58 Aug 12, 2024
03df33f
fix device mismatch in mlp module
shaikh58 Aug 15, 2024
1d2f5a5
support for adding embedding to instance
shaikh58 Aug 15, 2024
5a5f75f
bug fixes to pass unit tests
shaikh58 Aug 16, 2024
3ff1ab0
minor updates from PR review
shaikh58 Aug 16, 2024
fe2c88e
linting
shaikh58 Aug 19, 2024
de2ace9
add cross attn for rope-stack before final asso matrix output
shaikh58 Aug 26, 2024
9b29171
minor bug fix in rope embedding for single instance clips
shaikh58 Aug 27, 2024
3bc9fef
Merge branch 'mustafa-rope' of https://github.com/talmolab/dreem into…
shaikh58 Sep 27, 2024
1998f6f
- Started implementation for post processing fixes; no logic changes
shaikh58 Oct 9, 2024
511161d
- Add support for learned Fourier spatial/temporal embeddings
shaikh58 Oct 15, 2024
054147d
updates to fix pos emb bugs in encoder/decoder
shaikh58 Oct 15, 2024
28d5f4f
- fixed rope concat to only use t,x,y not t,x,y,orig
shaikh58 Oct 15, 2024
38f7798
linting for readability
shaikh58 Oct 15, 2024
1c1a340
bug fixes to fourier rope implementation - working version
shaikh58 Oct 16, 2024
e703bdf
- Bug fix in pre-encoder fourier - create coeffs array based on d_mod…
shaikh58 Oct 16, 2024
77a9437
Add support for choosing num fourier components using n_components
shaikh58 Oct 16, 2024
acc4c13
- Fix in n _components for fourier embeddings
shaikh58 Oct 23, 2024
15fec60
- bug fix to fourier embedding pre-encoder; was creating a new instan…
shaikh58 Oct 25, 2024
137c96e
- update test modules for rope + fourier implementations
shaikh58 Oct 25, 2024
711028a
linting
shaikh58 Oct 25, 2024
cf9453e
- Added comments to dataset classes
shaikh58 Nov 2, 2024
a2aeea9
removed debugging scripts
shaikh58 Nov 5, 2024
7783d6b
bug in logging model.metrics in eval
shaikh58 Nov 5, 2024
632481c
add comments to loss, no functional change to attention_head; set up …
shaikh58 Nov 7, 2024
8e40ba9
- switch off embeddings between decoder self attn and cross attn
shaikh58 Nov 7, 2024
0b70f6e
undo changes to decoder layer embeddings
shaikh58 Nov 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,8 @@ dreem/training/models/*

# docs
site/
*.xml
dreem/training/configs/base.yaml
dreem/training/configs/override.yaml
dreem/training/configs/override.yaml
dreem/training/configs/base.yaml
1 change: 1 addition & 0 deletions dreem/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def create_chunks(self) -> None:
if self.chunk:
self.chunked_frame_idx, self.label_idx = [], []
for i, frame_idx in enumerate(self.frame_idx):
# splits frame indices into chunks of length clip_length
frame_idx_split = torch.split(frame_idx, self.clip_length)
self.chunked_frame_idx.extend(frame_idx_split)
self.label_idx.extend(len(frame_idx_split) * [i])
Expand Down
28 changes: 14 additions & 14 deletions dreem/datasets/sleap_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,13 @@ def __init__(
# if self.seed is not None:
# np.random.seed(self.seed)
self.labels = [sio.load_slp(slp_file) for slp_file in self.slp_files]
self.videos = [imageio.get_reader(vid_file) for vid_file in self.vid_files]
self.vid_readers = {}
# do we need this? would need to update with sleap-io

# for label in self.labels:
# label.remove_empty_instances(keep_empty_frames=False)

# list of lists, each sublist is a list of frame indices for a given video
self.frame_idx = [torch.arange(len(labels)) for labels in self.labels]
# Method in BaseDataset. Creates label_idx and chunked_frame_idx to be
# used in call to get_instances()
Expand All @@ -123,6 +124,7 @@ def get_indices(self, idx: int) -> tuple:
Args:
idx: the index of the batch.
"""
# self.label_idx is a list of indices specifying which video each chunk belongs to
return self.label_idx[idx], self.chunked_frame_idx[idx]

def get_instances(self, label_idx: list[int], frame_idx: list[int]) -> list[Frame]:
Expand All @@ -136,19 +138,17 @@ def get_instances(self, label_idx: list[int], frame_idx: list[int]) -> list[Fram
A list of `dreem.io.Frame` objects containing metadata and instance data for the batch/clip.

"""
video = self.labels[label_idx]
# each entry in self.labels is a sleap Labels object (which is a list of LabeledFrames)
video = self.labels[label_idx] # label_idx is the

video_name = self.video_files[label_idx]

vid_reader = self.videos[label_idx]

# img = vid_reader.get_data(0)

skeleton = video.skeletons[-1]

frames = []
for i, frame_ind in enumerate(frame_idx):
(
( # frame_idx is a list of frame indices for a given video
instances,
gt_track_ids,
poses,
Expand All @@ -159,15 +159,15 @@ def get_instances(self, label_idx: list[int], frame_idx: list[int]) -> list[Fram

frame_ind = int(frame_ind)

lf = video[frame_ind]
lf = video[frame_ind] # video is a sleap Labels object for a given file

try:
img = vid_reader.get_data(int(lf.frame_idx))
except IndexError as e:
logger.warning(
f"Could not read frame {frame_ind} from {video_name} due to {e}"
)
continue
img = lf.image # a single frame from the video
except FileNotFoundError as e:
if video_name not in self.vid_readers:
self.vid_readers[video_name] = sio.load_video(video_name)
vid_reader = self.vid_readers[video_name]
img = vid_reader[lf.frame_idx]

if len(img.shape) == 2:
img = img.expand_dims(-1)
Expand Down Expand Up @@ -370,5 +370,5 @@ def get_instances(self, label_idx: list[int], frame_idx: list[int]) -> list[Fram

def __del__(self):
"""Handle file closing before garbage collection."""
for reader in self.videos:
for reader in self.vid_readers:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix video reader cleanup in del

The current cleanup implementation has several issues:

  1. It's iterating over dict keys instead of values
  2. Missing error handling for cleanup failures
  3. No type checking before calling close()

Apply this diff to improve the cleanup:

-    for reader in self.vid_readers:
-        reader.close()
+    for video_name, reader in self.vid_readers.items():
+        try:
+            if hasattr(reader, 'close'):
+                reader.close()
+        except Exception as e:
+            logger.warning(f"Failed to close video reader for {video_name}: {str(e)}")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
for reader in self.vid_readers:
for video_name, reader in self.vid_readers.items():
try:
if hasattr(reader, 'close'):
reader.close()
except Exception as e:
logger.warning(f"Failed to close video reader for {video_name}: {str(e)}")

reader.close()
33 changes: 22 additions & 11 deletions dreem/inference/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,31 +26,42 @@ def run(cfg: DictConfig) -> dict[int, sio.Labels]:
"""
eval_cfg = Config(cfg)

if "checkpoints" in cfg.keys():
# update with parameters for batch train job
if "batch_config" in cfg.keys():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Simplify dictionary membership test by removing .keys()

In Python, you can check if a key exists in a dictionary directly without calling .keys(). Replace if "batch_config" in cfg.keys(): with if "batch_config" in cfg: for cleaner and more idiomatic code.

Apply this diff to fix the issue:

-if "batch_config" in cfg.keys():
+if "batch_config" in cfg:
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if "batch_config" in cfg.keys():
if "batch_config" in cfg:
🧰 Tools
🪛 Ruff

30-30: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)

try:
index = int(os.environ["POD_INDEX"])
# For testing without deploying a job on runai
except KeyError:
index = input("Pod Index Not found! Please choose a pod index: ")

logger.info(f"Pod Index: {index}")

checkpoints = pd.read_csv(cfg.checkpoints)
checkpoint = checkpoints.iloc[index]
except KeyError as e:
index = int(
input(f"{e}. Assuming single run!\nPlease input task index to run:")
)
Comment on lines +34 to +36
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add input validation to handle non-integer user input

When prompting the user for the task index, if a non-integer value is entered, a ValueError will be raised. Consider adding error handling to manage invalid inputs gracefully.

Apply this diff to implement input validation:

 except KeyError as e:
-    index = int(
-        input(f"{e}. Assuming single run!\nPlease input task index to run:")
-    )
+    while True:
+        try:
+            index = int(
+                input(f"{e}. Assuming single run!\nPlease input task index to run:")
+            )
+            break
+        except ValueError:
+            print("Invalid input. Please enter an integer.")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
index = int(
input(f"{e}. Assuming single run!\nPlease input task index to run:")
)
while True:
try:
index = int(
input(f"{e}. Assuming single run!\nPlease input task index to run:")
)
break
except ValueError:
print("Invalid input. Please enter an integer.")


hparams_df = pd.read_csv(cfg.batch_config)
hparams = hparams_df.iloc[index].to_dict()
_ = hparams.pop("Unnamed: 0", None)
Comment on lines +38 to +40
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add CSV file validation and error handling.

The code assumes the CSV file exists and contains the expected columns. Consider adding validation:

  1. Check if the CSV file exists
  2. Verify required columns are present
  3. Validate the index is within bounds of the DataFrame
+    if not os.path.exists(cfg.batch_config):
+        raise FileNotFoundError(f"Batch config file {cfg.batch_config} not found")
     hparams_df = pd.read_csv(cfg.batch_config)
+    if len(hparams_df) <= index:
+        raise IndexError(f"Task index {index} exceeds number of configurations {len(hparams_df)}")
     hparams = hparams_df.iloc[index].to_dict()
     _ = hparams.pop("Unnamed: 0", None)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
hparams_df = pd.read_csv(cfg.batch_config)
hparams = hparams_df.iloc[index].to_dict()
_ = hparams.pop("Unnamed: 0", None)
if not os.path.exists(cfg.batch_config):
raise FileNotFoundError(f"Batch config file {cfg.batch_config} not found")
hparams_df = pd.read_csv(cfg.batch_config)
if len(hparams_df) <= index:
raise IndexError(f"Task index {index} exceeds number of configurations {len(hparams_df)}")
hparams = hparams_df.iloc[index].to_dict()
_ = hparams.pop("Unnamed: 0", None)


if eval_cfg.set_hparams(hparams):
logger.info("Updated the following hparams to the following values")
logger.info(hparams)
else:
checkpoint = eval_cfg.cfg.ckpt_path
hparams = {}

checkpoint = eval_cfg.cfg.ckpt_path

logger.info(f"Testing model saved at {checkpoint}")
model = GTRRunner.load_from_checkpoint(checkpoint)

model.tracker_cfg = eval_cfg.cfg.tracker
model.tracker = Tracker(**model.tracker_cfg)

logger.info(f"Using the following tracker:")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Remove unnecessary f prefix in string

The log message does not contain any placeholders, so the f prefix is unnecessary. Replace logger.info(f"Using the following tracker:") with logger.info("Using the following tracker:").

Apply this diff to fix the issue:

-logger.info(f"Using the following tracker:")
+logger.info("Using the following tracker:")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
logger.info(f"Using the following tracker:")
logger.info("Using the following tracker:")
🧰 Tools
🪛 Ruff

56-56: f-string without any placeholders

Remove extraneous f prefix

(F541)


print(model.tracker)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Replace print with logger.info for consistent logging

Using logger.info instead of print ensures that all output is managed by the logging system, providing better control over log levels and formatting.

Apply this diff to replace print with logging:

-print(model.tracker)
+logger.info(model.tracker)

Committable suggestion was skipped due to low confidence.

model.metrics["test"] = eval_cfg.cfg.runner.metrics.test
model.persistent_tracking["test"] = eval_cfg.cfg.tracker.get(
"persistent_tracking", False
)
logger.info(f"Computing the following metrics:")
logger.info(model.metrics.test)
logger.info(model.metrics['test'])
model.test_results["save_path"] = eval_cfg.cfg.runner.save_path
logger.info(f"Saving results to {model.test_results['save_path']}")

Expand Down
12 changes: 9 additions & 3 deletions dreem/inference/post_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ def filter_max_center_dist(
k_boxes: torch.Tensor | None = None,
nonk_boxes: torch.Tensor | None = None,
id_inds: torch.Tensor | None = None,
h: int = None,
w: int = None,
Comment on lines +129 to +130
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Implement the usage of newly added parameters.

The h and w parameters have been added but are not utilized in the implementation. These parameters are intended to scale the distances from fractions to pixel units as noted in the TODO comments.

Consider implementing the scaling by applying this change:

-        norm_dist = dist.mean(axis=-1)  # n_k x Np
+        # Scale distances from fractions to pixels
+        if h is not None and w is not None:
+            scale = torch.tensor([w, h], device=dist.device)
+            dist = dist * (scale ** 2).mean()  # Use mean of w^2 and h^2 for scaling
+        norm_dist = dist.mean(axis=-1)  # n_k x Np

Committable suggestion was skipped due to low confidence.

) -> torch.Tensor:
"""Filter trajectory score by distances between objects across frames.

Expand All @@ -135,6 +137,8 @@ def filter_max_center_dist(
k_boxes: The bounding boxes in the current frame
nonk_boxes: the boxes not in the current frame
id_inds: track ids
h: height of image
w: width of image

Returns:
An N_t x N association matrix
Expand All @@ -147,13 +151,15 @@ def filter_max_center_dist(
k_s = ((k_boxes[:, :, 2:] - k_boxes[:, :, :2]) ** 2).sum(dim=2) # n_k

nonk_ct = (nonk_boxes[:, :, :2] + nonk_boxes[:, :, 2:]) / 2

# TODO: nonk_boxes should be only from previous frame rather than entire window
dist = ((k_ct[:, None, :, :] - nonk_ct[None, :, :, :]) ** 2).sum(
dim=-1
) # n_k x Np

norm_dist = dist / (k_s[:, None, :] + 1e-8)
# TODO: note that dist is in units of fraction of the height and width of the image;
# TODO: need to scale it by the original image size so that its in units of pixels
# norm_dist = dist / (k_s[:, None, :] + 1e-8)
norm_dist = dist.mean(axis=-1) # n_k x Np
# norm_dist =

valid = norm_dist < max_center_dist # n_k x Np
valid_assn = (
Expand Down
34 changes: 23 additions & 11 deletions dreem/inference/track.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,25 +96,35 @@ def run(cfg: DictConfig) -> dict[int, sio.Labels]:
"""
pred_cfg = Config(cfg)

if "checkpoints" in cfg.keys():
# update with parameters for batch train job
if "batch_config" in cfg.keys():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Simplify dictionary key check by removing .keys()

In Python, checking for a key in a dictionary does not require the .keys() method. Removing it makes the code more readable and efficient.

Apply this diff to fix the issue:

-if "batch_config" in cfg.keys():
+if "batch_config" in cfg:
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if "batch_config" in cfg.keys():
if "batch_config" in cfg:
🧰 Tools
🪛 Ruff

100-100: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)

try:
index = int(os.environ["POD_INDEX"])
# For testing without deploying a job on runai
except KeyError:
index = input("Pod Index Not found! Please choose a pod index: ")

logger.info(f"Pod Index: {index}")

checkpoints = pd.read_csv(cfg.checkpoints)
checkpoint = checkpoints.iloc[index]
except KeyError as e:
index = int(
input(f"{e}. Assuming single run!\nPlease input task index to run:")
)
Comment on lines +103 to +106
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Avoid using input() for robustness in non-interactive environments

Using input() can cause the script to hang in non-interactive or automated environments. Consider providing a default index or handling the absence of POD_INDEX differently.

Modify the exception handling to set a default index and log a warning:

-except KeyError as e:
-    index = int(
-        input(f"{e}. Assuming single run!\nPlease input task index to run:")
-    )
+except KeyError:
+    logger.warning("Environment variable 'POD_INDEX' not found. Using default index 0.")
+    index = 0
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
except KeyError as e:
index = int(
input(f"{e}. Assuming single run!\nPlease input task index to run:")
)
except KeyError:
logger.warning("Environment variable 'POD_INDEX' not found. Using default index 0.")
index = 0


hparams_df = pd.read_csv(cfg.batch_config)
hparams = hparams_df.iloc[index].to_dict()
_ = hparams.pop("Unnamed: 0", None)

if pred_cfg.set_hparams(hparams):
logger.info("Updated the following hparams to the following values")
logger.info(hparams)
else:
checkpoint = pred_cfg.cfg.ckpt_path
hparams = {}

checkpoint = pred_cfg.cfg.ckpt_path

logger.info(f"Running inference with model from {checkpoint}")
model = GTRRunner.load_from_checkpoint(checkpoint)

tracker_cfg = pred_cfg.get_tracker_cfg()
logger.info("Updating tracker hparams")

model.tracker_cfg = tracker_cfg
model.tracker = Tracker(**model.tracker_cfg)

logger.info(f"Using the following tracker:")
logger.info(model.tracker)

Expand All @@ -124,12 +134,14 @@ def run(cfg: DictConfig) -> dict[int, sio.Labels]:
os.makedirs(outdir, exist_ok=True)

for label_file, vid_file in zip(labels_files, vid_files):
logger.info(f"Tracking {label_file} - {vid_file}...")
dataset = pred_cfg.get_dataset(
label_files=[label_file], vid_files=[vid_file], mode="test"
)
dataloader = pred_cfg.get_dataloader(dataset, mode="test")
preds = track(model, trainer, dataloader)
outpath = os.path.join(outdir, f"{Path(label_file).stem}.dreem_inference.slp")
logger.info(f"Saving results to {outpath}...")
preds.save(outpath)

return preds
Expand Down
29 changes: 22 additions & 7 deletions dreem/inference/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,10 @@ def track(
# asso_preds, pred_boxes, pred_time, embeddings = self.model(
# instances, reid_features
# )
# get reference and query instances from TrackQueue and calls _run_global_tracker()
instances_pred = self.sliding_inference(model, frames)

# e.g. during train/val, don't track across batches so persistent_tracking is switched off
if not self.persistent_tracking:
logger.debug(f"Clearing Queue after tracking")
self.track_queue.end_tracks()
Expand All @@ -164,7 +166,9 @@ def sliding_inference(
# H: height.
# W: width.

# frames is untracked clip for inference
for batch_idx, frame_to_track in enumerate(frames):
# tracked_frames is a list of reference frames that have been tracked (associated)
tracked_frames = self.track_queue.collate_tracks(
device=frame_to_track.frame_id.device
)
Expand All @@ -188,19 +192,21 @@ def sliding_inference(
)

curr_track_id = 0
# if track ids exist from another tracking program i.e. sleap, init with those
for i, instance in enumerate(frames[batch_idx].instances):
instance.pred_track_id = instance.gt_track_id
curr_track_id = max(curr_track_id, instance.pred_track_id)

# if no track ids, then assign new ones
for i, instance in enumerate(frames[batch_idx].instances):
if instance.pred_track_id == -1:
curr_track += 1
curr_track_id += 1
Comment on lines +195 to +202
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Unused loop variable i in loops.

In both loops on lines 196 and 200, the loop variable i is not used within the loop body. Renaming i to _ improves code readability by indicating that the variable is intentionally unused.

Apply this diff to rename i to _ in both loops:

- for i, instance in enumerate(frames[batch_idx].instances):
+ for _, instance in enumerate(frames[batch_idx].instances):
    instance.pred_track_id = instance.gt_track_id
    curr_track_id = max(curr_track_id, instance.pred_track_id)
- for i, instance in enumerate(frames[batch_idx].instances):
+ for _, instance in enumerate(frames[batch_idx].instances):
    if instance.pred_track_id == -1:
        curr_track_id += 1
        instance.pred_track_id = curr_track_id
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# if track ids exist from another tracking program i.e. sleap, init with those
for i, instance in enumerate(frames[batch_idx].instances):
instance.pred_track_id = instance.gt_track_id
curr_track_id = max(curr_track_id, instance.pred_track_id)
# if no track ids, then assign new ones
for i, instance in enumerate(frames[batch_idx].instances):
if instance.pred_track_id == -1:
curr_track += 1
curr_track_id += 1
# if track ids exist from another tracking program i.e. sleap, init with those
for _, instance in enumerate(frames[batch_idx].instances):
instance.pred_track_id = instance.gt_track_id
curr_track_id = max(curr_track_id, instance.pred_track_id)
# if no track ids, then assign new ones
for _, instance in enumerate(frames[batch_idx].instances):
if instance.pred_track_id == -1:
curr_track_id += 1
🧰 Tools
🪛 Ruff

196-196: Loop control variable i not used within loop body

Rename unused i to _i

(B007)


200-200: Loop control variable i not used within loop body

Rename unused i to _i

(B007)

instance.pred_track_id = curr_track_id

else:
if (
frame_to_track.has_instances()
): # Check if there are detections. If there are skip and increment gap count
# combine the tracked frames with the latest frame; inference pipeline uses latest frame as pred
frames_to_track = tracked_frames + [
frame_to_track
] # better var name?
Expand All @@ -217,7 +223,7 @@ def sliding_inference(
self.track_queue.add_frame(frame_to_track)
else:
self.track_queue.increment_gaps([])

# update the frame object from the input inference untracked clip
frames[batch_idx] = frame_to_track
return frames

Expand Down Expand Up @@ -252,7 +258,7 @@ def _run_global_tracker(
# E.g.: instances_per_frame: [4, 5, 6, 7]; window of length 4 with 4 detected instances in the first frame of the window.

_ = model.eval()

# get the last frame in the clip to perform inference on
query_frame = frames[query_ind]

query_instances = query_frame.instances
Expand All @@ -279,8 +285,10 @@ def _run_global_tracker(

# (L=1, n_query, total_instances)
with torch.no_grad():
# GTR knows this is for inference since query_instances is not None
asso_matrix = model(all_instances, query_instances)

# GTR output is n_query x n_instances - split this into per-frame to softmax each frame separately
asso_output = asso_matrix[-1].matrix.split(
instances_per_frame, dim=1
) # (window_size, n_query, N_i)
Expand All @@ -296,7 +304,7 @@ def _run_global_tracker(

asso_output_df.index.name = "Instances"
asso_output_df.columns.name = "Instances"

# save the association matrix to the Frame object
query_frame.add_traj_score("asso_output", asso_output_df)
query_frame.asso_output = asso_matrix[-1]

Expand Down Expand Up @@ -343,6 +351,8 @@ def _run_global_tracker(

query_frame.add_traj_score("asso_nonquery", asso_nonquery_df)

# need frame height and width to scale boxes during post-processing
_, h, w = query_frame.img_shape.flatten()
pred_boxes = model_utils.get_boxes(all_instances)
query_boxes = pred_boxes[query_inds] # n_k x 4
nonquery_boxes = pred_boxes[nonquery_inds] # n_nonquery x 4
Expand Down Expand Up @@ -374,7 +384,7 @@ def _run_global_tracker(

query_frame.add_traj_score("decay_time", decay_time_traj_score)
################################################################################

# reduce association matrix - aggregating reference instance association scores by tracks
# (n_query x n_nonquery) x (n_nonquery x n_traj) --> n_query x n_traj
traj_score = torch.mm(traj_score, id_inds.cpu()) # (n_query, n_traj)

Expand All @@ -387,6 +397,7 @@ def _run_global_tracker(

query_frame.add_traj_score("traj_score", traj_score_df)
################################################################################
# IOU-based post-processing; add a weighted IOU across successive frames to association scores

# with iou -> combining with location in tracker, they set to True
# todo -> should also work without pos_embed
Expand Down Expand Up @@ -421,11 +432,12 @@ def _run_global_tracker(

query_frame.add_traj_score("weight_iou", iou_traj_score)
################################################################################
# filters association matrix such that instances too far from each other get scores=0

# threshold for continuing a tracking or starting a new track -> they use 1.0
# todo -> should also work without pos_embed
traj_score = post_processing.filter_max_center_dist(
traj_score, self.max_center_dist, query_boxes, nonquery_boxes, id_inds
traj_score, self.max_center_dist, query_boxes, nonquery_boxes, id_inds, h, w
)

if self.max_center_dist is not None and self.max_center_dist > 0:
Expand All @@ -439,6 +451,7 @@ def _run_global_tracker(
query_frame.add_traj_score("max_center_dist", max_center_dist_traj_score)

################################################################################
# softmax along tracks for each instance, for interpretability
scaled_traj_score = torch.softmax(traj_score, dim=1)
scaled_traj_score_df = pd.DataFrame(
scaled_traj_score.numpy(), columns=unique_ids.cpu().numpy()
Expand All @@ -449,6 +462,7 @@ def _run_global_tracker(
query_frame.add_traj_score("scaled", scaled_traj_score_df)
################################################################################

# hungarian matching
match_i, match_j = linear_sum_assignment((-traj_score))

track_ids = instance_ids.new_full((n_query,), -1)
Expand All @@ -462,6 +476,7 @@ def _run_global_tracker(
thresh = (
overlap_thresh * id_inds[:, j].sum() if mult_thresh else overlap_thresh
)
# if the association score for a query instance is lower than the threshold, create a new track for it
if n_traj >= self.max_tracks or traj_score[i, j] > thresh:
logger.debug(
f"Assigning instance {i} to track {j} with id {unique_ids[j]}"
Expand Down
Loading