Skip to content

Commit 15e4c11

Browse files
authored
Merge pull request #7 from raphael-group/development
Command line support for initial alignments
2 parents 32c303f + 352b59e commit 15e4c11

File tree

3 files changed

+18
-14
lines changed

3 files changed

+18
-14
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ Note: `pairwise` will return pairwise alignment between each consecutive pair of
6868
| -t | threshold | Convergence threshold for `center_align` | (float) `0.001` |
6969
| -x | coordinates | Output new coordinates (toggle to turn on) | `Flase` |
7070
| -w | weights | Weights files of spots in each slice (.csv) | None |
71+
| -s | start | Initial alignments for OT. If not given uses uniform (.csv structure similar to alignment output) | None |
7172

7273
`pairwise_align` outputs a (.csv) file containing mapping of spots between each consecutive pair of slices. The rows correspond to spots of the first slice, and cols the second.
7374

paste-cmd-line.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,26 +11,21 @@ def main(args):
1111
n_slices = int(len(args.filename)/2)
1212
# Error check arguments
1313
if args.mode != 'pairwise' and args.mode != 'center':
14-
print("Error: please select either 'pairwise' or 'center' mode.")
15-
return
14+
raise(ValueError("Please select either 'pairwise' or 'center' mode."))
1615

1716
if args.alpha < 0 or args.alpha > 1:
18-
print("Error: alpha specified outside [0, 1].")
19-
return
17+
raise(ValueError("alpha specified outside [0, 1]"))
2018

2119
if args.initial_slice < 1 or args.initial_slice > n_slices:
22-
print("Error: initial slice specified outside [1, n].")
23-
return
20+
raise(ValueError("Initial slice specified outside [1, n]"))
2421

2522
if len(args.lmbda) == 0:
2623
lmbda = n_slices*[1./n_slices]
2724
elif len(args.lmbda) != n_slices:
28-
print("Error: length of lambda does not equal number of files.")
29-
return
25+
raise(ValueError("Length of lambda does not equal number of files"))
3026
else:
3127
if not all(i >= 0 for i in args.lmbda):
32-
print("Error: lambda includes negative weights.")
33-
return
28+
raise(ValueError("lambda includes negative weights"))
3429
else:
3530
print("Normalizing lambda weights into probability vector.")
3631
lmbda = args.lmbda
@@ -53,6 +48,13 @@ def main(args):
5348
slices[i].obsm['weights'] = np.genfromtxt(args.weights[i], delimiter = ',')
5449
slices[i].obsm['weights'] = slices[i].obsm['weights']/np.sum(slices[i].obsm['weights'])
5550

51+
if len(args.start)==0:
52+
pis_init = n_slices*[None]
53+
elif (args.mode == 'pairwise' and len(args.start)!=n_slices-1) or (args.mode == 'center' and len(args.start)!=n_slices):
54+
raise(ValueError("Number of slices {0} != number of start pi files {1}".format(n_slices,len(args.start))))
55+
else:
56+
pis_init = [pd.read_csv(args.start[i],index_col=0).to_numpy() for i in range(len(args.start))]
57+
5658
# create output folder
5759
output_path = os.path.join(args.direc, "paste_output")
5860
if not os.path.exists(output_path):
@@ -63,7 +65,7 @@ def main(args):
6365
# compute pairwise align
6466
pis = []
6567
for i in range(n_slices - 1):
66-
pi = pairwise_align(slices[i], slices[i+1], args.alpha, dissimilarity=args.cost, a_distribution=slices[i].obsm['weights'], b_distribution=slices[i+1].obsm['weights'])
68+
pi = pairwise_align(slices[i], slices[i+1], args.alpha, dissimilarity=args.cost, a_distribution=slices[i].obsm['weights'], b_distribution=slices[i+1].obsm['weights'], G_init=pis_init[i])
6769
pis.append(pi)
6870
pi = pd.DataFrame(pi, index = slices[i].obs.index, columns = slices[i+1].obs.index)
6971
output_filename = "paste_output/slice" + str(i+1) + "_slice" + str(i+2) + "_pairwise.csv"
@@ -77,13 +79,13 @@ def main(args):
7779
print("Computing center alignment.")
7880
initial_slice = slices[args.initial_slice - 1].copy()
7981
# compute center align
80-
center_slice, pis = center_align(initial_slice, slices, lmbda, args.alpha, args.n_components, args.threshold, dissimilarity=args.cost, distributions=[slices[i].obsm['weights'] for i in range(n_slices)])
82+
center_slice, pis = center_align(initial_slice, slices, lmbda, args.alpha, args.n_components, args.threshold, dissimilarity=args.cost, distributions=[slices[i].obsm['weights'] for i in range(n_slices)], pis_init=pis_init)
8183
W = pd.DataFrame(center_slice.uns['paste_W'], index = center_slice.obs.index)
8284
H = pd.DataFrame(center_slice.uns['paste_H'], columns = center_slice.var.index)
8385
W.to_csv(os.path.join(args.direc,"paste_output/W_center"))
8486
H.to_csv(os.path.join(args.direc,"paste_output/H_center"))
8587
for i in range(len(pis)):
86-
output_filename = "paste_output/slice_center_slice" + str(i+1) + "_pairwise"
88+
output_filename = "paste_output/slice_center_slice" + str(i+1) + "_pairwise.csv"
8789
pi = pd.DataFrame(pis[i], index = center_slice.obs.index, columns = slices[i].obs.index)
8890
pi.to_csv(os.path.join(args.direc, output_filename))
8991
if args.coordinates:
@@ -107,6 +109,7 @@ def main(args):
107109
parser.add_argument("-i", "--initial_slice", help="specify which slice is the intial slice for center_align (int from 1-n)",type=int, default = 1)
108110
parser.add_argument("-t","--threshold", help="convergence threshold for center_align",type=float, default = 0.001)
109111
parser.add_argument("-x","--coordinates", help="output new coordinates", action='store_true', default = False)
110-
parser.add_argument("-w","--weights", help="path to files containing weights of spots in each slice",type=str, default=[], nargs='+')
112+
parser.add_argument("-w","--weights", help="path to files containing weights of spots in each slice. The format of the files is the same as the coordinate files used as input",type=str, default=[], nargs='+')
113+
parser.add_argument("-s","--start", help="path to files containing initial starting alignmnets. If not given the OT starts the search with uniform alignments. The format of the files is the same as the alignments files output by PASTE",type=str, default=[], nargs='+')
111114
args = parser.parse_args()
112115
main(args)
121 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)