Skip to content

Feat: implement GradMonitor handler for gradient norm tracking and batch skipping#3624

Open
leelakrishnaRajasimha wants to merge 8 commits intopytorch:masterfrom
leelakrishnaRajasimha:final-clean-handler-3563
Open

Feat: implement GradMonitor handler for gradient norm tracking and batch skipping#3624
leelakrishnaRajasimha wants to merge 8 commits intopytorch:masterfrom
leelakrishnaRajasimha:final-clean-handler-3563

Conversation

@leelakrishnaRajasimha
Copy link
Contributor

Closes #3563

Overview

This PR introduces the GradMonitor handler, designed to track L2 gradient norms in real-time to detect training instability, such as "entropy collapse" or exploding gradients in Transformer models.

Key Features

  • Real-time Monitoring: Hooks into Events.ITERATION_COMPLETED to calculate and log L2 gradient norms.

  • AMP & DDP Support: Fully compatible with Automatic Mixed Precision (unscaling via GradScaler) and Distributed Data Parallel (synchronization via idist.all_reduce).

  • Reactive Training Logic: Sets an engine.state.unhealthy_spike flag, enabling users to implement automated batch skipping or custom callbacks when thresholds are breached.

  • Flexible Thresholding: Supports both static L2 norm limits and dynamic spike detection using a moving average window.

Performance & Benchmarking

The handler has been benchmarked on a model with 1 million parameters to ensure it does not significantly impact training throughput:

  • Local Hardware (HP Victus / RTX 4050): For 100,000 (1 lakh) iterations, the total execution time typically falls within a 9s to 16s range, depending on background system overhead.

  • High-Performance Environment (Google Colab GPU): In optimized cloud environments with dedicated high-performance GPUs, the same benchmark has achieved results as fast as 5.4s.

  • Efficiency: This demonstrates that the per-iteration overhead remains negligible across various hardware tiers, making it highly suitable for large-scale production models.

Testing Suite

  • Verified mathematical accuracy of L2 norm calculations.

  • Validated engine.state flag toggling and callback execution.

  • Confirmed correct unscaling behavior when a GradScaler is present.

  • ncluded input validation for all initialization parameters.

@github-actions github-actions bot added the module: handlers Core Handlers module label Mar 3, 2026
@aaishwarymishra
Copy link
Collaborator

aaishwarymishra commented Mar 3, 2026

Hmm, I read your pr there are few suggestions:

  1. Instead of using a brittle user defined threshold use dynamic threshold as default.
  2. Instead of using brittle mean*0.5 as threshold we can use mean + (k*std) where k can be defined by user and std is the standard deviation.
  3. Tracking history and window is kinda expensive I think using single variable for iteration_count, mean and variance will be much more efficient.
  4. Attaching the handler on event ITERATION.COMPLETED is wrong as on iteration.completed the weights are already updated it should be ITERATION.STARTED
  5. Also you are updating the list before the calculations for mean, it should be after.

Also by looking at code the handler doesn't actually skip the batch user has to write custom code in train step?

like :

if not engine.state.unhealthy_spike:
    optimizer.zero_grad()
    return None

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a new Ignite handler to monitor L2 gradient norms at iteration boundaries and expose an engine.state.unhealthy_spike flag for reactive training logic (e.g., batch skipping) to help detect instability during training.

Changes:

  • Introduces GradMonitor handler with optional static thresholding and moving-average-based dynamic spike detection.
  • Exposes the new handler via ignite.handlers public API.
  • Adds unit tests for basic norm math, flag toggling, AMP unscale behavior, and input validation.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.

File Description
ignite/handlers/grad_monitor.py New handler implementation: norm computation, optional AMP/DDP handling, thresholding, and state flagging.
ignite/handlers/__init__.py Re-exports GradMonitor in the public handlers namespace.
tests/ignite/handlers/test_grad_monitor.py Adds tests for the new handler and related behaviors.
Comments suppressed due to low confidence (7)

