Skip to content

[Feature Request] Implement GRPO #301

@floribau

Description

@floribau

🚀 Feature

Group Relative Policy Optimization (GRPO) is a reinforcement learning algorithm introduced in https://arxiv.org/pdf/2402.03300, which has gained a lot of attention following its use in fine-tuning DeepSeek-R1. GRPO is an extension of PPO with a similar clipped objective. The main difference is that GRPO eliminates the need for training a value model, reducing both memory and computational burden. Instead, for each training iteration, GRPO samples a group of trajectories sharing the same initial state, and then uses the group average reward as the baseline, replacing the value estimate baseline.

DeepSeekMath introduced two variants of GRPO:

  • Outcome Supervision, which assumes a single scalar reward to the final output of a trajectory. The advantage of each trajectory is the trajectory reward normalized by the mean trajectory reward.
    Note: While this variant can be directly applied to classic Gym environments, it doesn't make use of the per-step rewards returned after each action, thus leading to sparse rewards and slower learning.
  • Process Supervision, which uses per-step advantages. Each per-step reward is first normalized by the mean per-step reward across the whole group. Then, for each step, the advantage is calculated as the return-to-go of these normalized per-step rewards.
    Note: The per-step normalization poses the following issue: In environments which assign a fixed reward for each step (e.g., CartPole's +1 for each step), the mean reward across all steps equals each single per-step reward. This leads to all normalized rewards being 0 and thus, also all advantages being 0 as well.

To overcome the issue of DeepSeek's process supervision, at https://www.kaggle.com/discussions/general/573162 another process supervision variant has been proposed. This variant first computes the return-to-go for each time step, using the per-step rewards. Then, for each step, the return-to-go is normalized by the average return-to-go across the whole group to get the per-step advantage.

Probably not all of the variants can and should be added to SB3-contrib. Which of the variants to contribute is up for discussion for me.

Motivation

I implemented the GRPO algorithm out of curiosity, since it wasn't already existing in SB3. While originally developed for LLM fine-tuning, I hypothesize that the algorithm isn't limited to this type of task and should work on classical reinforcement learning tasks as well. So, I wanted to apply the different variants of GRPO on Gym environments and compare the performance to other RL baselines.

I know that there already is a similar open issue #273, but the algorithm suggested there doesn't strictly implement GRPO, but rather a hybrid approach closer to a multi-sample PPO approach. Therefore, I implemented a proof of concept version myself and opened this new issue.

Pitch

A first implementation is in my forked repository on the 'contrib-grpo' branch: https://github.com/floribau/stable-baselines3-contrib-grpo/tree/contrib-grpo. The repo contains the GRPO class in https://github.com/floribau/stable-baselines3-contrib-grpo/blob/contrib-grpo/sb3_contrib/grpo/grpo.py, with the three different variants described above being realized by initializing the GRPO class with different GroupBuffer classes (https://github.com/floribau/stable-baselines3-contrib-grpo/blob/contrib-grpo/sb3_contrib/common/buffers.py).

However, this should be understood only as a proof of concept and is by no means ready for a PR (no tests, not enough evaluation on benchmarks). Before further working on it to make it ready for a PR, I want to discuss the general algorithm and its suitability for SB3 here. Since my code structure isn't optimal either, I also want to discuss suggested changes to my structure. After the discussion, I'll go ahead and implement a version following the SB3 standards more closely.

Alternatives

To my knowledge, SB3 doesn't include any multi-sample algorithm, constructing the baseline from a group average. I'm happy to discuss alternatives in implementation details.

Additional context

As mentioned above, I know of the existence of issue #273, and want to address some of the questions raised there:

  • My proposed algorithm follows DeepSeek's implementation more strictly than the hybrid approach and only adds another possible variant on top of implementing their variants.
  • I acknowledge that GRPO was originally developed for LLM fine-tuning. My implementation elegantly uses seeds to enable sampling of multiple trajectories, thus working for all Gym envs. Since the group trajectories should all start from the same initial state, I simply reset the environments to a seed fixed for a group rollout collection. Thus, there's no need to cumbersomely deepcopy the env states.
  • Since GRPO was originally developed for LLM fine-tuning, there aren't any performance baselines on classical RL tasks in the paper. However, I ran the different variants in small experiments on CartPole and compared them to PPO. I plan to compare them on more complex benchmarks and larger experiments. An initial small comparison (with only 4 runs on CartPole) showed promising results, with the alternative process supervision variant achieving performance similar to PPO, thus motivating further work.
Image

Checklist

  • I have checked that there is no similar issue in the repo
  • If I'm requesting a new feature, I have proposed alternatives

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions