Skip to content

Commit 53ef426

Browse files
committed
Adding shift parameter to fix #72
1 parent f5da4ab commit 53ef426

File tree

5 files changed

+75
-3
lines changed

5 files changed

+75
-3
lines changed

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,10 @@ post, it can be solved by explicitly install the `matplotlib` dependency `dvipng
5858

5959
### Update History
6060

61+
##### 0.30.0
62+
* Bug fix for specifying numeric `loc` to `legend_kwargs`
63+
* Added `shift_params` when adding chains.
64+
6165
##### 0.29.1
6266
* Potential bug fix for `log_space` feature.
6367

chainconsumer/chain.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def __init__(
4444
marker_size=None,
4545
marker_alpha=None,
4646
zorder=None,
47+
shift_params=None,
4748
):
4849
self.chain = chain
4950
self.parameters = parameters
@@ -63,6 +64,15 @@ def __init__(
6364
for i, p in enumerate(parameters):
6465
self.posterior_max_params[p] = chain[self.posterior_max_index, i]
6566

67+
self.shift_params = shift_params
68+
if shift_params is not None:
69+
for key in shift_params.keys():
70+
try:
71+
index = self.parameters.index(key)
72+
avg = np.average(chain[:, index], weights=weights)
73+
chain[:, index] += shift_params[key] - avg
74+
except ValueError:
75+
continue
6676
self.weights = weights
6777
self.posterior = posterior
6878
self.walkers = walkers

chainconsumer/chainconsumer.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ class ChainConsumer(object):
1919
2020
"""
2121

22-
__version__ = "0.29.1"
22+
__version__ = "0.30.0"
2323

2424
def __init__(self):
2525
logging.basicConfig(level=logging.INFO)
@@ -86,6 +86,7 @@ def add_chain(
8686
cmap=None,
8787
num_cloud=None,
8888
zorder=None,
89+
shift_params=None,
8990
):
9091
r""" Add a chain to the consumer.
9192
@@ -184,7 +185,9 @@ def add_chain(
184185
to colour scatter. Defaults to 15k per chain.
185186
zorder : int, optional
186187
The zorder to pass to `matplotlib` when plotting to determine visual order in the plot.
187-
188+
shift_params : dict|list, optional
189+
Shifts the parameters specify to the numeric values. Useful to shift contours to the same location to perform blinded
190+
uncertainty comparisons.
188191
Returns
189192
-------
190193
ChainConsumer
@@ -251,6 +254,13 @@ def add_chain(
251254
if p not in self._all_parameters:
252255
self._all_parameters.append(p)
253256

257+
if shift_params is not None:
258+
if isinstance(shift_params, list):
259+
shift_params = dict([(p, s) for p, s in zip(parameters, shift_params)])
260+
for key in shift_params.keys():
261+
if key not in parameters:
262+
self._logger.warning("Warning, shift parameter %s is not in list of parameters %s" % (key, parameters))
263+
254264
# Sorry, no KDE for you on a grid.
255265
if grid:
256266
kde = None
@@ -290,6 +300,7 @@ def add_chain(
290300
cmap=cmap,
291301
num_cloud=num_cloud,
292302
zorder=zorder,
303+
shift_params=shift_params,
293304
)
294305
self.chains.append(c)
295306
self._init_params()
@@ -443,6 +454,7 @@ def configure(
443454
watermark_text_kwargs=None,
444455
summary_area=0.6827,
445456
zorder=None,
457+
stack=False,
446458
): # pragma: no cover
447459
r""" Configure the general plotting parameters common across the bar
448460
and contour plots.

chainconsumer/plotter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ def plot(
273273
text.set_color(c)
274274
if not outside:
275275
loc = legend_kwargs.get("loc") or ""
276-
if "right" in loc.lower():
276+
if isinstance(loc, str) and "right" in loc.lower():
277277
vp = leg._legend_box._children[-1]._children[0]
278278
vp.align = "right"
279279

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
==============
4+
Shifting Plots
5+
==============
6+
7+
Shift all your plots to the same location for blind uncertainty comparison.
8+
9+
10+
Plots will shift to the location you tell them to, in the same format as a truth dictionary.
11+
So you can use truth dict for both! Takes a list or a dict as input for convenience.
12+
13+
"""
14+
15+
import numpy as np
16+
from numpy.random import multivariate_normal
17+
from chainconsumer import ChainConsumer
18+
19+
np.random.seed(0)
20+
data1 = multivariate_normal([1, 0], [[3, 2], [2, 3]], size=300000)
21+
data2 = multivariate_normal([0, 0.5], [[1, -0.7], [-0.7, 1]], size=300000)
22+
data3 = multivariate_normal([2, -1], [[0.5, 0], [0, 0.5]], size=300000)
23+
24+
###############################################################################
25+
# And this is how easy it is to shift them:
26+
27+
truth = {"$x$": 1, "$y$": 0}
28+
c = ChainConsumer()
29+
c.add_chain(data1, parameters=["$x$", "$y$"], name="Chain A", shift_params=truth)
30+
c.add_chain(data2, name="Chain B", shift_params=truth)
31+
c.add_chain(data3, name="Chain C", shift_params=truth)
32+
fig = c.plotter.plot(truth=truth)
33+
34+
fig.set_size_inches(2.5 + fig.get_size_inches()) # Resize fig for doco. You don't need this.
35+
36+
###############################################################################
37+
# Here's without the shift:
38+
39+
truth = {"$x$": 1, "$y$": 0}
40+
c = ChainConsumer()
41+
c.add_chain(data1, parameters=["$x$", "$y$"], name="Chain A")
42+
c.add_chain(data2, name="Chain B")
43+
c.add_chain(data3, name="Chain C")
44+
fig = c.plotter.plot(truth=truth)
45+
46+
fig.set_size_inches(2.5 + fig.get_size_inches()) # Resize fig for doco. You don't need this.

0 commit comments

Comments
 (0)