Skip to content

Commit 58c4ff7

Browse files
committed
v0.12.0
1 parent 7c3ebfd commit 58c4ff7

File tree

6 files changed

+116
-39
lines changed

6 files changed

+116
-39
lines changed

README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,11 +98,14 @@ features requests thought up.
9898

9999
### Update History
100100

101+
##### 0.12.0
102+
* Adding support for grid data.
103+
101104
##### 0.11.3
102105
* Fixing bug in Gelman-Rubin statistic
103106

104107
##### 0.11.2
105-
* Improving text labels again
108+
* Improving text labels again.
106109

107110
##### 0.11.1
108111
* Improving text labels for high value data.

chainconsumer/chain.py

Lines changed: 61 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class ChainConsumer(object):
1616
""" A class for consuming chains produced by an MCMC walk
1717
1818
"""
19-
__version__ = "0.11.3"
19+
__version__ = "0.12.0"
2020

2121
def __init__(self):
2222
logging.basicConfig()
@@ -30,6 +30,7 @@ def __init__(self):
3030
self.names = []
3131
self.parameters = []
3232
self.all_parameters = []
33+
self.grids = []
3334
self.default_parameters = None
3435
self._configured_bar = False
3536
self._configured_contour = False
@@ -45,7 +46,7 @@ def __init__(self):
4546
"cumulative": self._get_parameter_summary_cumulative
4647
}
4748

48-
def add_chain(self, chain, parameters=None, name=None, weights=None, posterior=None, walkers=None):
49+
def add_chain(self, chain, parameters=None, name=None, weights=None, posterior=None, walkers=None, grid=False):
4950
""" Add a chain to the consumer.
5051
5152
Parameters
@@ -70,6 +71,10 @@ def add_chain(self, chain, parameters=None, name=None, weights=None, posterior=N
7071
How many walkers went into creating the chain. Each walker should
7172
contribute the same number of steps, and should appear in contiguous
7273
blocks in the final chain.
74+
grid : boolean, optional
75+
Whether the input is a flattened chain from a grid search instead of a Monte-Carlo
76+
chains. Note that when this is set, `walkers` should not be set, and `weights` should
77+
be set to the posterior evaluation for the grid point.
7378
7479
Returns
7580
-------
@@ -104,6 +109,11 @@ def add_chain(self, chain, parameters=None, name=None, weights=None, posterior=N
104109
if self.default_parameters is None and parameters is not None:
105110
self.default_parameters = parameters
106111

112+
self.grids.append(grid)
113+
if grid:
114+
assert walkers is None, "If grid is set, walkers should not be"
115+
assert weights is not None, "If grid is set, you need to supply weights"
116+
107117
if parameters is None:
108118
if self.default_parameters is not None:
109119
assert chain.shape[1] == len(self.default_parameters), \
@@ -399,11 +409,11 @@ def get_summary(self, squeeze=True):
399409
One entry per chain, parameter bounds stored in dictionary with parameter as key
400410
"""
401411
results = []
402-
for ind, (chain, parameters, weights) in enumerate(zip(self.chains,
403-
self.parameters, self.weights)):
412+
for ind, (chain, parameters, weights, g) in enumerate(zip(self.chains,
413+
self.parameters, self.weights, self.grids)):
404414
res = {}
405415
for i, p in enumerate(parameters):
406-
summary = self._get_parameter_summary(chain[:, i], weights, p, ind)
416+
summary = self._get_parameter_summary(chain[:, i], weights, p, ind, grid=g)
407417
res[p] = summary
408418
results.append(res)
409419
if squeeze and len(results) == 1:
@@ -742,13 +752,13 @@ def plot(self, figsize="GROW", parameters=None, extents=None, filename=None,
742752
do_flip = (flip and i == len(params1) - 1)
743753
if plot_hists and i == j:
744754
max_val = None
745-
for chain, weights, parameters, colour, bins, fit, ls, bs, lw in \
755+
for chain, weights, parameters, colour, bins, fit, ls, bs, lw, g in \
746756
zip(self.chains, self.weights, self.parameters, colours,
747-
num_bins, fit_values, linestyles, bar_shades, linewidths):
757+
num_bins, fit_values, linestyles, bar_shades, linewidths, self.grids):
748758
if p1 not in parameters:
749759
continue
750760
index = parameters.index(p1)
751-
m = self._plot_bars(ax, p1, chain[:, index], weights, colour, ls, bs, lw, bins=bins,
761+
m = self._plot_bars(ax, p1, chain[:, index], weights, colour, ls, bs, lw, g, bins=bins,
752762
fit_values=fit[p1], flip=do_flip, summary=summary,
753763
truth=truth, extents=extents[p1])
754764
if max_val is None or m > max_val:
@@ -759,15 +769,15 @@ def plot(self, figsize="GROW", parameters=None, extents=None, filename=None,
759769
ax.set_ylim(0, 1.1 * max_val)
760770

761771
else:
762-
for chain, parameters, bins, colour, ls, s, sa, lw, fit, weights in \
772+
for chain, parameters, bins, colour, ls, s, sa, lw, fit, weights, g in \
763773
zip(self.chains, self.parameters, num_bins, colours, linestyles, shades,
764-
shade_alphas, linewidths, fit_values, self.weights):
774+
shade_alphas, linewidths, fit_values, self.weights, self.grids):
765775
if p1 not in parameters or p2 not in parameters:
766776
continue
767777
i1 = parameters.index(p1)
768778
i2 = parameters.index(p2)
769779
self._plot_contour(ax, chain[:, i2], chain[:, i1], weights, p1, p2, colour, ls,
770-
s, sa, lw, bins=bins, truth=truth)
780+
s, sa, lw, g, bins=bins, truth=truth)
771781

772782
if self.names is not None and legend:
773783
ax = axes[0, -1]
@@ -1084,14 +1094,16 @@ def _plot_walk(self, ax, parameter, data, truth=None, extents=None,
10841094
ax.axhline(truth, **self.parameters_truth)
10851095

10861096
def _plot_bars(self, ax, parameter, chain_row, weights, colour, linestyle, bar_shade,
1087-
linewidth, bins=25, flip=False, summary=False, fit_values=None,
1097+
linewidth, grid, bins=25, flip=False, summary=False, fit_values=None,
10881098
truth=None, extents=None): # pragma: no cover
10891099

10901100
kde = self.parameters_general["kde"]
10911101
smooth = self.parameters_general["smooth"]
10921102
bins, smooth = self._get_smoothed_bins(smooth, bins)
1093-
1094-
bins = np.linspace(extents[0], extents[1], bins)
1103+
if grid:
1104+
bins = self._get_grid_bins(chain_row)
1105+
else:
1106+
bins = np.linspace(extents[0], extents[1], bins)
10951107
hist, edges = np.histogram(chain_row, bins=bins, normed=True, weights=weights)
10961108
edge_center = 0.5 * (edges[:-1] + edges[1:])
10971109
if smooth:
@@ -1149,17 +1161,22 @@ def _plot_bars(self, ax, parameter, chain_row, weights, colour, linestyle, bar_s
11491161
return hist.max()
11501162

11511163
def _plot_contour(self, ax, x, y, w, px, py, colour, linestyle, shade,
1152-
shade_alpha, linewidth, bins=25, truth=None): # pragma: no cover
1164+
shade_alpha, linewidth, grid, bins=25, truth=None): # pragma: no cover
11531165

11541166
levels = 1.0 - np.exp(-0.5 * self.parameters_contour["sigmas"] ** 2)
11551167
smooth = self.parameters_general["smooth"]
1156-
bins, smooth = self._get_smoothed_bins(smooth, bins, marginalsied=False)
1168+
if grid:
1169+
binsx = self._get_grid_bins(x)
1170+
binsy = self._get_grid_bins(y)
1171+
hist, x_bins, y_bins = np.histogram2d(x, y, bins=[binsx, binsy], weights=w)
1172+
else:
1173+
bins, smooth = self._get_smoothed_bins(smooth, bins, marginalsied=False)
1174+
hist, x_bins, y_bins = np.histogram2d(x, y, bins=bins, weights=w)
11571175

11581176
colours = self._scale_colours(colour, len(levels))
11591177
colours2 = [self._scale_colour(colours[0], 0.7)] + \
11601178
[self._scale_colour(c, 0.8) for c in colours[:-1]]
11611179

1162-
hist, x_bins, y_bins = np.histogram2d(x, y, bins=bins, weights=w)
11631180
x_centers = 0.5 * (x_bins[:-1] + x_bins[1:])
11641181
y_centers = 0.5 * (y_bins[:-1] + y_bins[1:])
11651182
if smooth:
@@ -1224,16 +1241,18 @@ def _get_figure(self, all_parameters, flip, figsize=(5, 5),
12241241
if external_extents is not None and p in external_extents:
12251242
min_val, max_val = external_extents[p]
12261243
else:
1227-
for chain, parameters in zip(self.chains, self.parameters):
1244+
for i, (chain, parameters) in enumerate(zip(self.chains, self.parameters)):
12281245
if p not in parameters:
12291246
continue
12301247
index = parameters.index(p)
1231-
# min_val = chain[:, index].min()
1232-
# max_val = chain[:, index].max()
1233-
mean = np.mean(chain[:, index])
1234-
std = np.std(chain[:, index])
1235-
min_prop = mean - sigma_extent * std
1236-
max_prop = mean + sigma_extent* std
1248+
if self.grids[i]:
1249+
min_prop = chain[:, index].min()
1250+
max_prop = chain[:, index].max()
1251+
else:
1252+
mean = np.mean(chain[:, index])
1253+
std = np.std(chain[:, index])
1254+
min_prop = mean - sigma_extent * std
1255+
max_prop = mean + sigma_extent * std
12371256
if min_val is None or min_prop < min_val:
12381257
min_val = min_prop
12391258
if max_val is None or max_prop > max_val:
@@ -1333,10 +1352,19 @@ def _get_smoothed_bins(self, smooth, bins, marginalsied=True):
13331352
else:
13341353
return ((3 if marginalsied else 2) * smooth * bins), smooth
13351354

1336-
def _get_smoothed_histogram(self, data, weights, chain_index):
1355+
def _get_grid_bins(self, data):
1356+
bin_c = sorted(np.unique(data))
1357+
delta = 0.5 * (bin_c[1] - bin_c[0])
1358+
bins = np.concatenate((bin_c - delta, [bin_c[-1] + delta]))
1359+
return bins
1360+
1361+
def _get_smoothed_histogram(self, data, weights, chain_index, grid):
13371362
smooth = self.parameters_general["smooth"]
1338-
bins = self.parameters_general['bins'][chain_index]
1339-
bins, smooth = self._get_smoothed_bins(smooth, bins)
1363+
if grid:
1364+
bins = self._get_grid_bins(data)
1365+
else:
1366+
bins = self.parameters_general['bins'][chain_index]
1367+
bins, smooth = self._get_smoothed_bins(smooth, bins)
13401368
hist, edges = np.histogram(data, bins=bins, normed=True, weights=weights)
13411369
edge_centers = 0.5 * (edges[1:] + edges[:-1])
13421370
xs = np.linspace(edge_centers[0], edge_centers[-1], 10000)
@@ -1363,21 +1391,21 @@ def _get_parameter_summary(self, data, weights, parameter, chain_index, **kwargs
13631391
method = self.summaries[self.parameters_general["statistics"][chain_index]]
13641392
return method(data, weights, parameter, chain_index, **kwargs)
13651393

1366-
def _get_parameter_summary_mean(self, data, weights, parameter, chain_index, desired_area=0.6827):
1367-
xs, ys, cs = self._get_smoothed_histogram(data, weights, chain_index)
1394+
def _get_parameter_summary_mean(self, data, weights, parameter, chain_index, desired_area=0.6827, grid=False):
1395+
xs, ys, cs = self._get_smoothed_histogram(data, weights, chain_index, grid)
13681396
vals = [0.5 - desired_area / 2, 0.5, 0.5 + desired_area / 2]
13691397
bounds = interp1d(cs, xs)(vals)
13701398
bounds[1] = 0.5 * (bounds[0] + bounds[2])
13711399
return bounds
13721400

1373-
def _get_parameter_summary_cumulative(self, data, weights, parameter, chain_index, desired_area=0.6827):
1374-
xs, ys, cs = self._get_smoothed_histogram(data, weights, chain_index)
1401+
def _get_parameter_summary_cumulative(self, data, weights, parameter, chain_index, desired_area=0.6827, grid=False):
1402+
xs, ys, cs = self._get_smoothed_histogram(data, weights, chain_index, grid)
13751403
vals = [0.5 - desired_area / 2, 0.5, 0.5 + desired_area / 2]
13761404
bounds = interp1d(cs, xs)(vals)
13771405
return bounds
13781406

1379-
def _get_parameter_summary_max(self, data, weights, parameter, chain_index, desired_area=0.6827):
1380-
xs, ys, cs = self._get_smoothed_histogram(data, weights, chain_index)
1407+
def _get_parameter_summary_max(self, data, weights, parameter, chain_index, desired_area=0.6827, grid=False):
1408+
xs, ys, cs = self._get_smoothed_histogram(data, weights, chain_index, grid)
13811409
startIndex = ys.argmax()
13821410
maxVal = ys[startIndex]
13831411
minVal = 0

doc/usage.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ The process of using ChainConsumer should be straightforward:
1313
1. Create an instance of ChainConsumer.
1414
2. Add your chains to this instance.
1515
3. Run convergence diagnostics, if desired.
16-
4. Update the configurations if needed.
16+
4. Update the configurations if needed (make sure you do this *after* loading in the data).
1717
5. Plot.
1818

1919
The main page and the examples page has code demonstrating these,

examples/Basics/plot_convergence.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
# of the chain isn't agreeing with anything else!
2323
data_good = normal(size=100000)
2424
data_bad = data_good.copy()
25-
data_bad[98000:] += 2.0
26-
data_bad[:1000] *= 2.0
25+
data_bad += np.linspace(-0.5, 0.5, 100000)
26+
data_bad[98000:] += 2
2727

2828
# Lets load it into ChainConsumer, and pretend 10 walks went into making the chain
2929
c = ChainConsumer()

examples/Basics/plot_grid.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
"""
2+
==========
3+
Grid Data!
4+
==========
5+
6+
If you don't have Monte Carlo chains, and have grid evaluations instead, that's fine too!
7+
8+
Just flatten your grid, set the weights to the grid evaluation, and set the grid flag.
9+
10+
Here is a nice diamond that you get from modifying a simple multivariate normal distribution.
11+
12+
"""
13+
import numpy as np
14+
from numpy.random import normal, multivariate_normal
15+
from chainconsumer import ChainConsumer
16+
17+
18+
if __name__ == "__main__":
19+
xx, yy = np.meshgrid(np.linspace(-3, 3, 100), np.linspace(-7, 7, 100))
20+
xs, ys = xx.flatten(), yy.flatten()
21+
data = np.vstack((xs, ys)).T
22+
pdf = (1 / (2 * np.pi)) * np.exp(-0.5 * (xs * xs + ys * ys / 4 + np.abs(xs * ys)))
23+
24+
c = ChainConsumer()
25+
c = ChainConsumer()
26+
c.add_chain(data, parameters=["$x$", "$y$"], weights=pdf, grid=True)
27+
fig = c.plot()
28+
29+
fig.set_size_inches(3.5 + fig.get_size_inches()) # Resize fig for doco. You don't need this.

test_chain.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import tempfile
33

44
import numpy as np
5-
from scipy.stats import skewnorm, norm
5+
from scipy.stats import skewnorm, norm, multivariate_normal
66
import pytest
77

88
from chainconsumer import ChainConsumer
@@ -547,4 +547,21 @@ def test_geweke_default_failed(self):
547547
data2 = data.copy()
548548
data2[98000:, :] += 0.3
549549
consumer.add_chain(data2, walkers=20, name="c2")
550-
assert not consumer.diagnostic_geweke()
550+
assert not consumer.diagnostic_geweke()
551+
552+
def test_grid_data(self):
553+
xx, yy = np.meshgrid(np.linspace(-3, 3, 100), np.linspace(-5, 5, 100))
554+
xs, ys = xx.flatten(), yy.flatten()
555+
chain = np.vstack((xs, ys)).T
556+
pdf = (1 / (2 * np.pi)) * np.exp(-0.5 * (xs * xs + ys * ys / 4))
557+
c = ChainConsumer()
558+
c.add_chain(chain, parameters=['x','y'], weights=pdf, grid=True)
559+
c.configure_general(smooth=1)
560+
summary = c.get_summary()
561+
x_sum = summary['x']
562+
y_sum = summary['y']
563+
expected_x = np.array([-1.0, 0.0, 1.0])
564+
expected_y = np.array([-2.0, 0.0, 2.0])
565+
threshold = 0.05
566+
assert np.all(np.abs(expected_x - x_sum) < threshold)
567+
assert np.all(np.abs(expected_y - y_sum) < threshold)

0 commit comments

Comments
 (0)