Skip to content

[feat] add grad clipping for dense params#424

Open
tiankongdeguiji wants to merge 2 commits intoalibaba:masterfrom
tiankongdeguiji:features/grad_clip
Open

[feat] add grad clipping for dense params#424
tiankongdeguiji wants to merge 2 commits intoalibaba:masterfrom
tiankongdeguiji:features/grad_clip

Conversation

@tiankongdeguiji
Copy link
Collaborator

No description provided.

# Apply gradient clipping for dense parameters if configured
if train_config.HasField("grad_clipping"):
gc_config = train_config.grad_clipping
clipping_type = GradientClipping[gc_config.clipping_type.upper()]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing Input Validation: The clipping_type string is accessed directly via enum lookup without validation. If an invalid value is provided (e.g., "invalid_type"), this will raise a KeyError with an unhelpful error message.

The codebase has an established pattern for handling similar cases in tzrec/features/feature.py:566-568:

pooling_type = self.config.pooling.upper()
assert pooling_type in {"SUM", "MEAN"}, "available pooling type is SUM | MEAN"
return getattr(PoolingType, pooling_type)

Consider adding validation before the enum lookup:

valid_clipping_types = {"NORM", "VALUE", "NONE"}
clipping_type_str = gc_config.clipping_type.upper()
if clipping_type_str not in valid_clipping_types:
    raise ValueError(
        f"Invalid clipping_type '{gc_config.clipping_type}'. "
        f"Valid values are: {', '.join(t.lower() for t in valid_clipping_types)}"
    )
clipping_type = GradientClipping[clipping_type_str]

combined_optimizer = GradientClippingOptimizer(
optimizer=combined_optimizer,
clipping=clipping_type,
max_gradient=gc_config.max_gradient,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing Validation for max_gradient: There's no validation that max_gradient is positive. Setting max_gradient to 0 or a negative value would result in nonsensical behavior:

  • With clipping_type: "value", a max_gradient of 0 would zero all gradients
  • With clipping_type: "norm", a max_gradient of 0 or negative would cause unexpected behavior

Consider adding validation:

if gc_config.max_gradient <= 0:
    raise ValueError(
        f"max_gradient must be positive, got {gc_config.max_gradient}"
    )

optimizer=combined_optimizer,
clipping=clipping_type,
max_gradient=gc_config.max_gradient,
norm_type=gc_config.norm_type,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing Validation for norm_type: The norm_type parameter should be validated. According to the documentation, valid values are 2.0 (L2 norm) and inf (infinity norm for max norm). However, the code passes any float value to GradientClippingOptimizer without validation, which could lead to unexpected behavior or hard-to-debug errors if an unsupported norm_type is provided.

Consider adding validation:

# Validate norm_type for norm clipping
if clipping_type == GradientClipping.NORM:
    valid_norm_types = {1.0, 2.0, float('inf')}
    if gc_config.norm_type not in valid_norm_types:
        raise ValueError(
            f"norm_type must be one of {valid_norm_types}, got {gc_config.norm_type}"
        )

# limitations under the License.

__version__ = "1.0.14"
__version__ = "1.0.15"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Version Bump in Feature PR: Version bump from 1.0.14 to 1.0.15 is typically done during a release process, not in a feature PR. This could cause conflicts if multiple PRs are merged before a release, and it conflates feature development with release management.

Consider removing the version change from this PR and handling it during the actual release process.

}
}
num_epochs: 1
grad_clipping {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test Config Missing enable_global_grad_clip Test Coverage: The test configuration doesn't include enable_global_grad_clip, which has a default value of false. While this tests the basic functionality, the test suite should also verify that enable_global_grad_clip: true works correctly in distributed training scenarios, especially since this is a new feature related to distributed training.

Consider adding a test case that explicitly tests the global gradient clipping feature, or document why it's not included.

@github-actions
Copy link

Code Review Summary

This PR adds gradient clipping support for dense parameters. I have identified the following issues:

Critical Issues

  1. Missing Input Validation (main.py:721): The clipping_type string is used directly for enum lookup without validation. An invalid value will raise a cryptic KeyError instead of a helpful error message.

  2. Missing max_gradient Validation (main.py:726): No validation that max_gradient is positive. Setting it to 0 or negative would cause unexpected behavior.

  3. Missing norm_type Validation (main.py:727): The norm_type parameter is passed to the optimizer without validation. Invalid values could cause hard-to-debug errors.

Minor Issues

  1. Version Bump in Feature PR (version.py:12): Version bumps should typically be handled during the release process, not in feature PRs, to avoid merge conflicts.

  2. Incomplete Test Coverage: The test config does not test enable_global_grad_clip: true, which is a key feature for distributed training scenarios.

Recommendations

  • Add proper input validation with clear error messages for all configuration parameters
  • Consider adding unit tests specifically for the gradient clipping configuration validation
  • Remove the version bump from this PR and handle it during release

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant