Skip to content

Commit 3111413

Browse files
authored
Merge pull request #13 from alexisthual/feat/efficient_sparsity_mask_computatin
Compute coarse_to_fine sparsity mask from matrix product
2 parents c605800 + 8902523 commit 3111413

File tree

5 files changed

+173
-136
lines changed

5 files changed

+173
-136
lines changed

doc/themes/custom.css

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ div[class^="highlight-"].sphx-glr-script-out {
5151

5252
.sphx-glr-script-out .highlight {
5353
border-radius: 0px;
54-
border-left: 4px solid var(--sd-color-primary);
54+
border-left: 4px solid var(--color-background-border);
5555
}
5656

5757
.sphx-glr-script-out .highlight pre {
@@ -74,11 +74,11 @@ button.copybtn {
7474
}
7575

7676
.highlight-default .highlight {
77-
border-left: 4px solid var(--color-background-border);
77+
border-left: 4px solid var(--sd-color-primary);
7878
}
7979

8080
.highlight-primary .highlight {
81-
border-left: 4px solid var(--sd-color-primary);
81+
border-left: 4px solid var(--color-background-border);
8282
}
8383

8484
table.plotting-table {

examples/00_basics/plot_2_2_coarse_to_fine.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,8 @@
111111
# Parametrize step 2 (selection of pairs of indices present in
112112
# fine-grained's sparsity mask)
113113
coarse_pairs_selection_method="topk",
114-
source_selection_radius=1 / source_d_max,
115-
target_selection_radius=1 / target_d_max,
114+
source_selection_radius=0.5 / source_d_max,
115+
target_selection_radius=0.5 / target_d_max,
116116
# Parametrize step 3 (fine-grained alignment)
117117
fine_mapping=fine_mapping,
118118
fine_mapping_solver=fine_mapping_solver,
@@ -216,7 +216,11 @@
216216
).permute(2, 0, 1)
217217
pi_normalized = pi / pi.sum(dim=1).reshape(-1, 1)
218218
line_segments = LineCollection(
219-
segments, alpha=pi_normalized.flatten(), colors="black", lw=1, zorder=1
219+
segments,
220+
alpha=pi_normalized.flatten().nan_to_num(),
221+
colors="black",
222+
lw=1,
223+
zorder=1,
220224
)
221225
ax.add_collection(line_segments)
222226

examples/01_brain_alignment/plot_1_aligning_brain_dense.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# %%
22
"""
3-
====================================================
4-
Align brain surfaces of 2 individuals with fMRI data
5-
====================================================
3+
===================================================================
4+
Align low-resolution brain surfaces of 2 individuals with fMRI data
5+
===================================================================
66
77
In this example, we align 2 low-resolution left hemispheres
88
using 4 fMRI feature maps (z-score contrast maps).

examples/01_brain_alignment/plot_2_aligning_brain_sparse.py

Lines changed: 24 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -72,18 +72,6 @@
7272
source_imgs_paths = brain_data["cmaps"][0 : len(contrasts)]
7373
target_imgs_paths = brain_data["cmaps"][len(contrasts) : 2 * len(contrasts)]
7474

75-
# %%
76-
# Here is what the first contrast map of the source subject looks like
77-
# (the following figure is interactive):
78-
79-
contrast_index = 0
80-
plotting.view_img(
81-
source_imgs_paths[contrast_index],
82-
brain_data["anats"][0],
83-
title=f"Contrast {contrast_index} (source subject)",
84-
opacity=0.5,
85-
)
86-
8775
# %%
8876
# Computing feature arrays
8977
# ------------------------
@@ -358,50 +346,28 @@ def plot_surface_map(
358346
# derive a scalable way to derive which vertices are selected.
359347

360348

361-
def vertices_in_sparcity_mask(embeddings, sample, radius):
362-
n_vertices = embeddings.shape[0]
363-
elligible_vertices = [
364-
torch.tensor(
365-
np.argwhere(
366-
np.linalg.norm(
367-
embeddings - embeddings[i],
368-
axis=1,
369-
)
370-
<= radius
371-
)
372-
)
373-
for i in sample
374-
]
375-
376-
rows = torch.concat(
377-
[
378-
i * torch.ones(len(elligible_vertices[i]))
379-
for i in range(len(elligible_vertices))
380-
]
381-
).type(torch.int)
382-
cols = torch.concat(elligible_vertices).flatten().type(torch.int)
383-
values = torch.ones_like(rows)
384-
385-
vertex_within_radius = torch.sparse_coo_tensor(
386-
torch.stack([rows, cols]),
387-
values,
388-
size=(n_vertices, n_vertices),
389-
)
390-
selected_vertices = torch.sparse.sum(
391-
vertex_within_radius, dim=0
392-
).to_dense()
393-
394-
return selected_vertices
395-
396-
397349
source_selection_radius = 7
398-
selected_source_vertices = vertices_in_sparcity_mask(
399-
source_geometry_embeddings, source_sample, source_selection_radius
350+
n_neighbourhoods_per_vertex_source = (
351+
torch.sparse.sum(
352+
coarse_to_fine.get_neighbourhood_matrix(
353+
source_geometry_embeddings, source_sample, source_selection_radius
354+
),
355+
dim=1,
356+
)
357+
.to_dense()
358+
.numpy()
400359
)
401360

402361
target_selection_radius = 7
403-
selected_target_vertices = vertices_in_sparcity_mask(
404-
target_geometry_embeddings, target_sample, target_selection_radius
362+
n_neighbourhoods_per_vertex_target = (
363+
torch.sparse.sum(
364+
coarse_to_fine.get_neighbourhood_matrix(
365+
target_geometry_embeddings, target_sample, target_selection_radius
366+
),
367+
dim=1,
368+
)
369+
.to_dense()
370+
.numpy()
405371
)
406372

407373
# %%
@@ -410,12 +376,16 @@ def vertices_in_sparcity_mask(embeddings, sample, radius):
410376
# sampled, light blue for vertices which are within radius-distance
411377
# of a sampled vertex. Vertices which won't be selected appear in white.
412378
# The following figure is interactive.
379+
# **Note that, because embeddings are not very precise for short distances,
380+
# vertices that are very close to sampled vertices can actually
381+
# be absent from the mask**. In order to limit this effect, the radius
382+
# should generally be set to a high enough value.
413383

414384
source_vertices_in_mask = np.zeros(source_features.shape[1])
415-
source_vertices_in_mask[selected_source_vertices > 0] = 1
385+
source_vertices_in_mask[n_neighbourhoods_per_vertex_source > 0] = 1
416386
source_vertices_in_mask[source_sample] = 2
417387
target_vertices_in_mask = np.zeros(target_features.shape[1])
418-
target_vertices_in_mask[selected_target_vertices > 0] = 1
388+
target_vertices_in_mask[n_neighbourhoods_per_vertex_target > 0] = 1
419389
target_vertices_in_mask[target_sample] = 2
420390

421391
# Generate figure with 2 subplots

0 commit comments

Comments
 (0)