Quentin Delfosse, Paul Seitz, Sebastian Wette, Daniel Kirn, Dominik Mandok, Lars Teubner Machine Learning Lab – TU Darmstadt
A GPU-accelerated, object-centric Atari environment suite built with JAX for fast, scalable reinforcement learning research.
JAXAtari introduces a GPU-accelerated, object-centric Atari environment framework powered by JAX. Inspired by OCAtari, this framework enables up to 16,000x faster training speeds through just-in-time (JIT) compilation, vectorization, and massive parallelization on GPU.
- Object-centric extraction of Atari game states with structured observations
- JAX-based vectorized execution with full GPU support and JIT compilation
- Comprehensive wrapper system for different observation types (pixel, object-centric, combined)
python3 -m venv .venv
source .venv/bin/activate
python3 -m pip install -U pip
pip3 install -e .
Note: This will install JAX without GPU acceleration.
CUDA Users should run the following to add GPU support:
pip install -U "jax[cuda12]"
For other accelerator types, please follow the instructions here.
The main entry point is the make()
function:
import jax
import jaxatari
# Create an environment
env = jaxatari.make("pong") # or "seaquest", "kangaroo", "freeway", etc.
# Get available games
available_games = jaxatari.list_available_games()
print(f"Available games: {available_games}")
JAXAtari provides a comprehensive wrapper system for different use cases:
import jax
import jaxatari
from jaxatari.wrappers import (
AtariWrapper,
ObjectCentricWrapper,
PixelObsWrapper,
PixelAndObjectCentricWrapper,
FlattenObservationWrapper,
LogWrapper
)
# Create base environment
base_env = jaxatari.make("pong")
# Apply wrappers for different observation types
env = AtariWrapper(base_env, frame_stack_size=4, frame_skip=4)
env = ObjectCentricWrapper(env) # Returns flattened object features
# OR
env = PixelObsWrapper(AtariWrapper(base_env)) # Returns pixel observations
# OR
env = PixelAndObjectCentricWrapper(AtariWrapper(base_env)) # Returns both
# OR
env = FlattenObservationWrapper(ObjectCentricWrapper(AtariWrapper(base_env))) # Returns flattened observations
# Add logging wrapper for training
env = LogWrapper(env)
import jax
import jaxatari
from jaxatari.wrappers import AtariWrapper, ObjectCentricWrapper
# Create environment with wrappers
base_env = jaxatari.make("pong")
env = FlattenObservationWrapper(ObjectCentricWrapper(AtariWrapper(base_env)))
rng = jax.random.PRNGKey(0)
# Vectorized reset and step functions
vmap_reset = lambda n_envs: lambda rng: jax.vmap(env.reset)(
jax.random.split(rng, n_envs)
)
vmap_step = lambda n_envs: lambda env_state, action: jax.vmap(
env.step
)(env_state, action)
# Initialize 128 parallel environments
init_obs, env_state = vmap_reset(128)(rng)
action = jax.random.randint(rng, (128,), 0, env.action_space().n)
# Take one step
new_obs, new_env_state, reward, done, info = vmap_step(128)(env_state, action)
# Take 100 steps with scan
def step_fn(carry, unused):
_, env_state = carry
new_obs, new_env_state, reward, done, info = vmap_step(128)(env_state, action)
return (new_obs, new_env_state), (reward, done, info)
carry = (init_obs, env_state)
_, (rewards, dones, infos) = jax.lax.scan(
step_fn, carry, None, length=100
)
Run a game manually with human input (e.g. on Pong):
pip install pygame
python3 scripts/play.py -g Pong
Game | Supported |
---|---|
Freeway | ✅ |
Kangaroo | ✅ |
Pong | ✅ |
Seaquest | ✅ |
More games can be added via the uniform wrapper system.
JAXAtari provides several wrappers to customize environment behavior:
AtariWrapper
: Base wrapper with frame stacking, frame skipping, and sticky actionsObjectCentricWrapper
: Returns flattened object-centric features (2D array:[frame_stack, features]
)PixelObsWrapper
: Returns pixel observations (4D array:[frame_stack, height, width, channels]
)PixelAndObjectCentricWrapper
: Returns both pixel and object-centric observationsFlattenObservationWrapper
: Flattens any observation structure to a single 1D arrayLogWrapper
: Tracks episode returns and lengths for trainingMultiRewardLogWrapper
: Tracks multiple reward components separately
# For pure RL with object-centric features (recommended)
env = ObjectCentricWrapper(AtariWrapper(jaxatari.make("pong")))
# For computer vision approaches
env = PixelObsWrapper(AtariWrapper(jaxatari.make("pong")))
# For multi-modal approaches
env = PixelAndObjectCentricWrapper(AtariWrapper(jaxatari.make("pong")))
# For training with logging
env = LogWrapper(ObjectCentricWrapper(AtariWrapper(jaxatari.make("pong"))))
# All wrapper combinations can be flattened using the FlattenObservationWrapper
Contributions are welcome!
- Fork this repository
- Create your feature branch:
git checkout -b feature/my-feature
- Commit your changes:
git commit -m 'Add some feature'
- Push to the branch:
git push origin feature/my-feature
- Open a pull request
This project is licensed under the MIT License.
See the LICENSE file for details.