@@ -1345,3 +1345,38 @@ def model(x):
13451345 # Check delta distributions are fine if observed.
13461346 guide = AutoDiagonalNormal (lambda : model (3.0 ))
13471347 numpyro .handlers .seed (guide , 9 )()
1348+
1349+
1350+ @pytest .mark .parametrize (
1351+ "guide_cls" ,
1352+ [
1353+ AutoBNAFNormal ,
1354+ AutoDAIS ,
1355+ AutoDelta ,
1356+ AutoDiagonalNormal ,
1357+ AutoLaplaceApproximation ,
1358+ AutoLowRankMultivariateNormal ,
1359+ AutoMultivariateNormal ,
1360+ AutoNormal ,
1361+ ],
1362+ )
1363+ def test_subsample (guide_cls ) -> None :
1364+ def model (n : int , x : jnp .ndarray ):
1365+ mu = numpyro .sample ("mu" , dist .Normal (0 , 1 ))
1366+ sigma = numpyro .sample ("sigma" , dist .HalfNormal (1 ))
1367+ with numpyro .plate ("n" , n , subsample_size = x .size ):
1368+ numpyro .sample ("x" , dist .Normal (mu , sigma ), obs = x )
1369+
1370+ n = 20
1371+ x = 5 + jax .random .normal (jax .random .key (1 ), (20 ,))
1372+ subset = x [: n // 2 ]
1373+
1374+ svi = numpyro .infer .SVI (
1375+ model ,
1376+ guide_cls (model ),
1377+ numpyro .optim .Adam (0.1 ),
1378+ numpyro .infer .Trace_ELBO (),
1379+ n = n ,
1380+ )
1381+ state = svi .init (jax .random .key (2 ), x = x )
1382+ svi .update (state , x = subset )
0 commit comments