Skip to content

Commit 4f39687

Browse files
authored
Merge pull request #751 from stan-dev/fix/multichain-profile-file
Fix profile file output when running multiple chains in one process
2 parents fad7a69 + bf39084 commit 4f39687

File tree

5 files changed

+28
-18
lines changed

5 files changed

+28
-18
lines changed

cmdstanpy/install_cmdstan.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,8 @@ def retrieve_version(version: str, progress: bool = True) -> None:
523523
first = tar.next()
524524
if first is not None:
525525
top_dir = first.name
526+
else:
527+
top_dir = ''
526528
cmdstan_dir = f'cmdstan-{version}'
527529
if top_dir != cmdstan_dir:
528530
raise CmdStanInstallError(

cmdstanpy/install_cxx_toolchain.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,8 +235,10 @@ def get_toolchain_name() -> str:
235235
return ''
236236

237237

238+
# TODO(2.0): drop 3.5 support
238239
def get_url(version: str) -> str:
239240
"""Return URL for toolchain."""
241+
url = ''
240242
if platform.system() == 'Windows':
241243
if version == '4.0':
242244
# pylint: disable=line-too-long
@@ -277,6 +279,8 @@ def run_rtools_install(args: Dict[str, Any]) -> None:
277279

278280
if 'verbose' in args:
279281
verbose = args['verbose']
282+
else:
283+
verbose = False
280284

281285
install_dir = args['dir']
282286
if install_dir is None:

cmdstanpy/stanfit/runset.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -61,40 +61,40 @@ def __init__(
6161
self._base_outfile = (
6262
f'{args.model_name}-{datetime.now().strftime(time_fmt)}'
6363
)
64-
# per-process console messages
64+
# per-process outputs
6565
self._stdout_files = [''] * self._num_procs
66+
self._profile_files = [''] * self._num_procs # optional
6667
if one_process_per_chain:
6768
for i in range(chains):
6869
self._stdout_files[i] = self.file_path("-stdout.txt", id=i)
70+
if args.save_profile:
71+
self._profile_files[i] = self.file_path(
72+
".csv", extra="-profile", id=chain_ids[i]
73+
)
6974
else:
7075
self._stdout_files[0] = self.file_path("-stdout.txt")
76+
if args.save_profile:
77+
self._profile_files[0] = self.file_path(
78+
".csv", extra="-profile"
79+
)
7180

7281
# per-chain output files
7382
self._csv_files: List[str] = [''] * chains
7483
self._diagnostic_files = [''] * chains # optional
75-
self._profile_files = [''] * chains # optional
7684

7785
if chains == 1:
7886
self._csv_files[0] = self.file_path(".csv")
7987
if args.save_latent_dynamics:
8088
self._diagnostic_files[0] = self.file_path(
8189
".csv", extra="-diagnostic"
8290
)
83-
if args.save_profile:
84-
self._profile_files[0] = self.file_path(
85-
".csv", extra="-profile"
86-
)
8791
else:
8892
for i in range(chains):
8993
self._csv_files[i] = self.file_path(".csv", id=chain_ids[i])
9094
if args.save_latent_dynamics:
9195
self._diagnostic_files[i] = self.file_path(
9296
".csv", extra="-diagnostic", id=chain_ids[i]
9397
)
94-
if args.save_profile:
95-
self._profile_files[i] = self.file_path(
96-
".csv", extra="-profile", id=chain_ids[i]
97-
)
9898

9999
def __repr__(self) -> str:
100100
repr = 'RunSet: chains={}, chain_ids={}, num_processes={}'.format(

cmdstanpy/utils/stancsv.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ def scan_optimize_csv(path: str, save_iters: bool = False) -> Dict[str, Any]:
116116
all_iters[i, :] = [float(x) for x in xs]
117117
if i == iters - 1:
118118
mle: np.ndarray = np.array(xs, dtype=float)
119+
# pylint: disable=possibly-used-before-assignment
119120
dict['mle'] = mle
120121
if save_iters:
121122
dict['all_iters'] = all_iters

test/test_sample.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1748,33 +1748,36 @@ def test_save_latent_dynamics() -> None:
17481748

17491749
def test_save_profile() -> None:
17501750
stan = os.path.join(DATAFILES_PATH, 'profile_likelihood.stan')
1751-
profile_model = CmdStanModel(stan_file=stan)
1751+
profile_model = CmdStanModel(
1752+
stan_file=stan, cpp_options={"STAN_THREADS": '1'}, force_compile=True
1753+
)
1754+
17521755
profile_fit = profile_model.sample(
17531756
chains=2,
17541757
parallel_chains=2,
1758+
force_one_process_per_chain=True,
17551759
seed=12345,
17561760
iter_warmup=100,
17571761
iter_sampling=200,
17581762
save_profile=True,
17591763
)
1760-
for i in range(profile_fit.runset.chains):
1761-
profile_file = profile_fit.runset.profile_files[i]
1764+
assert len(profile_fit.runset.profile_files) == 2
1765+
for profile_file in profile_fit.runset.profile_files:
17621766
assert os.path.exists(profile_file)
17631767

17641768
profile_fit = profile_model.sample(
17651769
chains=2,
17661770
parallel_chains=2,
1771+
force_one_process_per_chain=False,
17671772
seed=12345,
1773+
iter_warmup=100,
17681774
iter_sampling=200,
1769-
save_latent_dynamics=True,
17701775
save_profile=True,
17711776
)
17721777

1773-
for i in range(profile_fit.runset.chains):
1774-
profile_file = profile_fit.runset.profile_files[i]
1778+
assert len(profile_fit.runset.profile_files) == 1
1779+
for profile_file in profile_fit.runset.profile_files:
17751780
assert os.path.exists(profile_file)
1776-
diagnostics_file = profile_fit.runset.diagnostic_files[i]
1777-
assert os.path.exists(diagnostics_file)
17781781

17791782

17801783
def test_xarray_draws() -> None:

0 commit comments

Comments
 (0)