diff --git a/pyphi/__tpm.py b/pyphi/__tpm.py
new file mode 100644
index 000000000..31124dd27
--- /dev/null
+++ b/pyphi/__tpm.py
@@ -0,0 +1,676 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+# __tpm.py
+
+"""
+TPM classes.
+
+TPM -> Works with all state-by-state tpms 
+(binary, nonbinary, symmetric, asymmetric -> node TPMs)
+
+SbN -> Only works with state-by-node tpms,
+which are only binary. Can work with asymmetric tpms, however.
+"""
+
+import xarray as xr
+from pandas import DataFrame
+import numpy as np
+from string import ascii_uppercase
+from itertools import product
+from math import log2
+
+from pyphi.connectivity import get_inputs_from_cm
+from pyphi.distribution import repertoire_shape
+
+from .utils import all_states
+
+# labels provided?
+
+# type checking: Union[ Iterable[Node]? ]
+# mechanisms and purviews: Union[
+#    Set[Node]
+#    Iterable[Node]
+# ]
+# Node = Union[int, str]?
+
+# Set[]
+
+
+# Pass method calls to xarray?
+
+# Should use only one underlying structure for TPM, but subclass when it's a
+# matter of space or time
+# e.g. subclass for state-by-node
+
+class TPM:
+    """General class of TPM objects in state-by-state form. 
+
+    The standard TPM class represents a State-by-State matrix where each row defines the 
+    probability that the row's state will transition to any given column state. 
+
+    Xarray was chosen for its ability to define each node as a dimension, 
+    easily accessible by its label and marginalizable.
+
+    Can accept both DataFrame objects and 2D or pre-shaped multidimensional 
+    numpy arrays.
+    """
+    def __init__(self, tpm, p_nodes=None, p_states=None, n_nodes=None, n_states=None):
+
+        if isinstance(tpm, DataFrame):
+            p_states = [len(node_states) for node_states in tpm.index.levels]
+            n_states = [len(node_states) for node_states in tpm.columns.levels]
+            p_nodes = tpm.index.names
+            n_nodes = tpm.columns.names
+            dims = ["{}_p".format(node) for node in p_nodes] + [
+                    "{}_n".format(node) for node in n_nodes]
+            coords = dict(zip(dims, [np.array(range(state)) for state in p_states + n_states]))
+            self.tpm = xr.DataArray(tpm.values.reshape(p_states + n_states, order="F"), coords=coords, dims=dims)
+            
+
+        else:    
+            
+            if isinstance(tpm, xr.DataArray):
+                tpm = tpm.data
+
+            if p_states is None: # Binary
+                if p_nodes is None:
+                    # Gen p_nodes
+                    p_nodes = ["n{}".format(i) for i in range(int(log2(tpm.shape[0])))]
+
+                if n_nodes is None: 
+                    # NOTE Specifying a Node TPM with just data and p_nodes (and perhaps p_states and n_states if nb)
+                    # seems really useful, so I want a label generation method, but this can give awkward results
+                    # Gen n_nodes
+                    n_nodes = ["n{}".format(i) for i in range(int(log2(tpm.shape[1])))]
+                
+                # Gen p_states
+                p_states = [2] * len(p_nodes)
+                # Gen n_states
+                n_states = [2] * len(n_nodes)
+
+            else: # Non-binary
+                if p_nodes is None:
+                    # Gen p_nodes
+                    p_nodes = ["n{}".format(i) for i in range(len(p_states))]
+
+                if n_nodes is None: # Nbin, Sym
+                    if n_states is None:    
+                        # Cpy p_states -> n_states
+                        n_states = p_states.copy()
+                        # Cpy p_nodes -> n_nodes
+                        n_nodes = p_nodes.copy()
+                    
+                    else: # n_states specified so assumed different from p_states?
+                        n_nodes = ["n{}".format(i) for i in range(len(n_states))]
+
+                # Else Nbin, Asym, pass as should have all specified
+
+
+            # Previous nodes are labelled with _p and next nodes with _n
+            # to differentiate between nodes with the same name but different timestep
+            dims = ["{}_p".format(node) for node in p_nodes] + [
+                    "{}_n".format(node) for node in n_nodes]
+            
+            # The given numpy array is reshaped to have the appropriate number of dimensions for labeling
+            # Fortran, or little-endian ordering is used, meaning left-most (first called) node varies fastest
+            coords = dict(zip(dims, [np.array(range(state)) for state in p_states + n_states]))
+            self.tpm = xr.DataArray(tpm.reshape(p_states + n_states, order="F"), coords=coords, dims=dims)
+        
+        self.symmetric = bool(p_nodes == n_nodes)
+        self._p_nodes = p_nodes
+        self._n_nodes = n_nodes
+
+        # So far only used in infer_cm, might be possible to rework that and remove this
+        # Could also just be a list (or implicit in sum of p_nodes and n_nodes), states could
+        # be accessed by tpm.shape[[p_nodes + n_nodes].index(node)]
+        self.all_nodes = dict(zip(self.p_nodes + self.n_nodes, self.tpm.shape))
+
+    @property
+    def p_nodes(self):
+        """Returns p_nodes formatted"""
+        return ["{}_p".format(node) for node in self._p_nodes]
+    
+    @property
+    def n_nodes(self):
+        """Returns n_nodes formatted"""
+        return ["{}_n".format(node) for node in self._n_nodes]
+
+    @property
+    def p_states(self):
+        """Returns list of p_states"""
+        return [self.all_nodes[p_node] for p_node in self.p_nodes]
+
+    @property
+    def n_states(self):
+        """Returns list of n_states"""
+        return [self.all_nodes[n_node] for n_node in self.n_nodes]
+
+    # Maybe make this more generalized? Could be based on xarray's shape?
+    # Maybe make it just one tuple? Hard to separate then...
+    @property
+    def tpm_indices(self):
+        """Returns two tuples of indices for p_nodes and n_nodes"""
+        return tuple(range(len(self.p_nodes))), tuple(range(len(self.n_nodes)))
+
+        # return tuple(range(len(self.tpm.shape)))
+        # return tuple(range(len(self.p_nodes)) + range(len(self.n_nodes))))
+
+    @property
+    def is_deterministic(self) -> bool:
+        """Return whether the TPM is deterministic."""
+        return np.all(np.logical_or(self.tpm.data == 0, self.tpm.data == 1))
+
+    # Could get overridden by State-by-Node TPM's subclass to false? Or unneeded
+    @property
+    def is_state_by_state(self):
+        return True
+
+    @property
+    def num_states(self):
+        """int: Number of possible states the set of previous nodes can be in
+        """
+        # Number of states equal to product of states of p_nodes?
+        # This doesn't really make sense in asymmetric TPMs, so maybe a different property
+        # Makes more sense here? Or could split to num_p_states and num_n_states
+        num_states = 1
+        for node in self.p_nodes:
+            num_states *= self.all_nodes[node]
+        return num_states
+
+    @property
+    def num_n_states(self):
+        """int: Number of possible states the set of next nodes can be in
+        """
+        num_states = 1
+        for node in self.n_nodes:
+            num_states *= self.all_nodes[node]
+        return num_states
+
+    @property
+    def shape(self):
+        return tuple(self.tpm.shape)
+
+    @property
+    def p_node_indices(self):
+        return tuple(range(len(self.p_nodes)))
+
+    @property
+    def n_node_indices(self):
+        return tuple(range(len(self.n_nodes)))
+
+    # TODO Maybe try using xarray's coordinates feature to keep dims and state info?
+    # As opposed to trying to just use slices, keeping the coordinate of the dimension
+    # equal to the conditioned state could make things easier down the line
+    # TODO Consider splitting into condition rows and condition columns, then could mix
+    # into conditioning for asymmetric tpms (or force square dimensionality with empty dims)
+    def condition(self, fixed_nodes, state):
+        """Return a TPM conditioned on the given fixed node indices, whose states
+        are fixed according to the given state-tuple.
+
+        The dimensions of the new TPM that correspond to the fixed nodes are
+        collapsed onto their state, making those dimensions singletons suitable for
+        broadcasting. The number of dimensions of the conditioned TPM will be the
+        same as the unconditioned TPM.
+
+        Args:
+            fixed_nodes: tuple of indicies of nodes that are fixed
+        """
+        # Only doing this for symmetric tpms at the moment
+        if self.symmetric:
+            fixed_node_labels = [self.p_nodes[node] for node in fixed_nodes]
+            indexer = dict(zip(fixed_node_labels, [state[node] for node in fixed_nodes]))
+            # Drop rows that don't fit with the conditioned nodes
+            # TODO Better way to keep dims? IndexSlicer maybe?
+            # Tried using None slices but dimensions still got removed :(
+            kept_dims = [self.tpm.dims.index(label) for label in fixed_node_labels]
+            conditioned_tpm = self.tpm.loc[indexer]
+
+            # Regrow dimensions where they got trimmed
+            for i in range(len(kept_dims)):
+                conditioned_tpm = conditioned_tpm.expand_dims(fixed_node_labels[i], axis=kept_dims[i])
+
+            # Marginalize across columns that don't fit with the conditioned nodes
+            # Because assumed symmetry, self.n_nodes is same as self.p_nodes, except for labelling
+            # Since we're summing across columns, need labels from n_nodes
+
+            column_labels = [self.n_nodes[node] for node in fixed_nodes]
+
+            for label in column_labels:
+                conditioned_tpm = conditioned_tpm.sum(dim=label, keepdims=True)
+            # At some point maybe change the dict of self.all_nodes? Perhaps ideally make all_nodes unneeeded
+
+            for i in fixed_nodes:
+                self.p_states[i] = 1 
+                self.all_nodes[self.p_nodes[i]] = 1
+                self.all_nodes[self.n_nodes[i]] = 1
+
+            return conditioned_tpm
+        
+        else:
+            print(self, "symmetric?:", self.symmetric)
+            raise NotImplementedError
+
+    def create_node_tpm(self, index, cm):
+        """Create a new TPM object based on this one, except only describing the
+        transitions of a single node in the network.
+
+        Args:
+            node (int): The index of the node whose TPM we wish to create
+        """
+        # Want to marginalize out all other column nodes, but then don't need to worry about normalization
+        node = self.n_nodes[index]
+        other_nodes = [label for label in self.n_nodes if label != node]
+        node_tpm = self.tpm.sum(dim=other_nodes)
+
+        node_tpm = TPM(tpm=node_tpm, p_nodes=self._p_nodes, n_nodes=[self._n_nodes[index]], p_states=self.p_states, n_states=[self.all_nodes[node]])
+
+        return node_tpm
+        #return node_tpm.marginalize_out(tuple(set(self.p_node_indices) - set(get_inputs_from_cm(index, cm))), rows=True)
+
+    def condition_node_tpm(node_tpm, fixed_nodes, state, col=False):
+        """Condition a node TPM object, a special case of asymmetric TPMs
+        """
+        fixed_node_labels = [node_tpm.p_nodes[node] for node in fixed_nodes]
+        indexer = dict(zip(fixed_node_labels, [state[node] for node in fixed_nodes]))
+
+        kept_dims = [node_tpm.tpm.dims.index(label) for label in fixed_node_labels]
+        conditioned_tpm = node_tpm.tpm.loc[indexer]
+
+        # Regrow dimensions where they got trimmed
+        for j in range(len(kept_dims)):
+            conditioned_tpm = conditioned_tpm.expand_dims(fixed_node_labels[j], axis=kept_dims[j])
+
+        if col:
+            conditioned_tpm = conditioned_tpm.sum(dim=conditioned_tpm.dims[-1], keepdims=True)
+        
+        return conditioned_tpm
+
+    def condition_list(tpm_list, fixed_nodes, state):
+        """Condition a list of Node TPMs.
+
+        Args:
+            tpm_list (list[TPM]): The Node TPMs to be conditioned
+            fixed_nodes (tuple(int)): The node indicies that are now fixed
+            state (tuple(int)): The state of the system when conditioning
+        """
+        # Step 0: No conditioning required if fixed_nodes empty
+        if fixed_nodes is ():
+            return tpm_list
+
+        # Step 1: Drop rows that do not fit with the conditioned state, for all tpms
+        for i in range(len(tpm_list)):
+            # Replace TPM in list's tpm with dropped row tpm.
+            tpm_list[i].tpm = TPM.condition_node_tpm(tpm_list[i], fixed_nodes, state)
+
+        # Step 2: If a particular tpm describes the transitions of a fixed node, sum the column dimension together
+        for i in fixed_nodes:
+            # NOTE: Assumes Node TPM only has 1 column dimension. Can be generalized, but needs more information about labels.
+            # Since we're already assuming a list, however, this is a valid assumption. See the more general condition method
+            # For implementing generic asymmetric conditioning
+            tpm_list[i].tpm = tpm_list[i].tpm.sum(dim=tpm_list[i].tpm.dims[-1], keepdims=True)
+
+        return tpm_list
+            
+
+    # TODO Currently only works for symmetric TPMs
+    # TODO **kwargs to determine if marginalizing out just row/column?
+    def marginalize_out(self, node_indices, rows=False):
+        """Marginalize out nodes from a TPM.
+
+        Args:
+            node_indices (list[int]): The indices of nodes to be marginalized out.
+            Index based on dimension of the node.
+            rows (bool): Whether to marginalize on only the rows
+
+
+        Returns:
+            xarray: A tpm with the same number of dimensions, with the nodes
+        marginalized out.
+        """
+        def normalize(tpm):
+            """Returns a normalized TPM after marginalization"""
+            return tpm / (np.array(self.tpm.shape)[list(node_indices)].prod())
+
+        if self.symmetric:           
+            labels = [self.p_nodes[i] for i in node_indices] + [self.n_nodes[i] for i in node_indices]
+            new_p_states = self.p_states
+            new_n_states = self.n_states
+            for i in node_indices:
+                new_p_states[i], new_n_states[i] = 1
+                
+        elif rows:
+            labels = [self.p_nodes[i] for i in node_indices]
+            new_p_states = self.p_states
+            for i in node_indices:
+                new_p_states[i] = 1
+            new_n_states = self.n_states
+
+        else:
+            raise NotImplementedError
+
+        marginalized_tpm = self.tpm
+
+        for label in labels:
+            marginalized_tpm = marginalized_tpm.sum(dim=label, keepdims=True)
+
+        finished_tpm = normalize(marginalized_tpm)
+
+        return TPM(tpm=finished_tpm, p_nodes=self._p_nodes, n_nodes=self._n_nodes, p_states=new_p_states, n_states=new_n_states)
+
+
+    def infer_edge(tpm, a, b, contexts):
+        """Infer the presence or absence of an edge from node A to node B.
+
+        Let |S| be the set of all nodes in a network. Let |A' = S - {A}|. We call
+        the state of |A'| the context |C| of |A|. There is an edge from |A| to |B|
+        if there exists any context |C(A)| such that |Pr(B | C(A), A=0) != Pr(B |
+        C(A), A=1)|.
+
+        Args:
+            tpm (np.ndarray): The TPM as an object
+            a (int): The index of the putative source node.
+            b (int): The index of the putative sink node.
+        Returns:
+            bool: ``True`` if the edge |A -> B| exists, ``False`` otherwise.
+        """
+        def a_in_context(context):
+            """Given a context C(A), return the states of the full system with A
+            in each of its possible states, in order as a list.
+            """
+            a_states = [
+                context[:a] + (i, ) + context[a:]
+                for i in range(tpm.tpm.shape[a])
+            ]
+
+            return a_states
+
+        def marginalize_b(state):
+            """Return the distribution of possible states of b at t+1"""
+            temp_n_nodes = tpm.n_nodes
+            name = temp_n_nodes[b]
+            # Instead of making a full copy, just remove and insert afterwards
+            temp_n_nodes.remove(name)
+            marginalized = tpm.tpm.groupby(name).sum(temp_n_nodes).loc[state]
+            temp_n_nodes.insert(b, name)
+            return marginalized
+
+        def a_affects_b_in_context(context):
+            """Return ``True`` if A has an effect on B, given a context."""
+            a_states = a_in_context(context)
+            comparator = marginalize_b(tuple(a_states[0])).data.round(12)
+            return any(not np.array_equal(comparator, marginalize_b(state).data.round(12)) for state in a_states[1:])
+
+        return any(a_affects_b_in_context(context) for context in contexts)
+
+    # Takes TPM object, could use self instead?
+    def infer_cm(tpm):
+        """Infer the connectivity matrix associated with a state-by-state TPM in
+        object form.
+        """
+        # Set up empty cm based on nodes
+        cm = np.empty((len(tpm._p_nodes), len(tpm._n_nodes)), dtype=int)
+        # Iterate through every node pair
+        for a, b in np.ndindex(cm.shape):
+            # Determine context states based on a
+            a_prime = tpm.p_nodes
+            a_prime.pop(a)
+            contexts = tuple(product(*tuple(tuple(range(tpm.all_nodes[node])) for node in a_prime)))
+            cm[a][b] = tpm.infer_edge(a, b, contexts)
+        return cm
+
+    def infer_node_edge(tpm, a, contexts):
+
+        def a_in_context(context):
+            """Given a context C(A), return the states of the full system with A
+            in each of its possible states, in order as a list.
+            """
+            a_states = [
+                context[:a] + (i, ) + context[a:]
+                for i in range(tpm.tpm.shape[a])
+            ]
+            return a_states
+
+        def a_affects_b_in_context(context):
+            a_states = np.array(a_in_context(context))
+            return any(not np.array_equal(tpm.tpm.data[tuple(a_states[0])], tpm.tpm.data[tuple(state)]) for state in a_states[1:])
+
+        return any(a_affects_b_in_context(context) for context in contexts)
+
+    def infer_node_cm(tpm):
+        cm = np.empty((len(tpm._p_nodes), 1), dtype=int)
+        
+        for a, b in np.ndindex(cm.shape):
+            # Determine context states based on a
+            a_prime = tpm.p_nodes
+            a_prime.pop(a)
+            contexts = tuple(product(*tuple(tuple(range(tpm.all_nodes[node])) for node in a_prime)))
+            cm[a][b] = tpm.infer_node_edge(a, contexts)
+
+        return cm
+
+    def __getitem__(self, key):
+        return self.tpm[key]
+
+    # TODO maybe not needed?
+    def expand_tpm(tpm):
+        """Broadcast a state-by-node TPM so that singleton dimensions are expanded
+        over the full network.
+        """
+        raise NotImplementedError
+
+    @classmethod
+    def from_numpy():
+        pass
+
+    @classmethod
+    def from_xarray():
+        pass
+
+    @classmethod
+    def from_pands():
+        pass
+
+    def copy(self):
+        p_states = [self.tpm.shape[self.tpm.dims.index(node)] for node in self.p_nodes]
+        n_states = [self.tpm.shape[self.tpm.dims.index(node)] for node in self.n_nodes]
+        return TPM(self.tpm, p_nodes=self._p_nodes, p_states=p_states, n_nodes=self._n_nodes, n_states=n_states)
+
+    def sum(self, dim, keepdims=True):
+        return self.tpm.sum(dim, keepdims)
+
+    def __repr__(self):
+        return self.tpm.__repr__()
+
+# TODO Better name?
+class SbN(TPM):
+    """The subclass of <SbN> represents a State-by-Node matrix, only usable for binary 
+    systems, where each row represents the probability of each column Node being ON 
+    during the timestep after the given row's state.
+    """
+    def __init__(self, tpm, p_nodes=None, p_states=None, n_nodes=None, n_states=None):
+            format = False
+
+            if isinstance(tpm, xr.DataArray):
+                tpm = tpm.data
+
+            # Not multi-dimensional SbN
+            if tpm.ndim == 2:
+                if tpm.shape[0] == tpm.shape[1]: # SbS
+                    format = True
+                elif n_nodes: # If n_nodes isn't given, we have to assume SbN unless we ask the user to specify a flag
+                    if len(n_nodes) != tpm.shape[1]: # SbS
+                        format = True
+
+            # If format is true, change from SbS to SbN
+            if format:
+                super().__init__(tpm, p_nodes, p_states, n_nodes, n_states)
+                #bin_node_tpms = [self.tpm.sel({node: 1}).sum(self.n_nodes.copy().remove(node)).expand_dims("nodes", axis=-1) 
+                #for node in self.n_nodes]
+                bin_node_tpms=[]
+                # TODO Is there a way to do this with some form of list comp? unfortunately using .copy()
+                # seems to break things if I try list comp
+                # TODO Can we use the coordinates property to name the nodes in the "nodes" dimension?
+                for node in self.n_nodes:
+                    temp = self.n_nodes.copy()
+                    temp.remove(node)
+                    bin_node_tpms.append(self.tpm.sel({node: 1}).sum(temp).expand_dims("n_nodes", axis=-1))
+                self.tpm = xr.concat(bin_node_tpms, dim="n_nodes")
+
+            # If format is false, is already in SbN form
+            else:     
+                if p_nodes is None:
+                    p_nodes = ["n{}".format(i) for i in range(int(log2(np.prod(tpm.shape[:-1]))))]
+
+                # NOTE: This will produce a different naming scheme in some instances than using
+                # the super constructor, not too big of a deal but worth noting 
+                # Fixing it would probably require unnecessary convolution, only worthwhile
+                # if it is actually an issue
+                if n_nodes is None:
+                    n_nodes = ["n{}".format(i) for i in range(tpm.shape[-1])]
+
+                self._p_nodes = p_nodes
+                # Differences: Shape is going to be (S_a, S_b, S_c... N), rows are like normal but index of last is the size of the n_nodes list
+                dims = self.p_nodes + ["n_nodes"]
+                
+                # Binary only, so num_states_per_node is tuple of 2s with length of p_nodes
+                p_states = [2] * len(p_nodes)
+
+                # Need to keep track of location of nodes in the last dimension
+                self._n_nodes = n_nodes
+                self.tpm = xr.DataArray(tpm.reshape(p_states + [len(n_nodes)], order="F"), dims=dims)
+
+            self.all_nodes = self.all_nodes = dict(zip(self.p_nodes + self.n_nodes, [2 for i in self.p_nodes + self.n_nodes]))
+
+    # TODO Only valid for symmetric tpms
+    def tpm_indices(self):
+        """Return the indices of nodes in the SbN."""
+        return tuple(np.where(np.array(self.tpm.shape[:-1]) == 2)[0])
+
+    def is_state_by_state(self):
+        return False
+
+    def get_node_transitions(self, state, external):
+        """Intended for node TPMs only; but viable for any.
+        Returns the np.array of the transition data for a given state.
+        """
+        if not external:
+            n_tpm = self.tpm
+            indexer = dict({self.tpm.dims[-1]: state})
+            return n_tpm.loc[indexer].data
+        
+        else:
+            return self.tpm.data
+
+
+    # SbN form
+    def infer_edge(tpm, a, b, contexts):
+        """Infer the presence or absence of an edge from node A to node B.
+
+        Let |S| be the set of all nodes in a network. Let |A' = S - {A}|. We call
+        the state of |A'| the context |C| of |A|. There is an edge from |A| to |B|
+        if there exists any context |C(A)| such that |Pr(B | C(A), A=0) != Pr(B |
+        C(A), A=1)|.
+
+        Args:
+            tpm (SbN): The TPM in state-by-node, multidimensional form.
+            a (int): The index of the putative source node.
+            b (int): The index of the putative sink node.
+
+        Returns:
+            bool: ``True`` if the edge |A -> B| exists, ``False`` otherwise.
+        """
+        def a_in_context(context):
+            """Given a context C(A), return the states of the full system with A
+            OFF (0) and ON (1), respectively.
+            """
+            a_off = context[:a] + (0, ) + context[a:]
+            a_on = context[:a] + (1, ) + context[a:]
+            return (a_off, a_on)
+
+        def a_affects_b_in_context(context):
+            """Return ``True`` if A has an effect on B, given a context."""
+            a_off, a_on = a_in_context(context)
+            return tpm.tpm[a_off][b] != tpm.tpm[a_on][b]
+
+        return any(a_affects_b_in_context(context) for context in contexts)
+    
+    # SbN form
+    def infer_cm(tpm):
+        """Infer the connectivity matrix associated with a SbN tpm in
+        multidimensional form.
+        """
+        network_size = tpm.tpm.shape[-1]
+        all_contexts = tuple(all_states(network_size - 1))
+        cm = np.empty((network_size, network_size), dtype=int)
+        for a, b in np.ndindex(cm.shape):
+            cm[a][b] = SbN.infer_edge(tpm, a, b, all_contexts)
+        return cm
+
+
+    def marginalize_out(self, node_indices):
+        """Marginalize out nodes from a TPM.
+
+        Args:
+            node_indices (list[int]): The indices of nodes to be marginalized out.
+            tpm (np.ndarray): The TPM to marginalize the node out of.
+
+        Returns:
+            np.ndarray: An SbN with the same number of dimensions, with the nodes
+            marginalized out.
+        """
+        def normalize(tpm):
+            return tpm / (np.array(self.tpm.shape)[list(node_indices)].prod())
+
+        marginalized_tpm = self.tpm 
+    
+        for label in [self.p_nodes[i] for i in node_indices]:
+            marginalized_tpm = marginalized_tpm.sum(dim=label, keepdims=True)
+
+        return normalize(marginalized_tpm)
+
+    def condition(self, fixed_nodes, state):
+        """Return a TPM conditioned on the given fixed node indices, whose states
+        are fixed according to the given state-tuple.
+
+        The dimensions of the new TPM that correspond to the fixed nodes are
+        collapsed onto their state, making those dimensions singletons suitable for
+        broadcasting. The number of dimensions of the conditioned TPM will be the
+        same as the unconditioned TPM.
+        """
+        fixed_node_labels = [self.p_nodes[node] for node in fixed_nodes]
+        indexer = dict(zip(fixed_node_labels, [state[node] for node in fixed_nodes]))
+        kept_dims = [self.tpm.dims.index(label) for label in fixed_node_labels]
+        # Throw out rows that don't fit with fixed node states, don't need to worry
+        # about columns as they are already essentially marginalized
+        conditioned_tpm = self.tpm.loc[indexer]
+
+        # Regrow dimensions where they got trimmed
+        for i in range(len(kept_dims)):
+            conditioned_tpm = conditioned_tpm.expand_dims(dim=fixed_node_labels[i], axis=kept_dims[i])
+        return conditioned_tpm
+
+    def copy(self):
+        return SbN(self.tpm, p_nodes=self._p_nodes, n_nodes=self._n_nodes)
+
+    def create_node_tpm(self, index, cm):
+        # Grab the column which holds the transition data for the desired node
+        indexer = dict({"n_nodes": index})
+        node_tpm = self.tpm.loc[indexer]
+
+        # Expand the dimension so that we can add more data
+        node_tpm = node_tpm.expand_dims(dim='n0_n', axis=-1)
+
+        # Pad the trimmed node_tpm, with NaN values with length 1 before
+        # The current data (as current data is when the node is at state 1)
+        mapping = dict({"n0_n": (1, 0)})
+        node_tpm = node_tpm.pad(mapping)
+
+        # Replace NaN data with data on when it will in state 0 (1 - state 1)
+        node_tpm[..., 0] = 1 - node_tpm[..., 1]
+
+        node_TPM = TPM(node_tpm, p_nodes=self._p_nodes).marginalize_out(tuple(set(self.p_node_indices) - set(get_inputs_from_cm(index, cm))), rows=True)
+        # Create new TPM object with this data
+        return node_TPM
+
+    def __repr__(self):
+        return 1
diff --git a/pyphi/compute/subsystem.py b/pyphi/compute/subsystem.py
index aaf07fb57..45a3a8133 100644
--- a/pyphi/compute/subsystem.py
+++ b/pyphi/compute/subsystem.py
@@ -231,7 +231,7 @@ def _ces(subsystem):
     return ces(subsystem, parallel=config.PARALLEL_CUT_EVALUATION)
 
 
-@memory.cache(ignore=["subsystem"])
+#@memory.cache(ignore=["subsystem"])
 @time_annotated
 def _sia(cache_key, subsystem):
     """Return the minimal information partition of a subsystem.
diff --git a/pyphi/distribution.py b/pyphi/distribution.py
index a11724f50..14f6c7048 100644
--- a/pyphi/distribution.py
+++ b/pyphi/distribution.py
@@ -104,7 +104,7 @@ def purview_size(repertoire):
     return len(purview(repertoire))
 
 
-def repertoire_shape(purview, N):  # pylint: disable=redefined-outer-name
+def repertoire_shape(purview, N, states=None):  # pylint: disable=redefined-outer-name
     """Return the shape a repertoire.
 
     Args:
@@ -122,8 +122,12 @@ def repertoire_shape(purview, N):  # pylint: disable=redefined-outer-name
         >>> repertoire_shape(purview, N)
         [2, 1, 2]
     """
-    # TODO: extend to non-binary nodes
-    return [2 if i in purview else 1 for i in range(N)]
+    #TODO Remove once fully transitioned, as then states should never be None
+    if states is not None:
+        return [states[i] if i in purview else 1 for i in range(N)]
+    
+    else:
+        return [2 if i in purview else 1 for i in range(N)]
 
 
 def flatten(repertoire, big_endian=False):
@@ -173,7 +177,7 @@ def unflatten(repertoire, purview, N, big_endian=False):
 
 
 @cache(cache={}, maxmem=None)
-def max_entropy_distribution(node_indices, number_of_nodes):
+def max_entropy_distribution(node_indices, number_of_nodes, states=None):
     """Return the maximum entropy distribution over a set of nodes.
 
     This is different from the network's uniform distribution because nodes
@@ -184,10 +188,12 @@ def max_entropy_distribution(node_indices, number_of_nodes):
         node_indices (tuple[int]): The set of node indices over which to take
             the distribution.
         number_of_nodes (int): The total number of nodes in the network.
+        states (tuple): The states of each node in the network
 
     Returns:
         np.ndarray: The maximum entropy distribution over the set of nodes.
     """
-    distribution = np.ones(repertoire_shape(node_indices, number_of_nodes))
+    # TODO Remove once fully transitioned
+    distribution = np.ones(repertoire_shape(node_indices, number_of_nodes, states))
 
     return distribution / distribution.size
diff --git a/pyphi/network.py b/pyphi/network.py
index f7d2e4271..2cb921fb8 100644
--- a/pyphi/network.py
+++ b/pyphi/network.py
@@ -9,9 +9,10 @@
 
 import numpy as np
 
-from . import cache, config, connectivity, convert, jsonify, utils, validate
+from . import cache, config, connectivity, convert, jsonify, utils, validate, node
 from .labels import NodeLabels
 from .tpm import is_state_by_state
+from .__tpm import TPM, SbN
 
 
 class Network:
@@ -47,8 +48,10 @@ class Network:
             that node |i| is connected to node |j| (see :ref:`cm-conventions`).
             **If no connectivity matrix is given, PyPhi assumes that every node
             is connected to every node (including itself)**.
-        node_labels (tuple[str] or |NodeLabels|): Human-readable labels for
-            each node in the network.
+        p_nodes (list[str]): Human-readable list of names of nodes at time |t-1|
+        p_states (list[int]): List of the number of states of each node at time |t-1| (necessary for defining a multi-valued tpm)
+        n_nodes (list[str]): Human-readable list of names of nodes at time |t|
+        n_states (list[int]): List of the number of states of each node at time |t| (necessary for defining a multi-valued tpm)
 
     Example:
         In a 3-node network, ``the_network.tpm[(0, 0, 1)]`` gives the
@@ -57,13 +60,40 @@ class Network:
     """
 
     # TODO make tpm also optional when implementing logical network definition
-    def __init__(self, tpm, cm=None, node_labels=None, purview_cache=None):
-        self._tpm, self._tpm_hash = self._build_tpm(tpm)
-        self._cm, self._cm_hash = self._build_cm(cm)
-        self._node_indices = tuple(range(self.size))
-        self._node_labels = NodeLabels(node_labels, self._node_indices)
+    # TODO node_labels attribute deprecated, but many tests use it so currently keeping it an option
+    def __init__(self, tpm, cm=None, p_nodes=None, p_states=None, n_nodes=None, n_states=None, purview_cache=None, node_labels=None):
+        if isinstance(tpm, list): 
+            self._is_list = True
+            if node_labels:
+                p_nodes = node_labels
+            self._tpm = tpm
+            self._cm = self._build_cm_from_list(cm)
+            self._node_indices = tuple(range(len(tpm)))
+            # User can specify names of each node to be lined up with the list of node TPMs, else default labels are generated
+            self._node_labels = NodeLabels(p_nodes, self._node_indices)
+            
+        else:
+            self._is_list = False
+            if node_labels:
+                p_nodes = node_labels
+            if p_states or n_states: # Requires NB state-by-state, could just do only TPM and convert to SbN later?
+                self._tpm = TPM(tpm, p_nodes, p_states, n_nodes, n_states)
+                #validate.tpm(tpm, check_independence=config.VALIDATE_CONDITIONAL_INDEPENDENCE)
+            else:
+                self._tpm = SbN(tpm, p_nodes, p_states, n_nodes, n_states)
+
+            self._cm = self._tpm.infer_cm()
+            # Convert to list
+            tpm_list = [self._tpm.create_node_tpm(index, self._cm) for index in range(len(self._tpm.n_nodes))]
+
+            self._node_indices = tuple(range(self.size))
+            self._node_labels = NodeLabels(self._tpm._p_nodes, self.node_indices) # TODO consider using self._tpm._p_nodes instead for more readability?
+
+            self._tpm = tpm_list
+            self._is_list = True
+        
         self.purview_cache = purview_cache or cache.PurviewCache()
-
+         
         validate.network(self)
 
     @property
@@ -73,6 +103,7 @@ def tpm(self):
         """
         return self._tpm
 
+    # TODO Deprecated?
     @staticmethod
     def _build_tpm(tpm):
         """Validate the TPM passed by the user and convert to multidimensional
@@ -92,6 +123,20 @@ def _build_tpm(tpm):
 
         return (tpm, utils.np_hash(tpm))
 
+    def _build_cm_from_list(self, cm):
+        """Generate the connectivity matrix for a network whose tpm is defined as a 
+        list of Node TPMs, by concatenating the individual Node TPM cms.
+
+        Args:
+            tpm_list (list[TPM]): List of Node TPMs that define how the network transitions.
+        """
+        if cm is None:
+            cm_list = [TPM.infer_node_cm(node_tpm) for node_tpm in self._tpm] 
+            return np.concatenate(cm_list, axis=1)
+        else:
+            return np.array(cm)
+        # return np.ones((self.size, self.size))
+
     @property
     def cm(self):
         """np.ndarray: The network's connectivity matrix.
@@ -106,8 +151,8 @@ def _build_cm(self, cm):
         unitary CM if none was provided.
         """
         if cm is None:
-            # Assume all are connected.
-            cm = np.ones((self.size, self.size))
+            # Build cm from TPM method
+            cm = self._tpm.infer_cm()
         else:
             cm = np.array(cm)
 
@@ -134,7 +179,15 @@ def size(self):
     @property
     def num_states(self):
         """int: The number of possible states of the network."""
-        return 2 ** self.size
+        # If list, states can be counted as product of possible transitions of each node
+        # If not, use TPM.num_states?
+        # if self._is_list:
+        num = 1
+        for tpm in self._tpm:
+            num *= tpm.shape[-1]
+        return num
+        # else:
+        #    return self._tpm.num_states
 
     @property
     def node_indices(self):
@@ -169,10 +222,19 @@ def potential_purviews(self, direction, mechanism):
 
     def __len__(self):
         """int: The number of nodes in the network."""
-        return self.tpm.shape[-1]
+        if self._is_list:
+            return len(self._tpm)
+        elif isinstance(self.tpm, TPM):
+            # TODO Assumes symmetry for now
+            return len(self.tpm.p_nodes)
+        else:
+            return self.tpm.shape[-1]
 
     def __repr__(self):
-        return "Network({}, cm={})".format(self.tpm, self.cm)
+        if self._is_list:
+            return str([tpm for tpm in self._tpm]) #TODO Consider representations
+        else:
+            return "Network({}, cm={})".format(self.tpm, self.cm)
 
     def __str__(self):
         return self.__repr__()
@@ -182,6 +244,8 @@ def __eq__(self, other):
 
         Networks are equal if they have the same TPM and CM.
         """
+        if self._is_list:
+            return np.all([node_tpm == node_tpm for node_tpm in self._tpm])
         return (
             isinstance(other, Network)
             and np.array_equal(self.tpm, other.tpm)
diff --git a/pyphi/node.py b/pyphi/node.py
index c495e5e85..a80a82d2f 100644
--- a/pyphi/node.py
+++ b/pyphi/node.py
@@ -15,6 +15,7 @@
 from .connectivity import get_inputs_from_cm, get_outputs_from_cm
 from .labels import NodeLabels
 from .tpm import marginalize_out, tpm_indices
+from .__tpm import TPM, SbN
 
 
 # TODO extend to nonbinary nodes
@@ -56,41 +57,10 @@ def __init__(self, tpm, cm, index, state, node_labels):
         # Get indices of the inputs.
         self._inputs = frozenset(get_inputs_from_cm(self.index, cm))
         self._outputs = frozenset(get_outputs_from_cm(self.index, cm))
+        
+        # Node TPM has already been generated, take from subsystem list
+        self.tpm = tpm[self.index]
 
-        # Generate the node's TPM.
-        # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-        # We begin by getting the part of the subsystem's TPM that gives just
-        # the state of this node. This part is still indexed by network state,
-        # but its last dimension will be gone, since now there's just a single
-        # scalar value (this node's state) rather than a state-vector for all
-        # the network nodes.
-        tpm_on = tpm[..., self.index]
-
-        # TODO extend to nonbinary nodes
-        # Marginalize out non-input nodes that are in the subsystem, since the
-        # external nodes have already been dealt with as boundary conditions in
-        # the subsystem's TPM.
-        non_inputs = set(tpm_indices(tpm)) - self._inputs
-        tpm_on = marginalize_out(non_inputs, tpm_on)
-
-        # Get the TPM that gives the probability of the node being off, rather
-        # than on.
-        tpm_off = 1 - tpm_on
-
-        # Combine the on- and off-TPM so that the first dimension is indexed by
-        # the state of the node's inputs at t, and the last dimension is
-        # indexed by the node's state at t+1. This representation makes it easy
-        # to condition on the node state.
-        self.tpm = np.stack([tpm_off, tpm_on], axis=-1)
-        # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-
-        # Make the TPM immutable (for hashing).
-        utils.np_immutable(self.tpm)
-
-        # Only compute the hash once.
-        self._hash = hash(
-            (index, utils.np_hash(self.tpm), self.state, self._inputs, self._outputs)
-        )
 
     @property
     def tpm_off(self):
@@ -160,7 +130,7 @@ def generate_nodes(tpm, cm, network_state, indices, node_labels=None):
     """Generate |Node| objects for a subsystem.
 
     Args:
-        tpm (np.ndarray): The system's TPM
+        tpm (pyphi.TPM): The system's TPM
         cm (np.ndarray): The corresponding CM.
         network_state (tuple): The state of the network.
         indices (tuple[int]): Indices to generate nodes for.
diff --git a/pyphi/subsystem.py b/pyphi/subsystem.py
index 8b2c5862f..741dafa2f 100644
--- a/pyphi/subsystem.py
+++ b/pyphi/subsystem.py
@@ -24,6 +24,7 @@
 from .node import generate_nodes
 from .partition import mip_partitions
 from .tpm import condition_tpm, marginalize_out
+from .__tpm import TPM, SbN
 from .utils import time_annotated
 
 log = logging.getLogger(__name__)
@@ -68,8 +69,9 @@ def __init__(
         # The network this subsystem belongs to.
         validate.is_network(network)
         self.network = network
+        
 
-        self.node_labels = network.node_labels
+        self.node_labels = network._node_labels
         # Remove duplicates, sort, and ensure native Python `int`s
         # (for JSON serialization).
         self.node_indices = self.node_labels.coerce_to_indices(nodes)
@@ -87,9 +89,9 @@ def __init__(
             )
         else:
             self.external_indices = _external_indices
-
-        # The TPM conditioned on the state of the external nodes.
-        self.tpm = condition_tpm(self.network.tpm, self.external_indices, self.state)
+            # Remove external nodes from subsystem
+            self.node_indices = tuple([node for node in self.node_indices if node not in self.external_indices])
+            self.node_labels = tuple([self.node_labels[i] for i in self.node_indices])
 
         # The unidirectional cut applied for phi evaluation
         self.cut = (
@@ -99,9 +101,25 @@ def __init__(
         # The network's connectivity matrix with cut applied
         self.cm = self.cut.apply_cut(network.cm)
 
+            
+        # Condition TPM by conditioning every tpm in the list
+        # Thought: Could generalize this by having Network always use a list, just a list with one object if 
+        # Full tpm is described... except atm condition_list assumes node TPMs because condition doesn't work for asym
+        # Here we need to make a copy of the TPM objects from the network, or else only one Subsystem can be generated
+        # Could also be done further downstream, but both function the same.
+        # TODO Can we make copies unnecessary?
+        if nodes is not ():
+            self.tpm = [tpm.copy() for tpm in network.tpm]
+            self.tpm = TPM.condition_list(self.tpm, self.external_indices, self.state) # Returns a list of conditioned Node TPMs
+        else:
+            # Empty subsystem, no nodes, 100% chance to stay in the state selected always
+            self.tpm = TPM(tpm = np.array([[1]]))
+
+        self.states = tuple([tpm.tpm.shape[-1] for tpm in self.tpm]) # Tuple of number of states each node can transition to
+
         # Reusable cache for maximally-irreducible causes and effects
         self._mice_cache = cache.MICECache(self, mice_cache)
-
+        
         # Cause & effect repertoire caches
         # TODO: if repertoire caches are never reused, there's no reason to
         # have an accesible object-level cache. Just use a simple memoizer
@@ -110,11 +128,12 @@ def __init__(
         )
         self._repertoire_cache = repertoire_cache or cache.DictCache()
 
+        # Node objects
         self.nodes = generate_nodes(
             self.tpm, self.cm, self.state, self.node_indices, self.node_labels
         )
 
-        validate.subsystem(self)
+        #validate.subsystem(self)
 
     @property
     def nodes(self):
@@ -182,7 +201,7 @@ def cut_node_labels(self):
     @property
     def tpm_size(self):
         """int: The number of nodes in the TPM."""
-        return self.tpm.shape[-1]
+        return len(self.tpm)
 
     def cache_info(self):
         """Report repertoire cache statistics."""
@@ -292,19 +311,29 @@ def indices2nodes(self, indices):
             raise ValueError("`indices` must be a subset of the Subsystem's indices.")
         return tuple(self._index2node[n] for n in indices)
 
-    # TODO extend to nonbinary nodes
     @cache.method("_single_node_repertoire_cache", Direction.CAUSE)
     def _single_node_cause_repertoire(self, mechanism_node_index, purview):
         # pylint: disable=missing-docstring
         mechanism_node = self._index2node[mechanism_node_index]
-        # We're conditioning on this node's state, so take the TPM for the node
-        # being in that state.
-        tpm = mechanism_node.tpm[..., mechanism_node.state]
-        # Marginalize-out all parents of this mechanism node that aren't in the
-        # purview.
-        return marginalize_out((mechanism_node.inputs - purview), tpm)
-
-    # TODO extend to nonbinary nodes
+        # Set up some different paths while transitioning
+        if isinstance(mechanism_node.tpm, TPM):
+            # Valid to use n_nodes[0] as these are guaranteed node_tpms, therefore only have 1 n_node, itself
+            indexer = dict({mechanism_node.tpm.n_nodes[0]: mechanism_node.state})
+            #TODO DEBUGGING
+            print(indexer)
+            print("purview:", purview)
+            print('inputs:', mechanism_node.inputs)
+            print("to_marginalize:", mechanism_node.inputs - purview)
+            tpm = mechanism_node.tpm.marginalize_out((mechanism_node.inputs - purview), rows=True).tpm.loc[indexer]
+            return tpm.data
+        else:
+            # We're conditioning on this node's state, so take the TPM for the node
+            # being in that state.
+            tpm = mechanism_node.tpm[..., mechanism_node.state]
+            # Marginalize-out all parents of this mechanism node that aren't in the
+            # purview.
+            return marginalize_out((mechanism_node.inputs - purview), tpm)
+
     @cache.method("_repertoire_cache", Direction.CAUSE)
     def cause_repertoire(self, mechanism, purview):
         """Return the cause repertoire of a mechanism over a purview.
@@ -330,15 +359,21 @@ def cause_repertoire(self, mechanism, purview):
         # state of the purview; return the purview's maximum entropy
         # distribution.
         if not mechanism:
-            return max_entropy_distribution(purview, self.tpm_size)
+            return max_entropy_distribution(purview, self.tpm_size, self.states)
         # Use a frozenset so the arguments to `_single_node_cause_repertoire`
         # can be hashed and cached.
         purview = frozenset(purview)
         # Preallocate the repertoire with the proper shape, so that
         # probabilities are broadcasted appropriately.
-        joint = np.ones(repertoire_shape(purview, self.tpm_size))
+        joint = np.ones(repertoire_shape(purview, self.tpm_size, self.states))
         # The cause repertoire is the product of the cause repertoires of the
         # individual nodes.
+        #TODO DEBUGGING
+        print("joint shape before:", joint.shape)
+        print("mechanism:", mechanism)
+        print("product shape:", functools.reduce(np.multiply, [self._single_node_cause_repertoire(m, purview) for m in mechanism]).shape)
+        print([self._single_node_cause_repertoire(m, purview) for m in mechanism])
+        print("purview:", purview)
         joint *= functools.reduce(
             np.multiply,
             [self._single_node_cause_repertoire(m, purview) for m in mechanism],
@@ -348,19 +383,29 @@ def cause_repertoire(self, mechanism, purview):
         # TPM don't necessarily sum to 1, so we normalize.
         return distribution.normalize(joint)
 
-    # TODO extend to nonbinary nodes
     @cache.method("_single_node_repertoire_cache", Direction.EFFECT)
     def _single_node_effect_repertoire(self, mechanism, purview_node_index):
         # pylint: disable=missing-docstring
         purview_node = self._index2node[purview_node_index]
         # Condition on the state of the inputs that are in the mechanism.
         mechanism_inputs = purview_node.inputs & mechanism
-        tpm = condition_tpm(purview_node.tpm, mechanism_inputs, self.state)
         # Marginalize-out the inputs that aren't in the mechanism.
         nonmechanism_inputs = purview_node.inputs - mechanism
-        tpm = marginalize_out(nonmechanism_inputs, tpm)
-        # Reshape so that the distribution is over next states.
-        return tpm.reshape(repertoire_shape([purview_node.index], self.tpm_size))
+
+        # Set up different path for TPM objects while transitioning
+        if isinstance(purview_node.tpm, TPM):
+            tpm = purview_node.tpm.condition_node_tpm(mechanism_inputs, self.state)
+            # Doesn't do cols so gotta do that manually again
+            tpm = purview_node.tpm.marginalize_out(nonmechanism_inputs, rows=True)
+            indexer = dict({purview_node.tpm.n_nodes[0]: purview_node.state})
+            #tpm.reshape(repertoire_shape([purview_node.index], self.tpm_size, self.states))
+            return tpm.tpm.loc[indexer].data
+
+        else:
+            tpm = condition_tpm(purview_node.tpm, mechanism_inputs, self.state)    
+            tpm = marginalize_out(nonmechanism_inputs, tpm)
+            # Reshape so that the distribution is over next states.
+            return tpm.reshape(repertoire_shape([purview_node.index], self.tpm_size))
 
     @cache.method("_repertoire_cache", Direction.EFFECT)
     def effect_repertoire(self, mechanism, purview):
@@ -389,7 +434,7 @@ def effect_repertoire(self, mechanism, purview):
         mechanism = frozenset(mechanism)
         # Preallocate the repertoire with the proper shape, so that
         # probabilities are broadcasted appropriately.
-        joint = np.ones(repertoire_shape(purview, self.tpm_size))
+        joint = np.ones(repertoire_shape(purview, self.tpm_size, self.states))
         # The effect repertoire is the product of the effect repertoires of the
         # individual nodes.
         return joint * functools.reduce(
diff --git a/pyphi/validate.py b/pyphi/validate.py
index 125e90b34..bef826811 100644
--- a/pyphi/validate.py
+++ b/pyphi/validate.py
@@ -10,6 +10,8 @@
 
 from . import Direction, config, convert, exceptions
 from .tpm import is_state_by_state
+from itertools import product
+from pyphi.__tpm import TPM
 
 # pylint: disable=redefined-outer-name
 
@@ -29,7 +31,7 @@ def direction(direction, allow_bi=False):
 
     return True
 
-
+# TODO eliminate non-object route? 
 def tpm(tpm, check_independence=True):
     """Validate a TPM.
 
@@ -61,6 +63,8 @@ def tpm(tpm, check_independence=True):
             )
         if tpm.shape[0] == tpm.shape[1] and check_independence:
             conditionally_independent(tpm)
+        elif isinstance(tpm, TPM) and check_independence:
+            conditionally_independent_obj(tpm)
     elif tpm.ndim == (N + 1):
         if tpm.shape != tuple([2] * N + [N]):
             raise ValueError(
@@ -76,6 +80,37 @@ def tpm(tpm, check_independence=True):
         )
     return True
 
+def conditionally_independent_obj(tpm):
+    # Step 0: Cartesian products of potential state combos of the p_nodes and n_nodes for later access
+    p_states = tuple(product(*(tuple(range(tpm.all_nodes[node])) for node in tpm.p_nodes)))
+    n_states = tuple(product(*(tuple(range(tpm.all_nodes[node])) for node in tpm.n_nodes)))
+
+    # Step 1: For each p combo, calculate the state probabilities of each node in the columns
+    for state in p_states:
+        node_distributions = []
+        # ... would rather not have to use this copy method again (as in infer_cm) but python refuses to let me work with
+        # the list like I need to
+        temp_n_nodes = tpm.n_nodes.copy()
+        for n_node in tpm.n_nodes:
+            index = temp_n_nodes.index(n_node)
+            temp_n_nodes.remove(n_node)
+            node_distributions.append(tpm.tpm.groupby(n_node).sum(temp_n_nodes).loc[state])
+            temp_n_nodes.insert(index, n_node)
+
+        # Step 2: Generate the conditionally independent element, then compare with actual element
+
+        for n_state in n_states:
+            independent_probability = 1
+            for node_index in range(len(node_distributions)):
+                independent_probability *= node_distributions[node_index][n_state[node_index]]
+            
+            if tpm[state + n_state] != independent_probability:
+                raise exceptions.ConditionallyDependentError(
+                    "TPM is not conditionally independent.\n"
+                    "See the conditional independence example in the documentation "
+                    "for more info."
+                )
+    return True
 
 def conditionally_independent(tpm):
     """Validate that the TPM is conditionally independent."""
@@ -104,8 +139,6 @@ def connectivity_matrix(cm):
         return True
     if cm.ndim != 2:
         raise ValueError("Connectivity matrix must be 2-dimensional.")
-    if cm.shape[0] != cm.shape[1]:
-        raise ValueError("Connectivity matrix must be square.")
     if not np.all(np.logical_or(cm == 1, cm == 0)):
         raise ValueError("Connectivity matrix must contain only binary " "values.")
     return True
@@ -127,7 +160,7 @@ def network(n):
 
     Checks the TPM and connectivity matrix.
     """
-    tpm(n.tpm)
+    # tpm(n.tpm) TODO: Generating a valid TPM list implies the tpm is valid, only 'needed' if network given a non-list
     connectivity_matrix(n.cm)
     if n.cm.shape[0] != n.size:
         raise ValueError(
@@ -171,12 +204,38 @@ def state_reachable(subsystem):
     # reached from some state.
     # First we take the submatrix of the conditioned TPM that corresponds to
     # the nodes that are actually in the subsystem...
-    tpm = subsystem.tpm[..., subsystem.node_indices]
+    #tpm = subsystem.tpm[..., subsystem.node_indices]
     # Then we do the subtraction and test.
-    test = tpm - np.array(subsystem.proper_state)
-    if not np.any(np.logical_and(-1 < test, test < 1).all(-1)):
-        raise exceptions.StateUnreachableError(subsystem.state)
-
+    #test = tpm - np.array(subsystem.proper_state)
+    #if not np.any(np.logical_and(-1 < test, test < 1).all(-1)):
+    #    raise exceptions.StateUnreachableError(subsystem.state)
+
+    # If there exists a row in which every node in the subsystem
+    # Has a positive probability of being in that state, there is
+    # a non-zero chance of reaching that state.
+    
+    # Create state tuples to look thru
+    # If the node is in the subsystem, the tuple will include all its potential states
+    # If it is an external node, however, the tpms have been conditioned to a single state
+    # So access that state from the subsystem and use it in the tuple.
+    p_states = tuple(product(
+        *(tuple(range(subsystem.states[node])) if node in subsystem.node_indices else (subsystem.state[node], ) for node in subsystem.network.node_indices)
+        ))
+
+    reachable = False
+    for p_state in p_states:
+        # Reset reachable so that each state can be tried
+        reachable = True
+        for node in subsystem.node_indicies:
+            if subsystem.tpm[node][p_state + (subsystem.state[node], )] == 0:
+                # If there's any 0 probability, state can't be reached here so break
+                reachable = False
+                break
+        # If there is ever a point where reachable remains True, then it is reachable
+        if reachable:
+            return True    
+    # If no state led to it being reachable, raise the exception  
+    raise exceptions.StateUnreachableError(subsystem.state)
 
 def cut(cut, node_indices):
     """Check that the cut is for only the given nodes."""
diff --git a/pyphi_config.yml b/pyphi_config.yml
index 654086c34..4212dac8f 100644
--- a/pyphi_config.yml
+++ b/pyphi_config.yml
@@ -18,7 +18,7 @@ ASSUME_CUTS_CANNOT_CREATE_NEW_CONCEPTS: false
 # modular, sparsely-connected, or homogeneous networks.
 CUT_ONE_APPROXIMATION: false
 # The measure to use when computing phi ("EMD", "KLD", "L1", ...)
-MEASURE: "EMD"
+MEASURE: "EMD" #KLM for nonbinary calculations
 # Controls the number of parts in a partition.
 PARTITION_TYPE: "BI"
 # Controls how to resolve phi-ties when computing MICE.
diff --git a/test/mice_test.ipynb b/test/mice_test.ipynb
new file mode 100644
index 000000000..1df917511
--- /dev/null
+++ b/test/mice_test.ipynb
@@ -0,0 +1,601 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "\n",
+      "Welcome to PyPhi!\n",
+      "\n",
+      "If you use PyPhi in your research, please cite the paper:\n",
+      "\n",
+      "  Mayner WGP, Marshall W, Albantakis L, Findlay G, Marchman R, Tononi G.\n",
+      "  (2018). PyPhi: A toolbox for integrated information theory.\n",
+      "  PLOS Computational Biology 14(7): e1006343.\n",
+      "  https://doi.org/10.1371/journal.pcbi.1006343\n",
+      "\n",
+      "Documentation is available online (or with the built-in `help()` function):\n",
+      "  https://pyphi.readthedocs.io\n",
+      "\n",
+      "To report issues, please use the issue tracker on the GitHub repository:\n",
+      "  https://github.com/wmayner/pyphi\n",
+      "\n",
+      "For general discussion, you are welcome to join the pyphi-users group:\n",
+      "  https://groups.google.com/forum/#!forum/pyphi-users\n",
+      "\n",
+      "To suppress this message, either:\n",
+      "  - Set `WELCOME_OFF: true` in your `pyphi_config.yml` file, or\n",
+      "  - Set the environment variable PYPHI_WELCOME_OFF to any value in your shell:\n",
+      "        export PYPHI_WELCOME_OFF='yes'\n",
+      "\n"
+     ]
+    }
+   ],
+   "source": [
+    "import pyphi\n",
+    "import test_subsystem_phi_max\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "{None: {<Direction.CAUSE: 0>: [Maximally-irreducible cause\n",
+       "     φ = 1/2\n",
+       "     Mechanism: [B]\n",
+       "     Purview = [C]\n",
+       "     Direction: CAUSE\n",
+       "     MIP:\n",
+       "        ∅     B \n",
+       "       ─── ✕ ───\n",
+       "        C     ∅ \n",
+       "     Repertoire:\n",
+       "       ┌─────────────┐\n",
+       "       │ S    Pr(S)  │\n",
+       "       │ ╴╴╴╴╴╴╴╴╴╴╴ │\n",
+       "       │ 0    1      │\n",
+       "       │ 1    0      │\n",
+       "       └─────────────┘\n",
+       "     Partitioned repertoire:\n",
+       "       ┌─────────────┐\n",
+       "       │ S    Pr(S)  │\n",
+       "       │ ╴╴╴╴╴╴╴╴╴╴╴ │\n",
+       "       │ 0    1/2    │\n",
+       "       │ 1    1/2    │\n",
+       "       └─────────────┘,\n",
+       "   Maximally-irreducible cause\n",
+       "     φ = 1/2\n",
+       "     Mechanism: [C]\n",
+       "     Purview = [A, B]\n",
+       "     Direction: CAUSE\n",
+       "     MIP:\n",
+       "        ∅     C \n",
+       "       ─── ✕ ───\n",
+       "        A     B \n",
+       "     Repertoire:\n",
+       "       ┌──────────────┐\n",
+       "       │ S     Pr(S)  │\n",
+       "       │ ╴╴╴╴╴╴╴╴╴╴╴╴ │\n",
+       "       │ 00    1/2    │\n",
+       "       │ 10    0      │\n",
+       "       │ 01    0      │\n",
+       "       │ 11    1/2    │\n",
+       "       └──────────────┘\n",
+       "     Partitioned repertoire:\n",
+       "       ┌──────────────┐\n",
+       "       │ S     Pr(S)  │\n",
+       "       │ ╴╴╴╴╴╴╴╴╴╴╴╴ │\n",
+       "       │ 00    1/4    │\n",
+       "       │ 10    1/4    │\n",
+       "       │ 01    1/4    │\n",
+       "       │ 11    1/4    │\n",
+       "       └──────────────┘,\n",
+       "   Maximally-irreducible cause\n",
+       "     φ = 1/3\n",
+       "     Mechanism: [A, B]\n",
+       "     Purview = [B, C]\n",
+       "     Direction: CAUSE\n",
+       "     MIP:\n",
+       "        A     B \n",
+       "       ─── ✕ ───\n",
+       "        B     C \n",
+       "     Repertoire:\n",
+       "       ┌──────────────┐\n",
+       "       │ S     Pr(S)  │\n",
+       "       │ ╴╴╴╴╴╴╴╴╴╴╴╴ │\n",
+       "       │ 00    0      │\n",
+       "       │ 10    1      │\n",
+       "       │ 01    0      │\n",
+       "       │ 11    0      │\n",
+       "       └──────────────┘\n",
+       "     Partitioned repertoire:\n",
+       "       ┌──────────────┐\n",
+       "       │ S     Pr(S)  │\n",
+       "       │ ╴╴╴╴╴╴╴╴╴╴╴╴ │\n",
+       "       │ 00    1/3    │\n",
+       "       │ 10    2/3    │\n",
+       "       │ 01    0      │\n",
+       "       │ 11    0      │\n",
+       "       └──────────────┘,\n",
+       "   Maximally-irreducible cause\n",
+       "     φ = 1/2\n",
+       "     Mechanism: [A, B, C]\n",
+       "     Purview = [A, B, C]\n",
+       "     Direction: CAUSE\n",
+       "     MIP:\n",
+       "        ∅    A,B,C\n",
+       "       ─── ✕ ─────\n",
+       "        A     B,C \n",
+       "     Repertoire:\n",
+       "       ┌───────────────┐\n",
+       "       │  S     Pr(S)  │\n",
+       "       │ ╴╴╴╴╴╴╴╴╴╴╴╴╴ │\n",
+       "       │ 000    0      │\n",
+       "       │ 100    0      │\n",
+       "       │ 010    0      │\n",
+       "       │ 110    1      │\n",
+       "       │ 001    0      │\n",
+       "       │ 101    0      │\n",
+       "       │ 011    0      │\n",
+       "       │ 111    0      │\n",
+       "       └───────────────┘\n",
+       "     Partitioned repertoire:\n",
+       "       ┌───────────────┐\n",
+       "       │  S     Pr(S)  │\n",
+       "       │ ╴╴╴╴╴╴╴╴╴╴╴╴╴ │\n",
+       "       │ 000    0      │\n",
+       "       │ 100    0      │\n",
+       "       │ 010    1/2    │\n",
+       "       │ 110    1/2    │\n",
+       "       │ 001    0      │\n",
+       "       │ 101    0      │\n",
+       "       │ 011    0      │\n",
+       "       │ 111    0      │\n",
+       "       └───────────────┘],\n",
+       "  <Direction.EFFECT: 1>: [Maximally-irreducible effect\n",
+       "     φ = 1/4\n",
+       "     Mechanism: [B]\n",
+       "     Purview = [A]\n",
+       "     Direction: EFFECT\n",
+       "     MIP:\n",
+       "        ∅     B \n",
+       "       ─── ✕ ───\n",
+       "        A     ∅ \n",
+       "     Repertoire:\n",
+       "       ┌─────────────┐\n",
+       "       │ S    Pr(S)  │\n",
+       "       │ ╴╴╴╴╴╴╴╴╴╴╴ │\n",
+       "       │ 0    1/2    │\n",
+       "       │ 1    1/2    │\n",
+       "       └─────────────┘\n",
+       "     Partitioned repertoire:\n",
+       "       ┌─────────────┐\n",
+       "       │ S    Pr(S)  │\n",
+       "       │ ╴╴╴╴╴╴╴╴╴╴╴ │\n",
+       "       │ 0    1/4    │\n",
+       "       │ 1    3/4    │\n",
+       "       └─────────────┘,\n",
+       "   Maximally-irreducible effect\n",
+       "     φ = 1/2\n",
+       "     Mechanism: [C]\n",
+       "     Purview = [B]\n",
+       "     Direction: EFFECT\n",
+       "     MIP:\n",
+       "        ∅     C \n",
+       "       ─── ✕ ───\n",
+       "        B     ∅ \n",
+       "     Repertoire:\n",
+       "       ┌─────────────┐\n",
+       "       │ S    Pr(S)  │\n",
+       "       │ ╴╴╴╴╴╴╴╴╴╴╴ │\n",
+       "       │ 0    1      │\n",
+       "       │ 1    0      │\n",
+       "       └─────────────┘\n",
+       "     Partitioned repertoire:\n",
+       "       ┌─────────────┐\n",
+       "       │ S    Pr(S)  │\n",
+       "       │ ╴╴╴╴╴╴╴╴╴╴╴ │\n",
+       "       │ 0    1/2    │\n",
+       "       │ 1    1/2    │\n",
+       "       └─────────────┘,\n",
+       "   Maximally-irreducible effect\n",
+       "     φ = 1/2\n",
+       "     Mechanism: [A, B]\n",
+       "     Purview = [C]\n",
+       "     Direction: EFFECT\n",
+       "     MIP:\n",
+       "        ∅    A,B\n",
+       "       ─── ✕ ───\n",
+       "        C     ∅ \n",
+       "     Repertoire:\n",
+       "       ┌─────────────┐\n",
+       "       │ S    Pr(S)  │\n",
+       "       │ ╴╴╴╴╴╴╴╴╴╴╴ │\n",
+       "       │ 0    0      │\n",
+       "       │ 1    1      │\n",
+       "       └─────────────┘\n",
+       "     Partitioned repertoire:\n",
+       "       ┌─────────────┐\n",
+       "       │ S    Pr(S)  │\n",
+       "       │ ╴╴╴╴╴╴╴╴╴╴╴ │\n",
+       "       │ 0    1/2    │\n",
+       "       │ 1    1/2    │\n",
+       "       └─────────────┘,\n",
+       "   Maximally-irreducible effect\n",
+       "     φ = 1/2\n",
+       "     Mechanism: [A, B, C]\n",
+       "     Purview = [A, B, C]\n",
+       "     Direction: EFFECT\n",
+       "     MIP:\n",
+       "        ∅    A,B,C\n",
+       "       ─── ✕ ─────\n",
+       "        B     A,C \n",
+       "     Repertoire:\n",
+       "       ┌───────────────┐\n",
+       "       │  S     Pr(S)  │\n",
+       "       │ ╴╴╴╴╴╴╴╴╴╴╴╴╴ │\n",
+       "       │ 000    0      │\n",
+       "       │ 100    0      │\n",
+       "       │ 010    0      │\n",
+       "       │ 110    0      │\n",
+       "       │ 001    1      │\n",
+       "       │ 101    0      │\n",
+       "       │ 011    0      │\n",
+       "       │ 111    0      │\n",
+       "       └───────────────┘\n",
+       "     Partitioned repertoire:\n",
+       "       ┌───────────────┐\n",
+       "       │  S     Pr(S)  │\n",
+       "       │ ╴╴╴╴╴╴╴╴╴╴╴╴╴ │\n",
+       "       │ 000    0      │\n",
+       "       │ 100    0      │\n",
+       "       │ 010    0      │\n",
+       "       │ 110    0      │\n",
+       "       │ 001    1/2    │\n",
+       "       │ 101    0      │\n",
+       "       │ 011    1/2    │\n",
+       "       │ 111    0      │\n",
+       "       └───────────────┘]},\n",
+       " Cut [1, 2] ━━/ /━━➤ [0]: {<Direction.CAUSE: 0>: [Maximally-irreducible cause\n",
+       "     φ = 1/2\n",
+       "     Mechanism: [B]\n",
+       "     Purview = [C]\n",
+       "     Direction: CAUSE\n",
+       "     MIP:\n",
+       "        ∅     B \n",
+       "       ─── ✕ ───\n",
+       "        C     ∅ \n",
+       "     Repertoire:\n",
+       "       ┌─────────────┐\n",
+       "       │ S    Pr(S)  │\n",
+       "       │ ╴╴╴╴╴╴╴╴╴╴╴ │\n",
+       "       │ 0    1      │\n",
+       "       │ 1    0      │\n",
+       "       └─────────────┘\n",
+       "     Partitioned repertoire:\n",
+       "       ┌─────────────┐\n",
+       "       │ S    Pr(S)  │\n",
+       "       │ ╴╴╴╴╴╴╴╴╴╴╴ │\n",
+       "       │ 0    1/2    │\n",
+       "       │ 1    1/2    │\n",
+       "       └─────────────┘,\n",
+       "   Maximally-irreducible cause\n",
+       "     φ = 1/2\n",
+       "     Mechanism: [C]\n",
+       "     Purview = [A, B]\n",
+       "     Direction: CAUSE\n",
+       "     MIP:\n",
+       "        ∅     C \n",
+       "       ─── ✕ ───\n",
+       "        A     B \n",
+       "     Repertoire:\n",
+       "       ┌──────────────┐\n",
+       "       │ S     Pr(S)  │\n",
+       "       │ ╴╴╴╴╴╴╴╴╴╴╴╴ │\n",
+       "       │ 00    1/2    │\n",
+       "       │ 10    0      │\n",
+       "       │ 01    0      │\n",
+       "       │ 11    1/2    │\n",
+       "       └──────────────┘\n",
+       "     Partitioned repertoire:\n",
+       "       ┌──────────────┐\n",
+       "       │ S     Pr(S)  │\n",
+       "       │ ╴╴╴╴╴╴╴╴╴╴╴╴ │\n",
+       "       │ 00    1/4    │\n",
+       "       │ 10    1/4    │\n",
+       "       │ 01    1/4    │\n",
+       "       │ 11    1/4    │\n",
+       "       └──────────────┘,\n",
+       "   Maximally-irreducible cause\n",
+       "     φ = 0\n",
+       "     Mechanism: [0, 1]\n",
+       "     Purview = []\n",
+       "     Direction: CAUSE\n",
+       "     MIP:\n",
+       "       \n",
+       "     Repertoire:\n",
+       "       \n",
+       "     Partitioned repertoire:\n",
+       "       ,\n",
+       "   Maximally-irreducible cause\n",
+       "     φ = 0\n",
+       "     Mechanism: [0, 1, 2]\n",
+       "     Purview = []\n",
+       "     Direction: CAUSE\n",
+       "     MIP:\n",
+       "       \n",
+       "     Repertoire:\n",
+       "       \n",
+       "     Partitioned repertoire:\n",
+       "       ],\n",
+       "  <Direction.EFFECT: 1>: [Maximally-irreducible effect\n",
+       "     φ = 0\n",
+       "     Mechanism: [B]\n",
+       "     Purview = [C]\n",
+       "     Direction: EFFECT\n",
+       "     MIP:\n",
+       "        ∅     B \n",
+       "       ─── ✕ ───\n",
+       "        C     ∅ \n",
+       "     Repertoire:\n",
+       "       ┌─────────────┐\n",
+       "       │ S    Pr(S)  │\n",
+       "       │ ╴╴╴╴╴╴╴╴╴╴╴ │\n",
+       "       │ 0    1/2    │\n",
+       "       │ 1    1/2    │\n",
+       "       └─────────────┘\n",
+       "     Partitioned repertoire:\n",
+       "       ┌─────────────┐\n",
+       "       │ S    Pr(S)  │\n",
+       "       │ ╴╴╴╴╴╴╴╴╴╴╴ │\n",
+       "       │ 0    1/2    │\n",
+       "       │ 1    1/2    │\n",
+       "       └─────────────┘,\n",
+       "   Maximally-irreducible effect\n",
+       "     φ = 1/2\n",
+       "     Mechanism: [C]\n",
+       "     Purview = [B]\n",
+       "     Direction: EFFECT\n",
+       "     MIP:\n",
+       "        ∅     C \n",
+       "       ─── ✕ ───\n",
+       "        B     ∅ \n",
+       "     Repertoire:\n",
+       "       ┌─────────────┐\n",
+       "       │ S    Pr(S)  │\n",
+       "       │ ╴╴╴╴╴╴╴╴╴╴╴ │\n",
+       "       │ 0    1      │\n",
+       "       │ 1    0      │\n",
+       "       └─────────────┘\n",
+       "     Partitioned repertoire:\n",
+       "       ┌─────────────┐\n",
+       "       │ S    Pr(S)  │\n",
+       "       │ ╴╴╴╴╴╴╴╴╴╴╴ │\n",
+       "       │ 0    1/2    │\n",
+       "       │ 1    1/2    │\n",
+       "       └─────────────┘,\n",
+       "   Maximally-irreducible effect\n",
+       "     φ = 1/2\n",
+       "     Mechanism: [A, B]\n",
+       "     Purview = [C]\n",
+       "     Direction: EFFECT\n",
+       "     MIP:\n",
+       "        ∅    A,B\n",
+       "       ─── ✕ ───\n",
+       "        C     ∅ \n",
+       "     Repertoire:\n",
+       "       ┌─────────────┐\n",
+       "       │ S    Pr(S)  │\n",
+       "       │ ╴╴╴╴╴╴╴╴╴╴╴ │\n",
+       "       │ 0    0      │\n",
+       "       │ 1    1      │\n",
+       "       └─────────────┘\n",
+       "     Partitioned repertoire:\n",
+       "       ┌─────────────┐\n",
+       "       │ S    Pr(S)  │\n",
+       "       │ ╴╴╴╴╴╴╴╴╴╴╴ │\n",
+       "       │ 0    1/2    │\n",
+       "       │ 1    1/2    │\n",
+       "       └─────────────┘,\n",
+       "   Maximally-irreducible effect\n",
+       "     φ = 0\n",
+       "     Mechanism: [0, 1, 2]\n",
+       "     Purview = []\n",
+       "     Direction: EFFECT\n",
+       "     MIP:\n",
+       "       \n",
+       "     Repertoire:\n",
+       "       \n",
+       "     Partitioned repertoire:\n",
+       "       ]}}"
+      ]
+     },
+     "execution_count": 2,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "test_subsystem_phi_max.expected_mice"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "cut, direction, expected = mice_scenarios[5]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "\n",
+      "Welcome to PyPhi!\n",
+      "\n",
+      "If you use PyPhi in your research, please cite the paper:\n",
+      "\n",
+      "  Mayner WGP, Marshall W, Albantakis L, Findlay G, Marchman R, Tononi G.\n",
+      "  (2018). PyPhi: A toolbox for integrated information theory.\n",
+      "  PLOS Computational Biology 14(7): e1006343.\n",
+      "  https://doi.org/10.1371/journal.pcbi.1006343\n",
+      "\n",
+      "Documentation is available online (or with the built-in `help()` function):\n",
+      "  https://pyphi.readthedocs.io\n",
+      "\n",
+      "To report issues, please use the issue tracker on the GitHub repository:\n",
+      "  https://github.com/wmayner/pyphi\n",
+      "\n",
+      "For general discussion, you are welcome to join the pyphi-users group:\n",
+      "  https://groups.google.com/forum/#!forum/pyphi-users\n",
+      "\n",
+      "To suppress this message, either:\n",
+      "  - Set `WELCOME_OFF: true` in your `pyphi_config.yml` file, or\n",
+      "  - Set the environment variable PYPHI_WELCOME_OFF to any value in your shell:\n",
+      "        export PYPHI_WELCOME_OFF='yes'\n",
+      "\n"
+     ]
+    }
+   ],
+   "source": [
+    "%run test_subsystem_phi_max.py"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Expected:\n",
+      " Maximally-irreducible effect\n",
+      "  φ = 1/2\n",
+      "  Mechanism: [A, B, C]\n",
+      "  Purview = [A, B, C]\n",
+      "  Direction: EFFECT\n",
+      "  MIP:\n",
+      "     ∅    A,B,C\n",
+      "    ─── ✕ ─────\n",
+      "     B     A,C \n",
+      "  Repertoire:\n",
+      "    ┌───────────────┐\n",
+      "    │  S     Pr(S)  │\n",
+      "    │ ╴╴╴╴╴╴╴╴╴╴╴╴╴ │\n",
+      "    │ 000    0      │\n",
+      "    │ 100    0      │\n",
+      "    │ 010    0      │\n",
+      "    │ 110    0      │\n",
+      "    │ 001    1      │\n",
+      "    │ 101    0      │\n",
+      "    │ 011    0      │\n",
+      "    │ 111    0      │\n",
+      "    └───────────────┘\n",
+      "  Partitioned repertoire:\n",
+      "    ┌───────────────┐\n",
+      "    │  S     Pr(S)  │\n",
+      "    │ ╴╴╴╴╴╴╴╴╴╴╴╴╴ │\n",
+      "    │ 000    0      │\n",
+      "    │ 100    0      │\n",
+      "    │ 010    0      │\n",
+      "    │ 110    0      │\n",
+      "    │ 001    1/2    │\n",
+      "    │ 101    0      │\n",
+      "    │ 011    1/2    │\n",
+      "    │ 111    0      │\n",
+      "    └───────────────┘\n",
+      "Result:\n",
+      " Maximally-irreducible effect\n",
+      "  φ = 1/2\n",
+      "  Mechanism: [A, B, C]\n",
+      "  Purview = [A, B, C]\n",
+      "  Direction: EFFECT\n",
+      "  MIP:\n",
+      "     ∅    A,B,C\n",
+      "    ─── ✕ ─────\n",
+      "     B     A,C \n",
+      "  Repertoire:\n",
+      "    ┌───────────────┐\n",
+      "    │  S     Pr(S)  │\n",
+      "    │ ╴╴╴╴╴╴╴╴╴╴╴╴╴ │\n",
+      "    │ 000    0      │\n",
+      "    │ 100    0      │\n",
+      "    │ 010    0      │\n",
+      "    │ 110    0      │\n",
+      "    │ 001    1      │\n",
+      "    │ 101    0      │\n",
+      "    │ 011    0      │\n",
+      "    │ 111    0      │\n",
+      "    └───────────────┘\n",
+      "  Partitioned repertoire:\n",
+      "    ┌───────────────┐\n",
+      "    │  S     Pr(S)  │\n",
+      "    │ ╴╴╴╴╴╴╴╴╴╴╴╴╴ │\n",
+      "    │ 000    0      │\n",
+      "    │ 100    0      │\n",
+      "    │ 010    0      │\n",
+      "    │ 110    0      │\n",
+      "    │ 001    1/2    │\n",
+      "    │ 101    0      │\n",
+      "    │ 011    1/2    │\n",
+      "    │ 111    0      │\n",
+      "    └───────────────┘\n"
+     ]
+    }
+   ],
+   "source": [
+    "test_find_mice(cut, direction, expected)"
+   ]
+  }
+ ],
+ "metadata": {
+  "interpreter": {
+   "hash": "4d7e67e8959bd4cf30a4f9a3d906518b5030bde8ca90c929b80828e7ae4ed2e6"
+  },
+  "kernelspec": {
+   "display_name": "Python 3.7.9 64-bit",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.7.9"
+  },
+  "orig_nbformat": 4
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/test/test_network.py b/test/test_network.py
index cb255e080..9765445e2 100644
--- a/test/test_network.py
+++ b/test/test_network.py
@@ -17,15 +17,15 @@ def network():
 
 
 def test_network_init_validation(network):
-    with pytest.raises(ValueError):
+    with pytest.raises(IndexError): # TODO Do we want to change this error? Is this error guaranteed?
         # Totally wrong shape
         tpm = np.arange(3).astype(float)
         Network(tpm)
-    with pytest.raises(ValueError):
+    # with pytest.raises(ValueError): # This should no longer raise an error, tpm generalized to nb nodes
         # Non-binary nodes (4 states)
-        tpm = np.ones((4, 4, 4, 3)).astype(float)
-        Network(tpm)
-
+    #tpm = np.ones((4, 4, 4, 3)).astype(float)
+    #Network(tpm)
+    
     # Conditionally dependent
     # fmt: off
     tpm = np.array([
@@ -42,7 +42,7 @@ def test_network_init_validation(network):
             Network(tpm)
 
 
-def test_network_creates_fully_connected_cm_by_default():
+def test_network_creates_fully_connected_cm_by_default(): # Deprecated? Now have infer_cm method
     tpm = np.zeros((2 * 2 * 2, 3))
     network = Network(tpm, cm=None)
     target_cm = np.ones((3, 3))
diff --git a/test/test_subsystem.py b/test/test_subsystem.py
index 0908f6bd9..0d7c40ebd 100644
--- a/test/test_subsystem.py
+++ b/test/test_subsystem.py
@@ -124,7 +124,7 @@ def test_apply_cut(s):
     assert s.network == cut_s.network
     assert s.state == cut_s.state
     assert s.node_indices == cut_s.node_indices
-    assert np.array_equal(cut_s.tpm, s.tpm)
+    assert np.all([np.array_equal(cut_s.tpm[i].tpm.data, s.tpm[i].tpm.data) for i in range(len(s.tpm))])
     assert np.array_equal(cut_s.cm, cut.apply_cut(s.cm))
 
 
diff --git a/test/test_tpm_obj.py b/test/test_tpm_obj.py
new file mode 100644
index 000000000..71e8ff5fc
--- /dev/null
+++ b/test/test_tpm_obj.py
@@ -0,0 +1,257 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+# test/test_tpm_obj.py
+
+import numpy as np
+import pytest
+import pandas as pd
+from itertools import product
+
+from pyphi.__tpm import TPM
+from pyphi.__tpm import SbN
+
+
+# TODO for harsher tests, may want to ensure that node labels properly direct you to the correct position
+# But only necessary if signficant changes are made to shaping as results in notebook testing 
+# indicate it works
+def test_init():
+    tpm = np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])
+
+    # Nothing but data given
+    test = TPM(tpm)
+    assert test.tpm.shape == (2,2,2,2)
+    assert test.tpm.dims == ("n0_p", "n1_p", "n0_n", "n1_n")
+
+    # Symmetric, binary
+    test = TPM(tpm, p_nodes=["A", "B"])
+
+    assert test.tpm.shape == (2, 2, 2, 2)
+    assert test.tpm.dims == ("A_p", "B_p", "n0_n", "n1_n")
+
+    # Symmetric, nonbinary, wrong shape
+    with pytest.raises(ValueError):
+        TPM(tpm, p_nodes=["A", "B"], p_states=[2, 3])
+
+    # Asymmetric, binary
+    tpm = np.array([[0, 1], [0, 1], [1, 0], [1, 0]])
+    test = TPM(tpm, p_nodes=["A", "B"], n_nodes=["C"])
+
+    assert test.tpm.shape == (2, 2, 2)
+    assert test.tpm.dims == ("A_p", "B_p", "C_n")
+
+    # Asymmetric, nonbinary
+    tpm = np.array([[0, 1, 0], [0, 0, 1], [1, 0, 0], [1, 0, 0]])
+    test = TPM(tpm, p_nodes=["A", "B"], n_nodes=["C"], p_states=[2, 2], n_states=[3])
+
+    assert test.tpm.shape == (2, 2, 3)
+    assert test.tpm.dims == ("A_p", "B_p", "C_n")
+
+    # Symmetric, nonbinary, right shape
+    tpm = np.array([[1, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0], [0, 0, 0, 1, 0, 0], [0, 0, 0, 0, 1, 0], [0, 0, 0, 0, 0, 1]])
+    test = TPM(tpm, p_nodes=["A", "B"], p_states=[2, 3])
+
+    assert test.tpm.shape == (2, 3, 2, 3)
+    assert test.tpm.dims == ("A_p", "B_p", "A_n", "B_n")
+
+    # DataFrame
+    tpm = np.array([[0, 0, 0, 1], [0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 0, 0]])
+    base = [2, 2]
+    nam = ["A", "B"]
+
+    states_node=[list(range(b)) for b in base]
+
+    sates_all_nodes=[list(x[::-1])  for x in  list(product(*states_node[::-1])) ]
+    sates_all_nodes=np.transpose(sates_all_nodes).tolist()
+    index = pd.MultiIndex.from_arrays(sates_all_nodes, names=nam)
+    columns = pd.MultiIndex.from_arrays(sates_all_nodes, names=nam)
+
+    df = pd.DataFrame(tpm,columns=columns, index=index)
+    test = TPM(df, ["C", "D"]) #TODO should only need one argument, test ensures second arg doesn't influence naming
+    
+    assert test.tpm.shape == (2, 2, 2, 2)
+    assert test.tpm.dims == ("A_p", "B_p", "A_n", "B_n")
+
+    # SbN Network, check both init paths give same results
+    can = np.array([[0, 0, 0, 0, 1, 0, 0, 0],
+                    [0, 0, 0, 0, 1, 0, 0, 0],
+                    [1, 0, 0, 0, 0, 0, 0, 0],
+                    [1, 0, 0, 0, 0, 0, 0, 0],
+                    [0, 0, 0, 0, 0, 1, 0, 0],
+                    [0, 0, 0, 0, 0, 0, 0, 1],
+                    [0, 1, 0, 0, 0, 0, 0, 0],
+                    [0, 0, 0, 1, 0, 0, 0, 0]])
+
+    can_sbn = np.array([[0, 0, 1],
+                        [0, 0, 1],
+                        [0, 0, 0],
+                        [0, 0, 0],
+                        [1, 0, 1],
+                        [1, 1, 1],
+                        [1, 0, 0],
+                        [1, 1, 0]])
+
+    can_obj = SbN(can, p_nodes=["A", "B", "C"], p_states=[2, 2, 2])
+
+    can_sbn_obj = SbN(can_sbn, p_nodes=["A", "B", "C"], p_states=[2, 2, 2], n_nodes=["A", "B", "C"])
+    
+    assert np.array_equal(can_obj.tpm.data, can_sbn_obj.tpm.data)
+    assert can_obj.p_nodes == can_sbn_obj.p_nodes
+    assert can_obj.n_nodes == can_sbn_obj.n_nodes
+
+    # Test only data given
+    can_obj = SbN(can)
+    can_sbn_obj = SbN(can_sbn)
+
+    assert np.array_equal(can_obj.tpm.data, can_sbn_obj.tpm.data)
+    assert can_obj.p_nodes == can_sbn_obj.p_nodes
+    assert can_obj.n_nodes == can_sbn_obj.n_nodes
+
+def test_marginalize_out_obj():
+    p53 = np.array([[0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
+       [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
+       [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
+       [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
+       [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
+       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
+       [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
+       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
+       [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
+       [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
+       [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
+       [0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.]])
+
+    p53_obj = TPM(p53, p_nodes=["P", "Mc", "Mn"], p_states=[3, 2, 2])
+
+    marginalized = p53_obj.marginalize_out((1, )).data
+
+    solution = np.array([[0, 0, 0, 0, 0, 1],
+                         [0, 0, 0.5, 0, 0, 0.5],
+                         [0, 0, 0.5, 0, 0, 0.5],
+                         [0, 0, 0, 1, 0, 0],
+                         [0.5, 0, 0, 0.5, 0, 0],
+                         [0.5, 0, 0, 0.5, 0, 0]]).reshape(3, 1, 2, 3, 1, 2, order="F")
+
+    assert np.array_equal(marginalized, solution)
+
+def test_marginalize_out_sbn(s):
+    sbn = SbN(s.tpm, p_nodes=["A", "B", "C"], n_nodes=["A", "B", "C"], p_states=[2,2,2], n_states=[2,2,2])
+
+    marginalized_distribution = sbn.marginalize_out((0, )).data
+
+    answer = np.array([
+        [[[0.0, 0.0, 0.5],
+          [1.0, 1.0, 0.5]],
+         [[1.0, 0.0, 0.5],
+          [1.0, 1.0, 0.5]]],
+    ])
+    assert np.array_equal(marginalized_distribution, answer)
+
+    marginalized_distribution = sbn.marginalize_out((0, 1)).data
+
+    answer = np.array([
+        [[[0.5, 0.0, 0.5],
+          [1.0, 1.0, 0.5]]],
+    ])
+
+    assert np.array_equal(marginalized_distribution, answer)
+
+def test_infer_cm_obj():
+    # Check infer_cm functions on both TPM (state-by-state)
+    # and SbN (state-by-node) objects 
+    p53 = np.array([[0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
+       [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
+       [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
+       [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
+       [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
+       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
+       [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
+       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
+       [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
+       [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
+       [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
+       [0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.]])
+
+    test_p53 = TPM(p53, p_nodes=["P", "Mc", "Mn"], n_nodes=["P", "Mc", "Mn"], p_states=[3,2,2], n_states=[3,2,2])
+
+    solution = np.array([[0, 1, 1],
+                         [0, 0, 1],
+                         [1, 0, 0]])
+
+    assert np.array_equal(test_p53.infer_cm(), solution)
+
+def test_infer_cm_sbn():
+    # SbN Binary network: Node A (c)opies C, node B takes A (a)nd C, and C (n)ots B
+    can = np.array([[0, 0, 0, 0, 1, 0, 0, 0],
+                    [0, 0, 0, 0, 1, 0, 0, 0],
+                    [1, 0, 0, 0, 0, 0, 0, 0],
+                    [1, 0, 0, 0, 0, 0, 0, 0],
+                    [0, 0, 0, 0, 0, 1, 0, 0],
+                    [0, 0, 0, 0, 0, 0, 0, 1],
+                    [0, 1, 0, 0, 0, 0, 0, 0],
+                    [0, 0, 0, 1, 0, 0, 0, 0]])
+
+    can_obj = SbN(can, p_nodes=["A", "B", "C"], n_nodes=["A", "B", "C"],
+                        p_states=[2,2,2], n_states=[2,2,2])
+    solution = np.array([[0, 1, 0],
+                         [0, 0, 1],
+                         [1, 1, 0]])
+
+    assert np.array_equal(can_obj.infer_cm(), solution)
+
+def test_condition_obj():
+    p53 = np.array([[0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
+       [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
+       [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
+       [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
+       [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
+       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
+       [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
+       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
+       [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
+       [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
+       [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
+       [0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.]])
+
+    test_p53 = TPM(p53, p_nodes=["P", "Mc", "Mn"], n_nodes=["P", "Mc", "Mn"], p_states=[3,2,2], n_states=[3,2,2])
+
+    conditioned = test_p53.condition((1, ), (1, 0, 1))
+
+    solution = np.array([[0, 0, 0, 0, 0, 1],
+                         [0, 0, 1, 0, 0, 0],
+                         [0, 0, 1, 0, 0, 0],
+                         [0, 0, 0, 1, 0, 0],
+                         [1, 0, 0, 0, 0, 0],
+                         [1, 0, 0, 0, 0, 0]]).reshape(3, 1, 2, 3, 1, 2, order="F")
+    
+    # Dimensions should remain the same, but Mc has been conditioned on making its
+    # dimension one unit in size
+    assert conditioned.shape == (3, 1, 2, 3, 1, 2)
+
+    assert np.array_equal(conditioned.data, solution)
+
+def test_condition_sbn():
+    # SbN form testing
+    can_sbn = np.array([[0, 0, 1],
+                        [0, 0, 1],
+                        [0, 0, 0],
+                        [0, 0, 0],
+                        [1, 0, 1],
+                        [1, 1, 1],
+                        [1, 0, 0],
+                        [1, 1, 0]])
+    
+    can_obj = SbN(can_sbn, p_nodes=["A", "B", "C"], n_nodes=["A", "B", "C"], 
+                   p_states=[2,2,2])
+    
+    conditioned = can_obj.condition((1,), (1, 0, 1))
+
+    # At present we don't drop unneeded columns, though could if needed the space
+    solution = np.array([[0, 0, 1],
+                         [0, 0, 1],
+                         [1, 0, 1],
+                         [1, 1, 1]]).reshape(2, 1, 2, 3, order="F")
+
+    assert conditioned.shape == (2, 1, 2, 3)
+
+    assert np.array_equal(conditioned.data, solution)
+    
\ No newline at end of file