Skip to content

Conversation

@hjh0119
Copy link
Collaborator

@hjh0119 hjh0119 commented Dec 31, 2025

No description provided.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @hjh0119, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the Megatron GRPO training capabilities by introducing comprehensive support for entropy-based token masking and logging. The changes enable researchers and developers to fine-tune the training process by focusing on high-uncertainty tokens, potentially leading to more robust and efficient policy optimization. This is achieved through new configuration options, a refactored utility module for distributed computations, and integrated logic within the GRPO trainer to leverage these new features.

Highlights

  • Entropy Masking and Logging Support: Introduced new configuration arguments log_entropy and top_entropy_quantile to enable entropy-based masking and logging within the Megatron GRPO framework, allowing for more granular control over token contributions to loss.
  • Refactored Vocabulary-Parallel Utilities: Extracted common vocabulary-parallel functions like log_softmax, KL_divergence, and entropy computation into a new dedicated module vocab_parallel_utils.py for better modularity and reusability across trainers.
  • Enhanced GRPO Trainer Logic: The GRPOTrainer's forward_step and loss_func have been significantly updated to integrate the new entropy computation. This includes conditionally computing per-token entropy and applying an entropy mask to the loss, which can selectively ignore low-entropy (high-confidence) tokens during gradient calculation.
  • Entropy Statistics Logging: Added functionality to log detailed entropy statistics (mean, max, min, and threshold) during training, providing valuable insights into the model's uncertainty and the effect of the entropy mask.
  • Documentation Update: Updated the GRPO.md documentation to reflect that entropy-related configurations are now supported features, removing them from the 'Not Supported Yet' section.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for entropy masking and logging in the Megatron GRPO trainer, which is a valuable feature for controlling exploration during reinforcement learning. The changes include adding new arguments (log_entropy, top_entropy_quantile), implementing the core logic for entropy calculation and masking in grpo_trainer.py, and logging relevant metrics.

A significant part of this PR is the refactoring of vocabulary-parallel utility functions into a new swift/megatron/trainers/vocab_parallel_utils.py file. This is a great improvement for code organization and reusability.

I've identified a performance issue in the new vocab_parallel_utils.py where log_softmax is computed redundantly. My review includes suggestions to refactor the entropy calculation to avoid this, which should improve performance when entropy computation is enabled. Overall, this is a solid contribution with the suggested improvements.

Comment on lines 57 to 120
def vocab_parallel_entropy(logits: torch.Tensor, chunk_size: int = 512) -> torch.Tensor:
"""Compute entropy across vocab-parallel sharded logits.

When using Tensor Parallelism, vocab is sharded across TP ranks.
This function correctly computes entropy by:
1. Computing log_softmax with global normalization (all_reduce for max and sum_exp)
2. Computing entropy = -sum(exp(log_p) * log_p) with all_reduce

Entropy is computed in chunks to reduce memory usage.

Args:
logits: Logits tensor [..., partition_vocab_size] (sharded across TP)
chunk_size: Number of tokens to process per chunk (default: 512)

Returns:
Entropy tensor [...] (scalar per position)
"""
tp_group = mpu.get_tensor_model_parallel_group()
tp_size = mpu.get_tensor_model_parallel_world_size()

# Flatten all but the last dimension for chunked processing
original_shape = logits.shape[:-1]
vocab_size = logits.shape[-1]
logits_flat = logits.view(-1, vocab_size) # [total_tokens, partition_vocab_size]
total_tokens = logits_flat.shape[0]

entropies_list = []
for start_idx in range(0, total_tokens, chunk_size):
end_idx = min(start_idx + chunk_size, total_tokens)
logits_chunk = logits_flat[start_idx:end_idx] # [chunk_size, partition_vocab_size]

if tp_size > 1:
# Step 1: Find global max for numerical stability
logits_max = logits_chunk.max(dim=-1, keepdim=True)[0]
torch.distributed.all_reduce(logits_max, op=torch.distributed.ReduceOp.MAX, group=tp_group)

# Step 2: Compute exp(logits - max) and sum across all TP ranks
exp_logits = torch.exp(logits_chunk - logits_max)
sum_exp = exp_logits.sum(dim=-1, keepdim=True)
torch.distributed.all_reduce(sum_exp, op=torch.distributed.ReduceOp.SUM, group=tp_group)

# Step 3: Compute log_softmax = logits - max - log(sum_exp)
log_probs = logits_chunk - logits_max - torch.log(sum_exp)

# Step 4: Compute partial entropy on this rank's vocab partition
# entropy = -sum(p * log_p) = -sum(exp(log_p) * log_p)
probs = torch.exp(log_probs)
partial_entropy = -(probs * log_probs).sum(dim=-1) # [chunk_size]

# Step 5: All-reduce to get global entropy
torch.distributed.all_reduce(partial_entropy, op=torch.distributed.ReduceOp.SUM, group=tp_group)
else:
# Non-TP case: standard entropy computation
log_probs = torch.nn.functional.log_softmax(logits_chunk, dim=-1)
probs = torch.exp(log_probs)
partial_entropy = -(probs * log_probs).sum(dim=-1)

entropies_list.append(partial_entropy)

# Concatenate all chunks and reshape back
entropies = torch.cat(entropies_list, dim=0)
entropies = entropies.view(original_shape)

return entropies
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There's a performance inefficiency in how entropy is calculated. The vocab_parallel_entropy function re-computes log_softmax from logits, even though it's already computed in compute_logps_and_entropy_from_logits.

To avoid this redundant computation, I suggest refactoring vocab_parallel_entropy to accept pre-computed log_probs instead of logits. This will make the function more efficient and reusable.

def vocab_parallel_entropy(log_probs: torch.Tensor, chunk_size: int = 512) -> torch.Tensor:
    """Compute entropy from pre-computed vocab-parallel sharded log probabilities.

    When using Tensor Parallelism, vocab is sharded across TP ranks.
    This function correctly computes entropy by:
    1. Computing partial entropy = -sum(exp(log_p) * log_p) on each rank's partition
    2. All-reducing the partial entropies to get the global sum.

    Entropy is computed in chunks to reduce memory usage.

    Args:
        log_probs: Pre-computed log probabilities tensor [..., partition_vocab_size]
        chunk_size: Number of tokens to process per chunk (default: 512)

    Returns:
        Entropy tensor [...] (scalar per position)
    """
    tp_group = mpu.get_tensor_model_parallel_group()
    tp_size = mpu.get_tensor_model_parallel_world_size()

    # Flatten all but the last dimension for chunked processing
    original_shape = log_probs.shape[:-1]
    vocab_size = log_probs.shape[-1]
    log_probs_flat = log_probs.view(-1, vocab_size)  # [total_tokens, partition_vocab_size]
    total_tokens = log_probs_flat.shape[0]

    entropies_list = []
    for start_idx in range(0, total_tokens, chunk_size):
        end_idx = min(start_idx + chunk_size, total_tokens)
        log_probs_chunk = log_probs_flat[start_idx:end_idx]  # [chunk_size, partition_vocab_size]

        # Compute partial entropy on this rank's vocab partition
        # entropy = -sum(p * log_p) = -sum(exp(log_p) * log_p)
        probs = torch.exp(log_probs_chunk)
        partial_entropy = -(probs * log_probs_chunk).sum(dim=-1)  # [chunk_size]

        # All-reduce to get global entropy if using TP
        if tp_size > 1:
            torch.distributed.all_reduce(partial_entropy, op=torch.distributed.ReduceOp.SUM, group=tp_group)

        entropies_list.append(partial_entropy)

    # Concatenate all chunks and reshape back
    entropies = torch.cat(entropies_list, dim=0)
    entropies = entropies.view(original_shape)

    return entropies

Comment on lines 245 to 246
if compute_entropy:
per_token_entropy = vocab_parallel_entropy(logits, chunk_size=entropy_chunk_size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

To leverage the refactoring of vocab_parallel_entropy and avoid redundant log_softmax computation, please pass the pre-computed log_probs to vocab_parallel_entropy instead of logits.

    if compute_entropy:
        per_token_entropy = vocab_parallel_entropy(log_probs, chunk_size=entropy_chunk_size)

@hjh0119
Copy link
Collaborator Author

hjh0119 commented Dec 31, 2025

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for entropy masking and logging in Megatron GRPO, a feature from the paper 'Beyond the 80/20 Rule'. The implementation includes new arguments log_entropy and top_entropy_quantile, and updates the training loop in grpo_trainer.py to compute entropy, apply masks to the loss, and log relevant metrics. A significant and positive change is the refactoring of vocabulary-parallel utility functions into a new file, swift/megatron/trainers/vocab_parallel_utils.py. This improves code organization and reusability. The implementation appears correct and robust, especially in handling distributed computations and edge cases like NaN values. I have one suggestion for a minor optimization in the new utility file.

Comment on lines 122 to 129
tp_group = mpu.get_tensor_model_parallel_group()

# Compute partial KL on this rank's vocab partition
target_probs = torch.exp(target_log_probs)
partial_kl = (target_probs * (target_log_probs - input_log_probs)).sum(dim=-1)

# All-reduce to get global KL
torch.distributed.all_reduce(partial_kl, op=torch.distributed.ReduceOp.SUM, group=tp_group)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For consistency with other functions in this file and for a minor performance optimization, I suggest refactoring this part to only perform the all_reduce operation when tensor parallelism is active (tp_size > 1). This also avoids fetching the tensor parallel group unnecessarily when tp_size is 1.

Suggested change
tp_group = mpu.get_tensor_model_parallel_group()
# Compute partial KL on this rank's vocab partition
target_probs = torch.exp(target_log_probs)
partial_kl = (target_probs * (target_log_probs - input_log_probs)).sum(dim=-1)
# All-reduce to get global KL
torch.distributed.all_reduce(partial_kl, op=torch.distributed.ReduceOp.SUM, group=tp_group)
# Compute partial KL on this rank's vocab partition
target_probs = torch.exp(target_log_probs)
partial_kl = (target_probs * (target_log_probs - input_log_probs)).sum(dim=-1)
# All-reduce to get global KL if using TP
if mpu.get_tensor_model_parallel_world_size() > 1:
tp_group = mpu.get_tensor_model_parallel_group()
torch.distributed.all_reduce(partial_kl, op=torch.distributed.ReduceOp.SUM, group=tp_group)

@hjh0119 hjh0119 merged commit 90afa99 into modelscope:main Dec 31, 2025
2 of 3 checks passed
@hjh0119 hjh0119 deleted the mg-entropy branch December 31, 2025 08:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants