Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Poisson Model #140

Merged
merged 14 commits into from
Sep 21, 2022
2 changes: 1 addition & 1 deletion batchglm/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from . import glm_beta, glm_nb, glm_norm
from . import glm_beta, glm_nb, glm_norm, glm_poisson
1 change: 1 addition & 0 deletions batchglm/models/glm_poisson/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .model import Model
4 changes: 4 additions & 0 deletions batchglm/models/glm_poisson/external.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
import batchglm.utils.data as data_utils
from batchglm import pkg_constants
from batchglm.models.base_glm import _ModelGLM, closedform_glm_mean, closedform_glm_scale
from batchglm.utils.linalg import groupwise_solve_lm
102 changes: 102 additions & 0 deletions batchglm/models/glm_poisson/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import abc
from typing import Any, Callable, Dict, Optional, Tuple, Union

import dask.array
import numpy as np

from .external import ModelGLM


class Model(ModelGLM, metaclass=abc.ABCMeta):
"""
Generalized Linear Model (GLM) with Poisson noise.
"""

def link_loc(self, data) -> Union[np.ndarray, dask.array.core.Array]:
return np.log(data)

def inverse_link_loc(self, data) -> Union[np.ndarray, dask.array.core.Array]:
return np.exp(data)

def link_scale(self, data) -> Union[np.ndarray, dask.array.core.Array]:
return np.log(data)

def inverse_link_scale(self, data) -> Union[np.ndarray, dask.array.core.Array]:
return np.exp(data)

@property
def eta_loc(self) -> Union[np.ndarray, dask.array.core.Array]:
eta = np.matmul(self.design_loc, self.theta_location_constrained)
if self.size_factors is not None:
eta += self.size_factors
eta = self.np_clip_param(eta, "eta_loc")
return eta

def eta_loc_j(self, j) -> Union[np.ndarray, dask.array.core.Array]:
# Make sure that dimensionality of sliced array is kept:
if isinstance(j, int) or isinstance(j, np.int32) or isinstance(j, np.int64):
j = [j]
eta = np.matmul(self.design_loc, self.theta_location_constrained[:, j])
if self.size_factors is not None:
eta += self.size_factors
eta = self.np_clip_param(eta, "eta_loc")
return eta

# Re-parameterizations:

@property
def lam(self) -> Union[np.ndarray, dask.array.core.Array]:
return self.location

# param constraints:

def bounds(self, sf, dmax, dtype) -> Tuple[Dict[str, Any], Dict[str, Any]]:

bounds_min = {
"theta_location": np.log(np.nextafter(0, np.inf, dtype=dtype)) / sf,
"eta_loc": np.log(np.nextafter(0, np.inf, dtype=dtype)) / sf,
"loc": np.nextafter(0, np.inf, dtype=dtype),
"scale": np.nextafter(0, np.inf, dtype=dtype),
"likelihood": dtype(0),
"ll": np.log(np.nextafter(0, np.inf, dtype=dtype)),
# Not used and should be removed: https://github.com/theislab/batchglm/issues/148
"theta_scale": np.log(np.nextafter(0, np.inf, dtype=dtype)) / sf,
"eta_scale": np.log(np.nextafter(0, np.inf, dtype=dtype)) / sf,
}
bounds_max = {
"theta_location": np.nextafter(np.log(dmax), -np.inf, dtype=dtype) / sf,
"eta_loc": np.nextafter(np.log(dmax), -np.inf, dtype=dtype) / sf,
"loc": np.nextafter(dmax, -np.inf, dtype=dtype) / sf,
"scale": np.nextafter(dmax, -np.inf, dtype=dtype) / sf,
"likelihood": dtype(1),
"ll": dtype(10000), # poisson models can have large log likelhoods initially
# Not used and should be removed: https://github.com/theislab/batchglm/issues/148
"theta_scale": np.log(dmax) / sf,
"eta_scale": np.log(dmax) / sf,

}
return bounds_min, bounds_max

# simulator:

@property
def rand_fn_ave(self) -> Optional[Callable]:
return lambda shape: np.random.poisson(500, shape) + 1

@property
def rand_fn(self) -> Optional[Callable]:
return lambda shape: np.abs(np.random.uniform(2, 10, shape))

@property
def rand_fn_loc(self) -> Optional[Callable]:
return None

@property
def rand_fn_scale(self) -> Optional[Callable]:
return None

def generate_data(self) -> np.ndarray:
"""
Sample random data based on poisson distribution and parameters.
"""
return np.random.poisson(lam=self.lam)
115 changes: 115 additions & 0 deletions batchglm/models/glm_poisson/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import logging
from typing import Callable, Optional, Tuple, Union

import dask
import numpy as np
import scipy.sparse

from .external import closedform_glm_mean

logger = logging.getLogger("batchglm")


def closedform_poisson_glm_loglam(
x: Union[np.ndarray, scipy.sparse.csr_matrix, dask.array.core.Array],
design_loc: Union[np.ndarray, dask.array.core.Array],
constraints_loc: Union[np.ndarray, dask.array.core.Array],
size_factors: Optional[np.ndarray] = None,
link_fn: Callable = np.log,
inv_link_fn: Callable = np.exp,
):
r"""
Calculates a closed-form solution for the `lam` parameters of poisson GLMs.
:param x: The sample data
:param design_loc: design matrix for location
:param constraints_loc: tensor (all parameters x dependent parameters)
Tensor that encodes how complete parameter set which includes dependent
parameters arises from indepedent parameters: all = <constraints, indep>.
This form of constraints is used in vector generalized linear models (VGLMs).
:param size_factors: size factors for X
:return: tuple: (groupwise_means, mu, rmsd)
"""
return closedform_glm_mean(
x=x,
dmat=design_loc,
constraints=constraints_loc,
size_factors=size_factors,
link_fn=link_fn,
inv_link_fn=inv_link_fn,
)