ignite/handlers/grad_monitor.py:75

  • window_size is used as deque(maxlen=window_size) but isn’t validated. Passing a non-int or a negative value will raise an unhelpful deque error, and window_size=0 makes dynamic mode impossible to trigger. Add explicit validation (e.g., positive integer) and consider validating use_dynamic is bool for consistency with other handlers’ arg checks.
        threshold: float = 100.0,
        window_size: int = 10,
        use_dynamic: bool = False,
        callback: Optional[Callable[[Engine, float], None]] = None,
    ) -> None:
        if not isinstance(model, Module):
            raise TypeError(f"Argument model should be a torch.nn.Module, but given {type(model)}.")

        if not (isinstance(threshold, (int, float)) and math.isfinite(threshold) and threshold > 0):
            raise ValueError(f"Argument threshold should be a positive finite number, but given {threshold}.")

        if callback is not None and not callable(callback):
            raise TypeError(f"Argument callback should be callable, but given {type(callback)}.")

        self.model = model
        self.threshold = float(threshold)
        self.use_dynamic = use_dynamic
        self.callback = callback
        
        # Standard Ignite logger setup.
        self.logger = setup_logger(__name__ + "." + self.__class__.__name__)
        self._history = deque(maxlen=1000)
        self._window = deque(maxlen=window_size) 

ignite/handlers/grad_monitor.py:22

  • The docstring says .. versionadded:: 0.5.3, but the package version in ignite/__init__.py is 0.6.0. Please update the versionadded tag to the correct upcoming release version for this new handler to avoid misleading docs.

    .. versionadded:: 0.5.3

ignite/handlers/grad_monitor.py:10

  • There are a few unused imports here (import torch and engine from ignite.engine). Please remove them to keep the module clean and avoid lint failures.
import torch
from torch.nn import Module

from ignite import distributed as idist
from ignite.engine import Engine, Events, engine
from ignite.utils import setup_logger

ignite/handlers/grad_monitor.py:116

  • Dynamic thresholding (use_dynamic / moving average window) is new behavior but isn’t covered by the tests in test_grad_monitor.py. Please add a unit test that exercises use_dynamic=True (including the warm-up phase before the window is full) and verifies the computed effective_threshold/flag behavior.
        # Logic for dynamic thresholding.
        effective_threshold = self.threshold
        if self.use_dynamic and len(self._window) == self._window.maxlen:
            # Calculates spike relative to the moving average of the last 'n' steps.
            avg_norm = sum(self._window) / len(self._window)
            effective_threshold = avg_norm * 5.0  # Spike is defined as 5x the average.
        if grad_norm > effective_threshold:

ignite/handlers/init.py:2

  • Callable and Optional are imported but never used in this module. Please remove unused imports to avoid lint warnings and keep ignite.handlers exports clean.
from typing import Any, Callable, Optional

ignite/handlers/grad_monitor.py:87

  • Avoid using .data on tensors (p.grad.detach().data.norm(2)): it bypasses autograd safety checks and is discouraged in PyTorch. Use p.grad.detach().norm(2) (or equivalent) instead.
                # Uses .data.norm to get raw scalar value quickly.
                param_norm = p.grad.detach().data.norm(2)

tests/ignite/handlers/test_grad_monitor.py:87

  • This test module contains a benchmark function and an if __name__ == "__main__" block that runs pytest/prints benchmark output. Test files in this suite don’t appear to include ad-hoc benchmarks or main-guards; please remove these sections or move the benchmark to a dedicated benchmarking script (so CI remains deterministic and quiet).
def benchmark_overhead():
    model = torch.nn.Sequential(torch.nn.Linear(1000, 1000)) 
    monitor = GradMonitor(model)
    for p in model.parameters():
        p.grad = torch.randn_like(p)

    iterations = 1000
    start_time = time.perf_counter()
    for _ in range(iterations):
        _ = monitor._compute_grad_norm()
    end_time = time.perf_counter()
    
    avg_time_ms = ((end_time - start_time) / iterations) * 1000
    total_params = sum(p.numel() for p in model.parameters())
    ns_per_param = ((end_time - start_time) / (iterations * total_params)) * 1e9
    
    print(f"\n--- BENCHMARK ---")
    print(f"Average time per iteration: {avg_time_ms:.6f} ms")
    print(f"Approximate overhead per parameter: {ns_per_param:.2f} ns")

if __name__ == "__main__":
    print("Running Logic Tests...")
    pytest.main([__file__, "-v"])
    benchmark_overhead()

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +97 to +100
if engine is not None and hasattr(engine, "scaler"):
scaler = getattr(engine, "scaler")
if scaler is not None:
total_norm /= scaler.get_scale()
Copy link

Copilot AI Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ignite’s AMP trainer integration stores the GradScaler on engine.state.scaler (see ignite/engine/__init__.py), but this handler looks for engine.scaler. As a result, unscaling won’t happen in the common case and AMP support is effectively broken. Please read the scaler from engine.state (e.g., getattr(engine.state, "scaler", None)) and update the unit test to match.

