@@ -16,7 +16,7 @@ class ChainConsumer(object):
1616 """ A class for consuming chains produced by an MCMC walk
1717
1818 """
19- __version__ = "0.11.3 "
19+ __version__ = "0.12.0 "
2020
2121 def __init__ (self ):
2222 logging .basicConfig ()
@@ -30,6 +30,7 @@ def __init__(self):
3030 self .names = []
3131 self .parameters = []
3232 self .all_parameters = []
33+ self .grids = []
3334 self .default_parameters = None
3435 self ._configured_bar = False
3536 self ._configured_contour = False
@@ -45,7 +46,7 @@ def __init__(self):
4546 "cumulative" : self ._get_parameter_summary_cumulative
4647 }
4748
48- def add_chain (self , chain , parameters = None , name = None , weights = None , posterior = None , walkers = None ):
49+ def add_chain (self , chain , parameters = None , name = None , weights = None , posterior = None , walkers = None , grid = False ):
4950 """ Add a chain to the consumer.
5051
5152 Parameters
@@ -70,6 +71,10 @@ def add_chain(self, chain, parameters=None, name=None, weights=None, posterior=N
7071 How many walkers went into creating the chain. Each walker should
7172 contribute the same number of steps, and should appear in contiguous
7273 blocks in the final chain.
74+ grid : boolean, optional
75+ Whether the input is a flattened chain from a grid search instead of a Monte-Carlo
76+ chains. Note that when this is set, `walkers` should not be set, and `weights` should
77+ be set to the posterior evaluation for the grid point.
7378
7479 Returns
7580 -------
@@ -104,6 +109,11 @@ def add_chain(self, chain, parameters=None, name=None, weights=None, posterior=N
104109 if self .default_parameters is None and parameters is not None :
105110 self .default_parameters = parameters
106111
112+ self .grids .append (grid )
113+ if grid :
114+ assert walkers is None , "If grid is set, walkers should not be"
115+ assert weights is not None , "If grid is set, you need to supply weights"
116+
107117 if parameters is None :
108118 if self .default_parameters is not None :
109119 assert chain .shape [1 ] == len (self .default_parameters ), \
@@ -399,11 +409,11 @@ def get_summary(self, squeeze=True):
399409 One entry per chain, parameter bounds stored in dictionary with parameter as key
400410 """
401411 results = []
402- for ind , (chain , parameters , weights ) in enumerate (zip (self .chains ,
403- self .parameters , self .weights )):
412+ for ind , (chain , parameters , weights , g ) in enumerate (zip (self .chains ,
413+ self .parameters , self .weights , self . grids )):
404414 res = {}
405415 for i , p in enumerate (parameters ):
406- summary = self ._get_parameter_summary (chain [:, i ], weights , p , ind )
416+ summary = self ._get_parameter_summary (chain [:, i ], weights , p , ind , grid = g )
407417 res [p ] = summary
408418 results .append (res )
409419 if squeeze and len (results ) == 1 :
@@ -742,13 +752,13 @@ def plot(self, figsize="GROW", parameters=None, extents=None, filename=None,
742752 do_flip = (flip and i == len (params1 ) - 1 )
743753 if plot_hists and i == j :
744754 max_val = None
745- for chain , weights , parameters , colour , bins , fit , ls , bs , lw in \
755+ for chain , weights , parameters , colour , bins , fit , ls , bs , lw , g in \
746756 zip (self .chains , self .weights , self .parameters , colours ,
747- num_bins , fit_values , linestyles , bar_shades , linewidths ):
757+ num_bins , fit_values , linestyles , bar_shades , linewidths , self . grids ):
748758 if p1 not in parameters :
749759 continue
750760 index = parameters .index (p1 )
751- m = self ._plot_bars (ax , p1 , chain [:, index ], weights , colour , ls , bs , lw , bins = bins ,
761+ m = self ._plot_bars (ax , p1 , chain [:, index ], weights , colour , ls , bs , lw , g , bins = bins ,
752762 fit_values = fit [p1 ], flip = do_flip , summary = summary ,
753763 truth = truth , extents = extents [p1 ])
754764 if max_val is None or m > max_val :
@@ -759,15 +769,15 @@ def plot(self, figsize="GROW", parameters=None, extents=None, filename=None,
759769 ax .set_ylim (0 , 1.1 * max_val )
760770
761771 else :
762- for chain , parameters , bins , colour , ls , s , sa , lw , fit , weights in \
772+ for chain , parameters , bins , colour , ls , s , sa , lw , fit , weights , g in \
763773 zip (self .chains , self .parameters , num_bins , colours , linestyles , shades ,
764- shade_alphas , linewidths , fit_values , self .weights ):
774+ shade_alphas , linewidths , fit_values , self .weights , self . grids ):
765775 if p1 not in parameters or p2 not in parameters :
766776 continue
767777 i1 = parameters .index (p1 )
768778 i2 = parameters .index (p2 )
769779 self ._plot_contour (ax , chain [:, i2 ], chain [:, i1 ], weights , p1 , p2 , colour , ls ,
770- s , sa , lw , bins = bins , truth = truth )
780+ s , sa , lw , g , bins = bins , truth = truth )
771781
772782 if self .names is not None and legend :
773783 ax = axes [0 , - 1 ]
@@ -1084,14 +1094,16 @@ def _plot_walk(self, ax, parameter, data, truth=None, extents=None,
10841094 ax .axhline (truth , ** self .parameters_truth )
10851095
10861096 def _plot_bars (self , ax , parameter , chain_row , weights , colour , linestyle , bar_shade ,
1087- linewidth , bins = 25 , flip = False , summary = False , fit_values = None ,
1097+ linewidth , grid , bins = 25 , flip = False , summary = False , fit_values = None ,
10881098 truth = None , extents = None ): # pragma: no cover
10891099
10901100 kde = self .parameters_general ["kde" ]
10911101 smooth = self .parameters_general ["smooth" ]
10921102 bins , smooth = self ._get_smoothed_bins (smooth , bins )
1093-
1094- bins = np .linspace (extents [0 ], extents [1 ], bins )
1103+ if grid :
1104+ bins = self ._get_grid_bins (chain_row )
1105+ else :
1106+ bins = np .linspace (extents [0 ], extents [1 ], bins )
10951107 hist , edges = np .histogram (chain_row , bins = bins , normed = True , weights = weights )
10961108 edge_center = 0.5 * (edges [:- 1 ] + edges [1 :])
10971109 if smooth :
@@ -1149,17 +1161,22 @@ def _plot_bars(self, ax, parameter, chain_row, weights, colour, linestyle, bar_s
11491161 return hist .max ()
11501162
11511163 def _plot_contour (self , ax , x , y , w , px , py , colour , linestyle , shade ,
1152- shade_alpha , linewidth , bins = 25 , truth = None ): # pragma: no cover
1164+ shade_alpha , linewidth , grid , bins = 25 , truth = None ): # pragma: no cover
11531165
11541166 levels = 1.0 - np .exp (- 0.5 * self .parameters_contour ["sigmas" ] ** 2 )
11551167 smooth = self .parameters_general ["smooth" ]
1156- bins , smooth = self ._get_smoothed_bins (smooth , bins , marginalsied = False )
1168+ if grid :
1169+ binsx = self ._get_grid_bins (x )
1170+ binsy = self ._get_grid_bins (y )
1171+ hist , x_bins , y_bins = np .histogram2d (x , y , bins = [binsx , binsy ], weights = w )
1172+ else :
1173+ bins , smooth = self ._get_smoothed_bins (smooth , bins , marginalsied = False )
1174+ hist , x_bins , y_bins = np .histogram2d (x , y , bins = bins , weights = w )
11571175
11581176 colours = self ._scale_colours (colour , len (levels ))
11591177 colours2 = [self ._scale_colour (colours [0 ], 0.7 )] + \
11601178 [self ._scale_colour (c , 0.8 ) for c in colours [:- 1 ]]
11611179
1162- hist , x_bins , y_bins = np .histogram2d (x , y , bins = bins , weights = w )
11631180 x_centers = 0.5 * (x_bins [:- 1 ] + x_bins [1 :])
11641181 y_centers = 0.5 * (y_bins [:- 1 ] + y_bins [1 :])
11651182 if smooth :
@@ -1224,16 +1241,18 @@ def _get_figure(self, all_parameters, flip, figsize=(5, 5),
12241241 if external_extents is not None and p in external_extents :
12251242 min_val , max_val = external_extents [p ]
12261243 else :
1227- for chain , parameters in zip (self .chains , self .parameters ):
1244+ for i , ( chain , parameters ) in enumerate ( zip (self .chains , self .parameters ) ):
12281245 if p not in parameters :
12291246 continue
12301247 index = parameters .index (p )
1231- # min_val = chain[:, index].min()
1232- # max_val = chain[:, index].max()
1233- mean = np .mean (chain [:, index ])
1234- std = np .std (chain [:, index ])
1235- min_prop = mean - sigma_extent * std
1236- max_prop = mean + sigma_extent * std
1248+ if self .grids [i ]:
1249+ min_prop = chain [:, index ].min ()
1250+ max_prop = chain [:, index ].max ()
1251+ else :
1252+ mean = np .mean (chain [:, index ])
1253+ std = np .std (chain [:, index ])
1254+ min_prop = mean - sigma_extent * std
1255+ max_prop = mean + sigma_extent * std
12371256 if min_val is None or min_prop < min_val :
12381257 min_val = min_prop
12391258 if max_val is None or max_prop > max_val :
@@ -1333,10 +1352,19 @@ def _get_smoothed_bins(self, smooth, bins, marginalsied=True):
13331352 else :
13341353 return ((3 if marginalsied else 2 ) * smooth * bins ), smooth
13351354
1336- def _get_smoothed_histogram (self , data , weights , chain_index ):
1355+ def _get_grid_bins (self , data ):
1356+ bin_c = sorted (np .unique (data ))
1357+ delta = 0.5 * (bin_c [1 ] - bin_c [0 ])
1358+ bins = np .concatenate ((bin_c - delta , [bin_c [- 1 ] + delta ]))
1359+ return bins
1360+
1361+ def _get_smoothed_histogram (self , data , weights , chain_index , grid ):
13371362 smooth = self .parameters_general ["smooth" ]
1338- bins = self .parameters_general ['bins' ][chain_index ]
1339- bins , smooth = self ._get_smoothed_bins (smooth , bins )
1363+ if grid :
1364+ bins = self ._get_grid_bins (data )
1365+ else :
1366+ bins = self .parameters_general ['bins' ][chain_index ]
1367+ bins , smooth = self ._get_smoothed_bins (smooth , bins )
13401368 hist , edges = np .histogram (data , bins = bins , normed = True , weights = weights )
13411369 edge_centers = 0.5 * (edges [1 :] + edges [:- 1 ])
13421370 xs = np .linspace (edge_centers [0 ], edge_centers [- 1 ], 10000 )
@@ -1363,21 +1391,21 @@ def _get_parameter_summary(self, data, weights, parameter, chain_index, **kwargs
13631391 method = self .summaries [self .parameters_general ["statistics" ][chain_index ]]
13641392 return method (data , weights , parameter , chain_index , ** kwargs )
13651393
1366- def _get_parameter_summary_mean (self , data , weights , parameter , chain_index , desired_area = 0.6827 ):
1367- xs , ys , cs = self ._get_smoothed_histogram (data , weights , chain_index )
1394+ def _get_parameter_summary_mean (self , data , weights , parameter , chain_index , desired_area = 0.6827 , grid = False ):
1395+ xs , ys , cs = self ._get_smoothed_histogram (data , weights , chain_index , grid )
13681396 vals = [0.5 - desired_area / 2 , 0.5 , 0.5 + desired_area / 2 ]
13691397 bounds = interp1d (cs , xs )(vals )
13701398 bounds [1 ] = 0.5 * (bounds [0 ] + bounds [2 ])
13711399 return bounds
13721400
1373- def _get_parameter_summary_cumulative (self , data , weights , parameter , chain_index , desired_area = 0.6827 ):
1374- xs , ys , cs = self ._get_smoothed_histogram (data , weights , chain_index )
1401+ def _get_parameter_summary_cumulative (self , data , weights , parameter , chain_index , desired_area = 0.6827 , grid = False ):
1402+ xs , ys , cs = self ._get_smoothed_histogram (data , weights , chain_index , grid )
13751403 vals = [0.5 - desired_area / 2 , 0.5 , 0.5 + desired_area / 2 ]
13761404 bounds = interp1d (cs , xs )(vals )
13771405 return bounds
13781406
1379- def _get_parameter_summary_max (self , data , weights , parameter , chain_index , desired_area = 0.6827 ):
1380- xs , ys , cs = self ._get_smoothed_histogram (data , weights , chain_index )
1407+ def _get_parameter_summary_max (self , data , weights , parameter , chain_index , desired_area = 0.6827 , grid = False ):
1408+ xs , ys , cs = self ._get_smoothed_histogram (data , weights , chain_index , grid )
13811409 startIndex = ys .argmax ()
13821410 maxVal = ys [startIndex ]
13831411 minVal = 0
0 commit comments