7272source_imgs_paths = brain_data ["cmaps" ][0 : len (contrasts )]
7373target_imgs_paths = brain_data ["cmaps" ][len (contrasts ) : 2 * len (contrasts )]
7474
75- # %%
76- # Here is what the first contrast map of the source subject looks like
77- # (the following figure is interactive):
78-
79- contrast_index = 0
80- plotting .view_img (
81- source_imgs_paths [contrast_index ],
82- brain_data ["anats" ][0 ],
83- title = f"Contrast { contrast_index } (source subject)" ,
84- opacity = 0.5 ,
85- )
86-
8775# %%
8876# Computing feature arrays
8977# ------------------------
@@ -358,50 +346,28 @@ def plot_surface_map(
358346# derive a scalable way to derive which vertices are selected.
359347
360348
361- def vertices_in_sparcity_mask (embeddings , sample , radius ):
362- n_vertices = embeddings .shape [0 ]
363- elligible_vertices = [
364- torch .tensor (
365- np .argwhere (
366- np .linalg .norm (
367- embeddings - embeddings [i ],
368- axis = 1 ,
369- )
370- <= radius
371- )
372- )
373- for i in sample
374- ]
375-
376- rows = torch .concat (
377- [
378- i * torch .ones (len (elligible_vertices [i ]))
379- for i in range (len (elligible_vertices ))
380- ]
381- ).type (torch .int )
382- cols = torch .concat (elligible_vertices ).flatten ().type (torch .int )
383- values = torch .ones_like (rows )
384-
385- vertex_within_radius = torch .sparse_coo_tensor (
386- torch .stack ([rows , cols ]),
387- values ,
388- size = (n_vertices , n_vertices ),
389- )
390- selected_vertices = torch .sparse .sum (
391- vertex_within_radius , dim = 0
392- ).to_dense ()
393-
394- return selected_vertices
395-
396-
397349source_selection_radius = 7
398- selected_source_vertices = vertices_in_sparcity_mask (
399- source_geometry_embeddings , source_sample , source_selection_radius
350+ n_neighbourhoods_per_vertex_source = (
351+ torch .sparse .sum (
352+ coarse_to_fine .get_neighbourhood_matrix (
353+ source_geometry_embeddings , source_sample , source_selection_radius
354+ ),
355+ dim = 1 ,
356+ )
357+ .to_dense ()
358+ .numpy ()
400359)
401360
402361target_selection_radius = 7
403- selected_target_vertices = vertices_in_sparcity_mask (
404- target_geometry_embeddings , target_sample , target_selection_radius
362+ n_neighbourhoods_per_vertex_target = (
363+ torch .sparse .sum (
364+ coarse_to_fine .get_neighbourhood_matrix (
365+ target_geometry_embeddings , target_sample , target_selection_radius
366+ ),
367+ dim = 1 ,
368+ )
369+ .to_dense ()
370+ .numpy ()
405371)
406372
407373# %%
@@ -410,12 +376,16 @@ def vertices_in_sparcity_mask(embeddings, sample, radius):
410376# sampled, light blue for vertices which are within radius-distance
411377# of a sampled vertex. Vertices which won't be selected appear in white.
412378# The following figure is interactive.
379+ # **Note that, because embeddings are not very precise for short distances,
380+ # vertices that are very close to sampled vertices can actually
381+ # be absent from the mask**. In order to limit this effect, the radius
382+ # should generally be set to a high enough value.
413383
414384source_vertices_in_mask = np .zeros (source_features .shape [1 ])
415- source_vertices_in_mask [selected_source_vertices > 0 ] = 1
385+ source_vertices_in_mask [n_neighbourhoods_per_vertex_source > 0 ] = 1
416386source_vertices_in_mask [source_sample ] = 2
417387target_vertices_in_mask = np .zeros (target_features .shape [1 ])
418- target_vertices_in_mask [selected_target_vertices > 0 ] = 1
388+ target_vertices_in_mask [n_neighbourhoods_per_vertex_target > 0 ] = 1
419389target_vertices_in_mask [target_sample ] = 2
420390
421391# Generate figure with 2 subplots
0 commit comments