def init_par(model, init_location: str) -> Tuple[np.ndarray, np.ndarray, bool, bool]:
r"""
standard:
Only initialise intercept and keep other coefficients as zero.
closed-form:
Initialize with Maximum Likelihood / Maximum of Momentum estimators
Idea:
$$
\theta &= f(x) \\
\Rightarrow f^{-1}(\theta) &= x \\
&= (D \cdot D^{+}) \cdot x \\
&= D \cdot (D^{+} \cdot x) \\
&= D \cdot x' = f^{-1}(\theta)
$$
"""
train_loc = False

def auto_loc(dmat: Union[np.ndarray, dask.array.core.Array]) -> str:
"""
Checks if dmat is one-hot encoded and returns 'closed_form' if so, else 'standard'
:param dmat The design matrix to check.
"""
unique_params = np.unique(dmat)
if isinstance(unique_params, dask.array.core.Array):
unique_params = unique_params.compute()
if len(unique_params) == 2 and unique_params[0] == 0.0 and unique_params[1] == 1.0:
return "closed_form"
logger.warning(
(
"Cannot use 'closed_form' init for loc model: "
"design_loc is not one-hot encoded. Falling back to standard initialization."
)
)
return "standard"

groupwise_means = None

init_location_str = init_location.lower()
# Chose option if auto was chosen
if init_location_str == "auto":

init_location_str = auto_loc(model.design_loc)

if init_location_str == "closed_form":
groupwise_means, init_theta_location, rmsd_a = closedform_poisson_glm_loglam(
x=model.x,
design_loc=model.design_loc,
constraints_loc=model.constraints_loc,
size_factors=model.size_factors,
link_fn=lambda lam: np.log(lam + np.nextafter(0, 1, dtype=lam.dtype)),
)
# train mu, if the closed-form solution is inaccurate
train_loc = not (np.all(np.abs(rmsd_a) < 1e-20) or rmsd_a.size == 0)
if model.size_factors is not None:
if np.any(model.size_factors != 1):
train_loc = True

elif init_location_str == "standard":
overall_means = np.mean(model.x, axis=0) # directly calculate the mean
init_theta_location = np.zeros([model.num_loc_params, model.num_features])
init_theta_location[0, :] = np.log(overall_means)
train_loc = True
elif init_location_str == "all_zero":
init_theta_location = np.zeros([model.num_loc_params, model.num_features])
train_loc = True
else:
raise ValueError("init_location string %s not recognized" % init_location)

# Scale is not used so just return init_theta_location for what would be init_theta_scale
return init_theta_location, init_theta_location, train_loc, True
2 changes: 1 addition & 1 deletion batchglm/train/numpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from . import glm_nb as nb
from . import glm_nb, glm_poisson
2 changes: 2 additions & 0 deletions batchglm/train/numpy/glm_poisson/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .estimator import Estimator
from .model_container import ModelContainer
58 changes: 58 additions & 0 deletions batchglm/train/numpy/glm_poisson/estimator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import sys
from typing import Optional, Tuple, Union

import numpy as np

from .external import EstimatorGlm, Model, init_par
from .model_container import ModelContainer


class Estimator(EstimatorGlm):
"""
Estimator for Generalized Linear Models (GLMs) with negative binomial noise.
Uses the natural logarithm as linker function.
Attributes
----------
model_vars : ModelVars
model variables
"""

def __init__(
self,
model: Model,
init_location: str = "AUTO",
init_scale: str = "AUTO",
# batch_size: Optional[Union[Tuple[int, int], int]] = None,
quick_scale: bool = False,
dtype: str = "float64",
):
"""
Performs initialisation and creates a new estimator.
:param init_location: (Optional)
Low-level initial values for a. Can be:
- str:
* "auto": automatically choose best initialization
* "standard": initialize intercept with observed mean
* "init_model": initialize with another model (see `ìnit_model` parameter)
* "closed_form": try to initialize with closed form
- np.ndarray: direct initialization of 'a'
:param dtype: Numerical precision.
"""
init_theta_location, _, train_loc, _ = init_par(model=model, init_location=init_location)
self._train_loc = train_loc
# no need to train the scale parameter for the poisson model since it only has one parameter
self._train_scale = False
sys.stdout.write("training location model: %s\n" % str(self._train_loc))
init_theta_location = init_theta_location.astype(dtype)

_model_container = ModelContainer(
model=model,
init_theta_location=init_theta_location,
init_theta_scale=init_theta_location, # Not used.
chunk_size_genes=model.chunk_size_genes,
dtype=dtype,
)
super(Estimator, self).__init__(model_container=_model_container, dtype=dtype)
8 changes: 8 additions & 0 deletions batchglm/train/numpy/glm_poisson/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
class NoScaleError(Exception):
"""
Exception raised for attempting to access the scale parameter (or one of its derived methods) of a poisson model.
"""

def __init__(self, method):
self.message = f"Attempted to access {method}. No scale parameter is fit for poisson - please use location."
super().__init__(self.message)
9 changes: 9 additions & 0 deletions batchglm/train/numpy/glm_poisson/external.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import batchglm.utils.data as data_utils
from batchglm import pkg_constants
from batchglm.models.base_glm.utils import closedform_glm_mean, closedform_glm_scale
from batchglm.models.glm_poisson.model import Model
from batchglm.models.glm_poisson.utils import init_par

# import necessary base_glm layers
from batchglm.train.numpy.base_glm import BaseModelContainer, EstimatorGlm
from batchglm.utils.linalg import groupwise_solve_lm
Loading