diff --git a/python/oneflow/nn/utils/clip_grad.py b/python/oneflow/nn/utils/clip_grad.py index a667d3cc00e..508295385e9 100644 --- a/python/oneflow/nn/utils/clip_grad.py +++ b/python/oneflow/nn/utils/clip_grad.py @@ -119,6 +119,8 @@ def clip_grad_norm_( ] total_norm = norms[0] if len(norms) == 1 else flow.min(flow.stack(norms)) else: + """ + # data parallel total_norm = flow.linalg.vector_norm( flow.stack( [ @@ -130,6 +132,22 @@ def clip_grad_norm_( ), norm_type, ) + """ + # tensor parallel: + partial_grad_squre_sum = flow.sum( + flow.stack( + [ + flow.sum(flow.pow(p.grad.detach(), norm_type)).to_local() + for p in parameters + ] + ) + ) + + flow.comm.all_reduce(partial_grad_squre_sum) + total_norm = flow.pow(partial_grad_squre_sum, 1 / norm_type) + total_norm = total_norm.to_global( + sbp=sbp_broadcast, placement=param0_placement + ) if error_if_nonfinite and flow.logical_or( total_norm.isnan(), total_norm.isinf() ):