Suggested change
if engine is not None and hasattr(engine, "scaler"):
scaler = getattr(engine, "scaler")
if scaler is not None:
total_norm /= scaler.get_scale()
scaler = None
if engine is not None:
state = getattr(engine, "state", None)
if state is not None:
scaler = getattr(state, "scaler", None)
if scaler is not None:
total_norm /= scaler.get_scale()

Copilot uses AI. Check for mistakes.
class MockScaler:
def get_scale(self): return 1024.0

trainer.scaler = MockScaler()
Copy link

Copilot AI Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test sets trainer.scaler = MockScaler(), but Ignite’s supervised AMP setup uses trainer.state.scaler to store the GradScaler. To validate the real integration path, set the mock scaler on trainer.state.scaler (and adjust the handler accordingly).

Suggested change
trainer.scaler = MockScaler()
trainer.state.scaler = MockScaler()

Copilot uses AI. Check for mistakes.
Comment on lines +83 to +102
total_norm_sq = 0.0
for p in self.model.parameters():
if p.grad is not None:
# Uses .data.norm to get raw scalar value quickly.
param_norm = p.grad.detach().data.norm(2)
total_norm_sq += param_norm.item() ** 2

# DDP Sync: Sums squared norms across all GPUs so all nodes see the same 'spike'.
if idist.get_world_size() > 1:
total_norm_sq = idist.all_reduce(total_norm_sq)

total_norm = total_norm_sq ** 0.5

# Handles GradScaler for Automatic Mixed Precision (AMP) workflows.
if engine is not None and hasattr(engine, "scaler"):
scaler = getattr(engine, "scaler")
if scaler is not None:
total_norm /= scaler.get_scale()

return total_norm
Copy link

Copilot AI Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In _compute_grad_norm, using param_norm.item() inside the parameter loop will force a device sync per parameter when gradients are on GPU, which can add significant overhead every iteration. Consider accumulating total_norm_sq as a tensor on the same device as the grads (and calling .item() only once at the end), which also makes it easier to pass a tensor into idist.all_reduce without extra conversions.

Suggested change
total_norm_sq = 0.0
for p in self.model.parameters():
if p.grad is not None:
# Uses .data.norm to get raw scalar value quickly.
param_norm = p.grad.detach().data.norm(2)
total_norm_sq += param_norm.item() ** 2
# DDP Sync: Sums squared norms across all GPUs so all nodes see the same 'spike'.
if idist.get_world_size() > 1:
total_norm_sq = idist.all_reduce(total_norm_sq)
total_norm = total_norm_sq ** 0.5
# Handles GradScaler for Automatic Mixed Precision (AMP) workflows.
if engine is not None and hasattr(engine, "scaler"):
scaler = getattr(engine, "scaler")
if scaler is not None:
total_norm /= scaler.get_scale()
return total_norm
# Accumulate squared norms as a tensor on the same device as the gradients to
# avoid per-parameter host-device syncs from .item().
total_norm_sq = None
for p in self.model.parameters():
if p.grad is not None:
# Uses .data.norm to get raw scalar value quickly.
param_norm = p.grad.detach().data.norm(2)
param_norm_sq = param_norm.pow(2)
if total_norm_sq is None:
total_norm_sq = param_norm_sq
else:
total_norm_sq = total_norm_sq + param_norm_sq
# If there were no gradients, keep previous behavior (norm == 0.0).
if total_norm_sq is None:
total_norm_sq = 0.0
# DDP Sync: Sums squared norms across all GPUs so all nodes see the same 'spike'.
if idist.get_world_size() > 1:
total_norm_sq = idist.all_reduce(total_norm_sq)
# Ensure we work with a tensor for further operations.
if not isinstance(total_norm_sq, torch.Tensor):
total_norm_sq = torch.tensor(total_norm_sq)
total_norm = torch.sqrt(total_norm_sq)
# Handles GradScaler for Automatic Mixed Precision (AMP) workflows.
if engine is not None and hasattr(engine, "scaler"):
scaler = getattr(engine, "scaler")
if scaler is not None:
scale = scaler.get_scale()
if isinstance(scale, torch.Tensor):
total_norm = total_norm / scale
else:
total_norm = total_norm / float(scale)
return float(total_norm.item())

