Skip to content

Commit cbea5b3

Browse files
committed
fix bugs in GP interpolation
1 parent 7ab5bee commit cbea5b3

File tree

3 files changed

+66
-34
lines changed

3 files changed

+66
-34
lines changed

spateo/alignment/methods/morpho_class.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1508,8 +1508,8 @@ def _wrap_output(
15081508
"t": self.nx.to_numpy(self.t),
15091509
"optimal_R": self.nx.to_numpy(self.optimal_R),
15101510
"optimal_t": self.nx.to_numpy(self.optimal_t),
1511-
"init_R": self.nx.to_numpy(self.init_R) if self.nn_init else np.eye(self.Dim),
1512-
"init_t": self.nx.to_numpy(self.init_t) if self.nn_init else np.zeros(self.Dim),
1511+
"init_R": self.nx.to_numpy(self.init_R) if self.nn_init else np.eye(self.XAHat.shape[1]),
1512+
"init_t": self.nx.to_numpy(self.init_t) if self.nn_init else np.zeros(self.XAHat.shape[1]),
15131513
"beta": self.beta,
15141514
"Coff": self.nx.to_numpy(self.Coff),
15151515
"inducing_variables": self.nx.to_numpy(self.inducing_variables),

spateo/tdr/interpolations/interpolation_gaussianprocess/gp_train.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
import torch
33
from tqdm import tqdm
44

5+
from spateo.alignment.utils import _iteration
56

6-
def gp_train(model, likelihood, train_loader, train_epochs, method, N, device):
7+
8+
def gp_train(model, likelihood, train_loader, train_epochs, method, N, device, keys, verbose=True):
79
if torch.cuda.is_available() and device != "cpu":
810
model = model.cuda()
911
likelihood = likelihood.cuda()
@@ -24,16 +26,15 @@ def gp_train(model, likelihood, train_loader, train_epochs, method, N, device):
2426
lr=0.01,
2527
)
2628

27-
epochs_iter = tqdm(range(train_epochs), desc="Epoch")
29+
progress_name = f"Interpolation based on Gaussian Process Regression for {keys[0]}"
30+
epochs_iter = _iteration(n=train_epochs, progress_name=progress_name, verbose=verbose)
2831
for i in epochs_iter:
2932
if method == "SVGP":
3033
# Within each iteration, we will go over each minibatch of data
31-
minibatch_iter = tqdm(train_loader, desc="Minibatch", leave=True)
32-
for x_batch, y_batch in minibatch_iter:
34+
for x_batch, y_batch in train_loader:
3335
optimizer.zero_grad()
3436
output = model(x_batch)
3537
loss = -mll(output, y_batch)
36-
minibatch_iter.set_postfix(loss=loss.item())
3738
loss.backward()
3839
optimizer.step()
3940
else:

spateo/tdr/interpolations/interpolation_gp.py

