diff --git a/pyphi/__tpm.py b/pyphi/__tpm.py new file mode 100644 index 000000000..31124dd27 --- /dev/null +++ b/pyphi/__tpm.py @@ -0,0 +1,676 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# __tpm.py + +""" +TPM classes. + +TPM -> Works with all state-by-state tpms +(binary, nonbinary, symmetric, asymmetric -> node TPMs) + +SbN -> Only works with state-by-node tpms, +which are only binary. Can work with asymmetric tpms, however. +""" + +import xarray as xr +from pandas import DataFrame +import numpy as np +from string import ascii_uppercase +from itertools import product +from math import log2 + +from pyphi.connectivity import get_inputs_from_cm +from pyphi.distribution import repertoire_shape + +from .utils import all_states + +# labels provided? + +# type checking: Union[ Iterable[Node]? ] +# mechanisms and purviews: Union[ +# Set[Node] +# Iterable[Node] +# ] +# Node = Union[int, str]? + +# Set[] + + +# Pass method calls to xarray? + +# Should use only one underlying structure for TPM, but subclass when it's a +# matter of space or time +# e.g. subclass for state-by-node + +class TPM: + """General class of TPM objects in state-by-state form. + + The standard TPM class represents a State-by-State matrix where each row defines the + probability that the row's state will transition to any given column state. + + Xarray was chosen for its ability to define each node as a dimension, + easily accessible by its label and marginalizable. + + Can accept both DataFrame objects and 2D or pre-shaped multidimensional + numpy arrays. + """ + def __init__(self, tpm, p_nodes=None, p_states=None, n_nodes=None, n_states=None): + + if isinstance(tpm, DataFrame): + p_states = [len(node_states) for node_states in tpm.index.levels] + n_states = [len(node_states) for node_states in tpm.columns.levels] + p_nodes = tpm.index.names + n_nodes = tpm.columns.names + dims = ["{}_p".format(node) for node in p_nodes] + [ + "{}_n".format(node) for node in n_nodes] + coords = dict(zip(dims, [np.array(range(state)) for state in p_states + n_states])) + self.tpm = xr.DataArray(tpm.values.reshape(p_states + n_states, order="F"), coords=coords, dims=dims) + + + else: + + if isinstance(tpm, xr.DataArray): + tpm = tpm.data + + if p_states is None: # Binary + if p_nodes is None: + # Gen p_nodes + p_nodes = ["n{}".format(i) for i in range(int(log2(tpm.shape[0])))] + + if n_nodes is None: + # NOTE Specifying a Node TPM with just data and p_nodes (and perhaps p_states and n_states if nb) + # seems really useful, so I want a label generation method, but this can give awkward results + # Gen n_nodes + n_nodes = ["n{}".format(i) for i in range(int(log2(tpm.shape[1])))] + + # Gen p_states + p_states = [2] * len(p_nodes) + # Gen n_states + n_states = [2] * len(n_nodes) + + else: # Non-binary + if p_nodes is None: + # Gen p_nodes + p_nodes = ["n{}".format(i) for i in range(len(p_states))] + + if n_nodes is None: # Nbin, Sym + if n_states is None: + # Cpy p_states -> n_states + n_states = p_states.copy() + # Cpy p_nodes -> n_nodes + n_nodes = p_nodes.copy() + + else: # n_states specified so assumed different from p_states? + n_nodes = ["n{}".format(i) for i in range(len(n_states))] + + # Else Nbin, Asym, pass as should have all specified + + + # Previous nodes are labelled with _p and next nodes with _n + # to differentiate between nodes with the same name but different timestep + dims = ["{}_p".format(node) for node in p_nodes] + [ + "{}_n".format(node) for node in n_nodes] + + # The given numpy array is reshaped to have the appropriate number of dimensions for labeling + # Fortran, or little-endian ordering is used, meaning left-most (first called) node varies fastest + coords = dict(zip(dims, [np.array(range(state)) for state in p_states + n_states])) + self.tpm = xr.DataArray(tpm.reshape(p_states + n_states, order="F"), coords=coords, dims=dims) + + self.symmetric = bool(p_nodes == n_nodes) + self._p_nodes = p_nodes + self._n_nodes = n_nodes + + # So far only used in infer_cm, might be possible to rework that and remove this + # Could also just be a list (or implicit in sum of p_nodes and n_nodes), states could + # be accessed by tpm.shape[[p_nodes + n_nodes].index(node)] + self.all_nodes = dict(zip(self.p_nodes + self.n_nodes, self.tpm.shape)) + + @property + def p_nodes(self): + """Returns p_nodes formatted""" + return ["{}_p".format(node) for node in self._p_nodes] + + @property + def n_nodes(self): + """Returns n_nodes formatted""" + return ["{}_n".format(node) for node in self._n_nodes] + + @property + def p_states(self): + """Returns list of p_states""" + return [self.all_nodes[p_node] for p_node in self.p_nodes] + + @property + def n_states(self): + """Returns list of n_states""" + return [self.all_nodes[n_node] for n_node in self.n_nodes] + + # Maybe make this more generalized? Could be based on xarray's shape? + # Maybe make it just one tuple? Hard to separate then... + @property + def tpm_indices(self): + """Returns two tuples of indices for p_nodes and n_nodes""" + return tuple(range(len(self.p_nodes))), tuple(range(len(self.n_nodes))) + + # return tuple(range(len(self.tpm.shape))) + # return tuple(range(len(self.p_nodes)) + range(len(self.n_nodes)))) + + @property + def is_deterministic(self) -> bool: + """Return whether the TPM is deterministic.""" + return np.all(np.logical_or(self.tpm.data == 0, self.tpm.data == 1)) + + # Could get overridden by State-by-Node TPM's subclass to false? Or unneeded + @property + def is_state_by_state(self): + return True + + @property + def num_states(self): + """int: Number of possible states the set of previous nodes can be in + """ + # Number of states equal to product of states of p_nodes? + # This doesn't really make sense in asymmetric TPMs, so maybe a different property + # Makes more sense here? Or could split to num_p_states and num_n_states + num_states = 1 + for node in self.p_nodes: + num_states *= self.all_nodes[node] + return num_states + + @property + def num_n_states(self): + """int: Number of possible states the set of next nodes can be in + """ + num_states = 1 + for node in self.n_nodes: + num_states *= self.all_nodes[node] + return num_states + + @property + def shape(self): + return tuple(self.tpm.shape) + + @property + def p_node_indices(self): + return tuple(range(len(self.p_nodes))) + + @property + def n_node_indices(self): + return tuple(range(len(self.n_nodes))) + + # TODO Maybe try using xarray's coordinates feature to keep dims and state info? + # As opposed to trying to just use slices, keeping the coordinate of the dimension + # equal to the conditioned state could make things easier down the line + # TODO Consider splitting into condition rows and condition columns, then could mix + # into conditioning for asymmetric tpms (or force square dimensionality with empty dims) + def condition(self, fixed_nodes, state): + """Return a TPM conditioned on the given fixed node indices, whose states + are fixed according to the given state-tuple. + + The dimensions of the new TPM that correspond to the fixed nodes are + collapsed onto their state, making those dimensions singletons suitable for + broadcasting. The number of dimensions of the conditioned TPM will be the + same as the unconditioned TPM. + + Args: + fixed_nodes: tuple of indicies of nodes that are fixed + """ + # Only doing this for symmetric tpms at the moment + if self.symmetric: + fixed_node_labels = [self.p_nodes[node] for node in fixed_nodes] + indexer = dict(zip(fixed_node_labels, [state[node] for node in fixed_nodes])) + # Drop rows that don't fit with the conditioned nodes + # TODO Better way to keep dims? IndexSlicer maybe? + # Tried using None slices but dimensions still got removed :( + kept_dims = [self.tpm.dims.index(label) for label in fixed_node_labels] + conditioned_tpm = self.tpm.loc[indexer] + + # Regrow dimensions where they got trimmed + for i in range(len(kept_dims)): + conditioned_tpm = conditioned_tpm.expand_dims(fixed_node_labels[i], axis=kept_dims[i]) + + # Marginalize across columns that don't fit with the conditioned nodes + # Because assumed symmetry, self.n_nodes is same as self.p_nodes, except for labelling + # Since we're summing across columns, need labels from n_nodes + + column_labels = [self.n_nodes[node] for node in fixed_nodes] + + for label in column_labels: + conditioned_tpm = conditioned_tpm.sum(dim=label, keepdims=True) + # At some point maybe change the dict of self.all_nodes? Perhaps ideally make all_nodes unneeeded + + for i in fixed_nodes: + self.p_states[i] = 1 + self.all_nodes[self.p_nodes[i]] = 1 + self.all_nodes[self.n_nodes[i]] = 1 + + return conditioned_tpm + + else: + print(self, "symmetric?:", self.symmetric) + raise NotImplementedError + + def create_node_tpm(self, index, cm): + """Create a new TPM object based on this one, except only describing the + transitions of a single node in the network. + + Args: + node (int): The index of the node whose TPM we wish to create + """ + # Want to marginalize out all other column nodes, but then don't need to worry about normalization + node = self.n_nodes[index] + other_nodes = [label for label in self.n_nodes if label != node] + node_tpm = self.tpm.sum(dim=other_nodes) + + node_tpm = TPM(tpm=node_tpm, p_nodes=self._p_nodes, n_nodes=[self._n_nodes[index]], p_states=self.p_states, n_states=[self.all_nodes[node]]) + + return node_tpm + #return node_tpm.marginalize_out(tuple(set(self.p_node_indices) - set(get_inputs_from_cm(index, cm))), rows=True) + + def condition_node_tpm(node_tpm, fixed_nodes, state, col=False): + """Condition a node TPM object, a special case of asymmetric TPMs + """ + fixed_node_labels = [node_tpm.p_nodes[node] for node in fixed_nodes] + indexer = dict(zip(fixed_node_labels, [state[node] for node in fixed_nodes])) + + kept_dims = [node_tpm.tpm.dims.index(label) for label in fixed_node_labels] + conditioned_tpm = node_tpm.tpm.loc[indexer] + + # Regrow dimensions where they got trimmed + for j in range(len(kept_dims)): + conditioned_tpm = conditioned_tpm.expand_dims(fixed_node_labels[j], axis=kept_dims[j]) + + if col: + conditioned_tpm = conditioned_tpm.sum(dim=conditioned_tpm.dims[-1], keepdims=True) + + return conditioned_tpm + + def condition_list(tpm_list, fixed_nodes, state): + """Condition a list of Node TPMs. + + Args: + tpm_list (list[TPM]): The Node TPMs to be conditioned + fixed_nodes (tuple(int)): The node indicies that are now fixed + state (tuple(int)): The state of the system when conditioning + """ + # Step 0: No conditioning required if fixed_nodes empty + if fixed_nodes is (): + return tpm_list + + # Step 1: Drop rows that do not fit with the conditioned state, for all tpms + for i in range(len(tpm_list)): + # Replace TPM in list's tpm with dropped row tpm. + tpm_list[i].tpm = TPM.condition_node_tpm(tpm_list[i], fixed_nodes, state) + + # Step 2: If a particular tpm describes the transitions of a fixed node, sum the column dimension together + for i in fixed_nodes: + # NOTE: Assumes Node TPM only has 1 column dimension. Can be generalized, but needs more information about labels. + # Since we're already assuming a list, however, this is a valid assumption. See the more general condition method + # For implementing generic asymmetric conditioning + tpm_list[i].tpm = tpm_list[i].tpm.sum(dim=tpm_list[i].tpm.dims[-1], keepdims=True) + + return tpm_list + + + # TODO Currently only works for symmetric TPMs + # TODO **kwargs to determine if marginalizing out just row/column? + def marginalize_out(self, node_indices, rows=False): + """Marginalize out nodes from a TPM. + + Args: + node_indices (list[int]): The indices of nodes to be marginalized out. + Index based on dimension of the node. + rows (bool): Whether to marginalize on only the rows + + + Returns: + xarray: A tpm with the same number of dimensions, with the nodes + marginalized out. + """ + def normalize(tpm): + """Returns a normalized TPM after marginalization""" + return tpm / (np.array(self.tpm.shape)[list(node_indices)].prod()) + + if self.symmetric: + labels = [self.p_nodes[i] for i in node_indices] + [self.n_nodes[i] for i in node_indices] + new_p_states = self.p_states + new_n_states = self.n_states + for i in node_indices: + new_p_states[i], new_n_states[i] = 1 + + elif rows: + labels = [self.p_nodes[i] for i in node_indices] + new_p_states = self.p_states + for i in node_indices: + new_p_states[i] = 1 + new_n_states = self.n_states + + else: + raise NotImplementedError + + marginalized_tpm = self.tpm + + for label in labels: + marginalized_tpm = marginalized_tpm.sum(dim=label, keepdims=True) + + finished_tpm = normalize(marginalized_tpm) + + return TPM(tpm=finished_tpm, p_nodes=self._p_nodes, n_nodes=self._n_nodes, p_states=new_p_states, n_states=new_n_states) + + + def infer_edge(tpm, a, b, contexts): + """Infer the presence or absence of an edge from node A to node B. + + Let |S| be the set of all nodes in a network. Let |A' = S - {A}|. We call + the state of |A'| the context |C| of |A|. There is an edge from |A| to |B| + if there exists any context |C(A)| such that |Pr(B | C(A), A=0) != Pr(B | + C(A), A=1)|. + + Args: + tpm (np.ndarray): The TPM as an object + a (int): The index of the putative source node. + b (int): The index of the putative sink node. + Returns: + bool: ``True`` if the edge |A -> B| exists, ``False`` otherwise. + """ + def a_in_context(context): + """Given a context C(A), return the states of the full system with A + in each of its possible states, in order as a list. + """ + a_states = [ + context[:a] + (i, ) + context[a:] + for i in range(tpm.tpm.shape[a]) + ] + + return a_states + + def marginalize_b(state): + """Return the distribution of possible states of b at t+1""" + temp_n_nodes = tpm.n_nodes + name = temp_n_nodes[b] + # Instead of making a full copy, just remove and insert afterwards + temp_n_nodes.remove(name) + marginalized = tpm.tpm.groupby(name).sum(temp_n_nodes).loc[state] + temp_n_nodes.insert(b, name) + return marginalized + + def a_affects_b_in_context(context): + """Return ``True`` if A has an effect on B, given a context.""" + a_states = a_in_context(context) + comparator = marginalize_b(tuple(a_states[0])).data.round(12) + return any(not np.array_equal(comparator, marginalize_b(state).data.round(12)) for state in a_states[1:]) + + return any(a_affects_b_in_context(context) for context in contexts) + + # Takes TPM object, could use self instead? + def infer_cm(tpm): + """Infer the connectivity matrix associated with a state-by-state TPM in + object form. + """ + # Set up empty cm based on nodes + cm = np.empty((len(tpm._p_nodes), len(tpm._n_nodes)), dtype=int) + # Iterate through every node pair + for a, b in np.ndindex(cm.shape): + # Determine context states based on a + a_prime = tpm.p_nodes + a_prime.pop(a) + contexts = tuple(product(*tuple(tuple(range(tpm.all_nodes[node])) for node in a_prime))) + cm[a][b] = tpm.infer_edge(a, b, contexts) + return cm + + def infer_node_edge(tpm, a, contexts): + + def a_in_context(context): + """Given a context C(A), return the states of the full system with A + in each of its possible states, in order as a list. + """ + a_states = [ + context[:a] + (i, ) + context[a:] + for i in range(tpm.tpm.shape[a]) + ] + return a_states + + def a_affects_b_in_context(context): + a_states = np.array(a_in_context(context)) + return any(not np.array_equal(tpm.tpm.data[tuple(a_states[0])], tpm.tpm.data[tuple(state)]) for state in a_states[1:]) + + return any(a_affects_b_in_context(context) for context in contexts) + + def infer_node_cm(tpm): + cm = np.empty((len(tpm._p_nodes), 1), dtype=int) + + for a, b in np.ndindex(cm.shape): + # Determine context states based on a + a_prime = tpm.p_nodes + a_prime.pop(a) + contexts = tuple(product(*tuple(tuple(range(tpm.all_nodes[node])) for node in a_prime))) + cm[a][b] = tpm.infer_node_edge(a, contexts) + + return cm + + def __getitem__(self, key): + return self.tpm[key] + + # TODO maybe not needed? + def expand_tpm(tpm): + """Broadcast a state-by-node TPM so that singleton dimensions are expanded + over the full network. + """ + raise NotImplementedError + + @classmethod + def from_numpy(): + pass + + @classmethod + def from_xarray(): + pass + + @classmethod + def from_pands(): + pass + + def copy(self): + p_states = [self.tpm.shape[self.tpm.dims.index(node)] for node in self.p_nodes] + n_states = [self.tpm.shape[self.tpm.dims.index(node)] for node in self.n_nodes] + return TPM(self.tpm, p_nodes=self._p_nodes, p_states=p_states, n_nodes=self._n_nodes, n_states=n_states) + + def sum(self, dim, keepdims=True): + return self.tpm.sum(dim, keepdims) + + def __repr__(self): + return self.tpm.__repr__() + +# TODO Better name? +class SbN(TPM): + """The subclass of <SbN> represents a State-by-Node matrix, only usable for binary + systems, where each row represents the probability of each column Node being ON + during the timestep after the given row's state. + """ + def __init__(self, tpm, p_nodes=None, p_states=None, n_nodes=None, n_states=None): + format = False + + if isinstance(tpm, xr.DataArray): + tpm = tpm.data + + # Not multi-dimensional SbN + if tpm.ndim == 2: + if tpm.shape[0] == tpm.shape[1]: # SbS + format = True + elif n_nodes: # If n_nodes isn't given, we have to assume SbN unless we ask the user to specify a flag + if len(n_nodes) != tpm.shape[1]: # SbS + format = True + + # If format is true, change from SbS to SbN + if format: + super().__init__(tpm, p_nodes, p_states, n_nodes, n_states) + #bin_node_tpms = [self.tpm.sel({node: 1}).sum(self.n_nodes.copy().remove(node)).expand_dims("nodes", axis=-1) + #for node in self.n_nodes] + bin_node_tpms=[] + # TODO Is there a way to do this with some form of list comp? unfortunately using .copy() + # seems to break things if I try list comp + # TODO Can we use the coordinates property to name the nodes in the "nodes" dimension? + for node in self.n_nodes: + temp = self.n_nodes.copy() + temp.remove(node) + bin_node_tpms.append(self.tpm.sel({node: 1}).sum(temp).expand_dims("n_nodes", axis=-1)) + self.tpm = xr.concat(bin_node_tpms, dim="n_nodes") + + # If format is false, is already in SbN form + else: + if p_nodes is None: + p_nodes = ["n{}".format(i) for i in range(int(log2(np.prod(tpm.shape[:-1]))))] + + # NOTE: This will produce a different naming scheme in some instances than using + # the super constructor, not too big of a deal but worth noting + # Fixing it would probably require unnecessary convolution, only worthwhile + # if it is actually an issue + if n_nodes is None: + n_nodes = ["n{}".format(i) for i in range(tpm.shape[-1])] + + self._p_nodes = p_nodes + # Differences: Shape is going to be (S_a, S_b, S_c... N), rows are like normal but index of last is the size of the n_nodes list + dims = self.p_nodes + ["n_nodes"] + + # Binary only, so num_states_per_node is tuple of 2s with length of p_nodes + p_states = [2] * len(p_nodes) + + # Need to keep track of location of nodes in the last dimension + self._n_nodes = n_nodes + self.tpm = xr.DataArray(tpm.reshape(p_states + [len(n_nodes)], order="F"), dims=dims) + + self.all_nodes = self.all_nodes = dict(zip(self.p_nodes + self.n_nodes, [2 for i in self.p_nodes + self.n_nodes])) + + # TODO Only valid for symmetric tpms + def tpm_indices(self): + """Return the indices of nodes in the SbN.""" + return tuple(np.where(np.array(self.tpm.shape[:-1]) == 2)[0]) + + def is_state_by_state(self): + return False + + def get_node_transitions(self, state, external): + """Intended for node TPMs only; but viable for any. + Returns the np.array of the transition data for a given state. + """ + if not external: + n_tpm = self.tpm + indexer = dict({self.tpm.dims[-1]: state}) + return n_tpm.loc[indexer].data + + else: + return self.tpm.data + + + # SbN form + def infer_edge(tpm, a, b, contexts): + """Infer the presence or absence of an edge from node A to node B. + + Let |S| be the set of all nodes in a network. Let |A' = S - {A}|. We call + the state of |A'| the context |C| of |A|. There is an edge from |A| to |B| + if there exists any context |C(A)| such that |Pr(B | C(A), A=0) != Pr(B | + C(A), A=1)|. + + Args: + tpm (SbN): The TPM in state-by-node, multidimensional form. + a (int): The index of the putative source node. + b (int): The index of the putative sink node. + + Returns: + bool: ``True`` if the edge |A -> B| exists, ``False`` otherwise. + """ + def a_in_context(context): + """Given a context C(A), return the states of the full system with A + OFF (0) and ON (1), respectively. + """ + a_off = context[:a] + (0, ) + context[a:] + a_on = context[:a] + (1, ) + context[a:] + return (a_off, a_on) + + def a_affects_b_in_context(context): + """Return ``True`` if A has an effect on B, given a context.""" + a_off, a_on = a_in_context(context) + return tpm.tpm[a_off][b] != tpm.tpm[a_on][b] + + return any(a_affects_b_in_context(context) for context in contexts) + + # SbN form + def infer_cm(tpm): + """Infer the connectivity matrix associated with a SbN tpm in + multidimensional form. + """ + network_size = tpm.tpm.shape[-1] + all_contexts = tuple(all_states(network_size - 1)) + cm = np.empty((network_size, network_size), dtype=int) + for a, b in np.ndindex(cm.shape): + cm[a][b] = SbN.infer_edge(tpm, a, b, all_contexts) + return cm + + + def marginalize_out(self, node_indices): + """Marginalize out nodes from a TPM. + + Args: + node_indices (list[int]): The indices of nodes to be marginalized out. + tpm (np.ndarray): The TPM to marginalize the node out of. + + Returns: + np.ndarray: An SbN with the same number of dimensions, with the nodes + marginalized out. + """ + def normalize(tpm): + return tpm / (np.array(self.tpm.shape)[list(node_indices)].prod()) + + marginalized_tpm = self.tpm + + for label in [self.p_nodes[i] for i in node_indices]: + marginalized_tpm = marginalized_tpm.sum(dim=label, keepdims=True) + + return normalize(marginalized_tpm) + + def condition(self, fixed_nodes, state): + """Return a TPM conditioned on the given fixed node indices, whose states + are fixed according to the given state-tuple. + + The dimensions of the new TPM that correspond to the fixed nodes are + collapsed onto their state, making those dimensions singletons suitable for + broadcasting. The number of dimensions of the conditioned TPM will be the + same as the unconditioned TPM. + """ + fixed_node_labels = [self.p_nodes[node] for node in fixed_nodes] + indexer = dict(zip(fixed_node_labels, [state[node] for node in fixed_nodes])) + kept_dims = [self.tpm.dims.index(label) for label in fixed_node_labels] + # Throw out rows that don't fit with fixed node states, don't need to worry + # about columns as they are already essentially marginalized + conditioned_tpm = self.tpm.loc[indexer] + + # Regrow dimensions where they got trimmed + for i in range(len(kept_dims)): + conditioned_tpm = conditioned_tpm.expand_dims(dim=fixed_node_labels[i], axis=kept_dims[i]) + return conditioned_tpm + + def copy(self): + return SbN(self.tpm, p_nodes=self._p_nodes, n_nodes=self._n_nodes) + + def create_node_tpm(self, index, cm): + # Grab the column which holds the transition data for the desired node + indexer = dict({"n_nodes": index}) + node_tpm = self.tpm.loc[indexer] + + # Expand the dimension so that we can add more data + node_tpm = node_tpm.expand_dims(dim='n0_n', axis=-1) + + # Pad the trimmed node_tpm, with NaN values with length 1 before + # The current data (as current data is when the node is at state 1) + mapping = dict({"n0_n": (1, 0)}) + node_tpm = node_tpm.pad(mapping) + + # Replace NaN data with data on when it will in state 0 (1 - state 1) + node_tpm[..., 0] = 1 - node_tpm[..., 1] + + node_TPM = TPM(node_tpm, p_nodes=self._p_nodes).marginalize_out(tuple(set(self.p_node_indices) - set(get_inputs_from_cm(index, cm))), rows=True) + # Create new TPM object with this data + return node_TPM + + def __repr__(self): + return 1 diff --git a/pyphi/compute/subsystem.py b/pyphi/compute/subsystem.py index aaf07fb57..45a3a8133 100644 --- a/pyphi/compute/subsystem.py +++ b/pyphi/compute/subsystem.py @@ -231,7 +231,7 @@ def _ces(subsystem): return ces(subsystem, parallel=config.PARALLEL_CUT_EVALUATION) -@memory.cache(ignore=["subsystem"]) +#@memory.cache(ignore=["subsystem"]) @time_annotated def _sia(cache_key, subsystem): """Return the minimal information partition of a subsystem. diff --git a/pyphi/distribution.py b/pyphi/distribution.py index a11724f50..14f6c7048 100644 --- a/pyphi/distribution.py +++ b/pyphi/distribution.py @@ -104,7 +104,7 @@ def purview_size(repertoire): return len(purview(repertoire)) -def repertoire_shape(purview, N): # pylint: disable=redefined-outer-name +def repertoire_shape(purview, N, states=None): # pylint: disable=redefined-outer-name """Return the shape a repertoire. Args: @@ -122,8 +122,12 @@ def repertoire_shape(purview, N): # pylint: disable=redefined-outer-name >>> repertoire_shape(purview, N) [2, 1, 2] """ - # TODO: extend to non-binary nodes - return [2 if i in purview else 1 for i in range(N)] + #TODO Remove once fully transitioned, as then states should never be None + if states is not None: + return [states[i] if i in purview else 1 for i in range(N)] + + else: + return [2 if i in purview else 1 for i in range(N)] def flatten(repertoire, big_endian=False): @@ -173,7 +177,7 @@ def unflatten(repertoire, purview, N, big_endian=False): @cache(cache={}, maxmem=None) -def max_entropy_distribution(node_indices, number_of_nodes): +def max_entropy_distribution(node_indices, number_of_nodes, states=None): """Return the maximum entropy distribution over a set of nodes. This is different from the network's uniform distribution because nodes @@ -184,10 +188,12 @@ def max_entropy_distribution(node_indices, number_of_nodes): node_indices (tuple[int]): The set of node indices over which to take the distribution. number_of_nodes (int): The total number of nodes in the network. + states (tuple): The states of each node in the network Returns: np.ndarray: The maximum entropy distribution over the set of nodes. """ - distribution = np.ones(repertoire_shape(node_indices, number_of_nodes)) + # TODO Remove once fully transitioned + distribution = np.ones(repertoire_shape(node_indices, number_of_nodes, states)) return distribution / distribution.size diff --git a/pyphi/network.py b/pyphi/network.py index f7d2e4271..2cb921fb8 100644 --- a/pyphi/network.py +++ b/pyphi/network.py @@ -9,9 +9,10 @@ import numpy as np -from . import cache, config, connectivity, convert, jsonify, utils, validate +from . import cache, config, connectivity, convert, jsonify, utils, validate, node from .labels import NodeLabels from .tpm import is_state_by_state +from .__tpm import TPM, SbN class Network: @@ -47,8 +48,10 @@ class Network: that node |i| is connected to node |j| (see :ref:`cm-conventions`). **If no connectivity matrix is given, PyPhi assumes that every node is connected to every node (including itself)**. - node_labels (tuple[str] or |NodeLabels|): Human-readable labels for - each node in the network. + p_nodes (list[str]): Human-readable list of names of nodes at time |t-1| + p_states (list[int]): List of the number of states of each node at time |t-1| (necessary for defining a multi-valued tpm) + n_nodes (list[str]): Human-readable list of names of nodes at time |t| + n_states (list[int]): List of the number of states of each node at time |t| (necessary for defining a multi-valued tpm) Example: In a 3-node network, ``the_network.tpm[(0, 0, 1)]`` gives the @@ -57,13 +60,40 @@ class Network: """ # TODO make tpm also optional when implementing logical network definition - def __init__(self, tpm, cm=None, node_labels=None, purview_cache=None): - self._tpm, self._tpm_hash = self._build_tpm(tpm) - self._cm, self._cm_hash = self._build_cm(cm) - self._node_indices = tuple(range(self.size)) - self._node_labels = NodeLabels(node_labels, self._node_indices) + # TODO node_labels attribute deprecated, but many tests use it so currently keeping it an option + def __init__(self, tpm, cm=None, p_nodes=None, p_states=None, n_nodes=None, n_states=None, purview_cache=None, node_labels=None): + if isinstance(tpm, list): + self._is_list = True + if node_labels: + p_nodes = node_labels + self._tpm = tpm + self._cm = self._build_cm_from_list(cm) + self._node_indices = tuple(range(len(tpm))) + # User can specify names of each node to be lined up with the list of node TPMs, else default labels are generated + self._node_labels = NodeLabels(p_nodes, self._node_indices) + + else: + self._is_list = False + if node_labels: + p_nodes = node_labels + if p_states or n_states: # Requires NB state-by-state, could just do only TPM and convert to SbN later? + self._tpm = TPM(tpm, p_nodes, p_states, n_nodes, n_states) + #validate.tpm(tpm, check_independence=config.VALIDATE_CONDITIONAL_INDEPENDENCE) + else: + self._tpm = SbN(tpm, p_nodes, p_states, n_nodes, n_states) + + self._cm = self._tpm.infer_cm() + # Convert to list + tpm_list = [self._tpm.create_node_tpm(index, self._cm) for index in range(len(self._tpm.n_nodes))] + + self._node_indices = tuple(range(self.size)) + self._node_labels = NodeLabels(self._tpm._p_nodes, self.node_indices) # TODO consider using self._tpm._p_nodes instead for more readability? + + self._tpm = tpm_list + self._is_list = True + self.purview_cache = purview_cache or cache.PurviewCache() - + validate.network(self) @property @@ -73,6 +103,7 @@ def tpm(self): """ return self._tpm + # TODO Deprecated? @staticmethod def _build_tpm(tpm): """Validate the TPM passed by the user and convert to multidimensional @@ -92,6 +123,20 @@ def _build_tpm(tpm): return (tpm, utils.np_hash(tpm)) + def _build_cm_from_list(self, cm): + """Generate the connectivity matrix for a network whose tpm is defined as a + list of Node TPMs, by concatenating the individual Node TPM cms. + + Args: + tpm_list (list[TPM]): List of Node TPMs that define how the network transitions. + """ + if cm is None: + cm_list = [TPM.infer_node_cm(node_tpm) for node_tpm in self._tpm] + return np.concatenate(cm_list, axis=1) + else: + return np.array(cm) + # return np.ones((self.size, self.size)) + @property def cm(self): """np.ndarray: The network's connectivity matrix. @@ -106,8 +151,8 @@ def _build_cm(self, cm): unitary CM if none was provided. """ if cm is None: - # Assume all are connected. - cm = np.ones((self.size, self.size)) + # Build cm from TPM method + cm = self._tpm.infer_cm() else: cm = np.array(cm) @@ -134,7 +179,15 @@ def size(self): @property def num_states(self): """int: The number of possible states of the network.""" - return 2 ** self.size + # If list, states can be counted as product of possible transitions of each node + # If not, use TPM.num_states? + # if self._is_list: + num = 1 + for tpm in self._tpm: + num *= tpm.shape[-1] + return num + # else: + # return self._tpm.num_states @property def node_indices(self): @@ -169,10 +222,19 @@ def potential_purviews(self, direction, mechanism): def __len__(self): """int: The number of nodes in the network.""" - return self.tpm.shape[-1] + if self._is_list: + return len(self._tpm) + elif isinstance(self.tpm, TPM): + # TODO Assumes symmetry for now + return len(self.tpm.p_nodes) + else: + return self.tpm.shape[-1] def __repr__(self): - return "Network({}, cm={})".format(self.tpm, self.cm) + if self._is_list: + return str([tpm for tpm in self._tpm]) #TODO Consider representations + else: + return "Network({}, cm={})".format(self.tpm, self.cm) def __str__(self): return self.__repr__() @@ -182,6 +244,8 @@ def __eq__(self, other): Networks are equal if they have the same TPM and CM. """ + if self._is_list: + return np.all([node_tpm == node_tpm for node_tpm in self._tpm]) return ( isinstance(other, Network) and np.array_equal(self.tpm, other.tpm) diff --git a/pyphi/node.py b/pyphi/node.py index c495e5e85..a80a82d2f 100644 --- a/pyphi/node.py +++ b/pyphi/node.py @@ -15,6 +15,7 @@ from .connectivity import get_inputs_from_cm, get_outputs_from_cm from .labels import NodeLabels from .tpm import marginalize_out, tpm_indices +from .__tpm import TPM, SbN # TODO extend to nonbinary nodes @@ -56,41 +57,10 @@ def __init__(self, tpm, cm, index, state, node_labels): # Get indices of the inputs. self._inputs = frozenset(get_inputs_from_cm(self.index, cm)) self._outputs = frozenset(get_outputs_from_cm(self.index, cm)) + + # Node TPM has already been generated, take from subsystem list + self.tpm = tpm[self.index] - # Generate the node's TPM. - # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - # We begin by getting the part of the subsystem's TPM that gives just - # the state of this node. This part is still indexed by network state, - # but its last dimension will be gone, since now there's just a single - # scalar value (this node's state) rather than a state-vector for all - # the network nodes. - tpm_on = tpm[..., self.index] - - # TODO extend to nonbinary nodes - # Marginalize out non-input nodes that are in the subsystem, since the - # external nodes have already been dealt with as boundary conditions in - # the subsystem's TPM. - non_inputs = set(tpm_indices(tpm)) - self._inputs - tpm_on = marginalize_out(non_inputs, tpm_on) - - # Get the TPM that gives the probability of the node being off, rather - # than on. - tpm_off = 1 - tpm_on - - # Combine the on- and off-TPM so that the first dimension is indexed by - # the state of the node's inputs at t, and the last dimension is - # indexed by the node's state at t+1. This representation makes it easy - # to condition on the node state. - self.tpm = np.stack([tpm_off, tpm_on], axis=-1) - # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - - # Make the TPM immutable (for hashing). - utils.np_immutable(self.tpm) - - # Only compute the hash once. - self._hash = hash( - (index, utils.np_hash(self.tpm), self.state, self._inputs, self._outputs) - ) @property def tpm_off(self): @@ -160,7 +130,7 @@ def generate_nodes(tpm, cm, network_state, indices, node_labels=None): """Generate |Node| objects for a subsystem. Args: - tpm (np.ndarray): The system's TPM + tpm (pyphi.TPM): The system's TPM cm (np.ndarray): The corresponding CM. network_state (tuple): The state of the network. indices (tuple[int]): Indices to generate nodes for. diff --git a/pyphi/subsystem.py b/pyphi/subsystem.py index 8b2c5862f..741dafa2f 100644 --- a/pyphi/subsystem.py +++ b/pyphi/subsystem.py @@ -24,6 +24,7 @@ from .node import generate_nodes from .partition import mip_partitions from .tpm import condition_tpm, marginalize_out +from .__tpm import TPM, SbN from .utils import time_annotated log = logging.getLogger(__name__) @@ -68,8 +69,9 @@ def __init__( # The network this subsystem belongs to. validate.is_network(network) self.network = network + - self.node_labels = network.node_labels + self.node_labels = network._node_labels # Remove duplicates, sort, and ensure native Python `int`s # (for JSON serialization). self.node_indices = self.node_labels.coerce_to_indices(nodes) @@ -87,9 +89,9 @@ def __init__( ) else: self.external_indices = _external_indices - - # The TPM conditioned on the state of the external nodes. - self.tpm = condition_tpm(self.network.tpm, self.external_indices, self.state) + # Remove external nodes from subsystem + self.node_indices = tuple([node for node in self.node_indices if node not in self.external_indices]) + self.node_labels = tuple([self.node_labels[i] for i in self.node_indices]) # The unidirectional cut applied for phi evaluation self.cut = ( @@ -99,9 +101,25 @@ def __init__( # The network's connectivity matrix with cut applied self.cm = self.cut.apply_cut(network.cm) + + # Condition TPM by conditioning every tpm in the list + # Thought: Could generalize this by having Network always use a list, just a list with one object if + # Full tpm is described... except atm condition_list assumes node TPMs because condition doesn't work for asym + # Here we need to make a copy of the TPM objects from the network, or else only one Subsystem can be generated + # Could also be done further downstream, but both function the same. + # TODO Can we make copies unnecessary? + if nodes is not (): + self.tpm = [tpm.copy() for tpm in network.tpm] + self.tpm = TPM.condition_list(self.tpm, self.external_indices, self.state) # Returns a list of conditioned Node TPMs + else: + # Empty subsystem, no nodes, 100% chance to stay in the state selected always + self.tpm = TPM(tpm = np.array([[1]])) + + self.states = tuple([tpm.tpm.shape[-1] for tpm in self.tpm]) # Tuple of number of states each node can transition to + # Reusable cache for maximally-irreducible causes and effects self._mice_cache = cache.MICECache(self, mice_cache) - + # Cause & effect repertoire caches # TODO: if repertoire caches are never reused, there's no reason to # have an accesible object-level cache. Just use a simple memoizer @@ -110,11 +128,12 @@ def __init__( ) self._repertoire_cache = repertoire_cache or cache.DictCache() + # Node objects self.nodes = generate_nodes( self.tpm, self.cm, self.state, self.node_indices, self.node_labels ) - validate.subsystem(self) + #validate.subsystem(self) @property def nodes(self): @@ -182,7 +201,7 @@ def cut_node_labels(self): @property def tpm_size(self): """int: The number of nodes in the TPM.""" - return self.tpm.shape[-1] + return len(self.tpm) def cache_info(self): """Report repertoire cache statistics.""" @@ -292,19 +311,29 @@ def indices2nodes(self, indices): raise ValueError("`indices` must be a subset of the Subsystem's indices.") return tuple(self._index2node[n] for n in indices) - # TODO extend to nonbinary nodes @cache.method("_single_node_repertoire_cache", Direction.CAUSE) def _single_node_cause_repertoire(self, mechanism_node_index, purview): # pylint: disable=missing-docstring mechanism_node = self._index2node[mechanism_node_index] - # We're conditioning on this node's state, so take the TPM for the node - # being in that state. - tpm = mechanism_node.tpm[..., mechanism_node.state] - # Marginalize-out all parents of this mechanism node that aren't in the - # purview. - return marginalize_out((mechanism_node.inputs - purview), tpm) - - # TODO extend to nonbinary nodes + # Set up some different paths while transitioning + if isinstance(mechanism_node.tpm, TPM): + # Valid to use n_nodes[0] as these are guaranteed node_tpms, therefore only have 1 n_node, itself + indexer = dict({mechanism_node.tpm.n_nodes[0]: mechanism_node.state}) + #TODO DEBUGGING + print(indexer) + print("purview:", purview) + print('inputs:', mechanism_node.inputs) + print("to_marginalize:", mechanism_node.inputs - purview) + tpm = mechanism_node.tpm.marginalize_out((mechanism_node.inputs - purview), rows=True).tpm.loc[indexer] + return tpm.data + else: + # We're conditioning on this node's state, so take the TPM for the node + # being in that state. + tpm = mechanism_node.tpm[..., mechanism_node.state] + # Marginalize-out all parents of this mechanism node that aren't in the + # purview. + return marginalize_out((mechanism_node.inputs - purview), tpm) + @cache.method("_repertoire_cache", Direction.CAUSE) def cause_repertoire(self, mechanism, purview): """Return the cause repertoire of a mechanism over a purview. @@ -330,15 +359,21 @@ def cause_repertoire(self, mechanism, purview): # state of the purview; return the purview's maximum entropy # distribution. if not mechanism: - return max_entropy_distribution(purview, self.tpm_size) + return max_entropy_distribution(purview, self.tpm_size, self.states) # Use a frozenset so the arguments to `_single_node_cause_repertoire` # can be hashed and cached. purview = frozenset(purview) # Preallocate the repertoire with the proper shape, so that # probabilities are broadcasted appropriately. - joint = np.ones(repertoire_shape(purview, self.tpm_size)) + joint = np.ones(repertoire_shape(purview, self.tpm_size, self.states)) # The cause repertoire is the product of the cause repertoires of the # individual nodes. + #TODO DEBUGGING + print("joint shape before:", joint.shape) + print("mechanism:", mechanism) + print("product shape:", functools.reduce(np.multiply, [self._single_node_cause_repertoire(m, purview) for m in mechanism]).shape) + print([self._single_node_cause_repertoire(m, purview) for m in mechanism]) + print("purview:", purview) joint *= functools.reduce( np.multiply, [self._single_node_cause_repertoire(m, purview) for m in mechanism], @@ -348,19 +383,29 @@ def cause_repertoire(self, mechanism, purview): # TPM don't necessarily sum to 1, so we normalize. return distribution.normalize(joint) - # TODO extend to nonbinary nodes @cache.method("_single_node_repertoire_cache", Direction.EFFECT) def _single_node_effect_repertoire(self, mechanism, purview_node_index): # pylint: disable=missing-docstring purview_node = self._index2node[purview_node_index] # Condition on the state of the inputs that are in the mechanism. mechanism_inputs = purview_node.inputs & mechanism - tpm = condition_tpm(purview_node.tpm, mechanism_inputs, self.state) # Marginalize-out the inputs that aren't in the mechanism. nonmechanism_inputs = purview_node.inputs - mechanism - tpm = marginalize_out(nonmechanism_inputs, tpm) - # Reshape so that the distribution is over next states. - return tpm.reshape(repertoire_shape([purview_node.index], self.tpm_size)) + + # Set up different path for TPM objects while transitioning + if isinstance(purview_node.tpm, TPM): + tpm = purview_node.tpm.condition_node_tpm(mechanism_inputs, self.state) + # Doesn't do cols so gotta do that manually again + tpm = purview_node.tpm.marginalize_out(nonmechanism_inputs, rows=True) + indexer = dict({purview_node.tpm.n_nodes[0]: purview_node.state}) + #tpm.reshape(repertoire_shape([purview_node.index], self.tpm_size, self.states)) + return tpm.tpm.loc[indexer].data + + else: + tpm = condition_tpm(purview_node.tpm, mechanism_inputs, self.state) + tpm = marginalize_out(nonmechanism_inputs, tpm) + # Reshape so that the distribution is over next states. + return tpm.reshape(repertoire_shape([purview_node.index], self.tpm_size)) @cache.method("_repertoire_cache", Direction.EFFECT) def effect_repertoire(self, mechanism, purview): @@ -389,7 +434,7 @@ def effect_repertoire(self, mechanism, purview): mechanism = frozenset(mechanism) # Preallocate the repertoire with the proper shape, so that # probabilities are broadcasted appropriately. - joint = np.ones(repertoire_shape(purview, self.tpm_size)) + joint = np.ones(repertoire_shape(purview, self.tpm_size, self.states)) # The effect repertoire is the product of the effect repertoires of the # individual nodes. return joint * functools.reduce( diff --git a/pyphi/validate.py b/pyphi/validate.py index 125e90b34..bef826811 100644 --- a/pyphi/validate.py +++ b/pyphi/validate.py @@ -10,6 +10,8 @@ from . import Direction, config, convert, exceptions from .tpm import is_state_by_state +from itertools import product +from pyphi.__tpm import TPM # pylint: disable=redefined-outer-name @@ -29,7 +31,7 @@ def direction(direction, allow_bi=False): return True - +# TODO eliminate non-object route? def tpm(tpm, check_independence=True): """Validate a TPM. @@ -61,6 +63,8 @@ def tpm(tpm, check_independence=True): ) if tpm.shape[0] == tpm.shape[1] and check_independence: conditionally_independent(tpm) + elif isinstance(tpm, TPM) and check_independence: + conditionally_independent_obj(tpm) elif tpm.ndim == (N + 1): if tpm.shape != tuple([2] * N + [N]): raise ValueError( @@ -76,6 +80,37 @@ def tpm(tpm, check_independence=True): ) return True +def conditionally_independent_obj(tpm): + # Step 0: Cartesian products of potential state combos of the p_nodes and n_nodes for later access + p_states = tuple(product(*(tuple(range(tpm.all_nodes[node])) for node in tpm.p_nodes))) + n_states = tuple(product(*(tuple(range(tpm.all_nodes[node])) for node in tpm.n_nodes))) + + # Step 1: For each p combo, calculate the state probabilities of each node in the columns + for state in p_states: + node_distributions = [] + # ... would rather not have to use this copy method again (as in infer_cm) but python refuses to let me work with + # the list like I need to + temp_n_nodes = tpm.n_nodes.copy() + for n_node in tpm.n_nodes: + index = temp_n_nodes.index(n_node) + temp_n_nodes.remove(n_node) + node_distributions.append(tpm.tpm.groupby(n_node).sum(temp_n_nodes).loc[state]) + temp_n_nodes.insert(index, n_node) + + # Step 2: Generate the conditionally independent element, then compare with actual element + + for n_state in n_states: + independent_probability = 1 + for node_index in range(len(node_distributions)): + independent_probability *= node_distributions[node_index][n_state[node_index]] + + if tpm[state + n_state] != independent_probability: + raise exceptions.ConditionallyDependentError( + "TPM is not conditionally independent.\n" + "See the conditional independence example in the documentation " + "for more info." + ) + return True def conditionally_independent(tpm): """Validate that the TPM is conditionally independent.""" @@ -104,8 +139,6 @@ def connectivity_matrix(cm): return True if cm.ndim != 2: raise ValueError("Connectivity matrix must be 2-dimensional.") - if cm.shape[0] != cm.shape[1]: - raise ValueError("Connectivity matrix must be square.") if not np.all(np.logical_or(cm == 1, cm == 0)): raise ValueError("Connectivity matrix must contain only binary " "values.") return True @@ -127,7 +160,7 @@ def network(n): Checks the TPM and connectivity matrix. """ - tpm(n.tpm) + # tpm(n.tpm) TODO: Generating a valid TPM list implies the tpm is valid, only 'needed' if network given a non-list connectivity_matrix(n.cm) if n.cm.shape[0] != n.size: raise ValueError( @@ -171,12 +204,38 @@ def state_reachable(subsystem): # reached from some state. # First we take the submatrix of the conditioned TPM that corresponds to # the nodes that are actually in the subsystem... - tpm = subsystem.tpm[..., subsystem.node_indices] + #tpm = subsystem.tpm[..., subsystem.node_indices] # Then we do the subtraction and test. - test = tpm - np.array(subsystem.proper_state) - if not np.any(np.logical_and(-1 < test, test < 1).all(-1)): - raise exceptions.StateUnreachableError(subsystem.state) - + #test = tpm - np.array(subsystem.proper_state) + #if not np.any(np.logical_and(-1 < test, test < 1).all(-1)): + # raise exceptions.StateUnreachableError(subsystem.state) + + # If there exists a row in which every node in the subsystem + # Has a positive probability of being in that state, there is + # a non-zero chance of reaching that state. + + # Create state tuples to look thru + # If the node is in the subsystem, the tuple will include all its potential states + # If it is an external node, however, the tpms have been conditioned to a single state + # So access that state from the subsystem and use it in the tuple. + p_states = tuple(product( + *(tuple(range(subsystem.states[node])) if node in subsystem.node_indices else (subsystem.state[node], ) for node in subsystem.network.node_indices) + )) + + reachable = False + for p_state in p_states: + # Reset reachable so that each state can be tried + reachable = True + for node in subsystem.node_indicies: + if subsystem.tpm[node][p_state + (subsystem.state[node], )] == 0: + # If there's any 0 probability, state can't be reached here so break + reachable = False + break + # If there is ever a point where reachable remains True, then it is reachable + if reachable: + return True + # If no state led to it being reachable, raise the exception + raise exceptions.StateUnreachableError(subsystem.state) def cut(cut, node_indices): """Check that the cut is for only the given nodes.""" diff --git a/pyphi_config.yml b/pyphi_config.yml index 654086c34..4212dac8f 100644 --- a/pyphi_config.yml +++ b/pyphi_config.yml @@ -18,7 +18,7 @@ ASSUME_CUTS_CANNOT_CREATE_NEW_CONCEPTS: false # modular, sparsely-connected, or homogeneous networks. CUT_ONE_APPROXIMATION: false # The measure to use when computing phi ("EMD", "KLD", "L1", ...) -MEASURE: "EMD" +MEASURE: "EMD" #KLM for nonbinary calculations # Controls the number of parts in a partition. PARTITION_TYPE: "BI" # Controls how to resolve phi-ties when computing MICE. diff --git a/test/mice_test.ipynb b/test/mice_test.ipynb new file mode 100644 index 000000000..1df917511 --- /dev/null +++ b/test/mice_test.ipynb @@ -0,0 +1,601 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Welcome to PyPhi!\n", + "\n", + "If you use PyPhi in your research, please cite the paper:\n", + "\n", + " Mayner WGP, Marshall W, Albantakis L, Findlay G, Marchman R, Tononi G.\n", + " (2018). PyPhi: A toolbox for integrated information theory.\n", + " PLOS Computational Biology 14(7): e1006343.\n", + " https://doi.org/10.1371/journal.pcbi.1006343\n", + "\n", + "Documentation is available online (or with the built-in `help()` function):\n", + " https://pyphi.readthedocs.io\n", + "\n", + "To report issues, please use the issue tracker on the GitHub repository:\n", + " https://github.com/wmayner/pyphi\n", + "\n", + "For general discussion, you are welcome to join the pyphi-users group:\n", + " https://groups.google.com/forum/#!forum/pyphi-users\n", + "\n", + "To suppress this message, either:\n", + " - Set `WELCOME_OFF: true` in your `pyphi_config.yml` file, or\n", + " - Set the environment variable PYPHI_WELCOME_OFF to any value in your shell:\n", + " export PYPHI_WELCOME_OFF='yes'\n", + "\n" + ] + } + ], + "source": [ + "import pyphi\n", + "import test_subsystem_phi_max\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{None: {<Direction.CAUSE: 0>: [Maximally-irreducible cause\n", + " φ = 1/2\n", + " Mechanism: [B]\n", + " Purview = [C]\n", + " Direction: CAUSE\n", + " MIP:\n", + " ∅ B \n", + " ─── ✕ ───\n", + " C ∅ \n", + " Repertoire:\n", + " ┌─────────────┐\n", + " │ S Pr(S) │\n", + " │ ╴╴╴╴╴╴╴╴╴╴╴ │\n", + " │ 0 1 │\n", + " │ 1 0 │\n", + " └─────────────┘\n", + " Partitioned repertoire:\n", + " ┌─────────────┐\n", + " │ S Pr(S) │\n", + " │ ╴╴╴╴╴╴╴╴╴╴╴ │\n", + " │ 0 1/2 │\n", + " │ 1 1/2 │\n", + " └─────────────┘,\n", + " Maximally-irreducible cause\n", + " φ = 1/2\n", + " Mechanism: [C]\n", + " Purview = [A, B]\n", + " Direction: CAUSE\n", + " MIP:\n", + " ∅ C \n", + " ─── ✕ ───\n", + " A B \n", + " Repertoire:\n", + " ┌──────────────┐\n", + " │ S Pr(S) │\n", + " │ ╴╴╴╴╴╴╴╴╴╴╴╴ │\n", + " │ 00 1/2 │\n", + " │ 10 0 │\n", + " │ 01 0 │\n", + " │ 11 1/2 │\n", + " └──────────────┘\n", + " Partitioned repertoire:\n", + " ┌──────────────┐\n", + " │ S Pr(S) │\n", + " │ ╴╴╴╴╴╴╴╴╴╴╴╴ │\n", + " │ 00 1/4 │\n", + " │ 10 1/4 │\n", + " │ 01 1/4 │\n", + " │ 11 1/4 │\n", + " └──────────────┘,\n", + " Maximally-irreducible cause\n", + " φ = 1/3\n", + " Mechanism: [A, B]\n", + " Purview = [B, C]\n", + " Direction: CAUSE\n", + " MIP:\n", + " A B \n", + " ─── ✕ ───\n", + " B C \n", + " Repertoire:\n", + " ┌──────────────┐\n", + " │ S Pr(S) │\n", + " │ ╴╴╴╴╴╴╴╴╴╴╴╴ │\n", + " │ 00 0 │\n", + " │ 10 1 │\n", + " │ 01 0 │\n", + " │ 11 0 │\n", + " └──────────────┘\n", + " Partitioned repertoire:\n", + " ┌──────────────┐\n", + " │ S Pr(S) │\n", + " │ ╴╴╴╴╴╴╴╴╴╴╴╴ │\n", + " │ 00 1/3 │\n", + " │ 10 2/3 │\n", + " │ 01 0 │\n", + " │ 11 0 │\n", + " └──────────────┘,\n", + " Maximally-irreducible cause\n", + " φ = 1/2\n", + " Mechanism: [A, B, C]\n", + " Purview = [A, B, C]\n", + " Direction: CAUSE\n", + " MIP:\n", + " ∅ A,B,C\n", + " ─── ✕ ─────\n", + " A B,C \n", + " Repertoire:\n", + " ┌───────────────┐\n", + " │ S Pr(S) │\n", + " │ ╴╴╴╴╴╴╴╴╴╴╴╴╴ │\n", + " │ 000 0 │\n", + " │ 100 0 │\n", + " │ 010 0 │\n", + " │ 110 1 │\n", + " │ 001 0 │\n", + " │ 101 0 │\n", + " │ 011 0 │\n", + " │ 111 0 │\n", + " └───────────────┘\n", + " Partitioned repertoire:\n", + " ┌───────────────┐\n", + " │ S Pr(S) │\n", + " │ ╴╴╴╴╴╴╴╴╴╴╴╴╴ │\n", + " │ 000 0 │\n", + " │ 100 0 │\n", + " │ 010 1/2 │\n", + " │ 110 1/2 │\n", + " │ 001 0 │\n", + " │ 101 0 │\n", + " │ 011 0 │\n", + " │ 111 0 │\n", + " └───────────────┘],\n", + " <Direction.EFFECT: 1>: [Maximally-irreducible effect\n", + " φ = 1/4\n", + " Mechanism: [B]\n", + " Purview = [A]\n", + " Direction: EFFECT\n", + " MIP:\n", + " ∅ B \n", + " ─── ✕ ───\n", + " A ∅ \n", + " Repertoire:\n", + " ┌─────────────┐\n", + " │ S Pr(S) │\n", + " │ ╴╴╴╴╴╴╴╴╴╴╴ │\n", + " │ 0 1/2 │\n", + " │ 1 1/2 │\n", + " └─────────────┘\n", + " Partitioned repertoire:\n", + " ┌─────────────┐\n", + " │ S Pr(S) │\n", + " │ ╴╴╴╴╴╴╴╴╴╴╴ │\n", + " │ 0 1/4 │\n", + " │ 1 3/4 │\n", + " └─────────────┘,\n", + " Maximally-irreducible effect\n", + " φ = 1/2\n", + " Mechanism: [C]\n", + " Purview = [B]\n", + " Direction: EFFECT\n", + " MIP:\n", + " ∅ C \n", + " ─── ✕ ───\n", + " B ∅ \n", + " Repertoire:\n", + " ┌─────────────┐\n", + " │ S Pr(S) │\n", + " │ ╴╴╴╴╴╴╴╴╴╴╴ │\n", + " │ 0 1 │\n", + " │ 1 0 │\n", + " └─────────────┘\n", + " Partitioned repertoire:\n", + " ┌─────────────┐\n", + " │ S Pr(S) │\n", + " │ ╴╴╴╴╴╴╴╴╴╴╴ │\n", + " │ 0 1/2 │\n", + " │ 1 1/2 │\n", + " └─────────────┘,\n", + " Maximally-irreducible effect\n", + " φ = 1/2\n", + " Mechanism: [A, B]\n", + " Purview = [C]\n", + " Direction: EFFECT\n", + " MIP:\n", + " ∅ A,B\n", + " ─── ✕ ───\n", + " C ∅ \n", + " Repertoire:\n", + " ┌─────────────┐\n", + " │ S Pr(S) │\n", + " │ ╴╴╴╴╴╴╴╴╴╴╴ │\n", + " │ 0 0 │\n", + " │ 1 1 │\n", + " └─────────────┘\n", + " Partitioned repertoire:\n", + " ┌─────────────┐\n", + " │ S Pr(S) │\n", + " │ ╴╴╴╴╴╴╴╴╴╴╴ │\n", + " │ 0 1/2 │\n", + " │ 1 1/2 │\n", + " └─────────────┘,\n", + " Maximally-irreducible effect\n", + " φ = 1/2\n", + " Mechanism: [A, B, C]\n", + " Purview = [A, B, C]\n", + " Direction: EFFECT\n", + " MIP:\n", + " ∅ A,B,C\n", + " ─── ✕ ─────\n", + " B A,C \n", + " Repertoire:\n", + " ┌───────────────┐\n", + " │ S Pr(S) │\n", + " │ ╴╴╴╴╴╴╴╴╴╴╴╴╴ │\n", + " │ 000 0 │\n", + " │ 100 0 │\n", + " │ 010 0 │\n", + " │ 110 0 │\n", + " │ 001 1 │\n", + " │ 101 0 │\n", + " │ 011 0 │\n", + " │ 111 0 │\n", + " └───────────────┘\n", + " Partitioned repertoire:\n", + " ┌───────────────┐\n", + " │ S Pr(S) │\n", + " │ ╴╴╴╴╴╴╴╴╴╴╴╴╴ │\n", + " │ 000 0 │\n", + " │ 100 0 │\n", + " │ 010 0 │\n", + " │ 110 0 │\n", + " │ 001 1/2 │\n", + " │ 101 0 │\n", + " │ 011 1/2 │\n", + " │ 111 0 │\n", + " └───────────────┘]},\n", + " Cut [1, 2] ━━/ /━━➤ [0]: {<Direction.CAUSE: 0>: [Maximally-irreducible cause\n", + " φ = 1/2\n", + " Mechanism: [B]\n", + " Purview = [C]\n", + " Direction: CAUSE\n", + " MIP:\n", + " ∅ B \n", + " ─── ✕ ───\n", + " C ∅ \n", + " Repertoire:\n", + " ┌─────────────┐\n", + " │ S Pr(S) │\n", + " │ ╴╴╴╴╴╴╴╴╴╴╴ │\n", + " │ 0 1 │\n", + " │ 1 0 │\n", + " └─────────────┘\n", + " Partitioned repertoire:\n", + " ┌─────────────┐\n", + " │ S Pr(S) │\n", + " │ ╴╴╴╴╴╴╴╴╴╴╴ │\n", + " │ 0 1/2 │\n", + " │ 1 1/2 │\n", + " └─────────────┘,\n", + " Maximally-irreducible cause\n", + " φ = 1/2\n", + " Mechanism: [C]\n", + " Purview = [A, B]\n", + " Direction: CAUSE\n", + " MIP:\n", + " ∅ C \n", + " ─── ✕ ───\n", + " A B \n", + " Repertoire:\n", + " ┌──────────────┐\n", + " │ S Pr(S) │\n", + " │ ╴╴╴╴╴╴╴╴╴╴╴╴ │\n", + " │ 00 1/2 │\n", + " │ 10 0 │\n", + " │ 01 0 │\n", + " │ 11 1/2 │\n", + " └──────────────┘\n", + " Partitioned repertoire:\n", + " ┌──────────────┐\n", + " │ S Pr(S) │\n", + " │ ╴╴╴╴╴╴╴╴╴╴╴╴ │\n", + " │ 00 1/4 │\n", + " │ 10 1/4 │\n", + " │ 01 1/4 │\n", + " │ 11 1/4 │\n", + " └──────────────┘,\n", + " Maximally-irreducible cause\n", + " φ = 0\n", + " Mechanism: [0, 1]\n", + " Purview = []\n", + " Direction: CAUSE\n", + " MIP:\n", + " \n", + " Repertoire:\n", + " \n", + " Partitioned repertoire:\n", + " ,\n", + " Maximally-irreducible cause\n", + " φ = 0\n", + " Mechanism: [0, 1, 2]\n", + " Purview = []\n", + " Direction: CAUSE\n", + " MIP:\n", + " \n", + " Repertoire:\n", + " \n", + " Partitioned repertoire:\n", + " ],\n", + " <Direction.EFFECT: 1>: [Maximally-irreducible effect\n", + " φ = 0\n", + " Mechanism: [B]\n", + " Purview = [C]\n", + " Direction: EFFECT\n", + " MIP:\n", + " ∅ B \n", + " ─── ✕ ───\n", + " C ∅ \n", + " Repertoire:\n", + " ┌─────────────┐\n", + " │ S Pr(S) │\n", + " │ ╴╴╴╴╴╴╴╴╴╴╴ │\n", + " │ 0 1/2 │\n", + " │ 1 1/2 │\n", + " └─────────────┘\n", + " Partitioned repertoire:\n", + " ┌─────────────┐\n", + " │ S Pr(S) │\n", + " │ ╴╴╴╴╴╴╴╴╴╴╴ │\n", + " │ 0 1/2 │\n", + " │ 1 1/2 │\n", + " └─────────────┘,\n", + " Maximally-irreducible effect\n", + " φ = 1/2\n", + " Mechanism: [C]\n", + " Purview = [B]\n", + " Direction: EFFECT\n", + " MIP:\n", + " ∅ C \n", + " ─── ✕ ───\n", + " B ∅ \n", + " Repertoire:\n", + " ┌─────────────┐\n", + " │ S Pr(S) │\n", + " │ ╴╴╴╴╴╴╴╴╴╴╴ │\n", + " │ 0 1 │\n", + " │ 1 0 │\n", + " └─────────────┘\n", + " Partitioned repertoire:\n", + " ┌─────────────┐\n", + " │ S Pr(S) │\n", + " │ ╴╴╴╴╴╴╴╴╴╴╴ │\n", + " │ 0 1/2 │\n", + " │ 1 1/2 │\n", + " └─────────────┘,\n", + " Maximally-irreducible effect\n", + " φ = 1/2\n", + " Mechanism: [A, B]\n", + " Purview = [C]\n", + " Direction: EFFECT\n", + " MIP:\n", + " ∅ A,B\n", + " ─── ✕ ───\n", + " C ∅ \n", + " Repertoire:\n", + " ┌─────────────┐\n", + " │ S Pr(S) │\n", + " │ ╴╴╴╴╴╴╴╴╴╴╴ │\n", + " │ 0 0 │\n", + " │ 1 1 │\n", + " └─────────────┘\n", + " Partitioned repertoire:\n", + " ┌─────────────┐\n", + " │ S Pr(S) │\n", + " │ ╴╴╴╴╴╴╴╴╴╴╴ │\n", + " │ 0 1/2 │\n", + " │ 1 1/2 │\n", + " └─────────────┘,\n", + " Maximally-irreducible effect\n", + " φ = 0\n", + " Mechanism: [0, 1, 2]\n", + " Purview = []\n", + " Direction: EFFECT\n", + " MIP:\n", + " \n", + " Repertoire:\n", + " \n", + " Partitioned repertoire:\n", + " ]}}" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "test_subsystem_phi_max.expected_mice" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "cut, direction, expected = mice_scenarios[5]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Welcome to PyPhi!\n", + "\n", + "If you use PyPhi in your research, please cite the paper:\n", + "\n", + " Mayner WGP, Marshall W, Albantakis L, Findlay G, Marchman R, Tononi G.\n", + " (2018). PyPhi: A toolbox for integrated information theory.\n", + " PLOS Computational Biology 14(7): e1006343.\n", + " https://doi.org/10.1371/journal.pcbi.1006343\n", + "\n", + "Documentation is available online (or with the built-in `help()` function):\n", + " https://pyphi.readthedocs.io\n", + "\n", + "To report issues, please use the issue tracker on the GitHub repository:\n", + " https://github.com/wmayner/pyphi\n", + "\n", + "For general discussion, you are welcome to join the pyphi-users group:\n", + " https://groups.google.com/forum/#!forum/pyphi-users\n", + "\n", + "To suppress this message, either:\n", + " - Set `WELCOME_OFF: true` in your `pyphi_config.yml` file, or\n", + " - Set the environment variable PYPHI_WELCOME_OFF to any value in your shell:\n", + " export PYPHI_WELCOME_OFF='yes'\n", + "\n" + ] + } + ], + "source": [ + "%run test_subsystem_phi_max.py" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Expected:\n", + " Maximally-irreducible effect\n", + " φ = 1/2\n", + " Mechanism: [A, B, C]\n", + " Purview = [A, B, C]\n", + " Direction: EFFECT\n", + " MIP:\n", + " ∅ A,B,C\n", + " ─── ✕ ─────\n", + " B A,C \n", + " Repertoire:\n", + " ┌───────────────┐\n", + " │ S Pr(S) │\n", + " │ ╴╴╴╴╴╴╴╴╴╴╴╴╴ │\n", + " │ 000 0 │\n", + " │ 100 0 │\n", + " │ 010 0 │\n", + " │ 110 0 │\n", + " │ 001 1 │\n", + " │ 101 0 │\n", + " │ 011 0 │\n", + " │ 111 0 │\n", + " └───────────────┘\n", + " Partitioned repertoire:\n", + " ┌───────────────┐\n", + " │ S Pr(S) │\n", + " │ ╴╴╴╴╴╴╴╴╴╴╴╴╴ │\n", + " │ 000 0 │\n", + " │ 100 0 │\n", + " │ 010 0 │\n", + " │ 110 0 │\n", + " │ 001 1/2 │\n", + " │ 101 0 │\n", + " │ 011 1/2 │\n", + " │ 111 0 │\n", + " └───────────────┘\n", + "Result:\n", + " Maximally-irreducible effect\n", + " φ = 1/2\n", + " Mechanism: [A, B, C]\n", + " Purview = [A, B, C]\n", + " Direction: EFFECT\n", + " MIP:\n", + " ∅ A,B,C\n", + " ─── ✕ ─────\n", + " B A,C \n", + " Repertoire:\n", + " ┌───────────────┐\n", + " │ S Pr(S) │\n", + " │ ╴╴╴╴╴╴╴╴╴╴╴╴╴ │\n", + " │ 000 0 │\n", + " │ 100 0 │\n", + " │ 010 0 │\n", + " │ 110 0 │\n", + " │ 001 1 │\n", + " │ 101 0 │\n", + " │ 011 0 │\n", + " │ 111 0 │\n", + " └───────────────┘\n", + " Partitioned repertoire:\n", + " ┌───────────────┐\n", + " │ S Pr(S) │\n", + " │ ╴╴╴╴╴╴╴╴╴╴╴╴╴ │\n", + " │ 000 0 │\n", + " │ 100 0 │\n", + " │ 010 0 │\n", + " │ 110 0 │\n", + " │ 001 1/2 │\n", + " │ 101 0 │\n", + " │ 011 1/2 │\n", + " │ 111 0 │\n", + " └───────────────┘\n" + ] + } + ], + "source": [ + "test_find_mice(cut, direction, expected)" + ] + } + ], + "metadata": { + "interpreter": { + "hash": "4d7e67e8959bd4cf30a4f9a3d906518b5030bde8ca90c929b80828e7ae4ed2e6" + }, + "kernelspec": { + "display_name": "Python 3.7.9 64-bit", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.9" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/test/test_network.py b/test/test_network.py index cb255e080..9765445e2 100644 --- a/test/test_network.py +++ b/test/test_network.py @@ -17,15 +17,15 @@ def network(): def test_network_init_validation(network): - with pytest.raises(ValueError): + with pytest.raises(IndexError): # TODO Do we want to change this error? Is this error guaranteed? # Totally wrong shape tpm = np.arange(3).astype(float) Network(tpm) - with pytest.raises(ValueError): + # with pytest.raises(ValueError): # This should no longer raise an error, tpm generalized to nb nodes # Non-binary nodes (4 states) - tpm = np.ones((4, 4, 4, 3)).astype(float) - Network(tpm) - + #tpm = np.ones((4, 4, 4, 3)).astype(float) + #Network(tpm) + # Conditionally dependent # fmt: off tpm = np.array([ @@ -42,7 +42,7 @@ def test_network_init_validation(network): Network(tpm) -def test_network_creates_fully_connected_cm_by_default(): +def test_network_creates_fully_connected_cm_by_default(): # Deprecated? Now have infer_cm method tpm = np.zeros((2 * 2 * 2, 3)) network = Network(tpm, cm=None) target_cm = np.ones((3, 3)) diff --git a/test/test_subsystem.py b/test/test_subsystem.py index 0908f6bd9..0d7c40ebd 100644 --- a/test/test_subsystem.py +++ b/test/test_subsystem.py @@ -124,7 +124,7 @@ def test_apply_cut(s): assert s.network == cut_s.network assert s.state == cut_s.state assert s.node_indices == cut_s.node_indices - assert np.array_equal(cut_s.tpm, s.tpm) + assert np.all([np.array_equal(cut_s.tpm[i].tpm.data, s.tpm[i].tpm.data) for i in range(len(s.tpm))]) assert np.array_equal(cut_s.cm, cut.apply_cut(s.cm)) diff --git a/test/test_tpm_obj.py b/test/test_tpm_obj.py new file mode 100644 index 000000000..71e8ff5fc --- /dev/null +++ b/test/test_tpm_obj.py @@ -0,0 +1,257 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# test/test_tpm_obj.py + +import numpy as np +import pytest +import pandas as pd +from itertools import product + +from pyphi.__tpm import TPM +from pyphi.__tpm import SbN + + +# TODO for harsher tests, may want to ensure that node labels properly direct you to the correct position +# But only necessary if signficant changes are made to shaping as results in notebook testing +# indicate it works +def test_init(): + tpm = np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]) + + # Nothing but data given + test = TPM(tpm) + assert test.tpm.shape == (2,2,2,2) + assert test.tpm.dims == ("n0_p", "n1_p", "n0_n", "n1_n") + + # Symmetric, binary + test = TPM(tpm, p_nodes=["A", "B"]) + + assert test.tpm.shape == (2, 2, 2, 2) + assert test.tpm.dims == ("A_p", "B_p", "n0_n", "n1_n") + + # Symmetric, nonbinary, wrong shape + with pytest.raises(ValueError): + TPM(tpm, p_nodes=["A", "B"], p_states=[2, 3]) + + # Asymmetric, binary + tpm = np.array([[0, 1], [0, 1], [1, 0], [1, 0]]) + test = TPM(tpm, p_nodes=["A", "B"], n_nodes=["C"]) + + assert test.tpm.shape == (2, 2, 2) + assert test.tpm.dims == ("A_p", "B_p", "C_n") + + # Asymmetric, nonbinary + tpm = np.array([[0, 1, 0], [0, 0, 1], [1, 0, 0], [1, 0, 0]]) + test = TPM(tpm, p_nodes=["A", "B"], n_nodes=["C"], p_states=[2, 2], n_states=[3]) + + assert test.tpm.shape == (2, 2, 3) + assert test.tpm.dims == ("A_p", "B_p", "C_n") + + # Symmetric, nonbinary, right shape + tpm = np.array([[1, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0], [0, 0, 0, 1, 0, 0], [0, 0, 0, 0, 1, 0], [0, 0, 0, 0, 0, 1]]) + test = TPM(tpm, p_nodes=["A", "B"], p_states=[2, 3]) + + assert test.tpm.shape == (2, 3, 2, 3) + assert test.tpm.dims == ("A_p", "B_p", "A_n", "B_n") + + # DataFrame + tpm = np.array([[0, 0, 0, 1], [0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 0, 0]]) + base = [2, 2] + nam = ["A", "B"] + + states_node=[list(range(b)) for b in base] + + sates_all_nodes=[list(x[::-1]) for x in list(product(*states_node[::-1])) ] + sates_all_nodes=np.transpose(sates_all_nodes).tolist() + index = pd.MultiIndex.from_arrays(sates_all_nodes, names=nam) + columns = pd.MultiIndex.from_arrays(sates_all_nodes, names=nam) + + df = pd.DataFrame(tpm,columns=columns, index=index) + test = TPM(df, ["C", "D"]) #TODO should only need one argument, test ensures second arg doesn't influence naming + + assert test.tpm.shape == (2, 2, 2, 2) + assert test.tpm.dims == ("A_p", "B_p", "A_n", "B_n") + + # SbN Network, check both init paths give same results + can = np.array([[0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1], + [0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0]]) + + can_sbn = np.array([[0, 0, 1], + [0, 0, 1], + [0, 0, 0], + [0, 0, 0], + [1, 0, 1], + [1, 1, 1], + [1, 0, 0], + [1, 1, 0]]) + + can_obj = SbN(can, p_nodes=["A", "B", "C"], p_states=[2, 2, 2]) + + can_sbn_obj = SbN(can_sbn, p_nodes=["A", "B", "C"], p_states=[2, 2, 2], n_nodes=["A", "B", "C"]) + + assert np.array_equal(can_obj.tpm.data, can_sbn_obj.tpm.data) + assert can_obj.p_nodes == can_sbn_obj.p_nodes + assert can_obj.n_nodes == can_sbn_obj.n_nodes + + # Test only data given + can_obj = SbN(can) + can_sbn_obj = SbN(can_sbn) + + assert np.array_equal(can_obj.tpm.data, can_sbn_obj.tpm.data) + assert can_obj.p_nodes == can_sbn_obj.p_nodes + assert can_obj.n_nodes == can_sbn_obj.n_nodes + +def test_marginalize_out_obj(): + p53 = np.array([[0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.], + [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.], + [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.], + [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.]]) + + p53_obj = TPM(p53, p_nodes=["P", "Mc", "Mn"], p_states=[3, 2, 2]) + + marginalized = p53_obj.marginalize_out((1, )).data + + solution = np.array([[0, 0, 0, 0, 0, 1], + [0, 0, 0.5, 0, 0, 0.5], + [0, 0, 0.5, 0, 0, 0.5], + [0, 0, 0, 1, 0, 0], + [0.5, 0, 0, 0.5, 0, 0], + [0.5, 0, 0, 0.5, 0, 0]]).reshape(3, 1, 2, 3, 1, 2, order="F") + + assert np.array_equal(marginalized, solution) + +def test_marginalize_out_sbn(s): + sbn = SbN(s.tpm, p_nodes=["A", "B", "C"], n_nodes=["A", "B", "C"], p_states=[2,2,2], n_states=[2,2,2]) + + marginalized_distribution = sbn.marginalize_out((0, )).data + + answer = np.array([ + [[[0.0, 0.0, 0.5], + [1.0, 1.0, 0.5]], + [[1.0, 0.0, 0.5], + [1.0, 1.0, 0.5]]], + ]) + assert np.array_equal(marginalized_distribution, answer) + + marginalized_distribution = sbn.marginalize_out((0, 1)).data + + answer = np.array([ + [[[0.5, 0.0, 0.5], + [1.0, 1.0, 0.5]]], + ]) + + assert np.array_equal(marginalized_distribution, answer) + +def test_infer_cm_obj(): + # Check infer_cm functions on both TPM (state-by-state) + # and SbN (state-by-node) objects + p53 = np.array([[0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.], + [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.], + [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.], + [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.]]) + + test_p53 = TPM(p53, p_nodes=["P", "Mc", "Mn"], n_nodes=["P", "Mc", "Mn"], p_states=[3,2,2], n_states=[3,2,2]) + + solution = np.array([[0, 1, 1], + [0, 0, 1], + [1, 0, 0]]) + + assert np.array_equal(test_p53.infer_cm(), solution) + +def test_infer_cm_sbn(): + # SbN Binary network: Node A (c)opies C, node B takes A (a)nd C, and C (n)ots B + can = np.array([[0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1], + [0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0]]) + + can_obj = SbN(can, p_nodes=["A", "B", "C"], n_nodes=["A", "B", "C"], + p_states=[2,2,2], n_states=[2,2,2]) + solution = np.array([[0, 1, 0], + [0, 0, 1], + [1, 1, 0]]) + + assert np.array_equal(can_obj.infer_cm(), solution) + +def test_condition_obj(): + p53 = np.array([[0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.], + [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.], + [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.], + [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.]]) + + test_p53 = TPM(p53, p_nodes=["P", "Mc", "Mn"], n_nodes=["P", "Mc", "Mn"], p_states=[3,2,2], n_states=[3,2,2]) + + conditioned = test_p53.condition((1, ), (1, 0, 1)) + + solution = np.array([[0, 0, 0, 0, 0, 1], + [0, 0, 1, 0, 0, 0], + [0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0], + [1, 0, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 0]]).reshape(3, 1, 2, 3, 1, 2, order="F") + + # Dimensions should remain the same, but Mc has been conditioned on making its + # dimension one unit in size + assert conditioned.shape == (3, 1, 2, 3, 1, 2) + + assert np.array_equal(conditioned.data, solution) + +def test_condition_sbn(): + # SbN form testing + can_sbn = np.array([[0, 0, 1], + [0, 0, 1], + [0, 0, 0], + [0, 0, 0], + [1, 0, 1], + [1, 1, 1], + [1, 0, 0], + [1, 1, 0]]) + + can_obj = SbN(can_sbn, p_nodes=["A", "B", "C"], n_nodes=["A", "B", "C"], + p_states=[2,2,2]) + + conditioned = can_obj.condition((1,), (1, 0, 1)) + + # At present we don't drop unneeded columns, though could if needed the space + solution = np.array([[0, 0, 1], + [0, 0, 1], + [1, 0, 1], + [1, 1, 1]]).reshape(2, 1, 2, 3, order="F") + + assert conditioned.shape == (2, 1, 2, 3) + + assert np.array_equal(conditioned.data, solution) + \ No newline at end of file