-
Notifications
You must be signed in to change notification settings - Fork 1.1k
[feat] support megatron grpo entropy mask & log #7263
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Summary of ChangesHello @hjh0119, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the Megatron GRPO training capabilities by introducing comprehensive support for entropy-based token masking and logging. The changes enable researchers and developers to fine-tune the training process by focusing on high-uncertainty tokens, potentially leading to more robust and efficient policy optimization. This is achieved through new configuration options, a refactored utility module for distributed computations, and integrated logic within the GRPO trainer to leverage these new features. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for entropy masking and logging in the Megatron GRPO trainer, which is a valuable feature for controlling exploration during reinforcement learning. The changes include adding new arguments (log_entropy, top_entropy_quantile), implementing the core logic for entropy calculation and masking in grpo_trainer.py, and logging relevant metrics.
A significant part of this PR is the refactoring of vocabulary-parallel utility functions into a new swift/megatron/trainers/vocab_parallel_utils.py file. This is a great improvement for code organization and reusability.
I've identified a performance issue in the new vocab_parallel_utils.py where log_softmax is computed redundantly. My review includes suggestions to refactor the entropy calculation to avoid this, which should improve performance when entropy computation is enabled. Overall, this is a solid contribution with the suggested improvements.
| def vocab_parallel_entropy(logits: torch.Tensor, chunk_size: int = 512) -> torch.Tensor: | ||
| """Compute entropy across vocab-parallel sharded logits. | ||
|
|
||
| When using Tensor Parallelism, vocab is sharded across TP ranks. | ||
| This function correctly computes entropy by: | ||
| 1. Computing log_softmax with global normalization (all_reduce for max and sum_exp) | ||
| 2. Computing entropy = -sum(exp(log_p) * log_p) with all_reduce | ||
|
|
||
| Entropy is computed in chunks to reduce memory usage. | ||
|
|
||
| Args: | ||
| logits: Logits tensor [..., partition_vocab_size] (sharded across TP) | ||
| chunk_size: Number of tokens to process per chunk (default: 512) | ||
|
|
||
| Returns: | ||
| Entropy tensor [...] (scalar per position) | ||
| """ | ||
| tp_group = mpu.get_tensor_model_parallel_group() | ||
| tp_size = mpu.get_tensor_model_parallel_world_size() | ||
|
|
||
| # Flatten all but the last dimension for chunked processing | ||
| original_shape = logits.shape[:-1] | ||
| vocab_size = logits.shape[-1] | ||
| logits_flat = logits.view(-1, vocab_size) # [total_tokens, partition_vocab_size] | ||
| total_tokens = logits_flat.shape[0] | ||
|
|
||
| entropies_list = [] | ||
| for start_idx in range(0, total_tokens, chunk_size): | ||
| end_idx = min(start_idx + chunk_size, total_tokens) | ||
| logits_chunk = logits_flat[start_idx:end_idx] # [chunk_size, partition_vocab_size] | ||
|
|
||
| if tp_size > 1: | ||
| # Step 1: Find global max for numerical stability | ||
| logits_max = logits_chunk.max(dim=-1, keepdim=True)[0] | ||
| torch.distributed.all_reduce(logits_max, op=torch.distributed.ReduceOp.MAX, group=tp_group) | ||
|
|
||
| # Step 2: Compute exp(logits - max) and sum across all TP ranks | ||
| exp_logits = torch.exp(logits_chunk - logits_max) | ||
| sum_exp = exp_logits.sum(dim=-1, keepdim=True) | ||
| torch.distributed.all_reduce(sum_exp, op=torch.distributed.ReduceOp.SUM, group=tp_group) | ||
|
|
||
| # Step 3: Compute log_softmax = logits - max - log(sum_exp) | ||
| log_probs = logits_chunk - logits_max - torch.log(sum_exp) | ||
|
|
||
| # Step 4: Compute partial entropy on this rank's vocab partition | ||
| # entropy = -sum(p * log_p) = -sum(exp(log_p) * log_p) | ||
| probs = torch.exp(log_probs) | ||
| partial_entropy = -(probs * log_probs).sum(dim=-1) # [chunk_size] | ||
|
|
||
| # Step 5: All-reduce to get global entropy | ||
| torch.distributed.all_reduce(partial_entropy, op=torch.distributed.ReduceOp.SUM, group=tp_group) | ||
| else: | ||
| # Non-TP case: standard entropy computation | ||
| log_probs = torch.nn.functional.log_softmax(logits_chunk, dim=-1) | ||
| probs = torch.exp(log_probs) | ||
| partial_entropy = -(probs * log_probs).sum(dim=-1) | ||
|
|
||
| entropies_list.append(partial_entropy) | ||
|
|
||
| # Concatenate all chunks and reshape back | ||
| entropies = torch.cat(entropies_list, dim=0) | ||
| entropies = entropies.view(original_shape) | ||
|
|
||
| return entropies |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a performance inefficiency in how entropy is calculated. The vocab_parallel_entropy function re-computes log_softmax from logits, even though it's already computed in compute_logps_and_entropy_from_logits.
To avoid this redundant computation, I suggest refactoring vocab_parallel_entropy to accept pre-computed log_probs instead of logits. This will make the function more efficient and reusable.
def vocab_parallel_entropy(log_probs: torch.Tensor, chunk_size: int = 512) -> torch.Tensor:
"""Compute entropy from pre-computed vocab-parallel sharded log probabilities.
When using Tensor Parallelism, vocab is sharded across TP ranks.
This function correctly computes entropy by:
1. Computing partial entropy = -sum(exp(log_p) * log_p) on each rank's partition
2. All-reducing the partial entropies to get the global sum.
Entropy is computed in chunks to reduce memory usage.
Args:
log_probs: Pre-computed log probabilities tensor [..., partition_vocab_size]
chunk_size: Number of tokens to process per chunk (default: 512)
Returns:
Entropy tensor [...] (scalar per position)
"""
tp_group = mpu.get_tensor_model_parallel_group()
tp_size = mpu.get_tensor_model_parallel_world_size()
# Flatten all but the last dimension for chunked processing
original_shape = log_probs.shape[:-1]
vocab_size = log_probs.shape[-1]
log_probs_flat = log_probs.view(-1, vocab_size) # [total_tokens, partition_vocab_size]
total_tokens = log_probs_flat.shape[0]
entropies_list = []
for start_idx in range(0, total_tokens, chunk_size):
end_idx = min(start_idx + chunk_size, total_tokens)
log_probs_chunk = log_probs_flat[start_idx:end_idx] # [chunk_size, partition_vocab_size]
# Compute partial entropy on this rank's vocab partition
# entropy = -sum(p * log_p) = -sum(exp(log_p) * log_p)
probs = torch.exp(log_probs_chunk)
partial_entropy = -(probs * log_probs_chunk).sum(dim=-1) # [chunk_size]
# All-reduce to get global entropy if using TP
if tp_size > 1:
torch.distributed.all_reduce(partial_entropy, op=torch.distributed.ReduceOp.SUM, group=tp_group)
entropies_list.append(partial_entropy)
# Concatenate all chunks and reshape back
entropies = torch.cat(entropies_list, dim=0)
entropies = entropies.view(original_shape)
return entropies| if compute_entropy: | ||
| per_token_entropy = vocab_parallel_entropy(logits, chunk_size=entropy_chunk_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds support for entropy masking and logging in Megatron GRPO, a feature from the paper 'Beyond the 80/20 Rule'. The implementation includes new arguments log_entropy and top_entropy_quantile, and updates the training loop in grpo_trainer.py to compute entropy, apply masks to the loss, and log relevant metrics. A significant and positive change is the refactoring of vocabulary-parallel utility functions into a new file, swift/megatron/trainers/vocab_parallel_utils.py. This improves code organization and reusability. The implementation appears correct and robust, especially in handling distributed computations and edge cases like NaN values. I have one suggestion for a minor optimization in the new utility file.
| tp_group = mpu.get_tensor_model_parallel_group() | ||
|
|
||
| # Compute partial KL on this rank's vocab partition | ||
| target_probs = torch.exp(target_log_probs) | ||
| partial_kl = (target_probs * (target_log_probs - input_log_probs)).sum(dim=-1) | ||
|
|
||
| # All-reduce to get global KL | ||
| torch.distributed.all_reduce(partial_kl, op=torch.distributed.ReduceOp.SUM, group=tp_group) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For consistency with other functions in this file and for a minor performance optimization, I suggest refactoring this part to only perform the all_reduce operation when tensor parallelism is active (tp_size > 1). This also avoids fetching the tensor parallel group unnecessarily when tp_size is 1.
| tp_group = mpu.get_tensor_model_parallel_group() | |
| # Compute partial KL on this rank's vocab partition | |
| target_probs = torch.exp(target_log_probs) | |
| partial_kl = (target_probs * (target_log_probs - input_log_probs)).sum(dim=-1) | |
| # All-reduce to get global KL | |
| torch.distributed.all_reduce(partial_kl, op=torch.distributed.ReduceOp.SUM, group=tp_group) | |
| # Compute partial KL on this rank's vocab partition | |
| target_probs = torch.exp(target_log_probs) | |
| partial_kl = (target_probs * (target_log_probs - input_log_probs)).sum(dim=-1) | |
| # All-reduce to get global KL if using TP | |
| if mpu.get_tensor_model_parallel_world_size() > 1: | |
| tp_group = mpu.get_tensor_model_parallel_group() | |
| torch.distributed.all_reduce(partial_kl, op=torch.distributed.ReduceOp.SUM, group=tp_group) |
No description provided.