Skip to content

Commit d791555

Browse files
Add functionality to combine posteriors via evidence sampling in extract_samples (#5168)
* add functionality to combine posteriors by sampling points based on evidence values * add extra changes * clean up line breaks and docs * add handles for overlapping priors and mutually exclusive priors * bugfix evidence writing with one file; alphabetize remapped params * remove overlapping priors option * skip saving remapped_params when combining posteriors via sampling * fix to last commit * don't avoid remapped_params; try saving only the first remapped_params as long as output names are the same * brute-force combine remapped_params between multiple files * update docs * only save first file's remap metadata if not equal between files
1 parent c600c9f commit d791555

File tree

1 file changed

+109
-9
lines changed

1 file changed

+109
-9
lines changed

bin/inference/pycbc_inference_extract_samples

Lines changed: 109 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ import pycbc
4343
from pycbc.inference.io import (ResultsArgumentParser, results_from_cli,
4444
PosteriorFile, loadfile)
4545
from pycbc.inference.io.base_hdf import format_attr
46+
from scipy.special import logsumexp
47+
import warnings
4648

4749

4850
def isthesame(current_val, val):
@@ -83,6 +85,35 @@ parser.add_argument("--skip-groups", default=None, nargs="+",
8385
"to write all groups if only one file is provided, "
8486
"and all groups from the first file except "
8587
"sampler_info if multiple files are provided.")
88+
parser.add_argument("--combine-via-sampling", action='store_true',
89+
default=False,
90+
help="Specify whether to combine the posteriors by "
91+
"sampling. By default, extract_samples will dump all "
92+
"samples from multiple inputs into one output file. "
93+
"If this option is specified, the samples will "
94+
"be randomly sampled with weighting based on the "
95+
"evidence in each input. For example, if the "
96+
"evidence of input 1 is twice that of input 2, "
97+
"the resulting posterior file with this option "
98+
"specified will have twice as many points from input "
99+
"1 than input 2. The output will have the same "
100+
"number of samples as the smallest input file. "
101+
"An error is thrown if this option is specified and "
102+
"any of the input files does not have a log_evidence "
103+
"attribute (e.g., the file used a sampler like emcee "
104+
"that does not report evidence). This option "
105+
"assumes that the priors of all files do not "
106+
"overlap; we cannot properly combine posteriors if "
107+
"the priors do overlap.")
108+
parser.add_argument("--mutually-exclusive-priors", action='store_true',
109+
default=False,
110+
help="If specifying --combine-via-sampling, specify "
111+
"whether to treat priors as mutually exclusive. By "
112+
"default, the provided input files are assumed to "
113+
"have identical priors, and their evidences will be "
114+
"averaged. If this option is specified, the priors "
115+
"are assumed to be non-overlapping across files, and "
116+
"their evidences will be summed together.")
86117

87118
opts = parser.parse_args()
88119

@@ -101,10 +132,52 @@ if len(opts.input_file) == 1:
101132
# convert samples to a dict in which the keys are the labels
102133
# also stack results if multiple files were provided
103134
if len(opts.input_file) > 1:
104-
samples = {labels[p]: numpy.concatenate([s[p] for s in samples])
105-
for p in params}
135+
if opts.combine_via_sampling:
136+
raw_samples = {labels[p]: numpy.concatenate([s[p] for s in samples])
137+
for p in params}
138+
logz_list = []
139+
dlogz_list = []
140+
len_list = []
141+
raw_samps_list = []
142+
weights_list = []
143+
for file in opts.input_file:
144+
fp = loadfile(file, 'r')
145+
# get evidence from each file if possible
146+
try:
147+
logz, dlogz = fp.log_evidence
148+
except KeyError:
149+
raise ValueError(f"Cannot combine evidences; file {file} "
150+
"does not have a log_evidence attr")
151+
logz_list.append(logz)
152+
dlogz_list.append(dlogz)
153+
# get samples from each file
154+
file_samps = fp.read_samples(list(fp['samples'].keys()))
155+
raw_samps_list.append(file_samps)
156+
# get the number of samples from each file
157+
len_list.append(len(file_samps))
158+
# compute sampling weights from evidences
159+
logz_net = logsumexp(logz_list)
160+
len_net = sum(len_list)
161+
out_size = min(len_list)
162+
for i in range(len(opts.input_file)):
163+
# weight each file's samples according to logz
164+
logwt = logz_list[i] - logz_net
165+
weights_list.append([numpy.exp(logwt)/len_list[i] for j in
166+
range(len_list[i])])
167+
# randomly sample indices from all samples
168+
weights = numpy.concatenate(weights_list)
169+
idx = numpy.random.choice(int(len_net), size=out_size, replace=True,
170+
p=weights)
171+
samples = {param: raw_samples[param][idx] for param in
172+
raw_samples.keys()}
173+
else:
174+
samples = {labels[p]: numpy.concatenate([s[p] for s in samples])
175+
for p in params}
106176
else:
107177
samples = {labels[p]: samples[p] for p in params}
178+
if opts.combine_via_sampling:
179+
warnings.warn("Specified combine_via_sampling with only one input "
180+
"file. This option will have no effect.")
108181

109182
# create the file
110183
outtype = PosteriorFile.name
@@ -137,10 +210,13 @@ for fp in fps:
137210
skip_attrs = ['filetype', 'thin_start', 'thin_interval', 'thin_end',
138211
'thinned_by', 'cmd', 'resume_points', 'effective_nsamples',
139212
'run_start_time', 'run_end_time']
140-
# also skip evidence if multiple files are being combined, since that will
141-
# not be the same
213+
# also skip evidence if multiple files are being combined; this will be handled
214+
# via sampling if specified
142215
if len(opts.input_file) > 1:
143216
skip_attrs += ['log_evidence', 'dlog_evidence']
217+
218+
# make sure attrs are the same between files...
219+
cat_params = False
144220
for fp in fps:
145221
for key in map(format_attr, fp.attrs):
146222
if key not in skip_attrs:
@@ -151,12 +227,36 @@ for fp in fps:
151227
out.attrs[key] = val
152228
current_val = format_attr(out.attrs[key])
153229
if not isthesame(current_val, val):
154-
raise ValueError("cannot combine all files; file attr {} is "
155-
"not the same across all files ({} vs {})"
156-
.format(key, current_val, val))
230+
if key == 'remapped_params':
231+
# ...unless it's remapped_params; just save first file's
232+
# entries if they don't match between files
233+
warnings.warn("WARNING: remapped_params metadata does not "
234+
"match between files; saving metadata from "
235+
"first file")
236+
else:
237+
raise ValueError("cannot combine all files; file attr {} is "
238+
"not the same across all files ({} vs {})"
239+
.format(key, current_val, val))
157240

158-
# store what parameters were renamed
159-
out.attrs['remapped_params'] = list(labels.items())
241+
# store what parameters were renamed (if not already saved above)
242+
if not cat_params:
243+
out.attrs['remapped_params'] = list(labels.items())
244+
245+
# write combined evidence and dlog evidence if combining via sampling
246+
if opts.combine_via_sampling:
247+
if len(opts.input_file) == 1:
248+
# there's only one file; just return what's in the input
249+
out.attrs['log_evidence'] = fps[0].log_evidence[0]
250+
out.attrs['dlog_evidence'] = fps[0].log_evidence[1]
251+
elif opts.mutually_exclusive_priors:
252+
# add together the evidences; quadrature sum the residuals
253+
out.attrs['log_evidence'] = logsumexp(logz_list)
254+
out.attrs['dlog_evidence'] = numpy.sqrt(sum([i**2 for i in dlogz_list]))
255+
else:
256+
# average the evidences; quadrature sum and scale the residuals
257+
n = len(opts.input_file)
258+
out.attrs['log_evidence'] = logsumexp(logz_list) - numpy.log(n)
259+
out.attrs['dlog_evidence'] = numpy.sqrt(sum([i**2 for i in dlogz_list])) / n
160260

161261
# write the other groups using the first file
162262
fp = fps[0]

0 commit comments

Comments
 (0)