Skip to content

Commit 7084aaa

Browse files
authored
Bump to 0.9.0 (#1310)
* Add loose strategy for MCMC * merge svi and mcmc plate warning strategies * fix failing tests * validate model accross ELBOs * update vae example * fix typos * Bump to version 0.9.0 * Fix failing tests * Fix warnings in tests/examples * relax funsor requirement * Move optax_to_numpyro to optim * skip prodlda test on CI
1 parent 5d9e033 commit 7084aaa

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+178
-241
lines changed

.github/workflows/ci.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@ jobs:
2626
run: |
2727
sudo apt install -y pandoc gsfonts
2828
python -m pip install --upgrade pip
29-
pip install https://github.com/pyro-ppl/funsor/archive/master.zip
3029
pip install jaxlib
3130
pip install jax
3231
pip install .[doc,test]
32+
pip install https://github.com/pyro-ppl/funsor/archive/master.zip
3333
pip install -r docs/requirements.txt
3434
pip freeze
3535
- name: Lint with flake8
@@ -64,10 +64,10 @@ jobs:
6464
python -m pip install --upgrade pip
6565
# Keep track of pyro-api master branch
6666
pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip
67-
pip install https://github.com/pyro-ppl/funsor/archive/master.zip
6867
pip install jaxlib
6968
pip install jax
7069
pip install .[dev,test]
70+
pip install https://github.com/pyro-ppl/funsor/archive/master.zip
7171
pip freeze
7272
- name: Test with pytest
7373
run: |
@@ -93,10 +93,10 @@ jobs:
9393
python -m pip install --upgrade pip
9494
# Keep track of pyro-api master branch
9595
pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip
96-
pip install https://github.com/pyro-ppl/funsor/archive/master.zip
9796
pip install jaxlib
9897
pip install jax
9998
pip install .[dev,test]
99+
pip install https://github.com/pyro-ppl/funsor/archive/master.zip
100100
pip freeze
101101
- name: Test with pytest
102102
run: |
@@ -129,10 +129,10 @@ jobs:
129129
- name: Install dependencies
130130
run: |
131131
python -m pip install --upgrade pip
132-
pip install https://github.com/pyro-ppl/funsor/archive/master.zip
133132
pip install jaxlib
134133
pip install jax
135134
pip install .[dev,examples,test]
135+
pip install https://github.com/pyro-ppl/funsor/archive/master.zip
136136
pip freeze
137137
- name: Test with pytest
138138
run: |

docs/source/optimizers.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,4 +68,4 @@ SM3
6868

6969
Optax support
7070
-------------
71-
.. autofunction:: numpyro.contrib.optim.optax_to_numpyro
71+
.. autofunction:: numpyro.optim.optax_to_numpyro

examples/annotation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,7 @@ def main(args):
351351

352352

353353
if __name__ == "__main__":
354-
assert numpyro.__version__.startswith("0.8.0")
354+
assert numpyro.__version__.startswith("0.9.0")
355355
parser = argparse.ArgumentParser(description="Bayesian Models of Annotation")
356356
parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
357357
parser.add_argument("--num-warmup", nargs="?", default=1000, type=int)

examples/ar2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def main(args):
116116

117117

118118
if __name__ == "__main__":
119-
assert numpyro.__version__.startswith("0.8.0")
119+
assert numpyro.__version__.startswith("0.9.0")
120120
parser = argparse.ArgumentParser(description="AR2 example")
121121
parser.add_argument("--num-data", nargs="?", default=142, type=int)
122122
parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)

examples/baseball.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def main(args):
210210

211211

212212
if __name__ == "__main__":
213-
assert numpyro.__version__.startswith("0.8.0")
213+
assert numpyro.__version__.startswith("0.9.0")
214214
parser = argparse.ArgumentParser(description="Baseball batting average using MCMC")
215215
parser.add_argument("-n", "--num-samples", nargs="?", default=3000, type=int)
216216
parser.add_argument("--num-warmup", nargs="?", default=1500, type=int)

examples/bnn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def main(args):
160160

161161

162162
if __name__ == "__main__":
163-
assert numpyro.__version__.startswith("0.8.0")
163+
assert numpyro.__version__.startswith("0.9.0")
164164
parser = argparse.ArgumentParser(description="Bayesian neural network example")
165165
parser.add_argument("-n", "--num-samples", nargs="?", default=2000, type=int)
166166
parser.add_argument("--num-warmup", nargs="?", default=1000, type=int)

examples/capture_recapture.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,11 @@ def transition_fn(carry, y):
7272
with handlers.mask(mask=first_capture_mask):
7373
mu_z_t = first_capture_mask * phi * z + (1 - first_capture_mask)
7474
# NumPyro exactly sums out the discrete states z_t.
75-
z = numpyro.sample("z", dist.Bernoulli(dist.util.clamp_probs(mu_z_t)))
75+
z = numpyro.sample(
76+
"z",
77+
dist.Bernoulli(dist.util.clamp_probs(mu_z_t)),
78+
infer={"enumerate": "parallel"},
79+
)
7680
mu_y_t = rho * z
7781
numpyro.sample(
7882
"y", dist.Bernoulli(dist.util.clamp_probs(mu_y_t)), obs=y
@@ -112,7 +116,11 @@ def transition_fn(carry, y):
112116
with handlers.mask(mask=first_capture_mask):
113117
mu_z_t = first_capture_mask * phi_t * z + (1 - first_capture_mask)
114118
# NumPyro exactly sums out the discrete states z_t.
115-
z = numpyro.sample("z", dist.Bernoulli(dist.util.clamp_probs(mu_z_t)))
119+
z = numpyro.sample(
120+
"z",
121+
dist.Bernoulli(dist.util.clamp_probs(mu_z_t)),
122+
infer={"enumerate": "parallel"},
123+
)
116124
mu_y_t = rho * z
117125
numpyro.sample(
118126
"y", dist.Bernoulli(dist.util.clamp_probs(mu_y_t)), obs=y
@@ -160,7 +168,11 @@ def transition_fn(carry, y):
160168
with handlers.mask(mask=first_capture_mask):
161169
mu_z_t = first_capture_mask * phi_t * z + (1 - first_capture_mask)
162170
# NumPyro exactly sums out the discrete states z_t.
163-
z = numpyro.sample("z", dist.Bernoulli(dist.util.clamp_probs(mu_z_t)))
171+
z = numpyro.sample(
172+
"z",
173+
dist.Bernoulli(dist.util.clamp_probs(mu_z_t)),
174+
infer={"enumerate": "parallel"},
175+
)
164176
mu_y_t = rho * z
165177
numpyro.sample(
166178
"y", dist.Bernoulli(dist.util.clamp_probs(mu_y_t)), obs=y
@@ -202,7 +214,11 @@ def transition_fn(carry, y):
202214
with handlers.mask(mask=first_capture_mask):
203215
mu_z_t = first_capture_mask * phi * z + (1 - first_capture_mask)
204216
# NumPyro exactly sums out the discrete states z_t.
205-
z = numpyro.sample("z", dist.Bernoulli(dist.util.clamp_probs(mu_z_t)))
217+
z = numpyro.sample(
218+
"z",
219+
dist.Bernoulli(dist.util.clamp_probs(mu_z_t)),
220+
infer={"enumerate": "parallel"},
221+
)
206222
mu_y_t = rho * z
207223
numpyro.sample(
208224
"y", dist.Bernoulli(dist.util.clamp_probs(mu_y_t)), obs=y
@@ -249,7 +265,11 @@ def transition_fn(carry, y):
249265
with handlers.mask(mask=first_capture_mask):
250266
mu_z_t = first_capture_mask * phi_t * z + (1 - first_capture_mask)
251267
# NumPyro exactly sums out the discrete states z_t.
252-
z = numpyro.sample("z", dist.Bernoulli(dist.util.clamp_probs(mu_z_t)))
268+
z = numpyro.sample(
269+
"z",
270+
dist.Bernoulli(dist.util.clamp_probs(mu_z_t)),
271+
infer={"enumerate": "parallel"},
272+
)
253273
mu_y_t = rho * z
254274
numpyro.sample(
255275
"y", dist.Bernoulli(dist.util.clamp_probs(mu_y_t)), obs=y

examples/covtype.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def main(args):
206206

207207

208208
if __name__ == "__main__":
209-
assert numpyro.__version__.startswith("0.8.0")
209+
assert numpyro.__version__.startswith("0.9.0")
210210
parser = argparse.ArgumentParser(description="parse args")
211211
parser.add_argument(
212212
"-n", "--num-samples", default=1000, type=int, help="number of samples"

examples/funnel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def main(args):
139139

140140

141141
if __name__ == "__main__":
142-
assert numpyro.__version__.startswith("0.8.0")
142+
assert numpyro.__version__.startswith("0.9.0")
143143
parser = argparse.ArgumentParser(
144144
description="Non-centered reparameterization example"
145145
)

examples/gaussian_shells.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def run_inference(args, data):
8181
num_warmup=args.num_warmup,
8282
num_samples=args.num_samples,
8383
)
84-
mcmc.run(random.PRNGKey(2), **data)
84+
mcmc.run(random.PRNGKey(2), **data, enum=args.enum)
8585
mcmc.print_summary()
8686
mcmc_samples = mcmc.get_samples()
8787

@@ -123,7 +123,7 @@ def main(args):
123123

124124

125125
if __name__ == "__main__":
126-
assert numpyro.__version__.startswith("0.8.0")
126+
assert numpyro.__version__.startswith("0.9.0")
127127
parser = argparse.ArgumentParser(description="Nested sampler for Gaussian shells")
128128
parser.add_argument("-n", "--num-samples", nargs="?", default=10000, type=int)
129129
parser.add_argument("--num-warmup", nargs="?", default=1000, type=int)

0 commit comments

Comments
 (0)