Skip to content

Commit 20ed3a7

Browse files
committed
update the comments in external/MERFISHVI
1 parent 6e5610e commit 20ed3a7

File tree

1 file changed

+106
-103
lines changed

1 file changed

+106
-103
lines changed

spateo/external/MERFISHVI/_module.py

Lines changed: 106 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -285,16 +285,16 @@ def _get_inference_input(
285285
tensors: dict[str, torch.Tensor | None],
286286
full_forward_pass: bool = False,
287287
) -> dict[str, torch.Tensor | None]:
288-
"""获取推断过程所需的输入张量。
288+
"""Get input tensors required for the inference process.
289289
290290
Args:
291-
tensors: 输入数据张量
292-
full_forward_pass: 是否执行完整的前向传播
291+
tensors: Input data tensors
292+
full_forward_pass: Whether to execute a full forward pass
293293
294294
Returns:
295-
用于推断过程的输入字典
295+
Input dictionary for the inference process
296296
"""
297-
# 根据数据类型选择加载方式
297+
# Choose loading method based on data type
298298
if full_forward_pass or self.minified_data_type is None:
299299
loader = "full_data"
300300
elif self.minified_data_type in [
@@ -303,83 +303,85 @@ def _get_inference_input(
303303
]:
304304
loader = "minified_data"
305305
else:
306-
raise NotImplementedError(f"未知的简化数据类型: {self.minified_data_type}")
306+
raise NotImplementedError(f"Unknown minified data type: {self.minified_data_type}")
307307

308-
# 完整数据情况:提供表达数据和批次信息
308+
# Full data case: provide expression data and batch information
309309
if loader == "full_data":
310310
return {
311-
"x": tensors["X"], # 基因表达数据
312-
"batch_index": tensors["batch"], # 批次索引
313-
"cont_covariates": tensors.get("continuous_covariates", None), # 连续型协变量
314-
"cat_covariates": tensors.get("categorical_covariates", None), # 分类型协变量
311+
"x": tensors["X"], # Gene expression data
312+
"batch_index": tensors["batch"], # Batch indices
313+
"cont_covariates": tensors.get("continuous_covariates", None), # Continuous covariates
314+
"cat_covariates": tensors.get("categorical_covariates", None), # Categorical covariates
315315
}
316-
# 简化数据情况:提供已计算的潜在变量分布参数
316+
# Simplified data case: provide pre-computed latent variable distribution parameters
317317
else:
318318
return {
319-
"qzm": tensors["scvi_latent_qzm"], # 潜变量均值
320-
"qzv": tensors["scvi_latent_qzv"], # 潜变量方差
321-
"observed_lib_size": tensors["observed_lib_size"], # 观测到的文库大小
319+
"qzm": tensors["scvi_latent_qzm"], # Latent variable means
320+
"qzv": tensors["scvi_latent_qzv"], # Latent variable variances
321+
"observed_lib_size": tensors["observed_lib_size"], # Observed library size
322322
}
323323

324324
def _get_generative_input(
325325
self,
326326
tensors: dict[str, torch.Tensor],
327327
inference_outputs: dict[str, torch.Tensor | Distribution | None],
328328
) -> dict[str, torch.Tensor | None]:
329-
"""获取生成过程的输入张量。
329+
"""Get input tensors for the generative process.
330330
331-
将推断步骤的输出与原始数据结合,准备生成网络的输入。
331+
Combine outputs from the inference step with original data to prepare
332+
inputs for the generative network.
332333
333334
Args:
334-
tensors: 原始数据张量
335-
inference_outputs: 推断过程的输出
335+
tensors: Original data tensors
336+
inference_outputs: Outputs from the inference process
336337
337338
Returns:
338-
生成过程所需的输入字典
339+
Input dictionary required for the generative process
339340
"""
340-
# 获取size_factor(如果提供)
341+
# Get size_factor (if provided)
341342
size_factor = tensors.get("size_factor", None)
342343
if size_factor is not None:
343344
size_factor = torch.log(size_factor)
344345

345346
return {
346-
"z": inference_outputs["z"], # 潜在空间表示
347-
"library": inference_outputs["library"], # 文库大小
348-
"batch_index": tensors["batch"], # 批次索引
349-
"y": tensors["labels"], # 细胞类型标签
350-
"cont_covariates": tensors.get("continuous_covariates", None), # 连续型协变量
351-
"cat_covariates": tensors.get("categorical_covariates", None), # 分类型协变量
352-
"size_factor": size_factor, # 大小因子
347+
"z": inference_outputs["z"], # Latent space representation
348+
"library": inference_outputs["library"], # Library size
349+
"batch_index": tensors["batch"], # Batch indices
350+
"y": tensors["labels"], # Cell type labels
351+
"cont_covariates": tensors.get("continuous_covariates", None), # Continuous covariates
352+
"cat_covariates": tensors.get("categorical_covariates", None), # Categorical covariates
353+
"size_factor": size_factor, # Size factor
353354
}
354355

