Skip to content

Commit

Permalink
refactor adapter code
Browse files Browse the repository at this point in the history
Signed-off-by: Yuan Yao <[email protected]>
  • Loading branch information
yuanyao-nv committed Jan 5, 2024
1 parent 3932d9b commit 3019560
Showing 1 changed file with 61 additions and 40 deletions.
101 changes: 61 additions & 40 deletions onnx/version_converter/adapters/group_normalization_20_21.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,46 +19,64 @@ class GroupNormalization_20_21 final : public Adapter {
public:
explicit GroupNormalization_20_21() : Adapter("GroupNormalization", OpSetID(20), OpSetID(21)) {}

void transform_input(std::shared_ptr<Graph> graph, Node* node, int64_t input_id) const {
// Perform following sequence of ops on input, effect is similar to numpy.repeat()
void transform_input(
std::shared_ptr<Graph> graph,
Node* node,
int64_t input_id,
Value* reshape0_shape,
Value* reshape1_shape,
Value* expand_shape) const {
Node* reshape0 = graph->create(kReshape);
reshape0->addInput(node->inputs()[input_id]);
reshape0->addInput(reshape0_shape);
reshape0->insertBefore(node);

// -- Shape<start=1,end=2>() -- Div(Shape_out, num_groups)
// / |
// -- Reshape(,[-1,1]) ------------- Expand(Reshape_out,[1,Div_out]) -- Reshape(Expand_out, [-1])
Node* expand = graph->create(kExpand);
expand->addInput(reshape0->output());
expand->addInput(expand_shape);
expand->insertBefore(node);

Node* reshape1 = graph->create(kReshape);
reshape1->addInput(expand->output());
reshape1->addInput(reshape1_shape);
reshape1->insertBefore(node);

node->replaceInput(input_id, reshape1->output());
}

void adapt_group_normalization_20_21(std::shared_ptr<Graph> graph, Node* node) const {
// Perform following sequence of ops on scale/bias, effect is similar to numpy.repeat()
//
// Shape<start=1,end=2>(input0) -- Div(Shape_out (C), num_groups)
// |
// Reshape(input1/2, [-1, 1]) ----------- Expand(Reshape_out, [1, Div_out]) -- Reshape(Expand_out, [-1])
//
// The helper function transform_input() implements the bottom row of the diagram

// Get number of channels: C
Symbol kShape("Shape");
Node* shape = graph->create(kShape);
shape->i_(kstart, 1);
shape->i_(kend, 2);
shape->addInput(node->inputs()[input_id]);
shape->insertBefore(node);
Node* C = graph->create(kShape);
C->i_(kstart, 1);
C->i_(kend, 2);
C->addInput(node->inputs()[0]);
C->insertBefore(node);

// Get number of channels per group
Tensor tensor_num_groups;
tensor_num_groups.elem_type() = TensorProto_DataType_INT64;
int64_t num_groups = node->i(knum_groups);
tensor_num_groups.sizes() = {1};
tensor_num_groups.int64s() = {num_groups};
Node* constant_num_grp = graph->create(kConstant);
constant_num_grp->t_(kvalue, tensor_num_groups);
constant_num_grp->insertBefore(node);
Node* constant_num_groups = graph->create(kConstant);
constant_num_groups->t_(kvalue, tensor_num_groups);
constant_num_groups->insertBefore(node);

Node* div = graph->create(kDiv);
div->addInput(shape->output());
div->addInput(constant_num_grp->output());
div->addInput(C->output());
div->addInput(constant_num_groups->output());
div->insertBefore(node);

Tensor tensor_reshape0_shape;
tensor_reshape0_shape.elem_type() = TensorProto_DataType_INT64;
tensor_reshape0_shape.sizes() = {2};
tensor_reshape0_shape.int64s() = {-1, 1};
Node* constant_reshape0_shape = graph->create(kConstant);
constant_reshape0_shape->t_(kvalue, tensor_reshape0_shape);
constant_reshape0_shape->insertBefore(node);
Node* reshape0 = graph->create(kReshape);
reshape0->addInput(node->inputs()[input_id]);
reshape0->addInput(constant_reshape0_shape->output());
reshape0->insertBefore(node);

// Get Expand shape: [1, Div_out]
Tensor tensor_one;
tensor_one.elem_type() = TensorProto_DataType_INT64;
tensor_one.sizes() = {1};
Expand All @@ -71,29 +89,32 @@ class GroupNormalization_20_21 final : public Adapter {
concat->addInput(constant_one->output());
concat->addInput(div->output());
concat->insertBefore(node);
Node* expand = graph->create(kExpand);
expand->addInput(reshape0->output());
expand->addInput(concat->output());
expand->insertBefore(node);

// Get shape of first reshape: [-1, 1]
Tensor tensor_reshape0_shape;
tensor_reshape0_shape.elem_type() = TensorProto_DataType_INT64;
tensor_reshape0_shape.sizes() = {2};
tensor_reshape0_shape.int64s() = {-1, 1};
Node* constant_reshape0_shape = graph->create(kConstant);
constant_reshape0_shape->t_(kvalue, tensor_reshape0_shape);
constant_reshape0_shape->insertBefore(node);

// Get shape of last reshape: [-1]
Tensor tensor_reshape1_shape;
tensor_reshape1_shape.elem_type() = TensorProto_DataType_INT64;
tensor_reshape1_shape.sizes() = {1};
tensor_reshape1_shape.int64s() = {-1};
Node* constant_reshape1_shape = graph->create(kConstant);
constant_reshape1_shape->t_(kvalue, tensor_reshape1_shape);
constant_reshape1_shape->insertBefore(node);
Node* reshape1 = graph->create(kReshape);
reshape1->addInput(expand->output());
reshape1->addInput(constant_reshape1_shape->output());
reshape1->insertBefore(node);

node->replaceInput(input_id, reshape1->output());
}
// transform scale and bias
transform_input(
graph, node, 1, constant_reshape0_shape->output(), constant_reshape1_shape->output(), concat->output());
transform_input(
graph, node, 2, constant_reshape0_shape->output(), constant_reshape1_shape->output(), concat->output());

void adapt_group_normalization_20_21(std::shared_ptr<Graph> graph, Node* node) const {
transform_input(graph, node, 1);
transform_input(graph, node, 2);
// Set stash_type
node->i_(kstash_type, node->inputs()[0]->elemType());
}

Expand Down

0 comments on commit 3019560

Please sign in to comment.