Skip to content

Commit 9b66c70

Browse files
committed
adding labels for embedding type (i.e anchor, positive or negative).
1 parent ecf08c6 commit 9b66c70

File tree

1 file changed

+39
-10
lines changed

1 file changed

+39
-10
lines changed

viscy/representation/engine.py

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -258,34 +258,56 @@ def _prepare_embedding_metadata(
258258
self,
259259
anchor_metadata: dict,
260260
positive_metadata: dict | None = None,
261+
negative_metadata: dict | None = None,
261262
include_positive: bool = False,
263+
include_negative: bool = False,
262264
) -> tuple[list[list[str]], list[str]]:
263265
"""Prepare metadata for embedding visualization.
264266
265267
Args:
266268
anchor_metadata: Metadata for anchor samples
267269
positive_metadata: Metadata for positive samples (optional)
270+
negative_metadata: Metadata for negative samples (optional)
268271
include_positive: Whether to include positive samples in metadata
272+
include_negative: Whether to include negative samples in metadata
269273
270274
Returns:
271275
tuple containing:
272276
- metadata: List of lists containing metadata values
273277
- metadata_header: List of metadata field names
274278
"""
275-
# NOTE: temporarily hardcoded to the following:
276-
metadata_header = ["fov_name", "t", "id"]
279+
metadata_header = ["fov_name", "t", "id", "type"]
280+
281+
def process_field(x, field):
282+
if field in ["t", "id"] and isinstance(x, torch.Tensor):
283+
return str(x.detach().cpu().item())
284+
return str(x)
277285

278286
# Create lists for each metadata field
279287
metadata = [
280288
[str(x) for x in anchor_metadata["fov_name"]],
281-
[str(x.detach().cpu().numpy()) for x in anchor_metadata["t"]],
282-
[str(x.detach().cpu().numpy()) for x in anchor_metadata["id"]],
289+
[process_field(x, "t") for x in anchor_metadata["t"]],
290+
[process_field(x, "id") for x in anchor_metadata["id"]],
291+
["anchor"] * len(anchor_metadata["fov_name"]), # type field for anchors
283292
]
284293

285294
# If including positive samples, extend metadata
286295
if include_positive and positive_metadata is not None:
287-
for i, field in enumerate(metadata_header):
288-
metadata[i].extend([str(x) for x in positive_metadata[field]])
296+
for i, field in enumerate(metadata_header[:-1]): # Exclude 'type' field
297+
metadata[i].extend(
298+
[process_field(x, field) for x in positive_metadata[field]]
299+
)
300+
# Add 'positive' type for positive samples
301+
metadata[-1].extend(["positive"] * len(positive_metadata["fov_name"]))
302+
303+
# If including negative samples, extend metadata
304+
if include_negative and negative_metadata is not None:
305+
for i, field in enumerate(metadata_header[:-1]): # Exclude 'type' field
306+
metadata[i].extend(
307+
[process_field(x, field) for x in negative_metadata[field]]
308+
)
309+
# Add 'negative' type for negative samples
310+
metadata[-1].extend(["negative"] * len(negative_metadata["fov_name"]))
289311

290312
return metadata, metadata_header
291313

@@ -308,10 +330,11 @@ def validation_step(self, batch: TripletSample, batch_idx: int) -> Tensor:
308330

309331
# Store embeddings for visualization
310332
if self.current_epoch % self.embedding_log_interval == 0 and batch_idx == 0:
333+
# Must include positive samples since we're concatenating embeddings
311334
metadata, metadata_header = self._prepare_embedding_metadata(
312335
batch["anchor_metadata"],
313336
batch["positive_metadata"],
314-
include_positive=True,
337+
include_positive=True, # Required since we concatenate embeddings
315338
)
316339
self.val_embedding_outputs = {
317340
"embeddings": embeddings.detach(),
@@ -330,11 +353,17 @@ def validation_step(self, batch: TripletSample, batch_idx: int) -> Tensor:
330353
# Store embeddings for visualization
331354
if self.current_epoch % self.embedding_log_interval == 0 and batch_idx == 0:
332355
metadata, metadata_header = self._prepare_embedding_metadata(
333-
batch["anchor_metadata"], include_positive=False
356+
batch["anchor_metadata"],
357+
batch["positive_metadata"],
358+
batch["negative_metadata"],
359+
include_positive=True, # Required since we concatenate embeddings
360+
include_negative=True,
334361
)
335362
self.val_embedding_outputs = {
336-
"embeddings": anchor_projection.detach(),
337-
"images": anchor.detach(),
363+
"embeddings": torch.cat(
364+
(anchor_projection, positive_projection, negative_projection)
365+
).detach(),
366+
"images": torch.cat((anchor, pos_img, neg_img)).detach(),
338367
"metadata": list(zip(*metadata)),
339368
"metadata_header": metadata_header,
340369
}

0 commit comments

Comments
 (0)