You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This document proposes an asynchronous reinforcement learning (Async-RL) framework based on verl. The Async-RL framework aims to improve training efficiency by decoupling various tasks in the RL pipeline(separate actor-train/actor-forward-logp/ref_logp/rollout-generate) and enabling asynchronous parameter synchronization. This approach allows for bubble-free, non-blocking training, significantly enhancing the overall performance and scalability of RL training.
Motivation
Traditional synchronous RL training suffers from low efficiency due to significant idle time and resource underutilization.
The current verl platform lacks a unified Async-RL implementation, which motivates the design of an efficient Async-RL version.
The complexity of modifying the trainer's logic directly led to the decision to decompose tasks and decouple modules, separating logical tasks from resources to facilitate future expansion and optimization.
Design Overview
1. RL State Machine
The Async-RL framework is designed around a state machine that manages the entire pipeline workflow. This state machine approach allows for flexible scheduling strategies and ensures that each task operates independently, reducing the risk of errors and improving both performance and precision.
Setting the state machine's base class switching mode: Each logical task is both a producer and a consumer. The core concept is that each task focuses solely on its own logical inputs and outputs, with specific processing tasks placed on the current physical node via Ray.
When registering a new state machine, you only need to inherit the base state_machine class and implement the following three interfaces:
@abstractmethodasyncdefprocess_data(self, data: Any) ->Any:
"""Abstract method to process data, subclass must implement"""pass@abstractmethodasyncdefget_input_data(self) ->Optional[Any]:
"""Abstract method to get input data, subclass must implement"""pass@abstractmethodasyncdefsend_output_data(self, data: Any) ->bool:
"""Abstract method to send output data, subclass must implement"""pass
According to the RL tasks required by GRPO, they can be divided into the following:
dataloader: Receive the epochs of the entire training task for prefetch iteration, and set the queue size to overlap the time taken by the dataloader itself.
generate: Receives the param_update signal and the prompt data of the dataloader to generate and output to rollout
rollout: Logical task node, receives the return results of dataloader and generate, summarizes and processes them, and outputs them to logp/ref_logp/reward/train
reward: logical task node (here is rule-base-reward)
logp: The actor's forward task. After decoupling, it can be optimized for splitting between different MPs.
ref_logp: The ref-model's forward task. After decoupling, it can be optimized for splitting between different MPs.
actor-train: The main training task, which receives all rollout data (logp/ref_logp/reward) and triggers the param-update asynchronous process after training.
param-update: It is divided into gather/send/recv/register-buffer/load processes. The send process and recv process can be started after the actor-train end signal is started. After the register-buffer is completed, the load method of the inference engine will be called back, and the end signal will be given to generate
critic-train: currently based on GRPO, it has not been added yet. In the PPO scenario, the critic and actor are trained independently, and the critic part and the actor can be implemented symmetrically.
2. Async-param-update
design
To achieve true asynchrony, we shifted from synchronous parameter updates to asynchronous implementations. This involved gradually breaking down the parameter process. The original implementation used nccl-based parameter synchronization, but this approach was not asynchronous due to nccl's thread-safety and GPU preemption. Therefore, the process was broken down into five parts: gather/send/recv/register-buffer/load. To fully reuse verl's existing implementation logic, the gather implementation was reused. This part used nccl for parameter aggregation, which could only be performed serially. The subsequent asynchronous send/recv communication used the CPU, ensuring no impact on GPU computing power. This enabled asynchronous generation, param_update, and train operations.
Overlap effect of asynchronous parameter updates
param-update is a logical task that can be broken down into five parts: gather, send, recv, register-buffer, and load.
After actor-train, the background execution of parameter updates is triggered.
Gather: To reuse the existing usage of Verl, gather is maintained as a per_tensor_generator implementation to aggregate the parameters of different MPs (this should be optimized to sharding to reduce transmission overhead). Gather involves torch-group synchronous communication operations between different MPs (TP/PP/EP), which is limited by the thread safety issues of nccl, so bubble must be introduced here. To reduce the per-tensor overhead, per-tensor-bucket implementation is added to reduce overhead.
Send: Send all params-buckets after gather. Here, we use ray's ref-obj implementation (ray.gloo encountered cross-machine support issues and has not been followed up yet)
Recv: Receives all data sent by send and triggers the callback function of the logical update, namely the update_buffer_data_only function
register-buffer: The recv phase triggers a callback function, which serializes the current params object for parameter loading of the inference engine, saves it to dual_buffer, and updates the version status of the current buffer (which can be understood as the number of model updates).
load: After the communication is completed, callback-func will be automatically triggered, that is, dual_buffer will be read through sglang. The overhead of serializing objects here has been overlapped in the buffer registration phase. Of course, the load process here also increases the optimization of bucket granularity.
RL workload links are long. By modularizing the state machine and breaking down parameter synchronization, we can design a fully asynchronous async-RL link, which can achieve bubble-free training nodes (significantly reducing the problem of low MFU in RL training).
For example, appropriately set the current task's usage ratio for cluster resources: +trainer.use_nodes_ratios=[0.5,0.5,0.5,0.5]
This allows for flexible workload allocation, ensuring that the current e2e_cost_time approximates the training time of a single actor, achieving bubble-free performance.
Long RL chains have many corresponding bottlenecks. If any task is blocked, for example, the long-tail problem causes the generate task to be pending, leaving all other GPUs idle. By decoupling the training and rollout tasks in a fully asynchronous manner and allowing a certain off-policy, the rollout and training tasks can overlap, achieving optimal performance.
For example, rollout-generate may encounter some particularly short tasks that are completed quickly, but train is not yet completed. In this case, rollout-generate can proceed to the next round or even the next n rounds of generate tasks. When encountering a very long generate task, it seems to block the train task, but in fact, train can continue to consume the previous round or the previous n rounds of generate tasks, thereby ensuring the continuous operation of the train task. Therefore, in terms of the final effect, async-rl can transform the complex RL training process into a pure RL training process, reduce the long-tail effect to a certain length, ensure the training MFU, and provide a near-linear acceleration ratio.
Frame comparsion
RL is currently developing rapidly, and while there are various implementations of async-rl, most currently use nccl for synchronous parameter updates.
frame
good
bad
verl-hybrid-engine
1. Minimal resources required for a set of experiments. 2. Smaller overhead at smaller scales.
There is a large bubble: training offload, inference KVCache and other offloads
recipe/one-step-off
Asynchronous implementation of one-step-off-policy based on nccl
1. Synchronous parameter updates have additional performance overhead. 2. Training offload overhead still exists. 3. Asynchronous implementations are coupled together, making them prone to errors. 4. The open-source main branch cannot run as doc (chain establishment failed).
slime
Based on sglang+megatron, separate rollout and training
Synchronous parameter updates will have additional performance overhead
areal
Support interruptible generate mode
Synchronous parameter updates will have additional performance overhead
verl-async-rl
1. State machine-based approach for high scalability. 2. Asynchronous parameter synchronization without impacting GPU performance. 3. Bubble-free training for optimal performance.
1. with a certain degree of off-policy; 2. requires cluster resources greater than one MP, otherwise the ratio cannot be optimized.
Scalability: Increasing batch size can achieve up to 100% performance improvement
cost_time(s)
verl
async-rl
speedup
16
500
270
80%
32
260
170
50%
32+async-ref-logp
260
140
85%
32+tune-tp-config
260
115
125%
Usage Example
Just add the following parameters:
# Async RL Configuration
+actor_rollout_ref.async_pipeline=True \
# Resource Management
+trainer.use_nodes_ratios=[0.5,0.5,0.5,0.5] \
# means: train/logp/ref_logp use 0.5 ngpus, generate use 0.5 ngpus
# Performance Tuning, enable async-param-update
+actor_rollout_ref.rollout.enable_dual_buffer=True \
# The sender granularity of the actor training node during parameter update
+actor_rollout_ref.rollout.param_update_preduce_bucket_size_mb=512 \
# The receiver granularity of the rollout inference node is too large, which will cause GPU-OOM
+actor_rollout_ref.rollout.param_update_consume_bucket_size_mb=128 \
# The granularity of offpolicy, 2 means that generate is faster than the train node to execute 2 steps, that is, one-step-offpolicy
+trainer.generate_ahead_steps=2 \
Notes: Asynchronous application scenarios
Necessary conditions: The cluster has sufficient resources, more than the number of slots required for a training group (the minimum number of cards required for training under the hybrid engine), and a certain degree of off-policy tolerance for accuracy.
Room for optimization: If the scalability of the hybrid engine is found to be below expectations, for example, the bs per engine is too small and the long-tail effect significantly blocks the entire training process, then adding cards will not improve the training speed.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Abstract
This document proposes an asynchronous reinforcement learning (Async-RL) framework based on verl. The Async-RL framework aims to improve training efficiency by decoupling various tasks in the RL pipeline(separate actor-train/actor-forward-logp/ref_logp/rollout-generate) and enabling asynchronous parameter synchronization. This approach allows for bubble-free, non-blocking training, significantly enhancing the overall performance and scalability of RL training.
Motivation
Design Overview
1. RL State Machine
The Async-RL framework is designed around a state machine that manages the entire pipeline workflow. This state machine approach allows for flexible scheduling strategies and ensures that each task operates independently, reducing the risk of errors and improving both performance and precision.
Setting the state machine's base class switching mode: Each logical task is both a producer and a consumer. The core concept is that each task focuses solely on its own logical inputs and outputs, with specific processing tasks placed on the current physical node via Ray.
When registering a new state machine, you only need to inherit the base state_machine class and implement the following three interfaces:
According to the RL tasks required by GRPO, they can be divided into the following:

