Skip to content

Commit 8e85db5

Browse files
committed
Fixes #13
Adding ability to remove chains
1 parent 812190b commit 8e85db5

File tree

4 files changed

+141
-2
lines changed

4 files changed

+141
-2
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ post, it can be solved by explicitly install the `matplotlib` dependency `dvipng
3131

3232
### Update History
3333

34+
##### 0.15.4
35+
* Adding ability to remove chains.
3436

3537
##### 0.15.3
3638
* Adding ability to plot the walks of multiple chains together.

chainconsumer/chain.py

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class ChainConsumer(object):
1616
""" A class for consuming chains produced by an MCMC walk
1717
1818
"""
19-
__version__ = "0.15.3"
19+
__version__ = "0.15.4"
2020

2121
def __init__(self):
2222
logging.basicConfig()
@@ -163,6 +163,59 @@ def add_chain(self, chain, parameters=None, name=None, weights=None, posterior=N
163163
self._init_params()
164164
return self
165165

166+
def remove_chain(self, chain=-1):
167+
"""
168+
Removes a chain from ChainConsumer. Calling this will require any configurations set to be redone!
169+
170+
Parameters
171+
----------
172+
chain : int|str, list[str]
173+
The chain(s) to remove. You can pass in either the chain index, or the chain name, to remove it.
174+
By default removes the last chain added.
175+
176+
Returns
177+
-------
178+
ChainConsumer
179+
Itself, to allow chaining calls.
180+
"""
181+
if isinstance(chain, str) or isinstance(chain, int):
182+
chain = [chain]
183+
elif isinstance(chain, list):
184+
for c in chain:
185+
assert isinstance(c, str), "If you specify a list, " \
186+
"you must specify chain names, not indexes." \
187+
"This is to avoid confusion when specifying," \
188+
"for example, [0,0]. As this might be an error," \
189+
"or a request to remove the first two chains."
190+
for c in chain:
191+
index = self._get_chain(c)
192+
parameters = self._parameters[index]
193+
194+
del self._chains[index]
195+
del self._names[index]
196+
del self._weights[index]
197+
del self._posteriors[index]
198+
del self._parameters[index]
199+
del self._grids[index]
200+
del self._num_free[index]
201+
del self._num_data[index]
202+
203+
# Recompute all_parameters
204+
for p in parameters:
205+
has = False
206+
for ps in self._parameters:
207+
if p in ps:
208+
has = True
209+
break
210+
if not has:
211+
i = self._all_parameters.index(p)
212+
del self._all_parameters[i]
213+
214+
# Need to reconfigure
215+
self._init_params()
216+
217+
return self
218+
166219
def configure(self, statistics="max", max_ticks=5, plot_hists=True, flip=True,
167220
serif=True, sigmas=None, summary=None, bins=None, rainbow=None,
168221
colors=None, linestyles=None, linewidths=None, kde=False, smooth=None,
@@ -991,7 +1044,11 @@ def plot_walks(self, parameters=None, truth=None, extents=None, display=False,
9911044
chains = [chains]
9921045
chains = [self._get_chain(c) for c in chains]
9931046

994-
all_parameters = list(set([p for i in chains for p in self._parameters[i]]))
1047+
all_parameters2 = [p for i in chains for p in self._parameters[i]]
1048+
all_parameters = []
1049+
for p in all_parameters2:
1050+
if p not in all_parameters:
1051+
all_parameters.append(p)
9951052

9961053
if parameters is None:
9971054
parameters = all_parameters

doc/chain_api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ General Methods
1212
---------------
1313
* :func:`chainconsumer.ChainConsumer.add_chain` - Add a chain!
1414
* :func:`chainconsumer.ChainConsumer.divide_chain` - Split a chain into multiple chains to inspect each walk.
15+
* :func:`chainconsumer.ChainConsumer.remove_chain` - Remove a chain.
1516

1617
Plotting Methods
1718
----------------

test_chain.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -847,3 +847,82 @@ def test_dic_posterior_dependence(self):
847847
dic2 = 2 * np.mean(-2 * p2) + 2 * norm.logpdf(0, scale=2)
848848
assert np.isclose(bics[0], dic1 - dic2, atol=1e-3)
849849

850+
def test_remove_last_chain(self):
851+
tolerance = 5e-2
852+
consumer = ChainConsumer()
853+
consumer.add_chain(self.data)
854+
consumer.add_chain(self.data * 2)
855+
consumer.remove_chain()
856+
consumer.configure(bins=1.6)
857+
summary = consumer.get_summary()
858+
assert isinstance(summary, dict)
859+
actual = np.array(list(summary.values())[0])
860+
expected = np.array([3.5, 5.0, 6.5])
861+
diff = np.abs(expected - actual)
862+
assert np.all(diff < tolerance)
863+
864+
def test_remove_first_chain(self):
865+
tolerance = 5e-2
866+
consumer = ChainConsumer()
867+
consumer.add_chain(self.data * 2)
868+
consumer.add_chain(self.data)
869+
consumer.remove_chain(chain=0)
870+
consumer.configure(bins=1.6)
871+
summary = consumer.get_summary()
872+
assert isinstance(summary, dict)
873+
actual = np.array(list(summary.values())[0])
874+
expected = np.array([3.5, 5.0, 6.5])
875+
diff = np.abs(expected - actual)
876+
assert np.all(diff < tolerance)
877+
878+
def test_remove_chain_by_name(self):
879+
tolerance = 5e-2
880+
consumer = ChainConsumer()
881+
consumer.add_chain(self.data * 2, name="a")
882+
consumer.add_chain(self.data, name="b")
883+
consumer.remove_chain(chain="a")
884+
consumer.configure(bins=1.6)
885+
summary = consumer.get_summary()
886+
assert isinstance(summary, dict)
887+
actual = np.array(list(summary.values())[0])
888+
expected = np.array([3.5, 5.0, 6.5])
889+
diff = np.abs(expected - actual)
890+
assert np.all(diff < tolerance)
891+
892+
def test_remove_chain_recompute_params(self):
893+
tolerance = 5e-2
894+
consumer = ChainConsumer()
895+
consumer.add_chain(self.data * 2, parameters=["p1"], name="a")
896+
consumer.add_chain(self.data, parameters=["p2"], name="b")
897+
consumer.remove_chain(chain="a")
898+
consumer.configure(bins=1.6)
899+
summary = consumer.get_summary()
900+
assert isinstance(summary, dict)
901+
assert "p2" in summary
902+
assert "p1" not in summary
903+
actual = np.array(list(summary.values())[0])
904+
expected = np.array([3.5, 5.0, 6.5])
905+
diff = np.abs(expected - actual)
906+
assert np.all(diff < tolerance)
907+
908+
def test_remove_multiple_chains(self):
909+
tolerance = 5e-2
910+
consumer = ChainConsumer()
911+
consumer.add_chain(self.data * 2, parameters=["p1"], name="a")
912+
consumer.add_chain(self.data, parameters=["p2"], name="b")
913+
consumer.add_chain(self.data * 3, parameters=["p3"], name="c")
914+
consumer.remove_chain(chain=["a", "c"])
915+
consumer.configure(bins=1.6)
916+
summary = consumer.get_summary()
917+
assert isinstance(summary, dict)
918+
assert "p2" in summary
919+
assert "p1" not in summary
920+
assert "p3" not in summary
921+
actual = np.array(list(summary.values())[0])
922+
expected = np.array([3.5, 5.0, 6.5])
923+
diff = np.abs(expected - actual)
924+
assert np.all(diff < tolerance)
925+
926+
def test_remove_multiple_chains_fails(self):
927+
with pytest.raises(AssertionError):
928+
ChainConsumer().add_chain(self.data).remove_chain(chain=[0, 0])

0 commit comments

Comments
 (0)