@@ -285,16 +285,16 @@ def _get_inference_input(
285285 tensors : dict [str , torch .Tensor | None ],
286286 full_forward_pass : bool = False ,
287287 ) -> dict [str , torch .Tensor | None ]:
288- """获取推断过程所需的输入张量。
288+ """Get input tensors required for the inference process.
289289
290290 Args:
291- tensors: 输入数据张量
292- full_forward_pass: 是否执行完整的前向传播
291+ tensors: Input data tensors
292+ full_forward_pass: Whether to execute a full forward pass
293293
294294 Returns:
295- 用于推断过程的输入字典
295+ Input dictionary for the inference process
296296 """
297- # 根据数据类型选择加载方式
297+ # Choose loading method based on data type
298298 if full_forward_pass or self .minified_data_type is None :
299299 loader = "full_data"
300300 elif self .minified_data_type in [
@@ -303,83 +303,85 @@ def _get_inference_input(
303303 ]:
304304 loader = "minified_data"
305305 else :
306- raise NotImplementedError (f"未知的简化数据类型 : { self .minified_data_type } " )
306+ raise NotImplementedError (f"Unknown minified data type : { self .minified_data_type } " )
307307
308- # 完整数据情况:提供表达数据和批次信息
308+ # Full data case: provide expression data and batch information
309309 if loader == "full_data" :
310310 return {
311- "x" : tensors ["X" ], # 基因表达数据
312- "batch_index" : tensors ["batch" ], # 批次索引
313- "cont_covariates" : tensors .get ("continuous_covariates" , None ), # 连续型协变量
314- "cat_covariates" : tensors .get ("categorical_covariates" , None ), # 分类型协变量
311+ "x" : tensors ["X" ], # Gene expression data
312+ "batch_index" : tensors ["batch" ], # Batch indices
313+ "cont_covariates" : tensors .get ("continuous_covariates" , None ), # Continuous covariates
314+ "cat_covariates" : tensors .get ("categorical_covariates" , None ), # Categorical covariates
315315 }
316- # 简化数据情况:提供已计算的潜在变量分布参数
316+ # Simplified data case: provide pre-computed latent variable distribution parameters
317317 else :
318318 return {
319- "qzm" : tensors ["scvi_latent_qzm" ], # 潜变量均值
320- "qzv" : tensors ["scvi_latent_qzv" ], # 潜变量方差
321- "observed_lib_size" : tensors ["observed_lib_size" ], # 观测到的文库大小
319+ "qzm" : tensors ["scvi_latent_qzm" ], # Latent variable means
320+ "qzv" : tensors ["scvi_latent_qzv" ], # Latent variable variances
321+ "observed_lib_size" : tensors ["observed_lib_size" ], # Observed library size
322322 }
323323
324324 def _get_generative_input (
325325 self ,
326326 tensors : dict [str , torch .Tensor ],
327327 inference_outputs : dict [str , torch .Tensor | Distribution | None ],
328328 ) -> dict [str , torch .Tensor | None ]:
329- """获取生成过程的输入张量。
329+ """Get input tensors for the generative process.
330330
331- 将推断步骤的输出与原始数据结合,准备生成网络的输入。
331+ Combine outputs from the inference step with original data to prepare
332+ inputs for the generative network.
332333
333334 Args:
334- tensors: 原始数据张量
335- inference_outputs: 推断过程的输出
335+ tensors: Original data tensors
336+ inference_outputs: Outputs from the inference process
336337
337338 Returns:
338- 生成过程所需的输入字典
339+ Input dictionary required for the generative process
339340 """
340- # 获取size_factor(如果提供)
341+ # Get size_factor (if provided)
341342 size_factor = tensors .get ("size_factor" , None )
342343 if size_factor is not None :
343344 size_factor = torch .log (size_factor )
344345
345346 return {
346- "z" : inference_outputs ["z" ], # 潜在空间表示
347- "library" : inference_outputs ["library" ], # 文库大小
348- "batch_index" : tensors ["batch" ], # 批次索引
349- "y" : tensors ["labels" ], # 细胞类型标签
350- "cont_covariates" : tensors .get ("continuous_covariates" , None ), # 连续型协变量
351- "cat_covariates" : tensors .get ("categorical_covariates" , None ), # 分类型协变量
352- "size_factor" : size_factor , # 大小因子
347+ "z" : inference_outputs ["z" ], # Latent space representation
348+ "library" : inference_outputs ["library" ], # Library size
349+ "batch_index" : tensors ["batch" ], # Batch indices
350+ "y" : tensors ["labels" ], # Cell type labels
351+ "cont_covariates" : tensors .get ("continuous_covariates" , None ), # Continuous covariates
352+ "cat_covariates" : tensors .get ("categorical_covariates" , None ), # Categorical covariates
353+ "size_factor" : size_factor , # Size factor
353354 }
354355
355356 def _compute_local_library_params (
356357 self ,
357358 batch_index : torch .Tensor ,
358359 ) -> tuple [torch .Tensor , torch .Tensor ]:
359- """计算局部文库参数。
360+ """Compute local library parameters.
360361
361- 为每个细胞计算文库大小的均值和方差参数,这些参数取决于细胞所属的批次。
362+ Calculate mean and variance parameters for library size for each cell,
363+ which depend on the batch the cell belongs to.
362364
363365 Args:
364- batch_index: 形状为 (batch_size, 1) 的批次索引张量
366+ batch_index: Batch index tensor of shape (batch_size, 1)
365367
366368 Returns:
367- tuple: 包含两个张量,分别是文库大小对数的均值和方差
369+ tuple: Contains two tensors, log library size means and variances respectively
368370 """
369371 from torch .nn .functional import linear
370372
371- # 批次数量
373+ # Number of batches
372374 num_batches = self .library_log_means .shape [1 ]
373375
374- # 将批次索引转换为独热编码
376+ # Convert batch indices to one-hot encoding
375377 batch_one_hot = one_hot (batch_index .squeeze (- 1 ), num_batches ).float ()
376378
377- # 计算每个细胞的文库大小对数均值
378- # 相当于从全局文库均值表中查找对应批次的值
379+ # Compute log library size means for each cell
380+ # Equivalent to looking up the corresponding batch value from the global library means table
379381 library_log_means = linear (batch_one_hot , self .library_log_means )
380382
381- # 计算每个细胞的文库大小对数方差
382- # 相当于从全局文库方差表中查找对应批次的值
383+ # Compute log library size variances for each cell
384+ # Equivalent to looking up the corresponding batch value from the global library variances table
383385 library_log_vars = linear (batch_one_hot , self .library_log_vars )
384386
385387 return library_log_means , library_log_vars
@@ -393,59 +395,59 @@ def _regular_inference(
393395 cat_covs : torch .Tensor | None = None ,
394396 n_samples : int = 1 ,
395397 ) -> dict [str , torch .Tensor | Distribution | None ]:
396- """运行常规推断过程,获取数据的潜在表示。
398+ """Run regular inference process to obtain latent representations of data.
397399
398400 Args:
399- x: 基因表达数据
400- batch_index: 批次索引
401- cont_covs: 连续型协变量
402- cat_covs: 分类型协变量
403- n_samples: 采样数量
401+ x: Gene expression data
402+ batch_index: Batch indices
403+ cont_covs: Continuous covariates
404+ cat_covs: Categorical covariates
405+ n_samples: Number of samples
404406
405407 Returns:
406- 包含潜在变量和分布的字典
408+ Dictionary containing latent variables and distributions
407409 """
408410 x_ = x
409411 if self .use_observed_lib_size :
410- # 计算观测到的文库大小(基因表达总和)
412+ # Compute observed library size (sum of gene expression)
411413 library = torch .log (x .sum (1 )).unsqueeze (1 )
412414 if self .log_variational :
413- # 对数据取对数,增加数值稳定性
415+ # Take log of data for numerical stability
414416 x_ = torch .log1p (x_ )
415417
416418 if cont_covs is not None and self .encode_covariates :
417- # 将基因表达和连续协变量拼接起来
419+ # Concatenate gene expression and continuous covariates
418420 encoder_input = torch .cat ((x_ , cont_covs ), dim = - 1 )
419421 else :
420422 encoder_input = x_
421423 if cat_covs is not None and self .encode_covariates :
422- # 处理分类协变量
424+ # Process categorical covariates
423425 categorical_input = torch .split (cat_covs , 1 , dim = 1 )
424426 else :
425427 categorical_input = ()
426428
427429 if self .batch_representation == "embedding" and self .encode_covariates :
428- # 使用嵌入表示批次
430+ # Use embedding representation for batch
429431 batch_embedding = self .compute_embedding ("batch" , batch_index )
430432 encoder_input = torch .cat ([encoder_input , batch_embedding ], dim = - 1 )
431- # 获取潜在变量分布和样本
433+ # Get latent variable distribution and samples
432434 qz , z = self .z_encoder (encoder_input , * categorical_input )
433435 else :
434- # 使用独热编码表示批次
436+ # Use one-hot encoding representation for batch
435437 qz , z = self .z_encoder (encoder_input , batch_index , * categorical_input )
436438
437439 ql = None
438440 if not self .use_observed_lib_size :
439441 if self .batch_representation == "embedding" :
440- # 使用嵌入表示批次来编码文库大小
442+ # Use embedding representation for batch to encode library size
441443 ql , library_encoded = self .l_encoder (encoder_input , * categorical_input )
442444 else :
443- # 使用独热编码表示批次来编码文库大小
445+ # Use one-hot encoding representation for batch to encode library size
444446 ql , library_encoded = self .l_encoder (encoder_input , batch_index , * categorical_input )
445447 library = library_encoded
446448
447449 if n_samples > 1 :
448- # 多个样本情况下的处理
450+ # Handle multiple samples case
449451 untran_z = qz .sample ((n_samples ,))
450452 z = self .z_encoder .z_transformation (untran_z )
451453 if self .use_observed_lib_size :
@@ -454,10 +456,10 @@ def _regular_inference(
454456 library = ql .sample ((n_samples ,))
455457
456458 return {
457- "z" : z , # 潜在空间表示
458- "qz" : qz , # 潜在空间分布
459- "ql" : ql , # 文库大小分布
460- "library" : library , # 文库大小
459+ "z" : z , # Latent space representation
460+ "qz" : qz , # Latent space distribution
461+ "ql" : ql , # Library size distribution
462+ "library" : library , # Library size
461463 }
462464
463465 @auto_move_data
@@ -468,47 +470,48 @@ def _cached_inference(
468470 observed_lib_size : torch .Tensor ,
469471 n_samples : int = 1 ,
470472 ) -> dict [str , torch .Tensor | None ]:
471- """使用缓存的潜在变量分布参数进行推断。
473+ """Perform inference using cached latent variable distribution parameters.
472474
473- 这种方法主要用于已经计算并存储了潜在变量分布的情况,
474- 可以加快推断速度,无需重新运行编码器网络。
475+ This method is mainly used when latent variable distributions have already
476+ been computed and stored, which can speed up inference without re-running
477+ the encoder network.
475478
476479 Args:
477- qzm: 潜在变量均值
478- qzv: 潜在变量方差
479- observed_lib_size: 观测到的文库大小
480- n_samples: 采样数量
480+ qzm: Latent variable means
481+ qzv: Latent variable variances
482+ observed_lib_size: Observed library size
483+ n_samples: Number of samples
481484
482485 Returns:
483- 包含潜在变量和分布的字典
486+ Dictionary containing latent variables and distributions
484487 """
485488 from torch .distributions import Normal
486489
487- # 创建潜在变量的正态分布
490+ # Create normal distribution for latent variables
488491 latent_dist = Normal (qzm , qzv .sqrt ())
489492
490- # 从分布中采样
491- # 使用sample()而不是rsample(),因为我们不需要对z进行优化
493+ # Sample from distribution
494+ # Use sample() instead of rsample() because we don't need to optimize z
492495 if n_samples == 1 :
493496 untransformed_z = latent_dist .sample ()
494497 else :
495498 untransformed_z = latent_dist .sample ((n_samples ,))
496499
497- # 变换潜在变量(如果使用的是logistic normal分布)
500+ # Transform latent variables (if using logistic normal distribution)
498501 z = self .z_encoder .z_transformation (untransformed_z )
499502
500- # 计算文库大小(取对数)
503+ # Compute library size (take log)
501504 library = torch .log (observed_lib_size )
502505
503- # 多样本情况下扩展文库大小
506+ # Expand library size for multiple samples case
504507 if n_samples > 1 :
505508 library = library .unsqueeze (0 ).expand ((n_samples , library .size (0 ), library .size (1 )))
506509
507510 return {
508- "z" : z , # 潜在空间表示
509- "qz" : latent_dist , # 潜在空间分布
510- "ql" : None , # 没有文库大小分布(使用观测值)
511- "library" : library , # 文库大小
511+ "z" : z , # Latent space representation
512+ "qz" : latent_dist , # Latent space distribution
513+ "ql" : None , # No library size distribution (using observed values)
514+ "library" : library , # Library size
512515 }
513516
514517 @auto_move_data
@@ -668,66 +671,66 @@ def loss(
668671 generative_outputs : dict [str , Distribution | None ],
669672 kl_weight : torch .tensor | float = 1.0 ,
670673 ) -> LossOutput :
671- """计算变分自编码器的损失函数。
674+ """Compute the loss function for the variational autoencoder.
672675
673- 损失函数由重构损失和KL散度两部分组成:
674- 1. 重构损失:衡量生成的数据与原始数据的匹配程度
675- 2. KL散度:衡量后验分布与先验分布的差异,起到正则化作用
676+ The loss function consists of two parts: reconstruction loss and KL divergence:
677+ 1. Reconstruction loss: measures how well the generated data matches the original data
678+ 2. KL divergence: measures the difference between posterior and prior distributions, acting as regularization
676679
677680 Args:
678- tensors: 原始数据张量
679- inference_outputs: 推断过程的输出
680- generative_outputs: 生成过程的输出
681- kl_weight: KL散度项的权重系数(用于KL退火)
681+ tensors: Original data tensors
682+ inference_outputs: Outputs from the inference process
683+ generative_outputs: Outputs from the generative process
684+ kl_weight: Weight coefficient for KL divergence term (used for KL annealing)
682685
683686 Returns:
684- 包含总损失和各部分损失的对象
687+ Object containing total loss and individual loss components
685688 """
686689 from torch .distributions import kl_divergence
687690
688- # 获取原始基因表达数据
689- x = tensors ["X" ] # 使用直观的X代替REGISTRY_KEYS .X_KEY
691+ # Get original gene expression data
692+ x = tensors ["X" ] # Use intuitive X instead of REGISTRY_KEYS .X_KEY
690693
691- # 计算潜在变量的KL散度:后验分布q (z|x)与先验分布p (z)之间的差异
694+ # Compute KL divergence for latent variables: difference between posterior q (z|x) and prior p (z)
692695 kl_divergence_z = kl_divergence (inference_outputs ["qz" ], generative_outputs ["latent_space" ]).sum (dim = - 1 )
693696
694- # 计算文库大小的KL散度(如果使用学习的文库大小)
697+ # Compute KL divergence for library size (if using learned library size)
695698 if not self .use_observed_lib_size :
696699 kl_divergence_l = kl_divergence (inference_outputs ["ql" ], generative_outputs ["library_size" ]).sum (dim = 1 )
697700 else :
698- # 如果使用观测的文库大小,则KL散度为0
701+ # If using observed library size, KL divergence is 0
699702 kl_divergence_l = torch .zeros_like (kl_divergence_z )
700703
701- # 计算重构损失:负对数似然
704+ # Compute reconstruction loss: negative log likelihood
702705 reconstruction_loss = - generative_outputs ["gene_expression" ].log_prob (x ).sum (- 1 )
703706
704- # 区分需要进行权重调整的KL散度(用于KL退火)
705- kl_for_warmup = kl_divergence_z # 潜在变量z的KL散度会参与退火
706- kl_no_warmup = kl_divergence_l # 文库大小l的KL散度不参与退火
707+ # Distinguish KL divergences that need weight adjustment (for KL annealing)
708+ kl_for_warmup = kl_divergence_z # KL divergence for latent variable z participates in annealing
709+ kl_no_warmup = kl_divergence_l # KL divergence for library size l does not participate in annealing
707710
708- # 应用权重后的KL散度
711+ # Apply weighted KL divergence
709712 weighted_kl = kl_weight * kl_for_warmup + kl_no_warmup
710713
711- # 总损失 = 重构损失 + 加权KL散度
714+ # Total loss = reconstruction loss + weighted KL divergence
712715 total_loss = torch .mean (reconstruction_loss + weighted_kl )
713716
714- # 为自动调优准备额外的指标(如果需要)
717+ # Prepare additional metrics for auto-tuning (if needed)
715718 if self .extra_payload_autotune :
716719 extra_metrics = {
717720 "z" : inference_outputs ["z" ],
718- "batch" : tensors ["batch" ], # 使用直观的batch代替REGISTRY_KEYS .BATCH_KEY
719- "labels" : tensors ["labels" ], # 使用直观的labels代替REGISTRY_KEYS .LABELS_KEY
721+ "batch" : tensors ["batch" ], # Use intuitive batch instead of REGISTRY_KEYS .BATCH_KEY
722+ "labels" : tensors ["labels" ], # Use intuitive labels instead of REGISTRY_KEYS .LABELS_KEY
720723 }
721724 else :
722725 extra_metrics = {}
723726
724- # 返回损失对象
727+ # Return loss object
725728 return LossOutput (
726729 loss = total_loss ,
727730 reconstruction_loss = reconstruction_loss ,
728731 kl_local = {
729- "kl_divergence_l" : kl_divergence_l , # 文库大小的KL散度
730- "kl_divergence_z" : kl_divergence_z , # 潜在变量的KL散度
732+ "kl_divergence_l" : kl_divergence_l , # KL divergence for library size
733+ "kl_divergence_z" : kl_divergence_z , # KL divergence for latent variables
731734 },
732735 extra_metrics = extra_metrics ,
733736 )
0 commit comments