2. Async-param-update
design
To achieve true asynchrony, we shifted from synchronous parameter updates to asynchronous implementations. This involved gradually breaking down the parameter process. The original implementation used nccl-based parameter synchronization, but this approach was not asynchronous due to nccl's thread-safety and GPU preemption. Therefore, the process was broken down into five parts: gather/send/recv/register-buffer/load. To fully reuse verl's existing implementation logic, the gather implementation was reused. This part used nccl for parameter aggregation, which could only be performed serially. The subsequent asynchronous send/recv communication used the CPU, ensuring no impact on GPU computing power. This enabled asynchronous generation, param_update, and train operations.
Overlap effect of asynchronous parameter updates
total_cost_time = max(train, generate, ref_logp) = max(train + gather, generate+load, ref_logp)
3. training-bubble-free
RL workload links are long. By modularizing the state machine and breaking down parameter synchronization, we can design a fully asynchronous async-RL link, which can achieve bubble-free training nodes (significantly reducing the problem of low MFU in RL training).
The various offloads and param-updates in the middle are all large bubbles, and the GPU will have a lot of idle time.
For example, appropriately set the current task's usage ratio for cluster resources:
+trainer.use_nodes_ratios=[0.5,0.5,0.5,0.5]This allows for flexible workload allocation, ensuring that the current e2e_cost_time approximates the training time of a single actor, achieving bubble-free performance.
Long RL chains have many corresponding bottlenecks. If any task is blocked, for example, the long-tail problem causes the generate task to be pending, leaving all other GPUs idle. By decoupling the training and rollout tasks in a fully asynchronous manner and allowing a certain off-policy, the rollout and training tasks can overlap, achieving optimal performance.
For example, rollout-generate may encounter some particularly short tasks that are completed quickly, but train is not yet completed. In this case, rollout-generate can proceed to the next round or even the next n rounds of generate tasks. When encountering a very long generate task, it seems to block the train task, but in fact, train can continue to consume the previous round or the previous n rounds of generate tasks, thereby ensuring the continuous operation of the train task. Therefore, in terms of the final effect, async-rl can transform the complex RL training process into a pure RL training process, reduce the long-tail effect to a certain length, ensure the training MFU, and provide a near-linear acceleration ratio.
Frame comparsion
RL is currently developing rapidly, and while there are various implementations of async-rl, most currently use nccl for synchronous parameter updates.
Performance
ref: dots.rl pr async-rl
Benchmark Configuration
Performance Improvements
Usage Example
Just add the following parameters:
Beta Was this translation helpful? Give feedback.
All reactions