Skip to content

Commit 8842f6b

Browse files
Merge pull request #107 from theislab/singular_fix
Singular fix
2 parents ed2f5ba + d3287b3 commit 8842f6b

File tree

10 files changed

+262
-366
lines changed

10 files changed

+262
-366
lines changed

batchglm/api/data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from batchglm.data import design_matrix
22
from batchglm.data import constraint_matrix_from_dict, constraint_matrix_from_string, string_constraints_from_dict, \
33
constraint_system_from_star
4-
from batchglm.data import view_coef_names, preview_coef_names
4+
from batchglm.data import view_coef_names, preview_coef_names, bin_continuous_covariate

batchglm/data.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,3 +453,31 @@ def constraint_matrix_from_string(
453453
)
454454

455455
return constraint_mat
456+
457+
458+
def bin_continuous_covariate(
459+
sample_description: pd.DataFrame,
460+
factor_to_bin: str,
461+
bins: Union[int, list, np.ndarray, Tuple]
462+
):
463+
r"""
464+
Bin a continuous covariate.
465+
466+
Adds the binned covariate to the table. Binning is performed on quantiles of the distribution.
467+
468+
:param sample_description: Sample description table.
469+
:param factor_to_bin: Name of columns of factor to bin.
470+
:param bins: Number of bins or iteratable with bin borders. If given as integer, the bins are defined on the
471+
quantiles of the covariate, ie the bottom 20% of observations are in the first bin if bins==5.
472+
:return: Sample description table with binned covariate added.
473+
"""
474+
if isinstance(bins, list) or isinstance(bins, np.ndarray) or isinstance(bins, Tuple):
475+
bins = np.asarray(bins)
476+
else:
477+
bins = np.arange(0, 1, 1 / bins)
478+
479+
sample_description[factor_to_bin + "_binned"] = np.digitize(
480+
np.argsort(np.argsort(sample_description[factor_to_bin].values)) / sample_description.shape[0],
481+
bins
482+
)
483+
return sample_description

batchglm/models/base/estimator.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import numpy as np
66
import pandas as pd
77
import pprint
8+
import sys
89

