Skip to content

Commit

Permalink
[Unity] Split DecomposeOpsForTraining into two steps (apache#15954)
Browse files Browse the repository at this point in the history
* [Unity] Split DecomposeOpsForTraining into two steps

Prior to this commit, the `DecomposeOpsForTraining` transform directly
replaced `relax.nn.batch_norm` into more primitive relax operations.
This required the decomposed form of `relax.nn.batch_norm` to be
duplicated with `DecomposeOpsForInference`.  This commit refactors the
pass to occur in two steps, first to apply training-specific
mutations, and then to decompose.

Having a clear `DecomposeOps` pass also has a clear single location
for operator decomposition, which may be migrated into the operator
definition in the future, similar to `FLegalize`.

* Updated ApplyPassToFunction utility to use a regex
  • Loading branch information
Lunderberg authored Jan 16, 2024
1 parent cf14edd commit a2a1b53
Show file tree
Hide file tree
Showing 4 changed files with 162 additions and 121 deletions.
25 changes: 25 additions & 0 deletions include/tvm/ir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,31 @@ TVM_DLL Pass CreateModulePass(
const runtime::TypedPackedFunc<IRModule(IRModule, PassContext)>& pass_func, int opt_level,
String name, Array<runtime::String> required, bool traceable = false);

/*
* \brief Utility to apply a pass to specific functions in an IRModule
*
* TVM uses IRModule to IRModule transformations at all stages of
* lowering. These transformations may be useful when hand-writing an
* optimized model, or to perform optimizations on specific kernels
* within an IRModule. This utility allows a pass to be applied to a
* specified function, without altering other functions in the module.
*
* \param pass The IRModule to IRModule pass to be applied.
*
* \param func_name_regex A regex used to select the functions to be
* updated. The pass will be applied to all functions whose name
* matches the regex.
*
* \param error_if_no_function_matches_regex Specifies the behavior if
* an IRModule does not contain any function matching the provided
* regex. If true, an error will be raised. If false (default),
* the IRModule will be returned unmodified.
*
* \return The modified IRModule to IRModule pass.
*/
TVM_DLL Pass ApplyPassToFunction(Pass pass, String func_name_regex,
bool error_if_no_function_matches_regex = false);

/*!
* \brief A special trace pass that prints the header and IR to LOG(INFO).
* \param header The header to be attached to the output.
Expand Down
31 changes: 31 additions & 0 deletions src/ir/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

#include <chrono>
#include <iomanip>
#include <regex>
#include <stack>
#include <unordered_set>

Expand Down Expand Up @@ -531,6 +532,36 @@ Pass CreateModulePass(const runtime::TypedPackedFunc<IRModule(IRModule, PassCont
return ModulePass(pass_func, pass_info);
}

Pass ApplyPassToFunction(Pass pass, String func_name_regex,
bool error_if_no_function_matches_regex) {
auto pass_name =
static_cast<const std::stringstream&>(std::stringstream() << "ApplyPassTo" << func_name_regex)
.str();
std::regex regex(func_name_regex.operator std::string());

auto pass_func = [pass, regex](IRModule mod, PassContext) -> IRModule {
IRModule subset;

for (const auto& [gvar, func] : mod->functions) {
std::string name = gvar->name_hint;
if (std::regex_match(name, regex)) {
subset->Add(gvar, func);
}
}

if (subset->functions.size()) {
IRModule new_subset = pass(subset);
if (!new_subset.same_as(subset)) {
mod.CopyOnWrite()->Update(new_subset);
}
}

return mod;
};

return CreateModulePass(pass_func, 0, pass_name, {});
}

TVM_REGISTER_NODE_TYPE(PassInfoNode);

TVM_REGISTER_GLOBAL("transform.PassInfo")
Expand Down
156 changes: 73 additions & 83 deletions src/relax/transform/decompose_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ Expr ExpandToMatchInput(Expr data, int ndim, Array<Integer> axes) {
return expand_dims(data, expand_axes);
}

Tuple SimplifyBatchNormInference(const Call& call) {
Tuple DecomposeBatchNorm(const Call& call) {
auto attrs = call->attrs.as<BatchNormAttrs>();
ICHECK_NOTNULL(attrs);

Expand All @@ -75,14 +75,18 @@ Tuple SimplifyBatchNormInference(const Call& call) {
return Tuple({out, call->args[3], call->args[4]});
}

Tuple SimplifyBatchNormTraining(const Call& call) {
Expr MutateBatchNormForTraining(Call call) {
auto attrs = call->attrs.as<BatchNormAttrs>();
ICHECK_NOTNULL(attrs);

ICHECK_EQ(call->args.size(), 5);
Expr data = call->args[0];
TensorStructInfo sinfo = MatchTensorStructInfo(data);
Expr gamma = call->args[1];
Expr beta = call->args[2];
Expr moving_mean = call->args[3];
Expr moving_var = call->args[4];

TensorStructInfo sinfo = MatchTensorStructInfo(data);

Array<Integer> reduce_axes;
for (int i = 0; i < sinfo->ndim; ++i) {
Expand All @@ -92,35 +96,21 @@ Tuple SimplifyBatchNormTraining(const Call& call) {
}

Expr data_mean = mean(data, reduce_axes, false);
Expr data_mean_rs = ExpandToMatchInput(data_mean, sinfo->ndim, {attrs->axis});
Expr data_var = variance(data, reduce_axes, false);
Expr data_var_rs = ExpandToMatchInput(data_var, sinfo->ndim, {attrs->axis});

// output = (x - mean) / sqrt(var + epsilon) * gamma + beta
Expr epsilon = MakeConstantScalar(attrs->epsilon, sinfo->dtype);
Expr sqrt_var = sqrt(add(data_var_rs, epsilon));
Expr out = divide(subtract(data, data_mean_rs), sqrt_var);

if (attrs->scale) {
out = multiply(out, ExpandToMatchInput(gamma, sinfo->ndim, {attrs->axis}));
}
if (attrs->center) {
out = add(out, ExpandToMatchInput(beta, sinfo->ndim, {attrs->axis}));
}

Expr moving_mean = call->args[3];
Expr moving_var = call->args[4];
Expr momentum = MakeConstantScalar(attrs->momentum, sinfo->dtype);
Expr one_minus_mom = MakeConstantScalar(1 - attrs->momentum, sinfo->dtype);

return Tuple({
out,
add(multiply(one_minus_mom, moving_mean), multiply(momentum, data_mean)),
add(multiply(one_minus_mom, moving_var), multiply(momentum, data_var)),
});
Expr new_moving_mean = add(multiply(one_minus_mom, moving_mean), multiply(momentum, data_mean));
Expr new_moving_var = add(multiply(one_minus_mom, moving_var), multiply(momentum, data_var));

call.CopyOnWrite()->args = {data, gamma, beta, data_mean, data_var};
// return call;

return relax::Tuple({TupleGetItem(call, 0), new_moving_mean, new_moving_var});
}

Expr SimplifyLayerNorm(const Call& call) {
Expr DecomposeLayerNorm(const Call& call) {
auto attrs = call->attrs.as<LayerNormAttrs>();
ICHECK_NOTNULL(attrs);

Expand Down Expand Up @@ -172,92 +162,92 @@ Expr TensorToShape(const Call& call_node, const BlockBuilder& builder) {
return ShapeExpr(shape_var);
}

class OpDecomposer : public ExprMutator {
public:
constexpr static const char* kModeInference = "inference";
constexpr static const char* kModeTraining = "training";
/*! \brief Update operators that have a training-specific form
*
* Some operators, such as relax.op.batch_norm, need additional
* processing when being run for training. This mutator applies any mutations required
*/
class TrainingOperatorMutator : public ExprMutator {
private:
using ExprMutator::VisitExpr_;

explicit OpDecomposer(String mode) : ExprMutator(), mode_(mode) {
CHECK(mode == kModeInference || mode == kModeTraining)
<< "The argument mode must be one of the following values: \"inference\", \"training\".";
Expr VisitExpr_(const CallNode* call_node) final {
Call call = Downcast<Call>(VisitExprPostOrder_(call_node));
if (call->op == batch_norm_op_) {
return MutateBatchNormForTraining(call);
} else if (call->op == layer_norm_op_) {
// Here we only decompose LayerNorm in training because it is more efficient as a single op.
// In the future maybe we can also remove this decomposition during training.
return DecomposeLayerNorm(call);
} else {
return call;
}
}

/* composite opeartor list */
const Op& batch_norm_op_ = Op::Get("relax.nn.batch_norm");
const Op& layer_norm_op_ = Op::Get("relax.nn.layer_norm");
};

class OpDecomposer : public ExprMutator {
private:
using ExprMutator::VisitExpr_;

Expr VisitExpr_(const CallNode* call_node) final {
Call call = Downcast<Call>(VisitExprPostOrder_(call_node));
if (call->op == batch_norm_op_) {
if (mode_ == kModeInference) {
return SimplifyBatchNormInference(call);
} else {
ICHECK_EQ(mode_, kModeTraining);
return SimplifyBatchNormTraining(call);
}
} else if (call->op == layer_norm_op_ && mode_ == kModeTraining) {
// Here we only decompose LayerNorm in training because it is more efficient as a single op.
// In the future maybe we can also remove this decomposition during training.
return SimplifyLayerNorm(call);
return DecomposeBatchNorm(call);
} else if (call->op == tensor_to_shape_op_) {
return TensorToShape(call, builder_);
}
return call;
}

const String mode_;

/* composite opeartor list */
const Op& batch_norm_op_ = Op::Get("relax.nn.batch_norm");
const Op& layer_norm_op_ = Op::Get("relax.nn.layer_norm");
const Op& tensor_to_shape_op_ = Op::Get("relax.tensor_to_shape");
};

IRModule Decompose(IRModule mod, Optional<String> func_name, String mode) {
auto op_decomposer = OpDecomposer(mode);

IRModuleNode* new_module = mod.CopyOnWrite();
namespace transform {

if (!func_name.defined()) { // simplify all functions
Map<GlobalVar, BaseFunc> functions = mod->functions;
for (const auto& func_pr : functions) {
if (const auto* relax_f = func_pr.second.as<FunctionNode>()) {
Function f = Downcast<Function>(op_decomposer(GetRef<Function>(relax_f)));
new_module->Update(func_pr.first, f);
}
}
} else { // simplify specified function
auto* func_ptr = mod->Lookup(func_name.value()).as<FunctionNode>();
CHECK(func_ptr) << func_name.value() << "is not a Relax Function";
auto gvar = mod->GetGlobalVar(func_name.value());
auto func = GetRef<Function>(func_ptr);
func = Downcast<Function>(op_decomposer(func));
new_module->Update(gvar, func);
}
Pass MutateOpsForTraining() {
auto pass_func = [](Function func, IRModule, PassContext) -> Function {
TrainingOperatorMutator mutator;
return Downcast<Function>(mutator(func));
};
return CreateFunctionPass(/*pass_function=*/pass_func,
/*opt_level=*/0,
/*pass_name=*/"MutateOpsForTraining",
/*required=*/{});
}

return GetRef<IRModule>(new_module);
Pass DecomposeOps() {
auto pass_func = [](Function func, IRModule, PassContext) -> Function {
OpDecomposer mutator;
return Downcast<Function>(mutator(func));
};
return CreateFunctionPass(/*pass_function=*/pass_func,
/*opt_level=*/0,
/*pass_name=*/"DecomposeOps",
/*required=*/{});
}

namespace transform {
Pass DecomposeOpsForInference(Optional<String> func_name) {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = [=](IRModule mod,
PassContext pc) {
return Decompose(mod, func_name, OpDecomposer::kModeInference);
};
return CreateModulePass(/*pass_function=*/pass_func,
/*opt_level=*/0,
/*pass_name=*/"DecomposeOpsForInference",
/*required=*/{});
if (func_name) {
return ApplyPassToFunction(DecomposeOps(), func_name.value());
} else {
return DecomposeOps();
}
}

Pass DecomposeOpsForTraining(Optional<String> func_name) {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = [=](IRModule mod,
PassContext pc) {
return Decompose(mod, func_name, OpDecomposer::kModeTraining);
};
return CreateModulePass(/*pass_function=*/pass_func,
/*opt_level=*/0,
/*pass_name=*/"DecomposeOpsForTraining",
/*required=*/{});
auto module_pass = tvm::transform::Sequential({MutateOpsForTraining(), DecomposeOps()},
"DecomposeOpsForTraining");
if (func_name) {
return ApplyPassToFunction(module_pass, func_name.value());
} else {
return module_pass;
}
}

TVM_REGISTER_GLOBAL("relax.transform.DecomposeOpsForInference")
Expand Down
71 changes: 33 additions & 38 deletions tests/python/relax/test_transform_decompose_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,44 +137,39 @@ def main(
R.Tensor((64,), dtype="float32"),
):
with R.dataflow():
lv: R.Tensor((64,), dtype="float32") = R.mean(x, axis=[0, 2, 3], keepdims=False)
lv1: R.Tensor((1, 64, 1, 1), dtype="float32") = R.expand_dims(lv, axis=[0, 2, 3])
lv2: R.Tensor((1, 64, 112, 112), dtype="float32") = R.subtract(x, lv1)
lv3: R.Tensor((64,), dtype="float32") = R.variance(
x, axis=[0, 2, 3], keepdims=False
)
lv4: R.Tensor((1, 64, 1, 1), dtype="float32") = R.expand_dims(lv3, axis=[0, 2, 3])
lv5: R.Tensor((1, 64, 1, 1), dtype="float32") = R.add(
lv4, R.const(9.9999997473787516e-06, "float32")
)
lv6: R.Tensor((1, 64, 1, 1), dtype="float32") = R.sqrt(lv5)
lv7: R.Tensor((1, 64, 112, 112), dtype="float32") = R.divide(lv2, lv6)
lv8: R.Tensor((1, 64, 1, 1), dtype="float32") = R.expand_dims(gamma, axis=[0, 2, 3])
lv9: R.Tensor((1, 64, 112, 112), dtype="float32") = R.multiply(lv7, lv8)
lv10: R.Tensor((1, 64, 1, 1), dtype="float32") = R.expand_dims(beta, axis=[0, 2, 3])
lv11: R.Tensor((1, 64, 112, 112), dtype="float32") = R.add(lv9, lv10)
lv12: R.Tensor((64,), dtype="float32") = R.multiply(
R.const(0.89999997615814209, "float32"), moving_mean
)
lv13: R.Tensor((64,), dtype="float32") = R.multiply(
R.const(0.10000000149011612, "float32"), lv
)
lv14: R.Tensor((64,), dtype="float32") = R.add(lv12, lv13)
lv15: R.Tensor((64,), dtype="float32") = R.multiply(
R.const(0.89999997615814209, "float32"), moving_var
)
lv16: R.Tensor((64,), dtype="float32") = R.multiply(
R.const(0.10000000149011612, "float32"), lv3
)
lv17: R.Tensor((64,), dtype="float32") = R.add(lv15, lv16)
bn: R.Tuple(
R.Tensor((1, 64, 112, 112), dtype="float32"),
R.Tensor((64,), dtype="float32"),
R.Tensor((64,), dtype="float32"),
) = (lv11, lv14, lv17)
gv0: R.Tensor((1, 64, 112, 112), dtype="float32") = bn[0]
gv1: R.Tensor((64,), dtype="float32") = bn[1]
gv2: R.Tensor((64,), dtype="float32") = bn[2]
# This portion is training-specific, computing the
# mean/variance of the dataset.
lv = R.mean(x, axis=[0, 2, 3], keepdims=False)
lv3 = R.variance(x, axis=[0, 2, 3], keepdims=False)

# This portion is identical to the batch_norm run during inference
lv1 = R.expand_dims(lv, axis=[0, 2, 3])
lv2 = R.subtract(x, lv1)
lv4 = R.expand_dims(lv3, axis=[0, 2, 3])
lv5 = R.add(lv4, R.const(9.9999997473787516e-06, "float32"))
lv6 = R.sqrt(lv5)
lv7 = R.divide(lv2, lv6)
lv8 = R.expand_dims(gamma, axis=[0, 2, 3])
lv9 = R.multiply(lv7, lv8)
lv10 = R.expand_dims(beta, axis=[0, 2, 3])
lv11 = R.add(lv9, lv10)
inner_tuple = (lv11, lv, lv3)
# This is the result that would be returned from a
# batch_norm at inference.

# However, at training we need to update the moving
# mean/variance, and to return those updated values.
inner_res = inner_tuple[0]
lv12 = R.multiply(R.const(0.89999997615814209, "float32"), moving_mean)
lv13 = R.multiply(R.const(0.10000000149011612, "float32"), lv)
lv14 = R.add(lv12, lv13)
lv15 = R.multiply(R.const(0.89999997615814209, "float32"), moving_var)
lv16 = R.multiply(R.const(0.10000000149011612, "float32"), lv3)
lv17 = R.add(lv15, lv16)
bn = (inner_res, lv14, lv17)
gv0 = bn[0]
gv1 = bn[1]
gv2 = bn[2]
R.output(gv0, gv1, gv2)
return (gv0, gv1, gv2)

Expand Down

0 comments on commit a2a1b53

Please sign in to comment.