Skip to content

Commit b77086e

Browse files
authored
Merge pull request #82 from alexisthual/feat/parcellations
Parcellations as inputs to FUGWSparse
2 parents bede443 + d4a9910 commit b77086e

File tree

2 files changed

+238
-0
lines changed

2 files changed

+238
-0
lines changed

src/fugw/scripts/piecewise.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import torch
2+
from sklearn.preprocessing import OneHotEncoder
3+
4+
5+
def check_labels(labels: torch.Tensor) -> None:
6+
"""
7+
Check that labels are a 1D tensor of integers.
8+
9+
Parameters
10+
----------
11+
labels: torch.Tensor
12+
Labels to check.
13+
14+
Raises
15+
------
16+
ValueError
17+
If labels are not a 1D tensor of integers.
18+
"""
19+
if not torch.is_tensor(labels):
20+
raise ValueError(f"labels must be a tensor, got {type(labels)}.")
21+
if labels.dim() != 1:
22+
raise ValueError(f"labels must be a 1D tensor, got {labels.dim()}D.")
23+
if labels.dtype not in {
24+
torch.uint8,
25+
torch.int8,
26+
torch.int16,
27+
torch.int32,
28+
torch.int64,
29+
}:
30+
raise TypeError(
31+
f"labels must be an integer tensor, got {labels.dtype}."
32+
)
33+
34+
35+
def one_hot_encoding(labels: torch.Tensor) -> torch.Tensor:
36+
"""
37+
Compute one-hot encoding of the labels.
38+
39+
Parameters
40+
----------
41+
labels: torch.Tensor of size (n,)
42+
Cluster labels for each voxel.
43+
Must be a 1D tensor of c integers values.
44+
45+
Returns
46+
-------
47+
one_hot: torch.sparse_coo_tensor of size (n, c).
48+
One-hot encoding of the labels.
49+
"""
50+
# Convert labels to string
51+
labels_categorical = labels.cpu().numpy().astype(str).reshape(-1, 1)
52+
# Use sklearn to compute the one-hot encoding
53+
encoder = OneHotEncoder(sparse_output=False)
54+
one_hot = encoder.fit_transform(labels_categorical)
55+
one_hot_tensor = torch.from_numpy(one_hot)
56+
return one_hot_tensor.to_sparse_coo().to(labels.device)
57+
58+
59+
def compute_sparsity_mask(
60+
labels: torch.Tensor,
61+
device: str = "auto",
62+
) -> torch.Tensor:
63+
"""
64+
Compute sparsity mask from coarse mapping.
65+
66+
Parameters
67+
----------
68+
labels: torch.Tensor of size (n,)
69+
Cluster labels for each voxel.
70+
device: "auto" or torch.device
71+
if "auto": use first available gpu if it's available,
72+
cpu otherwise.
73+
74+
Returns
75+
-------
76+
sparsity_mask: torch.sparse_coo_tensor of size (n, m)
77+
Sparsity mask used to initialize the fine mapping.
78+
"""
79+
check_labels(labels)
80+
81+
if device == "auto":
82+
if torch.cuda.is_available():
83+
device = torch.device("cuda", 0)
84+
else:
85+
device = torch.device("cpu")
86+
labels = labels.to(device)
87+
88+
# Create a one-hot encoding of the voxels
89+
one_hot = one_hot_encoding(labels)
90+
return (one_hot @ one_hot.T).coalesce().to(torch.float32)