Copilot uses AI. Check for mistakes.
for p in self.model.parameters():
if p.grad is not None:
# Uses .data.norm to get raw scalar value quickly.
param_norm = p.grad.detach().data.norm(2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need to detach and call data on grad tensor?

Copy link
Collaborator

@vfdev-5 vfdev-5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the draft. I left few quick comments without going yet into details.

Comment on lines +33 to +39
# Implementation for batch skipping
@trainer.on(Events.ITERATION_COMPLETED)
def skip_batch_on_spike(engine):
if getattr(engine.state, "unhealthy_spike", False):
# Custom logic here, e.g., zeroing grads or logging
optimizer.zero_grad()
print(f"Spike detected at iteration {engine.state.iteration}, skipping.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This example is unclear. Let's not use getattr as the attribute should exist.
Second this should be put into train_step.

If the norm exceeds the threshold, we log a warning and run an optional callback.
This is especially useful for transformers where gradients can explode suddenly.

.. versionadded:: 0.5.3
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be in the end of docstring and mention the next unreleased version

@leelakrishnaRajasimha
Copy link
Contributor Author

I have refactored the GradMonitor implementation to fully address the feedback:
Efficiency: Implemented Welford's Algorithm for O(1) memory-efficient running statistics (mean and variance).Performance: Optimized the L2 norm calculation using .pow(2).sum()(avoiding redundant square roots inside the lop) and minimized CPU-GPU synchronization by calling .item() only once per iteration.
Logic: Attached the handler to Events.ITERATION_STARTED to allow for proper batch skipping before weight updates.
Testing: Added unit tests to verify the statistical threshold logic (mean + k \times std) and confirmed that the k multiplier is functioning correctly.
Examples: Updated the docstrings and examples to clarify how users can implement custom batch-skipping logic in their train_step.
All tests are passing locally, and the implementation follows the repository's formatting standards.

Comment on lines 1 to 3
from collections.abc import Callable
from typing import Any
from typing import Any, Callable, Optional

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are you removing these and replacing it with typing module and Optional?

self.m2 = 0.0

def __call__(self, engine: Engine):
device = next(self.model.parameters()).device
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you could just save the device self.device

Comment on lines +19 to +25

@trainer.on(Events.ITERATION_STARTED)
def check_spike(engine):
# The training step logic must check this flag to skip
if getattr(engine.state, "unhealthy_spike", False):
optimizer.zero_grad()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This example is not good, as all it does is apply optimizer.zero_grad() train_step will still run and weighs will get updated, in example there should be a custom train_step with uses the engine.state.unhealthy_spike

.. versionadded:: 0.6.0
"""

def __init__(self, model: torch.nn.Module, k: float = 3.0):
Copy link
Collaborator

@aaishwarymishra aaishwarymishra Mar 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

didn't this handler took an optional callable which ran in case of spike?


total_norm = torch.sqrt(total_norm_sq).item()

scaler = getattr(engine.state, "scaler", None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vfdev-5 can you check it I am not that much familiar with distributed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is fine. However, it is written nowhere that this handler can look as engine.state.scaler attribute. Do we really need it as dynamic? Can't we define it as a static const in the constructor? In any case, I prefer to recode this part

@leelakrishnaRajasimha
Copy link
Contributor Author

Thanks, @aaishwarymishra , I have updated the code based on your feedback. I optimized the device handling to avoid redundant calls and refactored the example to show a custom_train_step for proper batch skipping. I kept the typing module for now to stay consistent with the existing library style. Let me know if anything else needs adjusting.

@@ -1,10 +1,12 @@
from typing import Any, Callable, Optional
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this line here?

.. versionadded:: 0.6.0
"""

def __init__(self, model: torch.nn.Module, k: float = 3.0):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@leelakrishnaRajasimha I suggest to make configurable the algorithm computing the threshold.
Right now the rule is hard-coded, we can code this as a helper function and users that would like to use their own rule can easily override the default rule.


total_norm = torch.sqrt(total_norm_sq).item()

scaler = getattr(engine.state, "scaler", None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is fine. However, it is written nowhere that this handler can look as engine.state.scaler attribute. Do we really need it as dynamic? Can't we define it as a static const in the constructor? In any case, I prefer to recode this part

self.mean += delta / self.count
self.m2 += delta * (total_norm - self.mean)

def attach(self, engine: Engine):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if user's train step looks like that:

def train_step(engine, batch):
    ...
    optimizer.step()
    optimizer.zero_grad()
    return loss_value

so, the handler wont see any grads in the model and silently wont do anything.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: handlers Core Handlers module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature] Implementation of a TrainingHealthMonitor for real-time gradient norm tracking.

4 participants