Lines changed: 58 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ def __init__(
3535
inducing_num: int = 512,
3636
normalize_spatial: bool = True,
3737
):
38+
39+
self.keys = keys
3840
# Source data
3941
source_adata = source_adata.copy()
4042
source_adata.X = source_adata.X if layer == "X" else source_adata.layers[layer]
@@ -70,7 +72,6 @@ def __init__(
7072
self.train_y = self.train_y.squeeze()
7173

7274
self.nx = ot.backend.get_backend(self.train_x, self.train_y)
73-
7475
self.normalize_spatial = normalize_spatial
7576
if self.normalize_spatial:
7677
self.train_x = self.normalize_coords(self.train_x)
@@ -96,15 +97,14 @@ def __init__(
9697

9798
self.PCA_reduction = False
9899
self.info_keys = {"obs_keys": obs_keys, "var_keys": var_keys}
99-
print(self.info_keys)
100100

101101
# Target data
102102
self.target_points = torch.from_numpy(target_points).float()
103103
self.target_points = self.target_points.cpu() if self.device == "cpu" else self.target_points.cuda()
104104

105105
def normalize_coords(self, data: Union[np.ndarray, torch.Tensor], given_normalize: bool = False):
106106
if not given_normalize:
107-
self.mean_data = _unsqueeze(self.nx)(self.nx.mean(data, axis=0), 0)
107+
self.mean_data = self.nx.mean(data, axis=0)[None, :]
108108
data = data - self.mean_data
109109
if not given_normalize:
110110
self.variance = self.nx.sqrt(self.nx.sum(data**2) / data.shape[0])
@@ -114,12 +114,16 @@ def normalize_coords(self, data: Union[np.ndarray, torch.Tensor], given_normaliz
114114
def inference(
115115
self,
116116
training_iter: int = 50,
117+
verbose: bool = True,
117118
):
118119
self.likelihood = GaussianLikelihood()
119120
if self.method == "SVGP":
120121
self.GPR_model = Approx_GPModel(inducing_points=self.inducing_points)
121122
elif self.method == "ExactGP":
122-
self.GPR_model = Exact_GPModel(self.train_x, self.train_y, self.likelihood)
123+
self.GPR_models = [
124+
Exact_GPModel(self.train_x, self.train_y[:, i], self.likelihoods[i])
125+
for i in range(self.train_y.shape[1])
126+
]
123127
# if to convert to GPU
124128
if self.device != "cpu":
125129
self.GPR_model = self.GPR_model.cuda()
@@ -134,6 +138,8 @@ def inference(
134138
method=self.method,
135139
N=self.N,
136140
device=self.device,
141+
verbose=verbose,
142+
keys=self.keys,
137143
)
138144

139145
self.GPR_model.eval()
@@ -181,6 +187,7 @@ def gp_interpolation(
181187
batch_size: int = 1024,
182188
shuffle: bool = True,
183189
inducing_num: int = 512,
190+
verbose: bool = True,
184191
) -> AnnData:
185192
"""
186193
Learn a continuous mapping from space to gene expression pattern with the Gaussian Process method.
@@ -197,36 +204,60 @@ def gp_interpolation(
197204
Returns:
198205
interp_adata: an anndata object that has interpolated expression.
199206
"""
207+
assert keys != None, "`keys` cannot be None."
208+
keys = [keys] if isinstance(keys, str) else keys
209+
obs_keys = [key for key in keys if key in source_adata.obs.keys()]
210+
var_keys = [key for key in keys if key in source_adata.var_names.tolist()]
211+
info_keys = {"obs_keys": obs_keys, "var_keys": var_keys}
212+
print(info_keys)
213+
obs_data = []
214+
var_data = []
215+
if len(obs_keys) != 0:
216+
for key in obs_keys:
217+
GPR = Imputation_GPR(
218+
source_adata=source_adata,
219+
target_points=target_points,
220+
keys=[key],
221+
spatial_key=spatial_key,
222+
layer=layer,
223+
device=device,
224+
method=method,
225+
batch_size=batch_size,
226+
shuffle=shuffle,
227+
inducing_num=inducing_num,
228+
)
229+
GPR.inference(training_iter=training_iter, verbose=verbose)
200230

201-
# Inference
202-
GPR = Imputation_GPR(
203-
source_adata=source_adata,
204-
target_points=target_points,
205-
keys=keys,
206-
spatial_key=spatial_key,
207-
layer=layer,
208-
device=device,
209-
method=method,
210-
batch_size=batch_size,
211-
shuffle=shuffle,
212-
inducing_num=inducing_num,
213-
)
214-
GPR.inference(training_iter=training_iter)
231+
# Interpolation
232+
target_info_data = GPR.interpolate(use_chunk=True)
233+
obs_data.append(target_info_data[:, None])
234+
if len(var_keys) != 0:
235+
for key in var_keys:
236+
GPR = Imputation_GPR(
237+
source_adata=source_adata,
238+
target_points=target_points,
239+
keys=[key],
240+
spatial_key=spatial_key,
241+
layer=layer,
242+
device=device,
243+
method=method,
244+
batch_size=batch_size,
245+
shuffle=shuffle,
246+
inducing_num=inducing_num,
247+
)
248+
GPR.inference(training_iter=training_iter, verbose=verbose)
249+
250+
# Interpolation
251+
target_info_data = GPR.interpolate(use_chunk=True)
252+
var_data.append(target_info_data[:, None])
215253

216-
# Interpolation
217-
target_info_data = GPR.interpolate(use_chunk=True)
218-
target_info_data = target_info_data[:, None]
219254
# Output interpolated anndata
220255
lm.main_info("Creating an adata object with the interpolated expression...")
221-
222-
obs_keys = GPR.info_keys["obs_keys"]
223256
if len(obs_keys) != 0:
224-
obs_data = target_info_data[:, : len(obs_keys)]
257+
obs_data = np.concatenate(obs_data, axis=1)
225258
obs_data = pd.DataFrame(obs_data, columns=obs_keys)
226-
227-
var_keys = GPR.info_keys["var_keys"]
228259
if len(var_keys) != 0:
229-
X = target_info_data[:, len(obs_keys) :]
260+
X = np.concatenate(var_data, axis=1)
230261
var_data = pd.DataFrame(index=var_keys)
231262

232263
interp_adata = AnnData(

0 commit comments

Comments
 (0)