Skip to content

Commit 68828f3

Browse files
danielpalenaraffin
andauthored
Implemented CrossQ (#243)
* Implemented CrossQ * Fixed code style * Clean up, comments and refactored to sbx variable names * 1024 neuron Q function (sbx default) * batch norm parameters as function arguments * clean up. reshape instead of split * Added policy delay * fixed commit-checks * Fix f-string * Update documentation * Rename to torch layers * Fix for policy delay and minor edits * Update tests * Update documentation * Update doc * Add more tests for crossQ * Improve doc and expose batchnorm params * Add some comments and todos and fix type check * Use torch module for BN * Re-organize losses * Add set_bn_training_mode * Simplify network creation with new SB3 version, and fix default momentum * Use different b1 for Adam as in original implementation * Reformat TOML file * Update CI workflow, skip mypy for 3.8 * Update CrossQ doc * Use uv to download packages on github CI * System install for Github CI * Fix for pytorch install * Use +cpu version * Pytorch 2.5.0 doesn't support python 3.8 * Update comments --------- Co-authored-by: Antonin Raffin <[email protected]>
1 parent 3d9a975 commit 68828f3

File tree

20 files changed

+1221
-28
lines changed

20 files changed

+1221
-28
lines changed

.github/workflows/ci.yml

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,21 +30,24 @@ jobs:
3030
- name: Install dependencies
3131
run: |
3232
python -m pip install --upgrade pip
33+
# Use uv for faster downloads
34+
pip install uv
3335
# cpu version of pytorch
34-
pip install torch==2.1.1 --index-url https://download.pytorch.org/whl/cpu
36+
# See https://github.com/astral-sh/uv/issues/1497
37+
uv pip install --system torch==2.4.1+cpu --index https://download.pytorch.org/whl/cpu
3538
3639
# Install Atari Roms
37-
pip install autorom
40+
uv pip install --system autorom
3841
wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64
3942
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
4043
AutoROM --accept-license --source-file Roms.tar.gz
4144
4245
# Install master version
4346
# and dependencies for docs and tests
44-
pip install "stable_baselines3[extra_no_roms,tests,docs] @ git+https://github.com/DLR-RM/stable-baselines3"
45-
pip install .
47+
uv pip install --system "stable_baselines3[extra_no_roms,tests,docs] @ git+https://github.com/DLR-RM/stable-baselines3"
48+
uv pip install --system .
4649
# Use headless version
47-
pip install opencv-python-headless
50+
uv pip install --system opencv-python-headless
4851
4952
- name: Lint with ruff
5053
run: |
@@ -58,6 +61,8 @@ jobs:
5861
- name: Type check
5962
run: |
6063
make type
64+
# Do not run for python 3.8 (mypy internal error)
65+
if: matrix.python-version != '3.8'
6166
- name: Test with pytest
6267
run: |
6368
make pytest

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ See documentation for the full list of included features.
3131
- [PPO with recurrent policy (RecurrentPPO aka PPO LSTM)](https://ppo-details.cleanrl.dev//2021/11/05/ppo-implementation-details/)
3232
- [Truncated Quantile Critics (TQC)](https://arxiv.org/abs/2005.04269)
3333
- [Trust Region Policy Optimization (TRPO)](https://arxiv.org/abs/1502.05477)
34+
- [Batch Normalization in Deep Reinforcement Learning (CrossQ)](https://openreview.net/forum?id=PczQtTsTIX)
3435

3536
**Gym Wrappers**:
3637
- [Time Feature Wrapper](https://arxiv.org/abs/1712.00378)

docs/common/torch_layers.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
.. _th_layers:
2+
3+
Torch Layers
4+
============
5+
6+
.. automodule:: sb3_contrib.common.torch_layers
7+
:members:

docs/guide/algos.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary`` Multi Pr
1010
============ =========== ============ ================= =============== ================
1111
ARS ✔️ ❌️ ❌ ❌ ✔️
1212
MaskablePPO ❌ ✔️ ✔️ ✔️ ✔️
13+
CrossQ ✔️ ❌ ❌ ❌ ✔️
1314
QR-DQN ️❌ ️✔️ ❌ ❌ ✔️
1415
RecurrentPPO ✔️ ✔️ ✔️ ✔️ ✔️
1516
TQC ✔️ ❌ ❌ ❌ ✔️

docs/guide/examples.rst

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,3 +113,26 @@ Train a PPO agent with a recurrent policy on the CartPole environment.
113113
obs, rewards, dones, info = vec_env.step(action)
114114
episode_starts = dones
115115
vec_env.render("human")
116+
117+
CrossQ
118+
------
119+
120+
Train a CrossQ agent on the Pendulum environment.
121+
122+
.. code-block:: python
123+
124+
from sb3_contrib import CrossQ
125+
126+
model = CrossQ(
127+
"MlpPolicy",
128+
"Pendulum-v1",
129+
verbose=1,
130+
policy_kwargs=dict(
131+
net_arch=dict(
132+
pi=[256, 256],
133+
qf=[1024, 1024],
134+
)
135+
),
136+
)
137+
model.learn(total_timesteps=5_000, log_interval=4)
138+
model.save("crossq_pendulum")

docs/images/crossQ_performance.png

260 KB
Loading

docs/index.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ RL Baselines3 Zoo also offers a simple interface to train, evaluate agents and d
3232
:caption: RL Algorithms
3333

3434
modules/ars
35+
modules/crossq
3536
modules/ppo_mask
3637
modules/ppo_recurrent
3738
modules/qrdqn
@@ -42,6 +43,7 @@ RL Baselines3 Zoo also offers a simple interface to train, evaluate agents and d
4243
:maxdepth: 1
4344
:caption: Common
4445

46+
common/torch_layers
4547
common/utils
4648
common/wrappers
4749

docs/misc/changelog.rst

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,19 @@
33
Changelog
44
==========
55

6-
7-
Release 2.4.0a9 (WIP)
6+
Release 2.4.0a10 (WIP)
87
--------------------------
98

9+
**New algorithm: added CrossQ**
10+
1011
Breaking Changes:
1112
^^^^^^^^^^^^^^^^^
1213
- Upgraded to Stable-Baselines3 >= 2.4.0
1314

1415
New Features:
1516
^^^^^^^^^^^^^
17+
- Added ``CrossQ`` algorithm, from "Batch Normalization in Deep Reinforcement Learning" paper (@danielpalen)
18+
- Added ``BatchRenorm`` PyTorch layer used in ``CrossQ`` (@danielpalen)
1619

1720
Bug Fixes:
1821
^^^^^^^^^^
@@ -28,6 +31,7 @@ Others:
2831
^^^^^^^
2932
- Updated PyTorch version on CI to 2.3.1
3033
- Remove unnecessary SDE noise resampling in PPO/TRPO update
34+
- Switched to uv to download packages on GitHub CI
3135

3236
Documentation:
3337
^^^^^^^^^^^^^^
@@ -584,4 +588,4 @@ Contributors:
584588
-------------
585589

586590
@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen @Gregwar @rnederstigt @qgallouedec
587-
@mlodel @CppMaster @burakdmb @honglu2875 @ZikangXiong @AlexPasqua @jonasreiher @icheered @Armandpl @corentinlger
591+
@mlodel @CppMaster @burakdmb @honglu2875 @ZikangXiong @AlexPasqua @jonasreiher @icheered @Armandpl @danielpalen @corentinlger

docs/modules/crossq.rst

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
.. _crossq:
2+
3+
.. automodule:: sb3_contrib.crossq
4+
5+
6+
CrossQ
7+
======
8+
9+
Implementation of CrossQ proposed in:
10+
11+
`Bhatt A.* & Palenicek D.* et al. Batch Normalization in Deep Reinforcement Learning for Greater Sample Efficiency and Simplicity. ICLR 2024.`
12+
13+
CrossQ is an algorithm that uses batch normalization to improve the sample efficiency of off-policy deep reinforcement learning algorithms.
14+
It is based on the idea of carefully introducing batch normalization layers in the critic network and dropping target networks.
15+
This results in a simpler and more sample-efficient algorithm without requiring high update-to-data ratios.
16+
17+
.. rubric:: Available Policies
18+
19+
.. autosummary::
20+
:nosignatures:
21+
22+
MlpPolicy
23+
24+
.. note::
25+
26+
Compared to the original implementation, the default network architecture for the q-value function is ``[1024, 1024]``
27+
instead of ``[2048, 2048]`` as it provides a good compromise between speed and performance.
28+
29+
.. note::
30+
31+
There is currently no ``CnnPolicy`` for using CrossQ with images. We welcome help from contributors to add this feature.
32+
33+
34+
Notes
35+
-----
36+
37+
- Original paper: https://openreview.net/pdf?id=PczQtTsTIX
38+
- Original Implementation: https://github.com/adityab/CrossQ
39+
- SBX (SB3 Jax) Implementation: https://github.com/araffin/sbx
40+
41+
42+
Can I use?
43+
----------
44+
45+
- Recurrent policies: ❌
46+
- Multi processing: ✔️
47+
- Gym spaces:
48+
49+
50+
============= ====== ===========
51+
Space Action Observation
52+
============= ====== ===========
53+
Discrete ❌ ✔️
54+
Box ✔️ ✔️
55+
MultiDiscrete ❌ ✔️
56+
MultiBinary ❌ ✔️
57+
Dict ❌ ❌
58+
============= ====== ===========
59+
60+
61+
Example
62+
-------
63+
64+
.. code-block:: python
65+
66+
from sb3_contrib import CrossQ
67+
68+
model = CrossQ("MlpPolicy", "Walker2d-v4")
69+
model.learn(total_timesteps=1_000_000)
70+
model.save("crossq_walker")
71+
72+
73+
Results
74+
-------
75+
76+
Performance evaluation of CrossQ on six MuJoCo environments, see `PR #243 <https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/pull/243>`_.
77+
Compared to results from the original paper as well as a version from `SBX <https://github.com/araffin/sbx>`_.
78+
79+
.. image:: ../images/crossQ_performance.png
80+
81+
82+
Open RL benchmark report: https://wandb.ai/openrlbenchmark/sb3-contrib/reports/SB3-Contrib-CrossQ--Vmlldzo4NTE2MTEx
83+
84+
85+
How to replicate the results?
86+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
87+
88+
Clone RL-Zoo:
89+
90+
.. code-block:: bash
91+
92+
git clone https://github.com/DLR-RM/rl-baselines3-zoo
93+
cd rl-baselines3-zoo/
94+
95+
Run the benchmark (replace ``$ENV_ID`` by the envs mentioned above):
96+
97+
.. code-block:: bash
98+
99+
python train.py --algo crossq --env $ENV_ID --n-eval-envs 5 --eval-episodes 20 --eval-freq 25000
100+
101+
102+
Plot the results:
103+
104+
.. code-block:: bash
105+
106+
python scripts/all_plots.py -a crossq -e HalfCheetah Ant Hopper Walker2D -f logs/ -o logs/crossq_results
107+
python scripts/plot_from_file.py -i logs/crossq_results.pkl -latex -l CrossQ
108+
109+
110+
Comments
111+
--------
112+
113+
This implementation is based on SB3 SAC implementation.
114+
115+
116+
Parameters
117+
----------
118+
119+
.. autoclass:: CrossQ
120+
:members:
121+
:inherited-members:
122+
123+
.. _crossq_policies:
124+
125+
CrossQ Policies
126+
---------------
127+
128+
.. autoclass:: MlpPolicy
129+
:members:
130+
:inherited-members:
131+
132+
.. autoclass:: sb3_contrib.crossq.policies.CrossQPolicy
133+
:members:
134+
:noindex:

docs/modules/ppo_recurrent.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,6 @@ Clone the repo for the experiment:
109109
110110
git clone https://github.com/DLR-RM/rl-baselines3-zoo
111111
cd rl-baselines3-zoo
112-
git checkout feat/recurrent-ppo
113112
114113
115114
Run the benchmark (replace ``$ENV_ID`` by the envs mentioned above):

0 commit comments

Comments
 (0)