Skip to content

[FEATURE] Create an add_batch_dim method #1153

Open
@sash-a

Description

@sash-a

All over Mava we have statements like tree.map(lambda x: x[jnp.newaxis], pytree) I think it would make it a lot clearer if we had an add_batch_dim method in utils.jax_utils

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions