Skip to content

Initial TPM constructor #49

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
676 changes: 676 additions & 0 deletions pyphi/__tpm.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyphi/compute/subsystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def _ces(subsystem):
return ces(subsystem, parallel=config.PARALLEL_CUT_EVALUATION)


@memory.cache(ignore=["subsystem"])
#@memory.cache(ignore=["subsystem"])
@time_annotated
def _sia(cache_key, subsystem):
"""Return the minimal information partition of a subsystem.
Expand Down
16 changes: 11 additions & 5 deletions pyphi/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def purview_size(repertoire):
return len(purview(repertoire))


def repertoire_shape(purview, N): # pylint: disable=redefined-outer-name
def repertoire_shape(purview, N, states=None): # pylint: disable=redefined-outer-name
"""Return the shape a repertoire.

Args:
Expand All @@ -122,8 +122,12 @@ def repertoire_shape(purview, N): # pylint: disable=redefined-outer-name
>>> repertoire_shape(purview, N)
[2, 1, 2]
"""
# TODO: extend to non-binary nodes
return [2 if i in purview else 1 for i in range(N)]
#TODO Remove once fully transitioned, as then states should never be None
if states is not None:
return [states[i] if i in purview else 1 for i in range(N)]

else:
return [2 if i in purview else 1 for i in range(N)]


def flatten(repertoire, big_endian=False):
Expand Down Expand Up @@ -173,7 +177,7 @@ def unflatten(repertoire, purview, N, big_endian=False):


@cache(cache={}, maxmem=None)
def max_entropy_distribution(node_indices, number_of_nodes):
def max_entropy_distribution(node_indices, number_of_nodes, states=None):
"""Return the maximum entropy distribution over a set of nodes.

This is different from the network's uniform distribution because nodes
Expand All @@ -184,10 +188,12 @@ def max_entropy_distribution(node_indices, number_of_nodes):
node_indices (tuple[int]): The set of node indices over which to take
the distribution.
number_of_nodes (int): The total number of nodes in the network.
states (tuple): The states of each node in the network

Returns:
np.ndarray: The maximum entropy distribution over the set of nodes.
"""
distribution = np.ones(repertoire_shape(node_indices, number_of_nodes))
# TODO Remove once fully transitioned
distribution = np.ones(repertoire_shape(node_indices, number_of_nodes, states))

return distribution / distribution.size
92 changes: 78 additions & 14 deletions pyphi/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@

import numpy as np

from . import cache, config, connectivity, convert, jsonify, utils, validate
from . import cache, config, connectivity, convert, jsonify, utils, validate, node
from .labels import NodeLabels
from .tpm import is_state_by_state
from .__tpm import TPM, SbN


class Network:
Expand Down Expand Up @@ -47,8 +48,10 @@ class Network:
that node |i| is connected to node |j| (see :ref:`cm-conventions`).
**If no connectivity matrix is given, PyPhi assumes that every node
is connected to every node (including itself)**.
node_labels (tuple[str] or |NodeLabels|): Human-readable labels for
each node in the network.
p_nodes (list[str]): Human-readable list of names of nodes at time |t-1|
p_states (list[int]): List of the number of states of each node at time |t-1| (necessary for defining a multi-valued tpm)
n_nodes (list[str]): Human-readable list of names of nodes at time |t|
n_states (list[int]): List of the number of states of each node at time |t| (necessary for defining a multi-valued tpm)

Example:
In a 3-node network, ``the_network.tpm[(0, 0, 1)]`` gives the
Expand All @@ -57,13 +60,40 @@ class Network:
"""

# TODO make tpm also optional when implementing logical network definition
def __init__(self, tpm, cm=None, node_labels=None, purview_cache=None):
self._tpm, self._tpm_hash = self._build_tpm(tpm)
self._cm, self._cm_hash = self._build_cm(cm)
self._node_indices = tuple(range(self.size))
self._node_labels = NodeLabels(node_labels, self._node_indices)
# TODO node_labels attribute deprecated, but many tests use it so currently keeping it an option
def __init__(self, tpm, cm=None, p_nodes=None, p_states=None, n_nodes=None, n_states=None, purview_cache=None, node_labels=None):
if isinstance(tpm, list):
self._is_list = True
if node_labels:
p_nodes = node_labels
self._tpm = tpm
self._cm = self._build_cm_from_list(cm)
self._node_indices = tuple(range(len(tpm)))
# User can specify names of each node to be lined up with the list of node TPMs, else default labels are generated
self._node_labels = NodeLabels(p_nodes, self._node_indices)

else:
self._is_list = False
if node_labels:
p_nodes = node_labels
if p_states or n_states: # Requires NB state-by-state, could just do only TPM and convert to SbN later?
self._tpm = TPM(tpm, p_nodes, p_states, n_nodes, n_states)
#validate.tpm(tpm, check_independence=config.VALIDATE_CONDITIONAL_INDEPENDENCE)
else:
self._tpm = SbN(tpm, p_nodes, p_states, n_nodes, n_states)

self._cm = self._tpm.infer_cm()
# Convert to list
tpm_list = [self._tpm.create_node_tpm(index, self._cm) for index in range(len(self._tpm.n_nodes))]

self._node_indices = tuple(range(self.size))
self._node_labels = NodeLabels(self._tpm._p_nodes, self.node_indices) # TODO consider using self._tpm._p_nodes instead for more readability?

self._tpm = tpm_list
self._is_list = True

self.purview_cache = purview_cache or cache.PurviewCache()

validate.network(self)

@property
Expand All @@ -73,6 +103,7 @@ def tpm(self):
"""
return self._tpm

# TODO Deprecated?
@staticmethod
def _build_tpm(tpm):
"""Validate the TPM passed by the user and convert to multidimensional
Expand All @@ -92,6 +123,20 @@ def _build_tpm(tpm):

return (tpm, utils.np_hash(tpm))

def _build_cm_from_list(self, cm):
"""Generate the connectivity matrix for a network whose tpm is defined as a
list of Node TPMs, by concatenating the individual Node TPM cms.

Args:
tpm_list (list[TPM]): List of Node TPMs that define how the network transitions.
"""
if cm is None:
cm_list = [TPM.infer_node_cm(node_tpm) for node_tpm in self._tpm]
return np.concatenate(cm_list, axis=1)
else:
return np.array(cm)
# return np.ones((self.size, self.size))

@property
def cm(self):
"""np.ndarray: The network's connectivity matrix.
Expand All @@ -106,8 +151,8 @@ def _build_cm(self, cm):
unitary CM if none was provided.
"""
if cm is None:
# Assume all are connected.
cm = np.ones((self.size, self.size))
# Build cm from TPM method
cm = self._tpm.infer_cm()
else:
cm = np.array(cm)

Expand All @@ -134,7 +179,15 @@ def size(self):
@property
def num_states(self):
"""int: The number of possible states of the network."""
return 2 ** self.size
# If list, states can be counted as product of possible transitions of each node
# If not, use TPM.num_states?
# if self._is_list:
num = 1
for tpm in self._tpm:
num *= tpm.shape[-1]
return num
# else:
# return self._tpm.num_states

@property
def node_indices(self):
Expand Down Expand Up @@ -169,10 +222,19 @@ def potential_purviews(self, direction, mechanism):

def __len__(self):
"""int: The number of nodes in the network."""
return self.tpm.shape[-1]
if self._is_list:
return len(self._tpm)
elif isinstance(self.tpm, TPM):
# TODO Assumes symmetry for now
return len(self.tpm.p_nodes)
else:
return self.tpm.shape[-1]

def __repr__(self):
return "Network({}, cm={})".format(self.tpm, self.cm)
if self._is_list:
return str([tpm for tpm in self._tpm]) #TODO Consider representations
else:
return "Network({}, cm={})".format(self.tpm, self.cm)

def __str__(self):
return self.__repr__()
Expand All @@ -182,6 +244,8 @@ def __eq__(self, other):

Networks are equal if they have the same TPM and CM.
"""
if self._is_list:
return np.all([node_tpm == node_tpm for node_tpm in self._tpm])
return (
isinstance(other, Network)
and np.array_equal(self.tpm, other.tpm)
Expand Down
40 changes: 5 additions & 35 deletions pyphi/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .connectivity import get_inputs_from_cm, get_outputs_from_cm
from .labels import NodeLabels
from .tpm import marginalize_out, tpm_indices
from .__tpm import TPM, SbN


# TODO extend to nonbinary nodes
Expand Down Expand Up @@ -56,41 +57,10 @@ def __init__(self, tpm, cm, index, state, node_labels):
# Get indices of the inputs.
self._inputs = frozenset(get_inputs_from_cm(self.index, cm))
self._outputs = frozenset(get_outputs_from_cm(self.index, cm))

# Node TPM has already been generated, take from subsystem list
self.tpm = tpm[self.index]

# Generate the node's TPM.
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# We begin by getting the part of the subsystem's TPM that gives just
# the state of this node. This part is still indexed by network state,
# but its last dimension will be gone, since now there's just a single
# scalar value (this node's state) rather than a state-vector for all
# the network nodes.
tpm_on = tpm[..., self.index]

# TODO extend to nonbinary nodes
# Marginalize out non-input nodes that are in the subsystem, since the
# external nodes have already been dealt with as boundary conditions in
# the subsystem's TPM.
non_inputs = set(tpm_indices(tpm)) - self._inputs
tpm_on = marginalize_out(non_inputs, tpm_on)

# Get the TPM that gives the probability of the node being off, rather
# than on.
tpm_off = 1 - tpm_on

# Combine the on- and off-TPM so that the first dimension is indexed by
# the state of the node's inputs at t, and the last dimension is
# indexed by the node's state at t+1. This representation makes it easy
# to condition on the node state.
self.tpm = np.stack([tpm_off, tpm_on], axis=-1)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

# Make the TPM immutable (for hashing).
utils.np_immutable(self.tpm)

# Only compute the hash once.
self._hash = hash(
(index, utils.np_hash(self.tpm), self.state, self._inputs, self._outputs)
)

@property
def tpm_off(self):
Expand Down Expand Up @@ -160,7 +130,7 @@ def generate_nodes(tpm, cm, network_state, indices, node_labels=None):
"""Generate |Node| objects for a subsystem.

Args:
tpm (np.ndarray): The system's TPM
tpm (pyphi.TPM): The system's TPM
cm (np.ndarray): The corresponding CM.
network_state (tuple): The state of the network.
indices (tuple[int]): Indices to generate nodes for.
Expand Down
Loading