Skip to content

Commit

Permalink
fix layernorm grad sbp
Browse files Browse the repository at this point in the history
  • Loading branch information
ShawnXuan committed Nov 10, 2024
1 parent df98f12 commit 4476fa4
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion oneflow/user/ops/layer_norm_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,10 +183,16 @@ oneflow::DataType InferBnParamDataType(const DataType x_data_type) {
broadcast_args.emplace_back(user_op::OpArg("gamma", 0));
}
int64_t begin_norm_axis = ctx->Attr<int64_t>("begin_norm_axis");
int64_t begin_params_axis = ctx->Attr<int64_t>("begin_params_axis");
CHECK_EQ(begin_norm_axis, begin_params_axis)
<< "begin_norm_axis and begin_params_axis must be equal, but got "
<< begin_norm_axis << " and " << begin_params_axis;
for (int i = 0; i < begin_norm_axis; ++i) {
ctx->NewBuilder()
.Split(ctx->inputs(), i)
.Split(ctx->outputs(), i)
.Split(user_op::OpArg("dx", 0), i)
.PartialSum(user_op::OpArg("gamma_diff", 0))
.PartialSum(user_op::OpArg("beta_diff", 0))
.Broadcast(broadcast_args)
.Build();
}
Expand Down

0 comments on commit 4476fa4

Please sign in to comment.