355356
def _compute_local_library_params(
356357
self,
357358
batch_index: torch.Tensor,
358359
) -> tuple[torch.Tensor, torch.Tensor]:
359-
"""计算局部文库参数。
360+
"""Compute local library parameters.
360361
361-
为每个细胞计算文库大小的均值和方差参数,这些参数取决于细胞所属的批次。
362+
Calculate mean and variance parameters for library size for each cell,
363+
which depend on the batch the cell belongs to.
362364
363365
Args:
364-
batch_index: 形状为 (batch_size, 1) 的批次索引张量
366+
batch_index: Batch index tensor of shape (batch_size, 1)
365367
366368
Returns:
367-
tuple: 包含两个张量,分别是文库大小对数的均值和方差
369+
tuple: Contains two tensors, log library size means and variances respectively
368370
"""
369371
from torch.nn.functional import linear
370372

371-
# 批次数量
373+
# Number of batches
372374
num_batches = self.library_log_means.shape[1]
373375

374-
# 将批次索引转换为独热编码
376+
# Convert batch indices to one-hot encoding
375377
batch_one_hot = one_hot(batch_index.squeeze(-1), num_batches).float()
376378

377-
# 计算每个细胞的文库大小对数均值
378-
# 相当于从全局文库均值表中查找对应批次的值
379+
# Compute log library size means for each cell
380+
# Equivalent to looking up the corresponding batch value from the global library means table
379381
library_log_means = linear(batch_one_hot, self.library_log_means)
380382

381-
# 计算每个细胞的文库大小对数方差
382-
# 相当于从全局文库方差表中查找对应批次的值
383+
# Compute log library size variances for each cell
384+
# Equivalent to looking up the corresponding batch value from the global library variances table
383385
library_log_vars = linear(batch_one_hot, self.library_log_vars)
384386

385387
return library_log_means, library_log_vars
@@ -393,59 +395,59 @@ def _regular_inference(
393395
cat_covs: torch.Tensor | None = None,
394396
n_samples: int = 1,
395397
) -> dict[str, torch.Tensor | Distribution | None]:
396-
"""运行常规推断过程,获取数据的潜在表示。
398+
"""Run regular inference process to obtain latent representations of data.
397399
398400
Args:
399-
x: 基因表达数据
400-
batch_index: 批次索引
401-
cont_covs: 连续型协变量
402-
cat_covs: 分类型协变量
403-
n_samples: 采样数量
401+
x: Gene expression data
402+
batch_index: Batch indices
403+
cont_covs: Continuous covariates
404+
cat_covs: Categorical covariates
405+
n_samples: Number of samples
404406
405407
Returns:
406-
包含潜在变量和分布的字典
408+
Dictionary containing latent variables and distributions
407409
"""
408410
x_ = x
409411
if self.use_observed_lib_size:
410-
# 计算观测到的文库大小(基因表达总和)
412+
# Compute observed library size (sum of gene expression)
411413
library = torch.log(x.sum(1)).unsqueeze(1)
412414
if self.log_variational:
413-
# 对数据取对数,增加数值稳定性
415+
# Take log of data for numerical stability
414416
x_ = torch.log1p(x_)
415417

416418
if cont_covs is not None and self.encode_covariates:
417-
# 将基因表达和连续协变量拼接起来
419+
# Concatenate gene expression and continuous covariates
418420
encoder_input = torch.cat((x_, cont_covs), dim=-1)
419421
else:
420422
encoder_input = x_
421423
if cat_covs is not None and self.encode_covariates:
422-
# 处理分类协变量
424+
# Process categorical covariates
423425
categorical_input = torch.split(cat_covs, 1, dim=1)
424426
else:
425427
categorical_input = ()
426428

427429
if self.batch_representation == "embedding" and self.encode_covariates:
428-
# 使用嵌入表示批次
430+
# Use embedding representation for batch
429431
batch_embedding = self.compute_embedding("batch", batch_index)
430432
encoder_input = torch.cat([encoder_input, batch_embedding], dim=-1)
431-
# 获取潜在变量分布和样本
433+
# Get latent variable distribution and samples
432434
qz, z = self.z_encoder(encoder_input, *categorical_input)
433435
else:
434-
# 使用独热编码表示批次
436+
# Use one-hot encoding representation for batch
435437
qz, z = self.z_encoder(encoder_input, batch_index, *categorical_input)
436438

