Feat: implement GradMonitor handler for gradient norm tracking and batch skipping#3624
Conversation
|
Hmm, I read your pr there are few suggestions:
Also by looking at code the handler doesn't actually skip the batch user has to write custom code in train step? like : |
There was a problem hiding this comment.
Pull request overview
Adds a new Ignite handler to monitor L2 gradient norms at iteration boundaries and expose an engine.state.unhealthy_spike flag for reactive training logic (e.g., batch skipping) to help detect instability during training.
Changes:
- Introduces
GradMonitorhandler with optional static thresholding and moving-average-based dynamic spike detection. - Exposes the new handler via
ignite.handlerspublic API. - Adds unit tests for basic norm math, flag toggling, AMP unscale behavior, and input validation.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
ignite/handlers/grad_monitor.py |
New handler implementation: norm computation, optional AMP/DDP handling, thresholding, and state flagging. |
ignite/handlers/__init__.py |
Re-exports GradMonitor in the public handlers namespace. |
tests/ignite/handlers/test_grad_monitor.py |
Adds tests for the new handler and related behaviors. |
Comments suppressed due to low confidence (7)
ignite/handlers/grad_monitor.py:75
window_sizeis used asdeque(maxlen=window_size)but isn’t validated. Passing a non-int or a negative value will raise an unhelpfuldequeerror, andwindow_size=0makes dynamic mode impossible to trigger. Add explicit validation (e.g., positive integer) and consider validatinguse_dynamicis bool for consistency with other handlers’ arg checks.
threshold: float = 100.0,
window_size: int = 10,
use_dynamic: bool = False,
callback: Optional[Callable[[Engine, float], None]] = None,
) -> None:
if not isinstance(model, Module):
raise TypeError(f"Argument model should be a torch.nn.Module, but given {type(model)}.")
if not (isinstance(threshold, (int, float)) and math.isfinite(threshold) and threshold > 0):
raise ValueError(f"Argument threshold should be a positive finite number, but given {threshold}.")
if callback is not None and not callable(callback):
raise TypeError(f"Argument callback should be callable, but given {type(callback)}.")
self.model = model
self.threshold = float(threshold)
self.use_dynamic = use_dynamic
self.callback = callback
# Standard Ignite logger setup.
self.logger = setup_logger(__name__ + "." + self.__class__.__name__)
self._history = deque(maxlen=1000)
self._window = deque(maxlen=window_size)
ignite/handlers/grad_monitor.py:22
- The docstring says
.. versionadded:: 0.5.3, but the package version inignite/__init__.pyis0.6.0. Please update theversionaddedtag to the correct upcoming release version for this new handler to avoid misleading docs.
.. versionadded:: 0.5.3
ignite/handlers/grad_monitor.py:10
- There are a few unused imports here (
import torchandenginefromignite.engine). Please remove them to keep the module clean and avoid lint failures.
import torch
from torch.nn import Module
from ignite import distributed as idist
from ignite.engine import Engine, Events, engine
from ignite.utils import setup_logger
ignite/handlers/grad_monitor.py:116
- Dynamic thresholding (
use_dynamic/ moving average window) is new behavior but isn’t covered by the tests intest_grad_monitor.py. Please add a unit test that exercisesuse_dynamic=True(including the warm-up phase before the window is full) and verifies the computedeffective_threshold/flag behavior.
# Logic for dynamic thresholding.
effective_threshold = self.threshold
if self.use_dynamic and len(self._window) == self._window.maxlen:
# Calculates spike relative to the moving average of the last 'n' steps.
avg_norm = sum(self._window) / len(self._window)
effective_threshold = avg_norm * 5.0 # Spike is defined as 5x the average.
if grad_norm > effective_threshold:
ignite/handlers/init.py:2
CallableandOptionalare imported but never used in this module. Please remove unused imports to avoid lint warnings and keepignite.handlersexports clean.
from typing import Any, Callable, Optional
ignite/handlers/grad_monitor.py:87
- Avoid using
.dataon tensors (p.grad.detach().data.norm(2)): it bypasses autograd safety checks and is discouraged in PyTorch. Usep.grad.detach().norm(2)(or equivalent) instead.
# Uses .data.norm to get raw scalar value quickly.
param_norm = p.grad.detach().data.norm(2)
tests/ignite/handlers/test_grad_monitor.py:87
- This test module contains a benchmark function and an
if __name__ == "__main__"block that runs pytest/prints benchmark output. Test files in this suite don’t appear to include ad-hoc benchmarks or main-guards; please remove these sections or move the benchmark to a dedicated benchmarking script (so CI remains deterministic and quiet).
def benchmark_overhead():
model = torch.nn.Sequential(torch.nn.Linear(1000, 1000))
monitor = GradMonitor(model)
for p in model.parameters():
p.grad = torch.randn_like(p)
iterations = 1000
start_time = time.perf_counter()
for _ in range(iterations):
_ = monitor._compute_grad_norm()
end_time = time.perf_counter()
avg_time_ms = ((end_time - start_time) / iterations) * 1000
total_params = sum(p.numel() for p in model.parameters())
ns_per_param = ((end_time - start_time) / (iterations * total_params)) * 1e9
print(f"\n--- BENCHMARK ---")
print(f"Average time per iteration: {avg_time_ms:.6f} ms")
print(f"Approximate overhead per parameter: {ns_per_param:.2f} ns")
if __name__ == "__main__":
print("Running Logic Tests...")
pytest.main([__file__, "-v"])
benchmark_overhead()
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
ignite/handlers/grad_monitor.py
Outdated
| if engine is not None and hasattr(engine, "scaler"): | ||
| scaler = getattr(engine, "scaler") | ||
| if scaler is not None: | ||
| total_norm /= scaler.get_scale() |
There was a problem hiding this comment.
Ignite’s AMP trainer integration stores the GradScaler on engine.state.scaler (see ignite/engine/__init__.py), but this handler looks for engine.scaler. As a result, unscaling won’t happen in the common case and AMP support is effectively broken. Please read the scaler from engine.state (e.g., getattr(engine.state, "scaler", None)) and update the unit test to match.
| if engine is not None and hasattr(engine, "scaler"): | |
| scaler = getattr(engine, "scaler") | |
| if scaler is not None: | |
| total_norm /= scaler.get_scale() | |
| scaler = None | |
| if engine is not None: | |
| state = getattr(engine, "state", None) | |
| if state is not None: | |
| scaler = getattr(state, "scaler", None) | |
| if scaler is not None: | |
| total_norm /= scaler.get_scale() |
| class MockScaler: | ||
| def get_scale(self): return 1024.0 | ||
|
|
||
| trainer.scaler = MockScaler() |
There was a problem hiding this comment.
This test sets trainer.scaler = MockScaler(), but Ignite’s supervised AMP setup uses trainer.state.scaler to store the GradScaler. To validate the real integration path, set the mock scaler on trainer.state.scaler (and adjust the handler accordingly).
| trainer.scaler = MockScaler() | |
| trainer.state.scaler = MockScaler() |
ignite/handlers/grad_monitor.py
Outdated
| total_norm_sq = 0.0 | ||
| for p in self.model.parameters(): | ||
| if p.grad is not None: | ||
| # Uses .data.norm to get raw scalar value quickly. | ||
| param_norm = p.grad.detach().data.norm(2) | ||
| total_norm_sq += param_norm.item() ** 2 | ||
|
|
||
| # DDP Sync: Sums squared norms across all GPUs so all nodes see the same 'spike'. | ||
| if idist.get_world_size() > 1: | ||
| total_norm_sq = idist.all_reduce(total_norm_sq) | ||
|
|
||
| total_norm = total_norm_sq ** 0.5 | ||
|
|
||
| # Handles GradScaler for Automatic Mixed Precision (AMP) workflows. | ||
| if engine is not None and hasattr(engine, "scaler"): | ||
| scaler = getattr(engine, "scaler") | ||
| if scaler is not None: | ||
| total_norm /= scaler.get_scale() | ||
|
|
||
| return total_norm |
There was a problem hiding this comment.
In _compute_grad_norm, using param_norm.item() inside the parameter loop will force a device sync per parameter when gradients are on GPU, which can add significant overhead every iteration. Consider accumulating total_norm_sq as a tensor on the same device as the grads (and calling .item() only once at the end), which also makes it easier to pass a tensor into idist.all_reduce without extra conversions.
| total_norm_sq = 0.0 | |
| for p in self.model.parameters(): | |
| if p.grad is not None: | |
| # Uses .data.norm to get raw scalar value quickly. | |
| param_norm = p.grad.detach().data.norm(2) | |
| total_norm_sq += param_norm.item() ** 2 | |
| # DDP Sync: Sums squared norms across all GPUs so all nodes see the same 'spike'. | |
| if idist.get_world_size() > 1: | |
| total_norm_sq = idist.all_reduce(total_norm_sq) | |
| total_norm = total_norm_sq ** 0.5 | |
| # Handles GradScaler for Automatic Mixed Precision (AMP) workflows. | |
| if engine is not None and hasattr(engine, "scaler"): | |
| scaler = getattr(engine, "scaler") | |
| if scaler is not None: | |
| total_norm /= scaler.get_scale() | |
| return total_norm | |
| # Accumulate squared norms as a tensor on the same device as the gradients to | |
| # avoid per-parameter host-device syncs from .item(). | |
| total_norm_sq = None | |
| for p in self.model.parameters(): | |
| if p.grad is not None: | |
| # Uses .data.norm to get raw scalar value quickly. | |
| param_norm = p.grad.detach().data.norm(2) | |
| param_norm_sq = param_norm.pow(2) | |
| if total_norm_sq is None: | |
| total_norm_sq = param_norm_sq | |
| else: | |
| total_norm_sq = total_norm_sq + param_norm_sq | |
| # If there were no gradients, keep previous behavior (norm == 0.0). | |
| if total_norm_sq is None: | |
| total_norm_sq = 0.0 | |
| # DDP Sync: Sums squared norms across all GPUs so all nodes see the same 'spike'. | |
| if idist.get_world_size() > 1: | |
| total_norm_sq = idist.all_reduce(total_norm_sq) | |
| # Ensure we work with a tensor for further operations. | |
| if not isinstance(total_norm_sq, torch.Tensor): | |
| total_norm_sq = torch.tensor(total_norm_sq) | |
| total_norm = torch.sqrt(total_norm_sq) | |
| # Handles GradScaler for Automatic Mixed Precision (AMP) workflows. | |
| if engine is not None and hasattr(engine, "scaler"): | |
| scaler = getattr(engine, "scaler") | |
| if scaler is not None: | |
| scale = scaler.get_scale() | |
| if isinstance(scale, torch.Tensor): | |
| total_norm = total_norm / scale | |
| else: | |
| total_norm = total_norm / float(scale) | |
| return float(total_norm.item()) |
ignite/handlers/grad_monitor.py
Outdated
| for p in self.model.parameters(): | ||
| if p.grad is not None: | ||
| # Uses .data.norm to get raw scalar value quickly. | ||
| param_norm = p.grad.detach().data.norm(2) |
There was a problem hiding this comment.
Why do you need to detach and call data on grad tensor?
vfdev-5
left a comment
There was a problem hiding this comment.
Thanks for the draft. I left few quick comments without going yet into details.
ignite/handlers/grad_monitor.py
Outdated
| # Implementation for batch skipping | ||
| @trainer.on(Events.ITERATION_COMPLETED) | ||
| def skip_batch_on_spike(engine): | ||
| if getattr(engine.state, "unhealthy_spike", False): | ||
| # Custom logic here, e.g., zeroing grads or logging | ||
| optimizer.zero_grad() | ||
| print(f"Spike detected at iteration {engine.state.iteration}, skipping.") |
There was a problem hiding this comment.
This example is unclear. Let's not use getattr as the attribute should exist.
Second this should be put into train_step.
ignite/handlers/grad_monitor.py
Outdated
| If the norm exceeds the threshold, we log a warning and run an optional callback. | ||
| This is especially useful for transformers where gradients can explode suddenly. | ||
|
|
||
| .. versionadded:: 0.5.3 |
There was a problem hiding this comment.
This should be in the end of docstring and mention the next unreleased version
|
I have refactored the GradMonitor implementation to fully address the feedback: |
ignite/handlers/__init__.py
Outdated
| from collections.abc import Callable | ||
| from typing import Any | ||
| from typing import Any, Callable, Optional | ||
|
|
There was a problem hiding this comment.
why are you removing these and replacing it with typing module and Optional?
ignite/handlers/grad_monitor.py
Outdated
| self.m2 = 0.0 | ||
|
|
||
| def __call__(self, engine: Engine): | ||
| device = next(self.model.parameters()).device |
There was a problem hiding this comment.
you could just save the device self.device
|
|
||
| @trainer.on(Events.ITERATION_STARTED) | ||
| def check_spike(engine): | ||
| # The training step logic must check this flag to skip | ||
| if getattr(engine.state, "unhealthy_spike", False): | ||
| optimizer.zero_grad() | ||
|
|
There was a problem hiding this comment.
This example is not good, as all it does is apply optimizer.zero_grad() train_step will still run and weighs will get updated, in example there should be a custom train_step with uses the engine.state.unhealthy_spike
| .. versionadded:: 0.6.0 | ||
| """ | ||
|
|
||
| def __init__(self, model: torch.nn.Module, k: float = 3.0): |
There was a problem hiding this comment.
didn't this handler took an optional callable which ran in case of spike?
|
|
||
| total_norm = torch.sqrt(total_norm_sq).item() | ||
|
|
||
| scaler = getattr(engine.state, "scaler", None) |
There was a problem hiding this comment.
@vfdev-5 can you check it I am not that much familiar with distributed.
There was a problem hiding this comment.
it is fine. However, it is written nowhere that this handler can look as engine.state.scaler attribute. Do we really need it as dynamic? Can't we define it as a static const in the constructor? In any case, I prefer to recode this part
|
Thanks, @aaishwarymishra , I have updated the code based on your feedback. I optimized the device handling to avoid redundant calls and refactored the example to show a custom_train_step for proper batch skipping. I kept the typing module for now to stay consistent with the existing library style. Let me know if anything else needs adjusting. |
| @@ -1,10 +1,12 @@ | |||
| from typing import Any, Callable, Optional | |||
There was a problem hiding this comment.
Why do we need this line here?
| .. versionadded:: 0.6.0 | ||
| """ | ||
|
|
||
| def __init__(self, model: torch.nn.Module, k: float = 3.0): |
There was a problem hiding this comment.
@leelakrishnaRajasimha I suggest to make configurable the algorithm computing the threshold.
Right now the rule is hard-coded, we can code this as a helper function and users that would like to use their own rule can easily override the default rule.
|
|
||
| total_norm = torch.sqrt(total_norm_sq).item() | ||
|
|
||
| scaler = getattr(engine.state, "scaler", None) |
There was a problem hiding this comment.
it is fine. However, it is written nowhere that this handler can look as engine.state.scaler attribute. Do we really need it as dynamic? Can't we define it as a static const in the constructor? In any case, I prefer to recode this part
| self.mean += delta / self.count | ||
| self.m2 += delta * (total_norm - self.mean) | ||
|
|
||
| def attach(self, engine: Engine): |
There was a problem hiding this comment.
if user's train step looks like that:
def train_step(engine, batch):
...
optimizer.step()
optimizer.zero_grad()
return loss_valueso, the handler wont see any grads in the model and silently wont do anything.
Closes #3563
Overview
This PR introduces the GradMonitor handler, designed to track L2 gradient norms in real-time to detect training instability, such as "entropy collapse" or exploding gradients in Transformer models.
Key Features
Real-time Monitoring: Hooks into Events.ITERATION_COMPLETED to calculate and log L2 gradient norms.
AMP & DDP Support: Fully compatible with Automatic Mixed Precision (unscaling via GradScaler) and Distributed Data Parallel (synchronization via idist.all_reduce).
Reactive Training Logic: Sets an engine.state.unhealthy_spike flag, enabling users to implement automated batch skipping or custom callbacks when thresholds are breached.
Flexible Thresholding: Supports both static L2 norm limits and dynamic spike detection using a moving average window.
Performance & Benchmarking
The handler has been benchmarked on a model with 1 million parameters to ensure it does not significantly impact training throughput:
Local Hardware (HP Victus / RTX 4050): For 100,000 (1 lakh) iterations, the total execution time typically falls within a 9s to 16s range, depending on background system overhead.
High-Performance Environment (Google Colab GPU): In optimized cloud environments with dedicated high-performance GPUs, the same benchmark has achieved results as fast as 5.4s.
Efficiency: This demonstrates that the per-iteration overhead remains negligible across various hardware tiers, making it highly suitable for large-scale production models.
Testing Suite
Verified mathematical accuracy of L2 norm calculations.
Validated engine.state flag toggling and callback execution.
Confirmed correct unscaling behavior when a GradScaler is present.
ncluded input validation for all initialization parameters.