-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* [WIP] Starting poisson model * [WIP] Need to log likelihood * [WIP] Beginning to handle log likelihood * [WIP] Tests run but some don't pass * [WIP] Add error for scale model. * Fix Poisson Log-Likelihood * Add docs. * Make small changes. * Add comment. * Fix pre-commit. * Remove bounds not needed. * Add back in scale clipping * Fixes.
- Loading branch information
Showing
15 changed files
with
489 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
from . import glm_beta, glm_nb, glm_norm | ||
from . import glm_beta, glm_nb, glm_norm, glm_poisson |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .model import Model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
import batchglm.utils.data as data_utils | ||
from batchglm import pkg_constants | ||
from batchglm.models.base_glm import _ModelGLM, closedform_glm_mean, closedform_glm_scale | ||
from batchglm.utils.linalg import groupwise_solve_lm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
import abc | ||
from typing import Any, Callable, Dict, Optional, Tuple, Union | ||
|
||
import dask.array | ||
import numpy as np | ||
|
||
from .external import ModelGLM | ||
|
||
|
||
class Model(ModelGLM, metaclass=abc.ABCMeta): | ||
""" | ||
Generalized Linear Model (GLM) with Poisson noise. | ||
""" | ||
|
||
def link_loc(self, data) -> Union[np.ndarray, dask.array.core.Array]: | ||
return np.log(data) | ||
|
||
def inverse_link_loc(self, data) -> Union[np.ndarray, dask.array.core.Array]: | ||
return np.exp(data) | ||
|
||
def link_scale(self, data) -> Union[np.ndarray, dask.array.core.Array]: | ||
return np.log(data) | ||
|
||
def inverse_link_scale(self, data) -> Union[np.ndarray, dask.array.core.Array]: | ||
return np.exp(data) | ||
|
||
@property | ||
def eta_loc(self) -> Union[np.ndarray, dask.array.core.Array]: | ||
eta = np.matmul(self.design_loc, self.theta_location_constrained) | ||
if self.size_factors is not None: | ||
eta += self.size_factors | ||
eta = self.np_clip_param(eta, "eta_loc") | ||
return eta | ||
|
||
def eta_loc_j(self, j) -> Union[np.ndarray, dask.array.core.Array]: | ||
# Make sure that dimensionality of sliced array is kept: | ||
if isinstance(j, int) or isinstance(j, np.int32) or isinstance(j, np.int64): | ||
j = [j] | ||
eta = np.matmul(self.design_loc, self.theta_location_constrained[:, j]) | ||
if self.size_factors is not None: | ||
eta += self.size_factors | ||
eta = self.np_clip_param(eta, "eta_loc") | ||
return eta | ||
|
||
# Re-parameterizations: | ||
|
||
@property | ||
def lam(self) -> Union[np.ndarray, dask.array.core.Array]: | ||
return self.location | ||
|
||
# param constraints: | ||
|
||
def bounds(self, sf, dmax, dtype) -> Tuple[Dict[str, Any], Dict[str, Any]]: | ||
|
||
bounds_min = { | ||
"theta_location": np.log(np.nextafter(0, np.inf, dtype=dtype)) / sf, | ||
"eta_loc": np.log(np.nextafter(0, np.inf, dtype=dtype)) / sf, | ||
"loc": np.nextafter(0, np.inf, dtype=dtype), | ||
"scale": np.nextafter(0, np.inf, dtype=dtype), | ||
"likelihood": dtype(0), | ||
"ll": np.log(np.nextafter(0, np.inf, dtype=dtype)), | ||
# Not used and should be removed: https://github.com/theislab/batchglm/issues/148 | ||
"theta_scale": np.log(np.nextafter(0, np.inf, dtype=dtype)) / sf, | ||
"eta_scale": np.log(np.nextafter(0, np.inf, dtype=dtype)) / sf, | ||
} | ||
bounds_max = { | ||
"theta_location": np.nextafter(np.log(dmax), -np.inf, dtype=dtype) / sf, | ||
"eta_loc": np.nextafter(np.log(dmax), -np.inf, dtype=dtype) / sf, | ||
"loc": np.nextafter(dmax, -np.inf, dtype=dtype) / sf, | ||
"scale": np.nextafter(dmax, -np.inf, dtype=dtype) / sf, | ||
"likelihood": dtype(1), | ||
"ll": dtype(10000), # poisson models can have large log likelhoods initially | ||
# Not used and should be removed: https://github.com/theislab/batchglm/issues/148 | ||
"theta_scale": np.log(dmax) / sf, | ||
"eta_scale": np.log(dmax) / sf, | ||
|
||
} | ||
return bounds_min, bounds_max | ||
|
||
# simulator: | ||
|
||
@property | ||
def rand_fn_ave(self) -> Optional[Callable]: | ||
return lambda shape: np.random.poisson(500, shape) + 1 | ||
|
||
@property | ||
def rand_fn(self) -> Optional[Callable]: | ||
return lambda shape: np.abs(np.random.uniform(2, 10, shape)) | ||
|
||
@property | ||
def rand_fn_loc(self) -> Optional[Callable]: | ||
return None | ||
|
||
@property | ||
def rand_fn_scale(self) -> Optional[Callable]: | ||
return None | ||
|
||
def generate_data(self) -> np.ndarray: | ||
""" | ||
Sample random data based on poisson distribution and parameters. | ||
""" | ||
return np.random.poisson(lam=self.lam) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
import logging | ||
from typing import Callable, Optional, Tuple, Union | ||
|
||
import dask | ||
import numpy as np | ||
import scipy.sparse | ||
|
||
from .external import closedform_glm_mean | ||
|
||
logger = logging.getLogger("batchglm") | ||
|
||
|
||
def closedform_poisson_glm_loglam( | ||
x: Union[np.ndarray, scipy.sparse.csr_matrix, dask.array.core.Array], | ||
design_loc: Union[np.ndarray, dask.array.core.Array], | ||
constraints_loc: Union[np.ndarray, dask.array.core.Array], | ||
size_factors: Optional[np.ndarray] = None, | ||
link_fn: Callable = np.log, | ||
inv_link_fn: Callable = np.exp, | ||
): | ||
r""" | ||
Calculates a closed-form solution for the `lam` parameters of poisson GLMs. | ||
:param x: The sample data | ||
:param design_loc: design matrix for location | ||
:param constraints_loc: tensor (all parameters x dependent parameters) | ||
Tensor that encodes how complete parameter set which includes dependent | ||
parameters arises from indepedent parameters: all = <constraints, indep>. | ||
This form of constraints is used in vector generalized linear models (VGLMs). | ||
:param size_factors: size factors for X | ||
:return: tuple: (groupwise_means, mu, rmsd) | ||
""" | ||
return closedform_glm_mean( | ||
x=x, | ||
dmat=design_loc, | ||
constraints=constraints_loc, | ||
size_factors=size_factors, | ||
link_fn=link_fn, | ||
inv_link_fn=inv_link_fn, | ||
) | ||
|
||
|
||
def init_par(model, init_location: str) -> Tuple[np.ndarray, np.ndarray, bool, bool]: | ||
r""" | ||
standard: | ||
Only initialise intercept and keep other coefficients as zero. | ||
closed-form: | ||
Initialize with Maximum Likelihood / Maximum of Momentum estimators | ||
Idea: | ||
$$ | ||
\theta &= f(x) \\ | ||
\Rightarrow f^{-1}(\theta) &= x \\ | ||
&= (D \cdot D^{+}) \cdot x \\ | ||
&= D \cdot (D^{+} \cdot x) \\ | ||
&= D \cdot x' = f^{-1}(\theta) | ||
$$ | ||
""" | ||
train_loc = False | ||
|
||
def auto_loc(dmat: Union[np.ndarray, dask.array.core.Array]) -> str: | ||
""" | ||
Checks if dmat is one-hot encoded and returns 'closed_form' if so, else 'standard' | ||
:param dmat The design matrix to check. | ||
""" | ||
unique_params = np.unique(dmat) | ||
if isinstance(unique_params, dask.array.core.Array): | ||
unique_params = unique_params.compute() | ||
if len(unique_params) == 2 and unique_params[0] == 0.0 and unique_params[1] == 1.0: | ||
return "closed_form" | ||
logger.warning( | ||
( | ||
"Cannot use 'closed_form' init for loc model: " | ||
"design_loc is not one-hot encoded. Falling back to standard initialization." | ||
) | ||
) | ||
return "standard" | ||
|
||
groupwise_means = None | ||
|
||
init_location_str = init_location.lower() | ||
# Chose option if auto was chosen | ||
if init_location_str == "auto": | ||
|
||
init_location_str = auto_loc(model.design_loc) | ||
|
||
if init_location_str == "closed_form": | ||
groupwise_means, init_theta_location, rmsd_a = closedform_poisson_glm_loglam( | ||
x=model.x, | ||
design_loc=model.design_loc, | ||
constraints_loc=model.constraints_loc, | ||
size_factors=model.size_factors, | ||
link_fn=lambda lam: np.log(lam + np.nextafter(0, 1, dtype=lam.dtype)), | ||
) | ||
# train mu, if the closed-form solution is inaccurate | ||
train_loc = not (np.all(np.abs(rmsd_a) < 1e-20) or rmsd_a.size == 0) | ||
if model.size_factors is not None: | ||
if np.any(model.size_factors != 1): | ||
train_loc = True | ||
|
||
elif init_location_str == "standard": | ||
overall_means = np.mean(model.x, axis=0) # directly calculate the mean | ||
init_theta_location = np.zeros([model.num_loc_params, model.num_features]) | ||
init_theta_location[0, :] = np.log(overall_means) | ||
train_loc = True | ||
elif init_location_str == "all_zero": | ||
init_theta_location = np.zeros([model.num_loc_params, model.num_features]) | ||
train_loc = True | ||
else: | ||
raise ValueError("init_location string %s not recognized" % init_location) | ||
|
||
# Scale is not used so just return init_theta_location for what would be init_theta_scale | ||
return init_theta_location, init_theta_location, train_loc, True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
from . import glm_nb as nb | ||
from . import glm_nb, glm_poisson |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .estimator import Estimator | ||
from .model_container import ModelContainer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import sys | ||
from typing import Optional, Tuple, Union | ||
|
||
import numpy as np | ||
|
||
from .external import EstimatorGlm, Model, init_par | ||
from .model_container import ModelContainer | ||
|
||
|
||
class Estimator(EstimatorGlm): | ||
""" | ||
Estimator for Generalized Linear Models (GLMs) with negative binomial noise. | ||
Uses the natural logarithm as linker function. | ||
Attributes | ||
---------- | ||
model_vars : ModelVars | ||
model variables | ||
""" | ||
|
||
def __init__( | ||
self, | ||
model: Model, | ||
init_location: str = "AUTO", | ||
init_scale: str = "AUTO", | ||
# batch_size: Optional[Union[Tuple[int, int], int]] = None, | ||
quick_scale: bool = False, | ||
dtype: str = "float64", | ||
): | ||
""" | ||
Performs initialisation and creates a new estimator. | ||
:param init_location: (Optional) | ||
Low-level initial values for a. Can be: | ||
- str: | ||
* "auto": automatically choose best initialization | ||
* "standard": initialize intercept with observed mean | ||
* "init_model": initialize with another model (see `ìnit_model` parameter) | ||
* "closed_form": try to initialize with closed form | ||
- np.ndarray: direct initialization of 'a' | ||
:param dtype: Numerical precision. | ||
""" | ||
init_theta_location, _, train_loc, _ = init_par(model=model, init_location=init_location) | ||
self._train_loc = train_loc | ||
# no need to train the scale parameter for the poisson model since it only has one parameter | ||
self._train_scale = False | ||
sys.stdout.write("training location model: %s\n" % str(self._train_loc)) | ||
init_theta_location = init_theta_location.astype(dtype) | ||
|
||
_model_container = ModelContainer( | ||
model=model, | ||
init_theta_location=init_theta_location, | ||
init_theta_scale=init_theta_location, # Not used. | ||
chunk_size_genes=model.chunk_size_genes, | ||
dtype=dtype, | ||
) | ||
super(Estimator, self).__init__(model_container=_model_container, dtype=dtype) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
class NoScaleError(Exception): | ||
""" | ||
Exception raised for attempting to access the scale parameter (or one of its derived methods) of a poisson model. | ||
""" | ||
|
||
def __init__(self, method): | ||
self.message = f"Attempted to access {method}. No scale parameter is fit for poisson - please use location." | ||
super().__init__(self.message) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
import batchglm.utils.data as data_utils | ||
from batchglm import pkg_constants | ||
from batchglm.models.base_glm.utils import closedform_glm_mean, closedform_glm_scale | ||
from batchglm.models.glm_poisson.model import Model | ||
from batchglm.models.glm_poisson.utils import init_par | ||
|
||
# import necessary base_glm layers | ||
from batchglm.train.numpy.base_glm import BaseModelContainer, EstimatorGlm | ||
from batchglm.utils.linalg import groupwise_solve_lm |
Oops, something went wrong.