437439
ql = None
438440
if not self.use_observed_lib_size:
439441
if self.batch_representation == "embedding":
440-
# 使用嵌入表示批次来编码文库大小
442+
# Use embedding representation for batch to encode library size
441443
ql, library_encoded = self.l_encoder(encoder_input, *categorical_input)
442444
else:
443-
# 使用独热编码表示批次来编码文库大小
445+
# Use one-hot encoding representation for batch to encode library size
444446
ql, library_encoded = self.l_encoder(encoder_input, batch_index, *categorical_input)
445447
library = library_encoded
446448

447449
if n_samples > 1:
448-
# 多个样本情况下的处理
450+
# Handle multiple samples case
449451
untran_z = qz.sample((n_samples,))
450452
z = self.z_encoder.z_transformation(untran_z)
451453
if self.use_observed_lib_size:
@@ -454,10 +456,10 @@ def _regular_inference(
454456
library = ql.sample((n_samples,))
455457

456458
return {
457-
"z": z, # 潜在空间表示
458-
"qz": qz, # 潜在空间分布
459-
"ql": ql, # 文库大小分布
460-
"library": library, # 文库大小
459+
"z": z, # Latent space representation
460+
"qz": qz, # Latent space distribution
461+
"ql": ql, # Library size distribution
462+
"library": library, # Library size
461463
}
462464

463465
@auto_move_data
@@ -468,47 +470,48 @@ def _cached_inference(
468470
observed_lib_size: torch.Tensor,
469471
n_samples: int = 1,
470472
) -> dict[str, torch.Tensor | None]:
471-
"""使用缓存的潜在变量分布参数进行推断。
473+
"""Perform inference using cached latent variable distribution parameters.
472474
473-
这种方法主要用于已经计算并存储了潜在变量分布的情况,
474-
可以加快推断速度,无需重新运行编码器网络。
475+
This method is mainly used when latent variable distributions have already
476+
been computed and stored, which can speed up inference without re-running
477+
the encoder network.
475478
476479
Args:
477-
qzm: 潜在变量均值
478-
qzv: 潜在变量方差
479-
observed_lib_size: 观测到的文库大小
480-
n_samples: 采样数量
480+
qzm: Latent variable means
481+
qzv: Latent variable variances
482+
observed_lib_size: Observed library size
483+
n_samples: Number of samples
481484
482485
Returns:
483-
包含潜在变量和分布的字典
486+
Dictionary containing latent variables and distributions
484487
"""
485488
from torch.distributions import Normal
486489

487-
# 创建潜在变量的正态分布
490+
# Create normal distribution for latent variables
488491
latent_dist = Normal(qzm, qzv.sqrt())
489492

490-
# 从分布中采样
491-
# 使用sample()而不是rsample(),因为我们不需要对z进行优化
493+
# Sample from distribution
494+
# Use sample() instead of rsample() because we don't need to optimize z
492495
if n_samples == 1:
493496
untransformed_z = latent_dist.sample()
494497
else:
495498
untransformed_z = latent_dist.sample((n_samples,))
496499

497-
# 变换潜在变量(如果使用的是logistic normal分布)
500+
# Transform latent variables (if using logistic normal distribution)
498501
z = self.z_encoder.z_transformation(untransformed_z)
499502

500-
# 计算文库大小(取对数)
503+
# Compute library size (take log)
501504
library = torch.log(observed_lib_size)
502505

503-
# 多样本情况下扩展文库大小
506+
# Expand library size for multiple samples case
504507
if n_samples > 1:
505508
library = library.unsqueeze(0).expand((n_samples, library.size(0), library.size(1)))
506509

507510
return {
508-
"z": z, # 潜在空间表示
509-
"qz": latent_dist, # 潜在空间分布
510-
"ql": None, # 没有文库大小分布(使用观测值)
511-
"library": library, # 文库大小
511+
"z": z, # Latent space representation
512+
"qz": latent_dist, # Latent space distribution
513+
"ql": None, # No library size distribution (using observed values)
514+
"library": library, # Library size
512515
}
513516

514517
@auto_move_data
@@ -668,66 +671,66 @@ def loss(
668671
generative_outputs: dict[str, Distribution | None],
669672
kl_weight: torch.tensor | float = 1.0,
670673
) -> LossOutput:
671-
"""计算变分自编码器的损失函数。
674+
"""Compute the loss function for the variational autoencoder.
672675
673-
损失函数由重构损失和KL散度两部分组成:
674-
1. 重构损失:衡量生成的数据与原始数据的匹配程度
675-
2. KL散度:衡量后验分布与先验分布的差异,起到正则化作用
676+
The loss function consists of two parts: reconstruction loss and KL divergence:
677+
1. Reconstruction loss: measures how well the generated data matches the original data
678+
2. KL divergence: measures the difference between posterior and prior distributions, acting as regularization
676679
677680
Args:
678-
tensors: 原始数据张量
679-
inference_outputs: 推断过程的输出
680-
generative_outputs: 生成过程的输出
681-
kl_weight: KL散度项的权重系数(用于KL退火)
681+
tensors: Original data tensors
682+
inference_outputs: Outputs from the inference process
683+
generative_outputs: Outputs from the generative process
684+
kl_weight: Weight coefficient for KL divergence term (used for KL annealing)
682685
683686
Returns:
684-
包含总损失和各部分损失的对象
687+
Object containing total loss and individual loss components
685688
"""
686689
from torch.distributions import kl_divergence
687690

688-
# 获取原始基因表达数据
689-
x = tensors["X"] # 使用直观的X代替REGISTRY_KEYS.X_KEY
691+
# Get original gene expression data
692+
x = tensors["X"] # Use intuitive X instead of REGISTRY_KEYS.X_KEY
690693

691-
# 计算潜在变量的KL散度:后验分布q(z|x)与先验分布p(z)之间的差异
694+
# Compute KL divergence for latent variables: difference between posterior q(z|x) and prior p(z)
692695
kl_divergence_z = kl_divergence(inference_outputs["qz"], generative_outputs["latent_space"]).sum(dim=-1)
693696

694-
# 计算文库大小的KL散度(如果使用学习的文库大小)
697+
# Compute KL divergence for library size (if using learned library size)
695698
if not self.use_observed_lib_size:
696699
kl_divergence_l = kl_divergence(inference_outputs["ql"], generative_outputs["library_size"]).sum(dim=1)
697700
else:
698-
# 如果使用观测的文库大小,则KL散度为0
701+
# If using observed library size, KL divergence is 0
699702
kl_divergence_l = torch.zeros_like(kl_divergence_z)
700703

701-
# 计算重构损失:负对数似然
704+
# Compute reconstruction loss: negative log likelihood
702705
reconstruction_loss = -generative_outputs["gene_expression"].log_prob(x).sum(-1)
703706

704-
# 区分需要进行权重调整的KL散度(用于KL退火)
705-
kl_for_warmup = kl_divergence_z # 潜在变量z的KL散度会参与退火
706-
kl_no_warmup = kl_divergence_l # 文库大小l的KL散度不参与退火
707+
# Distinguish KL divergences that need weight adjustment (for KL annealing)
708+
kl_for_warmup = kl_divergence_z # KL divergence for latent variable z participates in annealing
709+
kl_no_warmup = kl_divergence_l # KL divergence for library size l does not participate in annealing
707710

708-
# 应用权重后的KL散度
711+
# Apply weighted KL divergence
709712
weighted_kl = kl_weight * kl_for_warmup + kl_no_warmup
710713

711-
# 总损失 = 重构损失 + 加权KL散度
714+
# Total loss = reconstruction loss + weighted KL divergence
712715
total_loss = torch.mean(reconstruction_loss + weighted_kl)
713716

714-
# 为自动调优准备额外的指标(如果需要)
717+
# Prepare additional metrics for auto-tuning (if needed)
715718
if self.extra_payload_autotune:
716719
extra_metrics = {
717720
"z": inference_outputs["z"],
718-
"batch": tensors["batch"], # 使用直观的batch代替REGISTRY_KEYS.BATCH_KEY
719-
"labels": tensors["labels"], # 使用直观的labels代替REGISTRY_KEYS.LABELS_KEY
721+
"batch": tensors["batch"], # Use intuitive batch instead of REGISTRY_KEYS.BATCH_KEY
722+
"labels": tensors["labels"], # Use intuitive labels instead of REGISTRY_KEYS.LABELS_KEY
720723
}
721724
else:
722725
extra_metrics = {}
723726

724-
# 返回损失对象
727+
# Return loss object
725728
return LossOutput(
726729
loss=total_loss,
727730
reconstruction_loss=reconstruction_loss,
728731
kl_local={
729-
"kl_divergence_l": kl_divergence_l, # 文库大小的KL散度
730-
"kl_divergence_z": kl_divergence_z, # 潜在变量的KL散度
732+
"kl_divergence_l": kl_divergence_l, # KL divergence for library size
733+
"kl_divergence_z": kl_divergence_z, # KL divergence for latent variables
731734
},
732735
extra_metrics=extra_metrics,
733736
)

0 commit comments

Comments
 (0)