910
try:
1011
import anndata
@@ -112,6 +113,14 @@ def train_sequence(
112113
logger.debug("training strategy:\n%s", pprint.pformat(training_strategy))
113114
for idx, d in enumerate(training_strategy):
114115
logger.debug("Beginning with training sequence #%d", idx + 1)
116+
# Override duplicate arguments with user choice:
117+
if np.any([x in list(d.keys()) for x in list(kwargs.keys())]):
118+
d = dict([(x, y) for x, y in d.items() if x not in list(kwargs.keys())])
119+
for x in [xx for xx in list(d.keys()) if xx in list(kwargs.keys())]:
120+
sys.stdout.write(
121+
"overrding %s from training strategy with value %s with new value %s\n" %
122+
(x, str(d[x]), str(kwargs[x]))
123+
)
115124
self.train(**d, **kwargs)
116125
logger.debug("Training sequence #%d complete", idx + 1)
117126

@@ -165,6 +174,11 @@ def _plot_coef_vs_ref(
165174
from matplotlib import gridspec
166175
from matplotlib import rcParams
167176

177+
if isinstance(true_values, dask.array.core.Array):
178+
true_values = true_values.compute()
179+
if isinstance(estim_values, dask.array.core.Array):
180+
estim_values = estim_values.compute()
181+
168182
plt.ioff()
169183

170184
n_par = true_values.shape[0]
@@ -258,6 +272,11 @@ def _plot_deviation(
258272
import seaborn as sns
259273
import matplotlib.pyplot as plt
260274

275+
if isinstance(true_values, dask.array.core.Array):
276+
true_values = true_values.compute()
277+
if isinstance(estim_values, dask.array.core.Array):
278+
estim_values = estim_values.compute()
279+
261280
plt.ioff()
262281

263282
n_par = true_values.shape[0]

batchglm/models/glm_nb/utils.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import dask
2+
import logging
13
import numpy as np
24
import scipy.sparse
35
from typing import Union
@@ -71,3 +73,145 @@ def compute_scales_fun(variance, mean):
7173
inv_link_fn=invlink_fn,
7274
compute_scales_fun=compute_scales_fun
7375
)
76+
77+
78+
def init_par(
79+
input_data,
80+
init_a,
81+
init_b,
82+
init_model
83+
):
84+
r"""
85+
standard:
86+
Only initialise intercept and keep other coefficients as zero.
87+
88+
closed-form:
89+
Initialize with Maximum Likelihood / Maximum of Momentum estimators
90+
91+
Idea:
92+
$$
93+
\theta &= f(x) \\
94+
\Rightarrow f^{-1}(\theta) &= x \\
95+
&= (D \cdot D^{+}) \cdot x \\
96+
&= D \cdot (D^{+} \cdot x) \\
97+
&= D \cdot x' = f^{-1}(\theta)
98+
$$
99+
"""
100+
train_loc = True
101+
train_scale = True
102+
103+
if init_model is None:
104+
groupwise_means = None
105+
init_a_str = None
106+
if isinstance(init_a, str):
107+
init_a_str = init_a.lower()
108+
# Chose option if auto was chosen
109+
if init_a.lower() == "auto":
110+
if isinstance(input_data.design_loc, dask.array.core.Array):
111+
dloc = input_data.design_loc.compute()
112+
else:
113+
dloc = input_data.design_loc
114+
one_hot = len(np.unique(dloc)) == 2 and \
115+
np.abs(np.min(dloc) - 0.) == 0. and \
116+
np.abs(np.max(dloc) - 1.) == 0.
117+
init_a = "standard" if not one_hot else "closed_form"
118+
119+
if init_a.lower() == "closed_form":
120+
groupwise_means, init_a, rmsd_a = closedform_nb_glm_logmu(
121+
x=input_data.x,
122+
design_loc=input_data.design_loc,
123+
constraints_loc=input_data.constraints_loc,
124+
size_factors=input_data.size_factors,
125+
link_fn=lambda mu: np.log(mu+np.nextafter(0, 1, dtype=mu.dtype))
126+
)
127+
128+
# train mu, if the closed-form solution is inaccurate
129+
train_loc = not (np.all(np.abs(rmsd_a) < 1e-20) or rmsd_a.size == 0)
130+
131+
if input_data.size_factors is not None:
132+
if np.any(input_data.size_factors != 1):
133+
train_loc = True
134+
elif init_a.lower() == "standard":
135+
overall_means = np.mean(input_data.x, axis=0) # directly calculate the mean
136+
init_a = np.zeros([input_data.num_loc_params, input_data.num_features])
137+
init_a[0, :] = np.log(overall_means)
138+
train_loc = True
139+
elif init_a.lower() == "all_zero":
140+
init_a = np.zeros([input_data.num_loc_params, input_data.num_features])
141+
train_loc = True
142+
else:
143+
raise ValueError("init_a string %s not recognized" % init_a)
144+
145+
if isinstance(init_b, str):
146+
if init_b.lower() == "auto":
147+
init_b = "standard"
148+
149+
if init_b.lower() == "standard":
150+
groupwise_scales, init_b_intercept, rmsd_b = closedform_nb_glm_logphi(
151+
x=input_data.x,
152+
design_scale=input_data.design_scale[:, [0]],
153+
constraints=input_data.constraints_scale[[0], :][:, [0]],
154+
size_factors=input_data.size_factors,
155+
groupwise_means=None,
156+
link_fn=lambda r: np.log(r+np.nextafter(0, 1, dtype=r.dtype))
157+
)
158+
init_b = np.zeros([input_data.num_scale_params, input_data.num_features])
159+
init_b[0, :] = init_b_intercept
160+
elif init_b.lower() == "closed_form":
161+
dmats_unequal = False
162+
if input_data.design_loc.shape[1] == input_data.design_scale.shape[1]:
163+
if np.any(input_data.design_loc != input_data.design_scale):
164+
dmats_unequal = True
165+
166+
inits_unequal = False
167+
if init_a_str is not None:
168+
if init_a_str != init_b:
169+
inits_unequal = True
170+
171+
if inits_unequal or dmats_unequal:
172+
raise ValueError("cannot use closed_form init for scale model " +
173+
"if scale model differs from loc model")
174+
175+
groupwise_scales, init_b, rmsd_b = closedform_nb_glm_logphi(
176+
x=input_data.x,
177+
design_scale=input_data.design_scale,
178+
constraints=input_data.constraints_scale,
179+
size_factors=input_data.size_factors,
180+
groupwise_means=groupwise_means,
181+
link_fn=lambda r: np.log(r)
182+
)
183+
elif init_b.lower() == "all_zero":
184+
init_b = np.zeros([input_data.num_scale_params, input_data.x.shape[1]])
185+
else:
186+
raise ValueError("init_b string %s not recognized" % init_b)
187+
else:
188+
# Locations model:
189+
if isinstance(init_a, str) and (init_a.lower() == "auto" or init_a.lower() == "init_model"):
190+
my_loc_names = set(input_data.loc_names)
191+
my_loc_names = my_loc_names.intersection(set(init_model.input_data.loc_names))
192+
193+
init_loc = np.zeros([input_data.num_loc_params, input_data.num_features])
194+
for parm in my_loc_names:
195+
init_idx = np.where(init_model.input_data.loc_names == parm)[0]
196+
my_idx = np.where(input_data.loc_names == parm)[0]
197+
init_loc[my_idx] = init_model.a_var[init_idx]
198+
199+
init_a = init_loc
200+
logging.getLogger("batchglm").debug("Using initialization based on input model for mean")
201+
202+
# Scale model:
203+
if isinstance(init_b, str) and (init_b.lower() == "auto" or init_b.lower() == "init_model"):
204+
my_scale_names = set(input_data.scale_names)
205+
my_scale_names = my_scale_names.intersection(init_model.input_data.scale_names)
206+
207+
init_scale = np.zeros([input_data.num_scale_params, input_data.num_features])
208+
for parm in my_scale_names:
209+
init_idx = np.where(init_model.input_data.scale_names == parm)[0]
210+
my_idx = np.where(input_data.scale_names == parm)[0]
211+
init_scale[my_idx] = init_model.b_var[init_idx]
212+
213+
init_b = init_scale
214+
logging.getLogger("batchglm").debug("Using initialization based on input model for dispersion")
215+
216+
return init_a, init_b, train_loc, train_scale
217+

batchglm/train/numpy/base_glm/estimator.py

Lines changed: 53 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,11 @@ def train(
8282
"""
8383
# Iterate until conditions are fulfilled.
8484
train_step = 0
85+
if self._train_scale:
86+
if not self._train_loc:
87+
update_b_freq = 1
88+
else:
89+
update_b_freq = np.inf
8590
epochs_until_b_update = update_b_freq
8691
fully_converged = np.tile(False, self.model.model_vars.n_features)
8792

@@ -97,25 +102,29 @@ def train(
97102
if epochs_until_b_update == 0:
98103
# Compute update.
99104
idx_update = np.where(np.logical_not(fully_converged))[0]
100-
b_step = self.b_step(
101-
idx_update=idx_update,
102-
method=method_b,
103-
ftol=ftol_b,
104-
lr=lr_b,
105-
max_iter=max_iter_b,
106-
nproc=nproc
107-
)
108-
# Perform trial update.
109-
self.model.b_var = self.model.b_var + b_step
110-
# Reverse update by feature if update leads to worse loss:
111-
ll_proposal = - self.model.ll_byfeature_j(j=idx_update).compute()
112-
idx_bad_step = idx_update[np.where(ll_proposal > ll_current[idx_update])[0]]
113-
if isinstance(self.model.b_var, dask.array.core.Array):
114-
b_var_new = self.model.b_var.compute()
105+
if self._train_scale:
106+
b_step = self.b_step(
107+
idx_update=idx_update,
108+
method=method_b,
109+
ftol=ftol_b,
110+
lr=lr_b,
111+
max_iter=max_iter_b,
112+
nproc=nproc
113+
)
114+
# Perform trial update.
115+
self.model.b_var = self.model.b_var + b_step
116+
# Reverse update by feature if update leads to worse loss:
117+
ll_proposal = - self.model.ll_byfeature_j(j=idx_update).compute()
118+
idx_bad_step = idx_update[np.where(ll_proposal > ll_current[idx_update])[0]]
119+
if isinstance(self.model.b_var, dask.array.core.Array):
120+
b_var_new = self.model.b_var.compute()
121+
else:
122+
b_var_new = self.model.b_var.copy()
123+
b_var_new[:, idx_bad_step] = b_var_new[:, idx_bad_step] - b_step[:, idx_bad_step]
124+
self.model.b_var = b_var_new
115125
else:
116-
b_var_new = self.model.b_var.copy()
117-
b_var_new[:, idx_bad_step] = b_var_new[:, idx_bad_step] - b_step[:, idx_bad_step]
118-
self.model.b_var = b_var_new
126+
ll_proposal = ll_current[idx_update]
127+
idx_bad_step = np.array([], dtype=np.int32)
119128
# Update likelihood vector with updated genes based on already evaluated proposal likelihood.
120129
ll_new = ll_current.copy()
121130
ll_new[idx_update] = ll_proposal
@@ -126,18 +135,22 @@ def train(
126135
# IWLS step for location model:
127136
# Compute update.
128137
idx_update = self.model.idx_not_converged
129-
a_step = self.iwls_step(idx_update=idx_update)
130-
# Perform trial update.
131-
self.model.a_var = self.model.a_var + a_step
132-
# Reverse update by feature if update leads to worse loss:
133-
ll_proposal = - self.model.ll_byfeature_j(j=idx_update).compute()
134-
idx_bad_step = idx_update[np.where(ll_proposal > ll_current[idx_update])[0]]
135-
if isinstance(self.model.b_var, dask.array.core.Array):
136-
a_var_new = self.model.a_var.compute()
138+
if self._train_loc:
139+
a_step = self.iwls_step(idx_update=idx_update)
140+
# Perform trial update.
141+
self.model.a_var = self.model.a_var + a_step
142+
# Reverse update by feature if update leads to worse loss:
143+
ll_proposal = - self.model.ll_byfeature_j(j=idx_update).compute()
144+
idx_bad_step = idx_update[np.where(ll_proposal > ll_current[idx_update])[0]]
145+
if isinstance(self.model.b_var, dask.array.core.Array):
146+
a_var_new = self.model.a_var.compute()
147+
else:
148+
a_var_new = self.model.a_var.copy()
149+
a_var_new[:, idx_bad_step] = a_var_new[:, idx_bad_step] - a_step[:, idx_bad_step]
150+
self.model.a_var = a_var_new
137151
else:
138-
a_var_new = self.model.a_var.copy()
139-
a_var_new[:, idx_bad_step] = a_var_new[:, idx_bad_step] - a_step[:, idx_bad_step]
140-
self.model.a_var = a_var_new
152+
ll_proposal = ll_current[idx_update]
153+
idx_bad_step = np.array([], dtype=np.int32)
141154
# Update likelihood vector with updated genes based on already evaluated proposal likelihood.
142155
ll_new = ll_current.copy()
143156
ll_new[idx_update] = ll_proposal
@@ -273,10 +286,16 @@ def iwls_step(
273286
invertible = np.where(dask.array.map_blocks(
274287
get_cond_number, a, chunks=a.shape
275288
).squeeze().compute() < 1 / sys.float_info.epsilon)[0]
276-
delta_theta[:, idx_update[invertible]] = dask.array.map_blocks(
277-
np.linalg.solve, a[invertible], b[invertible, :, None],
278-
chunks=b[invertible, :, None].shape
279-
).squeeze().T.compute()
289+
if len(idx_update[invertible]) > 1:
290+
delta_theta[:, idx_update[invertible]] = dask.array.map_blocks(
291+
np.linalg.solve, a[invertible], b[invertible, :, None],
292+
chunks=b[invertible, :, None].shape
293+
).squeeze().T.compute()
294+
elif len(idx_update[invertible]) == 1:
295+
delta_theta[:, idx_update[invertible]] = np.expand_dims(
296+
np.linalg.solve(a[invertible[0]], b[invertible[0]]).compute(),
297+
axis=-1
298+
)
280299
else:
281300
if np.linalg.cond(a.compute(), p=None) < 1 / sys.float_info.epsilon:
282301
delta_theta[:, idx_update] = np.expand_dims(
@@ -290,7 +309,7 @@ def iwls_step(
290309
invertible = np.where(np.linalg.cond(a, p=None) < 1 / sys.float_info.epsilon)[0]
291310
delta_theta[:, idx_update[invertible]] = np.linalg.solve(a[invertible], b[invertible]).T
292311
if invertible.shape[0] < len(idx_update):
293-
print("caught %i linalg singular matrix errors" % (len(idx_update) - invertible.shape[0]))
312+
sys.stdout.write("caught %i linalg singular matrix errors\n" % (len(idx_update) - invertible.shape[0]))
294313
# Via np.linalg.lsts:
295314
#delta_theta[:, idx_update] = np.concatenate([
296315
# np.expand_dims(np.linalg.lstsq(a[i, :, :], b[i, :])[0], axis=-1)
@@ -537,14 +556,3 @@ def get_model_container(
537556
input_data
538557
):
539558
pass
540-
541-
@abc.abstractmethod
542-
def init_par(
543-
self,
544-
input_data,
545-
init_a,
546-
init_b,
547-
init_model
548-
):
549-
pass
550-

0 commit comments

Comments
 (0)