Skip to content

fix: be explicit about squeeze dim in prioritised sampling to avoid flattening (1,1) arrays#27

Merged
EdanToledo merged 2 commits intomainfrom
fix/prioritised-sampling-with-batch-of-1
Jul 4, 2024
Merged

fix: be explicit about squeeze dim in prioritised sampling to avoid flattening (1,1) arrays#27
EdanToledo merged 2 commits intomainfrom
fix/prioritised-sampling-with-batch-of-1

Conversation

@callumtilbury
Copy link
Contributor

Currently, the following snippet will fail:

from flashbax import make_prioritised_flat_buffer
import jax
import jax.numpy as jnp

buffer = make_prioritised_flat_buffer(
    max_length=100,
    min_length=1,
    sample_batch_size=1,  # NB
    add_sequences=False,
)

timestep = {"obs": jnp.zeros(shape=(3)),}

state = buffer.init(timestep)

for i in range(5):
    timestep = {
        "obs": jnp.ones(shape=(3)) * i,
    }
    state = buffer.add(state, timestep)

buffer.sample(state, jax.random.PRNGKey(0))

because of this line:

state, None, query_values.squeeze()

If the sample_batch_size is 1, query_values is shape (1,1), which squeezes to ().

Instead we must be explicit about the squeeze dim.

@EdanToledo EdanToledo merged commit 3c74aa8 into main Jul 4, 2024
@EdanToledo EdanToledo deleted the fix/prioritised-sampling-with-batch-of-1 branch July 4, 2024 15:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants