.. automodule:: sb3_contrib.crossq
Implementation of CrossQ proposed in:
Bhatt A.* & Palenicek D.* et al. Batch Normalization in Deep Reinforcement Learning for Greater Sample Efficiency and Simplicity. ICLR 2024.
CrossQ is an algorithm that uses batch normalization to improve the sample efficiency of off-policy deep reinforcement learning algorithms. It is based on the idea of carefully introducing batch normalization layers in the critic network and dropping target networks. This results in a simpler and more sample-efficient algorithm without requiring high update-to-data ratios.
Available Policies
.. autosummary:: :nosignatures: MlpPolicy
Note
Compared to the original implementation, the default network architecture for the q-value function is [1024, 1024]
instead of [2048, 2048]
as it provides a good compromise between speed and performance.
Note
There is currently no CnnPolicy
for using CrossQ with images. We welcome help from contributors to add this feature.
- Original paper: https://openreview.net/pdf?id=PczQtTsTIX
- Original Implementation: https://github.com/adityab/CrossQ
- SBX (SB3 Jax) Implementation: https://github.com/araffin/sbx
- Recurrent policies: ❌
- Multi processing: ✔️
- Gym spaces:
Space | Action | Observation |
---|---|---|
Discrete | ❌ | ✔️ |
Box | ✔️ | ✔️ |
MultiDiscrete | ❌ | ✔️ |
MultiBinary | ❌ | ✔️ |
Dict | ❌ | ❌ |
from sb3_contrib import CrossQ
model = CrossQ("MlpPolicy", "Walker2d-v4")
model.learn(total_timesteps=1_000_000)
model.save("crossq_walker")
Performance evaluation of CrossQ on six MuJoCo environments, see PR #243. Compared to results from the original paper as well as a version from SBX.
Open RL benchmark report: https://wandb.ai/openrlbenchmark/sb3-contrib/reports/SB3-Contrib-CrossQ--Vmlldzo4NTE2MTEx
Clone RL-Zoo:
git clone https://github.com/DLR-RM/rl-baselines3-zoo
cd rl-baselines3-zoo/
Run the benchmark (replace $ENV_ID
by the envs mentioned above):
python train.py --algo crossq --env $ENV_ID --n-eval-envs 5 --eval-episodes 20 --eval-freq 25000
Plot the results:
python scripts/all_plots.py -a crossq -e HalfCheetah Ant Hopper Walker2D -f logs/ -o logs/crossq_results
python scripts/plot_from_file.py -i logs/crossq_results.pkl -latex -l CrossQ
This implementation is based on SB3 SAC implementation.
.. autoclass:: CrossQ :members: :inherited-members:
.. autoclass:: MlpPolicy :members: :inherited-members:
.. autoclass:: sb3_contrib.crossq.policies.CrossQPolicy :members: :noindex: