Skip to content

Flax BNN is several times slower in JAX 0.4.33 compared to JAX 0.4.31 #1867

@ziatdinovmax

Description

@ziatdinovmax

Jax-0.4.31: Runtime: 27.06 seconds
https://colab.research.google.com/drive/1EsFY1St8Y2ZNBZ9UXTa9FDWrjPDdTU4U?usp=sharing

Jax-0.4.33: Runtime: 84.91 seconds
https://colab.research.google.com/drive/1g7GkuK4-GloO6cywvDUf5BVU9qO2jf1W?usp=sharing

I’m not sure if this issue is specific to flax_random_module or a broader problem, but I’ve primarily been using NumPyro for HMC BNNs, and the difference in speed with the latest JAX release is quite dramatic

Code:

import time
import numpy as np
import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import NUTS, MCMC
from numpyro.contrib.module import random_flax_module
import flax.linen as nn


# Set a random seed for reproducibility
rng_key = jax.random.PRNGKey(0)

# Generate some dummy data
def generate_data(n=100, noise_std=0.1):
    X = jnp.linspace(-1, 1, n)
    y = 3 * X + 2 + np.random.normal(0, noise_std, size=X.shape)
    return X[:, None], y

# Define a simple neural network
class SimpleNN(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(10)(x)
        x = nn.relu(x)
        x = nn.Dense(1)(x)
        return x.squeeze()

# Define the model
def model(X, y):
    module = SimpleNN()
    nn = random_flax_module("nn", module, input_shape=(1, X.shape[-1]), prior=dist.Normal(0, 1))

    with numpyro.plate("data", X.shape[0]):
        mean = nn(X)
        numpyro.sample("obs", dist.Normal(mean, 0.1), obs=y)

# Generate data
X, y = generate_data()

# Initialize the NUTS sampler
nuts_kernel = NUTS(model)

# Run inference
num_warmup, num_samples = 500, 1000

start_time = time.time()

mcmc = MCMC(nuts_kernel, num_warmup=num_warmup, num_samples=num_samples)
mcmc.run(rng_key, X, y)

end_time = time.time()

# Print runtime
print(f"Runtime: {end_time - start_time:.2f} seconds")

# Print summary statistics
print(mcmc.print_summary())

Metadata

Metadata

Assignees

No one assigned

    Labels

    jaxThis issue is specific to JAXperformance

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions