Skip to content

Commit fb7e1a6

Browse files
committed
Fix the bugs for paste align: add paste_align_preprocess back
1 parent 622a20f commit fb7e1a6

File tree

4 files changed

+95
-3
lines changed

4 files changed

+95
-3
lines changed

spateo/alignment/methods/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# from .morpho_sparse import BA_align_sparse
33
from .backend import NumpyBackend, TorchBackend
44
from .deprecated_utils import (
5+
paste_align_preprocess,
56
align_preprocess,
67
cal_dist,
78
cal_dot,

spateo/alignment/methods/backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1067,7 +1067,7 @@ def data(self, a, type_as=None):
10671067
else:
10681068
return np.asarray(a, dtype=type_as.dtype)
10691069

1070-
def unique(self, a, return_index, return_inverse=False, axis=None):
1070+
def unique(self, a, return_index=False, return_inverse=False, axis=None):
10711071
return np.unique(a, return_index=return_index, return_inverse=return_inverse, axis=axis)
10721072

10731073
def unsqueeze(self, a, axis=-1):

spateo/alignment/methods/deprecated_utils.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -728,6 +728,96 @@ def align_preprocess(
728728
)
729729

730730

731+
def paste_align_preprocess(
732+
samples: List[AnnData],
733+
genes: Optional[Union[list, np.ndarray]] = None,
734+
spatial_key: str = "spatial",
735+
layer: str = "X",
736+
use_rep: Optional[str] = None,
737+
normalize_c: bool = False,
738+
normalize_g: bool = False,
739+
select_high_exp_genes: Union[bool, float, int] = False,
740+
dtype: str = "float64",
741+
device: str = "cpu",
742+
verbose: bool = True,
743+
**kwargs,
744+
) -> Tuple[
745+
TorchBackend or NumpyBackend,
746+
torch.Tensor or np.ndarray,
747+
list,
748+
list,
749+
list,
750+
Optional[float],
751+
Optional[list],
752+
]:
753+
"""
754+
Data preprocessing before alignment.
755+
756+
Args:
757+
samples: A list of anndata object.
758+
genes: Genes used for calculation. If None, use all common genes for calculation.
759+
spatial_key: The key in `.obsm` that corresponds to the raw spatial coordinates.
760+
layer: If `'X'`, uses ``sample.X`` to calculate dissimilarity between spots, otherwise uses the representation given by ``sample.layers[layer]``.
761+
normalize_c: Whether to normalize spatial coordinates.
762+
normalize_g: Whether to normalize gene expression.
763+
select_high_exp_genes: Whether to select genes with high differences in gene expression.
764+
dtype: The floating-point number type. Only float32 and float64.
765+
device: Equipment used to run the program. You can also set the specified GPU for running. E.g.: '0'.
766+
verbose: If ``True``, print progress updates.
767+
"""
768+
769+
# Determine if gpu or cpu is being used
770+
nx, type_as = check_backend(device=device, dtype=dtype)
771+
# Subset for common genes
772+
new_samples = [s.copy() for s in samples]
773+
all_samples_genes = [s[0].var.index for s in new_samples]
774+
common_genes = filter_common_genes(*all_samples_genes, verbose=verbose)
775+
common_genes = common_genes if genes is None else intersect_lsts(common_genes, genes)
776+
new_samples = [s[:, common_genes] for s in new_samples]
777+
778+
# Gene expression matrix of all samples
779+
if (use_rep is None) or (not isinstance(use_rep, str)) or (use_rep not in samples[0].obsm.keys()) or (use_rep not in samples[1].obsm.keys()):
780+
exp_matrices = [nx.from_numpy(check_exp(sample=s, layer=layer), type_as=type_as) for s in new_samples]
781+
else:
782+
exp_matrices = [nx.from_numpy(s.obsm[use_rep], type_as=type_as) for s in new_samples] + [nx.from_numpy(check_exp(sample=s, layer=layer), type_as=type_as) for s in new_samples]
783+
if not (select_high_exp_genes is False):
784+
# Select significance genes if select_high_exp_genes is True
785+
ExpressionData = _cat(nx=nx, x=exp_matrices, dim=0)
786+
787+
ExpressionVar = _var(nx, ExpressionData, 0)
788+
exp_threshold = 10 if isinstance(select_high_exp_genes, bool) else select_high_exp_genes
789+
EvidenceExpression = nx.where(ExpressionVar > exp_threshold)[0]
790+
exp_matrices = [exp_matrix[:, EvidenceExpression] for exp_matrix in exp_matrices]
791+
if verbose:
792+
lm.main_info(message=f"Evidence expression number: {len(EvidenceExpression)}.")
793+
794+
# Spatial coordinates of all samples
795+
spatial_coords = [
796+
nx.from_numpy(check_spatial_coords(sample=s, spatial_key=spatial_key), type_as=type_as) for s in new_samples
797+
]
798+
coords_dims = nx.unique(_data(nx, [c.shape[1] for c in spatial_coords], type_as))
799+
# coords_dims = np.unique(np.asarray([c.shape[1] for c in spatial_coords]))
800+
assert len(coords_dims) == 1, "Spatial coordinate dimensions are different, please check again."
801+
802+
normalize_scale_list, normalize_mean_list = None, None
803+
if normalize_c:
804+
spatial_coords, normalize_scale_list, normalize_mean_list = normalize_coords(
805+
coords=spatial_coords, nx=nx, verbose=verbose
806+
)
807+
if normalize_g and ((use_rep is None) or (not isinstance(use_rep, str)) or (use_rep not in samples[0].obsm.keys()) or (use_rep not in samples[1].obsm.keys())):
808+
exp_matrices = normalize_exps(matrices=exp_matrices, nx=nx, verbose=verbose)
809+
810+
return (
811+
nx,
812+
type_as,
813+
new_samples,
814+
exp_matrices,
815+
spatial_coords,
816+
normalize_scale_list,
817+
normalize_mean_list,
818+
)
819+
820+
731821
# Finished
732822
def guidance_pair_preprocess(
733823
nx: Union[TorchBackend, NumpyBackend],

spateo/alignment/methods/paste.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from spateo.logging import logger_manager as lm
1010

1111
from .deprecated_utils import (
12+
paste_align_preprocess,
1213
align_preprocess,
1314
calc_exp_dissimilarity,
1415
check_exp,
@@ -70,7 +71,7 @@ def paste_pairwise_align(
7071
"""
7172

7273
# Preprocessing
73-
(nx, type_as, new_samples, exp_matrices, spatial_coords, normalize_scale, normalize_mean_list,) = align_preprocess(
74+
(nx, type_as, new_samples, exp_matrices, spatial_coords, normalize_scale, normalize_mean_list,) = paste_align_preprocess(
7475
samples=[sampleA, sampleB],
7576
genes=genes,
7677
spatial_key=spatial_key,
@@ -116,7 +117,7 @@ def paste_pairwise_align(
116117
except ImportError:
117118
from ot.gromov import cg
118119

119-
pi, log = ot.gromov.cg(
120+
pi, log = cg(
120121
a,
121122
b,
122123
(1 - alpha) * M,

0 commit comments

Comments
 (0)