@@ -43,6 +43,8 @@ import pycbc
4343from pycbc .inference .io import (ResultsArgumentParser , results_from_cli ,
4444 PosteriorFile , loadfile )
4545from pycbc .inference .io .base_hdf import format_attr
46+ from scipy .special import logsumexp
47+ import warnings
4648
4749
4850def isthesame (current_val , val ):
@@ -83,6 +85,35 @@ parser.add_argument("--skip-groups", default=None, nargs="+",
8385 "to write all groups if only one file is provided, "
8486 "and all groups from the first file except "
8587 "sampler_info if multiple files are provided." )
88+ parser .add_argument ("--combine-via-sampling" , action = 'store_true' ,
89+ default = False ,
90+ help = "Specify whether to combine the posteriors by "
91+ "sampling. By default, extract_samples will dump all "
92+ "samples from multiple inputs into one output file. "
93+ "If this option is specified, the samples will "
94+ "be randomly sampled with weighting based on the "
95+ "evidence in each input. For example, if the "
96+ "evidence of input 1 is twice that of input 2, "
97+ "the resulting posterior file with this option "
98+ "specified will have twice as many points from input "
99+ "1 than input 2. The output will have the same "
100+ "number of samples as the smallest input file. "
101+ "An error is thrown if this option is specified and "
102+ "any of the input files does not have a log_evidence "
103+ "attribute (e.g., the file used a sampler like emcee "
104+ "that does not report evidence). This option "
105+ "assumes that the priors of all files do not "
106+ "overlap; we cannot properly combine posteriors if "
107+ "the priors do overlap." )
108+ parser .add_argument ("--mutually-exclusive-priors" , action = 'store_true' ,
109+ default = False ,
110+ help = "If specifying --combine-via-sampling, specify "
111+ "whether to treat priors as mutually exclusive. By "
112+ "default, the provided input files are assumed to "
113+ "have identical priors, and their evidences will be "
114+ "averaged. If this option is specified, the priors "
115+ "are assumed to be non-overlapping across files, and "
116+ "their evidences will be summed together." )
86117
87118opts = parser .parse_args ()
88119
@@ -101,10 +132,52 @@ if len(opts.input_file) == 1:
101132# convert samples to a dict in which the keys are the labels
102133# also stack results if multiple files were provided
103134if len (opts .input_file ) > 1 :
104- samples = {labels [p ]: numpy .concatenate ([s [p ] for s in samples ])
105- for p in params }
135+ if opts .combine_via_sampling :
136+ raw_samples = {labels [p ]: numpy .concatenate ([s [p ] for s in samples ])
137+ for p in params }
138+ logz_list = []
139+ dlogz_list = []
140+ len_list = []
141+ raw_samps_list = []
142+ weights_list = []
143+ for file in opts .input_file :
144+ fp = loadfile (file , 'r' )
145+ # get evidence from each file if possible
146+ try :
147+ logz , dlogz = fp .log_evidence
148+ except KeyError :
149+ raise ValueError (f"Cannot combine evidences; file { file } "
150+ "does not have a log_evidence attr" )
151+ logz_list .append (logz )
152+ dlogz_list .append (dlogz )
153+ # get samples from each file
154+ file_samps = fp .read_samples (list (fp ['samples' ].keys ()))
155+ raw_samps_list .append (file_samps )
156+ # get the number of samples from each file
157+ len_list .append (len (file_samps ))
158+ # compute sampling weights from evidences
159+ logz_net = logsumexp (logz_list )
160+ len_net = sum (len_list )
161+ out_size = min (len_list )
162+ for i in range (len (opts .input_file )):
163+ # weight each file's samples according to logz
164+ logwt = logz_list [i ] - logz_net
165+ weights_list .append ([numpy .exp (logwt )/ len_list [i ] for j in
166+ range (len_list [i ])])
167+ # randomly sample indices from all samples
168+ weights = numpy .concatenate (weights_list )
169+ idx = numpy .random .choice (int (len_net ), size = out_size , replace = True ,
170+ p = weights )
171+ samples = {param : raw_samples [param ][idx ] for param in
172+ raw_samples .keys ()}
173+ else :
174+ samples = {labels [p ]: numpy .concatenate ([s [p ] for s in samples ])
175+ for p in params }
106176else :
107177 samples = {labels [p ]: samples [p ] for p in params }
178+ if opts .combine_via_sampling :
179+ warnings .warn ("Specified combine_via_sampling with only one input "
180+ "file. This option will have no effect." )
108181
109182# create the file
110183outtype = PosteriorFile .name
@@ -137,10 +210,13 @@ for fp in fps:
137210skip_attrs = ['filetype' , 'thin_start' , 'thin_interval' , 'thin_end' ,
138211 'thinned_by' , 'cmd' , 'resume_points' , 'effective_nsamples' ,
139212 'run_start_time' , 'run_end_time' ]
140- # also skip evidence if multiple files are being combined, since that will
141- # not be the same
213+ # also skip evidence if multiple files are being combined; this will be handled
214+ # via sampling if specified
142215if len (opts .input_file ) > 1 :
143216 skip_attrs += ['log_evidence' , 'dlog_evidence' ]
217+
218+ # make sure attrs are the same between files...
219+ cat_params = False
144220for fp in fps :
145221 for key in map (format_attr , fp .attrs ):
146222 if key not in skip_attrs :
@@ -151,12 +227,36 @@ for fp in fps:
151227 out .attrs [key ] = val
152228 current_val = format_attr (out .attrs [key ])
153229 if not isthesame (current_val , val ):
154- raise ValueError ("cannot combine all files; file attr {} is "
155- "not the same across all files ({} vs {})"
156- .format (key , current_val , val ))
230+ if key == 'remapped_params' :
231+ # ...unless it's remapped_params; just save first file's
232+ # entries if they don't match between files
233+ warnings .warn ("WARNING: remapped_params metadata does not "
234+ "match between files; saving metadata from "
235+ "first file" )
236+ else :
237+ raise ValueError ("cannot combine all files; file attr {} is "
238+ "not the same across all files ({} vs {})"
239+ .format (key , current_val , val ))
157240
158- # store what parameters were renamed
159- out .attrs ['remapped_params' ] = list (labels .items ())
241+ # store what parameters were renamed (if not already saved above)
242+ if not cat_params :
243+ out .attrs ['remapped_params' ] = list (labels .items ())
244+
245+ # write combined evidence and dlog evidence if combining via sampling
246+ if opts .combine_via_sampling :
247+ if len (opts .input_file ) == 1 :
248+ # there's only one file; just return what's in the input
249+ out .attrs ['log_evidence' ] = fps [0 ].log_evidence [0 ]
250+ out .attrs ['dlog_evidence' ] = fps [0 ].log_evidence [1 ]
251+ elif opts .mutually_exclusive_priors :
252+ # add together the evidences; quadrature sum the residuals
253+ out .attrs ['log_evidence' ] = logsumexp (logz_list )
254+ out .attrs ['dlog_evidence' ] = numpy .sqrt (sum ([i ** 2 for i in dlogz_list ]))
255+ else :
256+ # average the evidences; quadrature sum and scale the residuals
257+ n = len (opts .input_file )
258+ out .attrs ['log_evidence' ] = logsumexp (logz_list ) - numpy .log (n )
259+ out .attrs ['dlog_evidence' ] = numpy .sqrt (sum ([i ** 2 for i in dlogz_list ])) / n
160260
161261# write the other groups using the first file
162262fp = fps [0 ]
0 commit comments