tests/scripts/test_piecewise.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
from itertools import product
2+
3+
import numpy as np
4+
import pytest
5+
import torch
6+
7+
from fugw.scripts import piecewise
8+
from fugw.mappings import FUGWSparse
9+
from fugw.utils import _init_mock_distribution
10+
11+
np.random.seed(0)
12+
torch.manual_seed(0)
13+
14+
n_voxels = 100
15+
n_samples_source = 50
16+
n_samples_target = 45
17+
n_features_train = 10
18+
n_features_test = 5
19+
n_pieces = 10
20+
21+
devices = [torch.device("cpu")]
22+
if torch.cuda.is_available():
23+
devices.append(torch.device("cuda:0"))
24+
25+
return_numpys = [False, True]
26+
27+
28+
@pytest.mark.skip_if_no_mkl
29+
def test_one_hot_encoding():
30+
labels = torch.randint(0, n_pieces, (n_voxels,))
31+
one_hot = piecewise.one_hot_encoding(labels)
32+
assert one_hot.shape == (n_voxels, n_pieces)
33+
34+
35+
@pytest.mark.skip_if_no_mkl
36+
@pytest.mark.parametrize(
37+
"device",
38+
devices,
39+
)
40+
def test_compute_sparsity_mask(device):
41+
labels = torch.tensor([0, 1, 1], device=device)
42+
mask = piecewise.compute_sparsity_mask(labels, device=device)
43+
assert mask.shape == (3, 3)
44+
assert mask.is_sparse
45+
assert torch.allclose(
46+
mask.to_dense(),
47+
torch.tensor(
48+
[[1.0, 0, 0], [0, 1.0, 1.0], [0, 1.0, 1.0]], device=device
49+
),
50+
)
51+
52+
labels = torch.randint(0, n_pieces, (n_voxels,))
53+
sparsity_mask = piecewise.compute_sparsity_mask(labels)
54+
assert sparsity_mask.shape == (n_voxels, n_voxels)
55+
56+
57+
@pytest.mark.skip_if_no_mkl
58+
@pytest.mark.parametrize(
59+
"device,return_numpy",
60+
product(devices, return_numpys),
61+
)
62+
def test_piecewise(device, return_numpy):
63+
source_weights, source_features, source_geometry, source_embeddings = (
64+
_init_mock_distribution(
65+
n_features_train, n_voxels, return_numpy=return_numpy
66+
)
67+
)
68+
target_weights, target_features, target_geometry, target_embeddings = (
69+
_init_mock_distribution(
70+
n_features_train, n_voxels, return_numpy=return_numpy
71+
)
72+
)
73+
74+
labels = torch.randint(0, n_pieces, (n_voxels,))
75+
init_plan = piecewise.compute_sparsity_mask(
76+
labels=labels,
77+
device=device,
78+
)
79+
80+
piecewise_mapping = FUGWSparse()
81+
piecewise_mapping.fit(
82+
source_features,
83+
target_features,
84+
source_geometry_embedding=source_embeddings,
85+
target_geometry_embedding=target_embeddings,
86+
source_weights=source_weights,
87+
target_weights=target_weights,
88+
init_plan=init_plan,
89+
device=device,
90+
verbose=True,
91+
)
92+
93+
assert piecewise_mapping.pi.shape == (n_voxels, n_voxels)
94+
95+
# Use trained model to transport new features
96+
# 1. with numpy arrays
97+
source_features_test = np.random.rand(n_features_test, n_voxels)
98+
target_features_test = np.random.rand(n_features_test, n_voxels)
99+
source_features_on_target = piecewise_mapping.transform(
100+
source_features_test
101+
)
102+
assert source_features_on_target.shape == target_features_test.shape
103+
assert isinstance(source_features_on_target, np.ndarray)
104+
target_features_on_source = piecewise_mapping.inverse_transform(
105+
target_features_test
106+
)
107+
assert target_features_on_source.shape == source_features_test.shape
108+
assert isinstance(target_features_on_source, np.ndarray)
109+
110+
source_features_test = np.random.rand(n_voxels)
111+
target_features_test = np.random.rand(n_voxels)
112+
source_features_on_target = piecewise_mapping.transform(
113+
source_features_test
114+
)
115+
assert source_features_on_target.shape == target_features_test.shape
116+
assert isinstance(source_features_on_target, np.ndarray)
117+
target_features_on_source = piecewise_mapping.inverse_transform(
118+
target_features_test
119+
)
120+
assert target_features_on_source.shape == source_features_test.shape
121+
assert isinstance(target_features_on_source, np.ndarray)
122+
123+
# 2. with torch tensors
124+
source_features_test = torch.rand(n_features_test, n_voxels)
125+
target_features_test = torch.rand(n_features_test, n_voxels)
126+
source_features_on_target = piecewise_mapping.transform(
127+
source_features_test
128+
)
129+
assert source_features_on_target.shape == target_features_test.shape
130+
assert isinstance(source_features_on_target, torch.Tensor)
131+
target_features_on_source = piecewise_mapping.inverse_transform(
132+
target_features_test
133+
)
134+
assert target_features_on_source.shape == source_features_test.shape
135+
assert isinstance(target_features_on_source, torch.Tensor)
136+
137+
source_features_test = torch.rand(n_voxels)
138+
target_features_test = torch.rand(n_voxels)
139+
source_features_on_target = piecewise_mapping.transform(
140+
source_features_test
141+
)
142+
assert source_features_on_target.shape == target_features_test.shape
143+
assert isinstance(source_features_on_target, torch.Tensor)
144+
target_features_on_source = piecewise_mapping.inverse_transform(
145+
target_features_test
146+
)
147+
assert target_features_on_source.shape == source_features_test.shape
148+
assert isinstance(target_features_on_source, torch.Tensor)

0 commit comments

Comments
 (0)