Skip to content

Commit ecd6255

Browse files
authored
Bump to 0.6.0 (#959)
* sketch the plan * add a running implementation (not working yet * remove unnecessary changes * sequential update * temp save * add a working implementation * add 24d example * fix lint * test for the order * merge master * fix bug at reset momentum * add various mh functions * add various discrete gibbs function method * change stay_prob to modified to avoid confusing users * expose more information for mixed hmc * sketch an implementation * temp save * temp save * finish the implementation * keep kinetic energy * add temperature experiment * add dual averaging * add various debug statements * fix bugs * clean up and separating out clock adapter; but target distribution is wrong due to a bug somewhere * clean up * add comments and an example * make sure forward mode work * add docs for new HMC fields * add tests for mixedhmc * fix step_size bug * use modified=False * tests pass with the fix * skip print summary * adjust trajectory length * port update_version script from Pyro * pin jax/jaxlib versions * run isort * fix some issues during collection notes * use result_type instead of canonicalize_dtype * fix lint * change get_dtype to jnp.result_type * add print summary * fix compiling issue for mcmc * also try to avoid compiling issue in other samplers * also fix compiling issue in barkermh * convert init params types to strong types * address comments * fix wrong docs * run isort
1 parent af06eda commit ecd6255

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+109
-79
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ Pyro users will note that the API for model specification and inference is large
182182

183183
## Installation
184184

185-
> **Limited Windows Support:** Note that NumPyro is untested on Windows, and might require building jaxlib from source. See this [JAX issue](https://github.com/google/jax/issues/438) for more details. Alternatively, you can install [Windows Subsystem for Linux](https://docs.microsoft.com/en-us/windows/wsl/) and use NumPyro on it as on a Linux system. See also [CUDA on Windows Subsystem for Linux](https://developer.nvidia.com/cuda/wsl) if you want to use GPUs on Windows.
185+
> **Limited Windows Support:** Note that NumPyro is untested on Windows, and might require building jaxlib from source. See this [JAX issue](https://github.com/google/jax/issues/438) for more details. Alternatively, you can install [Windows Subsystem for Linux](https://docs.microsoft.com/en-us/windows/wsl/) and use NumPyro on it as on a Linux system. See also [CUDA on Windows Subsystem for Linux](https://developer.nvidia.com/cuda/wsl) and [this forum post](https://forum.pyro.ai/t/numpyro-with-gpu-works-on-windows/2690) if you want to use GPUs on Windows.
186186
187187
To install NumPyro with a CPU version of JAX, you can use pip:
188188

docs/source/distributions.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -574,6 +574,14 @@ real_vector
574574
-----------
575575
.. autodata:: numpyro.distributions.constraints.real_vector
576576

577+
softplus_positive
578+
-----------------
579+
.. autodata:: numpyro.distributions.constraints.softplus_positive
580+
581+
softplus_lower_cholesky
582+
-----------------------
583+
.. autodata:: numpyro.distributions.constraints.softplus_lower_cholesky
584+
577585
simplex
578586
-------
579587
.. autodata:: numpyro.distributions.constraints.simplex

docs/source/mcmc.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ MCMC Kernels
7070

7171
.. autofunction:: numpyro.infer.hmc.hmc.sample_kernel
7272

73+
.. autofunction:: numpyro.infer.hmc_gibbs.taylor_proxy
74+
7375
.. autodata:: numpyro.infer.barker.BarkerMHState
7476

7577
.. autodata:: numpyro.infer.hmc.HMCState

examples/annotation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ def main(args):
266266

267267

268268
if __name__ == "__main__":
269-
assert numpyro.__version__.startswith("0.5.0")
269+
assert numpyro.__version__.startswith('0.6.0')
270270
parser = argparse.ArgumentParser(description="Bayesian Models of Annotation")
271271
parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
272272
parser.add_argument("--num-warmup", nargs="?", default=1000, type=int)

examples/baseball.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def main(args):
196196

197197

198198
if __name__ == "__main__":
199-
assert numpyro.__version__.startswith('0.5.0')
199+
assert numpyro.__version__.startswith('0.6.0')
200200
parser = argparse.ArgumentParser(description="Baseball batting average using MCMC")
201201
parser.add_argument("-n", "--num-samples", nargs="?", default=3000, type=int)
202202
parser.add_argument("--num-warmup", nargs='?', default=1500, type=int)

examples/bnn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def main(args):
138138

139139

140140
if __name__ == "__main__":
141-
assert numpyro.__version__.startswith('0.5.0')
141+
assert numpyro.__version__.startswith('0.6.0')
142142
parser = argparse.ArgumentParser(description="Bayesian neural network example")
143143
parser.add_argument("-n", "--num-samples", nargs="?", default=2000, type=int)
144144
parser.add_argument("--num-warmup", nargs='?', default=1000, type=int)

examples/covtype.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def main(args):
139139

140140

141141
if __name__ == '__main__':
142-
assert numpyro.__version__.startswith('0.5.0')
142+
assert numpyro.__version__.startswith('0.6.0')
143143
parser = argparse.ArgumentParser(description="parse args")
144144
parser.add_argument('-n', '--num-samples', default=1000, type=int, help='number of samples')
145145
parser.add_argument('--num-warmup', default=1000, type=int, help='number of warmup steps')

examples/funnel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def main(args):
8787

8888

8989
if __name__ == "__main__":
90-
assert numpyro.__version__.startswith('0.5.0')
90+
assert numpyro.__version__.startswith('0.6.0')
9191
parser = argparse.ArgumentParser(description="Non-centered reparameterization example")
9292
parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
9393
parser.add_argument("--num-warmup", nargs='?', default=1000, type=int)

examples/gp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def main(args):
142142

143143

144144
if __name__ == "__main__":
145-
assert numpyro.__version__.startswith('0.5.0')
145+
assert numpyro.__version__.startswith('0.6.0')
146146
parser = argparse.ArgumentParser(description="Gaussian Process example")
147147
parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
148148
parser.add_argument("--num-warmup", nargs='?', default=1000, type=int)

examples/hmm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def main(args):
191191

192192

193193
if __name__ == '__main__':
194-
assert numpyro.__version__.startswith('0.5.0')
194+
assert numpyro.__version__.startswith('0.6.0')
195195
parser = argparse.ArgumentParser(description='Semi-supervised Hidden Markov Model')
196196
parser.add_argument('--num-categories', default=3, type=int)
197197
parser.add_argument('--num-words', default=10, type=int)

0 commit comments

Comments
 (0)