@@ -258,34 +258,56 @@ def _prepare_embedding_metadata(
258258 self ,
259259 anchor_metadata : dict ,
260260 positive_metadata : dict | None = None ,
261+ negative_metadata : dict | None = None ,
261262 include_positive : bool = False ,
263+ include_negative : bool = False ,
262264 ) -> tuple [list [list [str ]], list [str ]]:
263265 """Prepare metadata for embedding visualization.
264266
265267 Args:
266268 anchor_metadata: Metadata for anchor samples
267269 positive_metadata: Metadata for positive samples (optional)
270+ negative_metadata: Metadata for negative samples (optional)
268271 include_positive: Whether to include positive samples in metadata
272+ include_negative: Whether to include negative samples in metadata
269273
270274 Returns:
271275 tuple containing:
272276 - metadata: List of lists containing metadata values
273277 - metadata_header: List of metadata field names
274278 """
275- # NOTE: temporarily hardcoded to the following:
276- metadata_header = ["fov_name" , "t" , "id" ]
279+ metadata_header = ["fov_name" , "t" , "id" , "type" ]
280+
281+ def process_field (x , field ):
282+ if field in ["t" , "id" ] and isinstance (x , torch .Tensor ):
283+ return str (x .detach ().cpu ().item ())
284+ return str (x )
277285
278286 # Create lists for each metadata field
279287 metadata = [
280288 [str (x ) for x in anchor_metadata ["fov_name" ]],
281- [str (x .detach ().cpu ().numpy ()) for x in anchor_metadata ["t" ]],
282- [str (x .detach ().cpu ().numpy ()) for x in anchor_metadata ["id" ]],
289+ [process_field (x , "t" ) for x in anchor_metadata ["t" ]],
290+ [process_field (x , "id" ) for x in anchor_metadata ["id" ]],
291+ ["anchor" ] * len (anchor_metadata ["fov_name" ]), # type field for anchors
283292 ]
284293
285294 # If including positive samples, extend metadata
286295 if include_positive and positive_metadata is not None :
287- for i , field in enumerate (metadata_header ):
288- metadata [i ].extend ([str (x ) for x in positive_metadata [field ]])
296+ for i , field in enumerate (metadata_header [:- 1 ]): # Exclude 'type' field
297+ metadata [i ].extend (
298+ [process_field (x , field ) for x in positive_metadata [field ]]
299+ )
300+ # Add 'positive' type for positive samples
301+ metadata [- 1 ].extend (["positive" ] * len (positive_metadata ["fov_name" ]))
302+
303+ # If including negative samples, extend metadata
304+ if include_negative and negative_metadata is not None :
305+ for i , field in enumerate (metadata_header [:- 1 ]): # Exclude 'type' field
306+ metadata [i ].extend (
307+ [process_field (x , field ) for x in negative_metadata [field ]]
308+ )
309+ # Add 'negative' type for negative samples
310+ metadata [- 1 ].extend (["negative" ] * len (negative_metadata ["fov_name" ]))
289311
290312 return metadata , metadata_header
291313
@@ -308,10 +330,11 @@ def validation_step(self, batch: TripletSample, batch_idx: int) -> Tensor:
308330
309331 # Store embeddings for visualization
310332 if self .current_epoch % self .embedding_log_interval == 0 and batch_idx == 0 :
333+ # Must include positive samples since we're concatenating embeddings
311334 metadata , metadata_header = self ._prepare_embedding_metadata (
312335 batch ["anchor_metadata" ],
313336 batch ["positive_metadata" ],
314- include_positive = True ,
337+ include_positive = True , # Required since we concatenate embeddings
315338 )
316339 self .val_embedding_outputs = {
317340 "embeddings" : embeddings .detach (),
@@ -330,11 +353,17 @@ def validation_step(self, batch: TripletSample, batch_idx: int) -> Tensor:
330353 # Store embeddings for visualization
331354 if self .current_epoch % self .embedding_log_interval == 0 and batch_idx == 0 :
332355 metadata , metadata_header = self ._prepare_embedding_metadata (
333- batch ["anchor_metadata" ], include_positive = False
356+ batch ["anchor_metadata" ],
357+ batch ["positive_metadata" ],
358+ batch ["negative_metadata" ],
359+ include_positive = True , # Required since we concatenate embeddings
360+ include_negative = True ,
334361 )
335362 self .val_embedding_outputs = {
336- "embeddings" : anchor_projection .detach (),
337- "images" : anchor .detach (),
363+ "embeddings" : torch .cat (
364+ (anchor_projection , positive_projection , negative_projection )
365+ ).detach (),
366+ "images" : torch .cat ((anchor , pos_img , neg_img )).detach (),
338367 "metadata" : list (zip (* metadata )),
339368 "metadata_header" : metadata_header ,
340369 }
0 commit comments