-
Notifications
You must be signed in to change notification settings - Fork 230
Description
🚀 Feature
Group Relative Policy Optimization (GRPO) is a reinforcement learning algorithm introduced in https://arxiv.org/pdf/2402.03300, which has gained a lot of attention following its use in fine-tuning DeepSeek-R1. GRPO is an extension of PPO with a similar clipped objective. The main difference is that GRPO eliminates the need for training a value model, reducing both memory and computational burden. Instead, for each training iteration, GRPO samples a group of trajectories sharing the same initial state, and then uses the group average reward as the baseline, replacing the value estimate baseline.
DeepSeekMath introduced two variants of GRPO:
- Outcome Supervision, which assumes a single scalar reward to the final output of a trajectory. The advantage of each trajectory is the trajectory reward normalized by the mean trajectory reward.
Note: While this variant can be directly applied to classic Gym environments, it doesn't make use of the per-step rewards returned after each action, thus leading to sparse rewards and slower learning. - Process Supervision, which uses per-step advantages. Each per-step reward is first normalized by the mean per-step reward across the whole group. Then, for each step, the advantage is calculated as the return-to-go of these normalized per-step rewards.
Note: The per-step normalization poses the following issue: In environments which assign a fixed reward for each step (e.g., CartPole's +1 for each step), the mean reward across all steps equals each single per-step reward. This leads to all normalized rewards being 0 and thus, also all advantages being 0 as well.
To overcome the issue of DeepSeek's process supervision, at https://www.kaggle.com/discussions/general/573162 another process supervision variant has been proposed. This variant first computes the return-to-go for each time step, using the per-step rewards. Then, for each step, the return-to-go is normalized by the average return-to-go across the whole group to get the per-step advantage.
Probably not all of the variants can and should be added to SB3-contrib. Which of the variants to contribute is up for discussion for me.
Motivation
I implemented the GRPO algorithm out of curiosity, since it wasn't already existing in SB3. While originally developed for LLM fine-tuning, I hypothesize that the algorithm isn't limited to this type of task and should work on classical reinforcement learning tasks as well. So, I wanted to apply the different variants of GRPO on Gym environments and compare the performance to other RL baselines.
I know that there already is a similar open issue #273, but the algorithm suggested there doesn't strictly implement GRPO, but rather a hybrid approach closer to a multi-sample PPO approach. Therefore, I implemented a proof of concept version myself and opened this new issue.
Pitch
A first implementation is in my forked repository on the 'contrib-grpo' branch: https://github.com/floribau/stable-baselines3-contrib-grpo/tree/contrib-grpo. The repo contains the GRPO class in https://github.com/floribau/stable-baselines3-contrib-grpo/blob/contrib-grpo/sb3_contrib/grpo/grpo.py, with the three different variants described above being realized by initializing the GRPO class with different GroupBuffer classes (https://github.com/floribau/stable-baselines3-contrib-grpo/blob/contrib-grpo/sb3_contrib/common/buffers.py).
However, this should be understood only as a proof of concept and is by no means ready for a PR (no tests, not enough evaluation on benchmarks). Before further working on it to make it ready for a PR, I want to discuss the general algorithm and its suitability for SB3 here. Since my code structure isn't optimal either, I also want to discuss suggested changes to my structure. After the discussion, I'll go ahead and implement a version following the SB3 standards more closely.
Alternatives
To my knowledge, SB3 doesn't include any multi-sample algorithm, constructing the baseline from a group average. I'm happy to discuss alternatives in implementation details.
Additional context
As mentioned above, I know of the existence of issue #273, and want to address some of the questions raised there:
- My proposed algorithm follows DeepSeek's implementation more strictly than the hybrid approach and only adds another possible variant on top of implementing their variants.
- I acknowledge that GRPO was originally developed for LLM fine-tuning. My implementation elegantly uses seeds to enable sampling of multiple trajectories, thus working for all Gym envs. Since the group trajectories should all start from the same initial state, I simply reset the environments to a seed fixed for a group rollout collection. Thus, there's no need to cumbersomely deepcopy the env states.
- Since GRPO was originally developed for LLM fine-tuning, there aren't any performance baselines on classical RL tasks in the paper. However, I ran the different variants in small experiments on CartPole and compared them to PPO. I plan to compare them on more complex benchmarks and larger experiments. An initial small comparison (with only 4 runs on CartPole) showed promising results, with the alternative process supervision variant achieving performance similar to PPO, thus motivating further work.
Checklist
- I have checked that there is no similar issue in the repo
- If I'm requesting a new feature, I have proposed alternatives