Skip to content

Commit

Permalink
add typing module in api to fix docs and relayed some input data atrr…
Browse files Browse the repository at this point in the history
…ibutes to model
  • Loading branch information
davidsebfischer committed Aug 22, 2019
1 parent 261e73c commit b19257f
Show file tree
Hide file tree
Showing 8 changed files with 70 additions and 8 deletions.
1 change: 1 addition & 0 deletions batchglm/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@

from . import models
from . import data
from . import typing
from . import utils
from .. import pkg_constants
2 changes: 2 additions & 0 deletions batchglm/api/typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from batchglm.models.base.estimator import EstimatorBaseTyping
from batchglm.models.base.input import InputDataBaseTyping
4 changes: 2 additions & 2 deletions batchglm/models/base/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .input import _InputDataBase
from .estimator import _EstimatorBase
from .input import _InputDataBase, InputDataBaseTyping
from .estimator import _EstimatorBase, EstimatorBaseTyping
from .model import _ModelBase
from .simulator import _SimulatorBase
6 changes: 6 additions & 0 deletions batchglm/models/base/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,3 +291,9 @@ def _plot_deviation(
else:
return


class EstimatorBaseTyping(_EstimatorBase):
r"""
Estimator base class used for typing in other packages.
"""

5 changes: 5 additions & 0 deletions batchglm/models/base/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,8 @@ def fetch_x_sparse(self, idx):
data_idx = np.squeeze(data_idx, axis=0)

return data_idx, data_val, data_shape

class InputDataBaseTyping:
"""
Input data base class used for typing in other packages.
"""
1 change: 0 additions & 1 deletion batchglm/models/base/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import abc
import numpy as np
from typing import Union, Any, Dict, Iterable
import logging

Expand Down
28 changes: 24 additions & 4 deletions batchglm/models/base_glm/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ class InputDataGLM(_InputDataBase):
"""
Input data for Generalized Linear Models (GLMs).
"""
loc_names: list
design_loc_names: list
scale_names: list
design_scale_names: list

def __init__(
self,
Expand Down Expand Up @@ -94,8 +98,8 @@ def __init__(

self.design_loc = design_loc
self.design_scale = design_scale
self.design_loc_names = design_loc_names
self.design_scale_names = design_scale_names
self._design_loc_names = design_loc_names
self._design_scale_names = design_scale_names

constraints_loc, loc_names = parse_constraints(
dmat=design_loc,
Expand All @@ -111,11 +115,27 @@ def __init__(
)
self.constraints_loc = constraints_loc
self.constraints_scale = constraints_scale
self.loc_names = loc_names
self.scale_names = scale_names
self._loc_names = loc_names
self._scale_names = scale_names

self.size_factors = size_factors

@property
def design_loc_names(self):
return self._design_loc_names

@property
def design_scale_names(self):
return self._design_scale_names

@property
def loc_names(self):
return self._loc_names

@property
def scale_names(self):
return self._scale_names

@property
def num_design_loc_params(self):
return self.design_loc.shape[1]
Expand Down
31 changes: 30 additions & 1 deletion batchglm/models/base_glm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
anndata = None

from .external import _ModelBase
from .input import InputDataGLM


class _ModelGLM(_ModelBase, metaclass=abc.ABCMeta):
Expand All @@ -26,7 +27,7 @@ class _ModelGLM(_ModelBase, metaclass=abc.ABCMeta):

def __init__(
self,
input_data
input_data: InputDataGLM
):
_ModelBase.__init__(
self=self,
Expand Down Expand Up @@ -63,6 +64,34 @@ def constraints_scale(self) -> np.ndarray:
else:
return self.input_data.constraints_scale

@property
def design_loc_names(self) -> list:
if self.input_data is None:
return None
else:
return self.input_data.design_loc_names

@property
def design_scale_names(self) -> list:
if self.input_data is None:
return None
else:
return self.input_data.design_scale_names

@property
def loc_names(self) -> list:
if self.input_data is None:
return None
else:
return self.input_data.loc_names

@property
def scale_names(self) -> list:
if self.input_data is None:
return None
else:
return self.input_data.scale_names

@abc.abstractmethod
def eta_loc(self) -> np.ndarray:
pass
Expand Down

0 comments on commit b19257f

Please sign in to comment.