Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 5 additions & 27 deletions src/fuse_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
*
*/
#include <migraphx/fuse_attention.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/match/softmax.hpp>
Expand Down Expand Up @@ -82,39 +83,16 @@ struct find_attention
}

std::vector<instruction_ref>
get_attn_instructions(module& m, instruction_ref start, instruction_ref end) const
get_attn_instructions(module& m, instruction_ref gemm1, instruction_ref gemm2) const
{
std::queue<instruction_ref> inputs;
std::unordered_set<instruction_ref> inss;
inputs.push(end);
auto attn_inss = find_instructions_between(gemm1, gemm2, &m);

static const std::unordered_set<std::string> valid_attn_ops = {
"reshape", "reduce_sum", "reduce_max", "broadcast", "multibroadcast", "@literal"};

auto is_valid_attn_op = [&](auto i) {
return i->get_operator().attributes().get("pointwise", false) or
contains(valid_attn_ops, i->get_operator().name()) or i == start or i == end;
};

while(not inputs.empty())
{
auto current_inp = inputs.front();
inputs.pop();

if(is_valid_attn_op(current_inp) and inss.insert(current_inp).second and
current_inp != start)
{
for(auto i : current_inp->inputs())
{
inputs.push(i);
}
}
}
std::vector<instruction_ref> sorted_inss(inss.begin(), inss.end());
std::vector<instruction_ref> sorted_inss(attn_inss.begin(), attn_inss.end());
std::sort(
sorted_inss.begin(), sorted_inss.end(), [&](instruction_ref x, instruction_ref y) {
return std::distance(m.begin(), x) < std::distance(m.begin(), y);
});

return sorted_inss;
}

Expand Down
5 changes: 5 additions & 0 deletions src/include/migraphx/instruction.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <migraphx/erase.hpp>
#include <migraphx/config.hpp>
#include <string>
#include <unordered_set>
#include <utility>

namespace migraphx {
Expand All @@ -49,6 +50,9 @@ MIGRAPHX_EXPORT bool reaches(instruction_ref start, instruction_ref end);

MIGRAPHX_EXPORT bool reaches(instruction_ref start, instruction_ref end, const_module_ref m);

MIGRAPHX_EXPORT std::unordered_set<instruction_ref>
find_instructions_between(instruction_ref start, instruction_ref end, const_module_ref m);

struct MIGRAPHX_EXPORT instruction
{
instruction() {}
Expand Down Expand Up @@ -185,6 +189,7 @@ struct MIGRAPHX_EXPORT instruction
bool normalized = false;
std::size_t target_id = 0;
};

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

Expand Down
29 changes: 27 additions & 2 deletions src/instruction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -578,9 +578,9 @@ bool reaches(instruction_ref start, instruction_ref end, const_module_ref m)
{
if(start == end)
return true;
if(not m->has_instruction(start) or not m->has_instruction(end))
if(not m->has_instruction(start) or not m->has_instruction(end) or
std::distance(m->begin(), start) > std::distance(m->begin(), end))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should not be added here, as this traverses the entire program which can be really slow. The algorithm for find_instructions_between needs to be updated to check for this since it can check a smaller subset of the program.

return false;
assert(std::distance(m->begin(), start) < std::distance(m->begin(), end));
std::size_t initial_distance = std::distance(start, end);
std::unordered_set<instruction_ref> visited;
return fix<bool>([&](auto self, auto ins) -> bool {
Expand All @@ -596,5 +596,30 @@ bool reaches(instruction_ref start, instruction_ref end, const_module_ref m)
})(end);
}

// Return set of all instructions that are connected to both start and end nodes (inclusive)
std::unordered_set<instruction_ref>
find_instructions_between(instruction_ref start, instruction_ref end, const_module_ref m)
{
std::queue<instruction_ref> inputs;
std::unordered_set<instruction_ref> inss;
inputs.push(end);

while(not inputs.empty())
{
auto current_inp = inputs.front();
inputs.pop();

if(reaches(start, current_inp, m) and inss.insert(current_inp).second and
current_inp != start)
{
for(auto i : current_inp->inputs())
{
inputs.push(i);
}
}
}
return inss;
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
141 changes: 133 additions & 8 deletions test/fuse_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ TEST_CASE(gemm_pw_softmax_gemm)
auto rmax = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {3}}}), where);
rmax = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}),
rmax);
auto sub = mm->add_instruction(migraphx::make_op("sub"), gemm1, rmax);
auto sub = mm->add_instruction(migraphx::make_op("sub"), where, rmax);
auto exp = mm->add_instruction(migraphx::make_op("exp"), sub);
auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {3}}}), exp);
rsum = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}),
Expand All @@ -154,36 +154,161 @@ TEST_CASE(gemm_pw_softmax_gemm)
auto select = mm->add_parameter("4", s2);
std::vector<float> eights(s1_elements, 0.125);
std::vector<float> tens(s1_elements, 10);
auto eight = mm->add_literal(migraphx::literal{s1, eights});
auto ten = mm->add_literal(migraphx::literal{s1, tens});
b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b);
b1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}),
b1);

auto group = add_group(
p2, "attn0", "attention", {a, b, select, b1}, [=](auto* gm, const auto& inputs) {
auto ten = gm->add_literal(migraphx::literal{s1, tens});
auto eight = gm->add_literal(migraphx::literal{s1, eights});
p2,
"attn0",
"attention",
{a, b, eight, select, ten, b1},
[=](auto* gm, const auto& inputs) {
auto gemm1 = gm->add_instruction(migraphx::make_op("dot"), inputs[0], inputs[1]);
auto mul = gm->add_instruction(migraphx::make_op("mul"), gemm1, eight);
auto where = gm->add_instruction(migraphx::make_op("where"), inputs[2], mul, ten);
auto mul = gm->add_instruction(migraphx::make_op("mul"), gemm1, inputs[2]);
auto where =
gm->add_instruction(migraphx::make_op("where"), inputs[3], mul, inputs[4]);
auto rmax =
gm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {3}}}), where);
rmax = gm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), rmax);
auto sub = gm->add_instruction(migraphx::make_op("sub"), gemm1, rmax);
auto sub = gm->add_instruction(migraphx::make_op("sub"), where, rmax);
auto exp = gm->add_instruction(migraphx::make_op("exp"), sub);
auto rsum =
gm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {3}}}), exp);
rsum = gm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), rsum);
auto div = gm->add_instruction(migraphx::make_op("div"), exp, rsum);

return gm->add_instruction(migraphx::make_op("dot"), div, inputs[3]);
return gm->add_instruction(migraphx::make_op("dot"), div, inputs[5]);
});
mm->add_return({group});
}
EXPECT(p1 == p2);
}

TEST_CASE(gemm_multi_use_pw_softmax_gemm)
{
migraphx::shape s1{migraphx::shape::float_type, {2, 4, 16, 8}};
migraphx::shape s2{migraphx::shape::float_type, {2, 4, 8, 16}};
migraphx::shape s3{migraphx::shape::float_type, {2, 4, 16, 16}};
migraphx::shape s_mask{migraphx::shape::int64_type, {2, 16}};
auto s1_elements = s1.elements();

migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto mask = mm->add_parameter("mask", s_mask);
auto x = mm->add_parameter("x", s1);

std::vector<float> c1_vec(s1_elements, 0.125);
std::vector<float> c2_vec(s1_elements, 10);
auto c1 = mm->add_literal(migraphx::literal(s2, c1_vec));
auto c2 = mm->add_literal(migraphx::literal(s1, c2_vec));
auto ten = mm->add_literal(migraphx::literal(10.0f));
auto zero = mm->add_literal(migraphx::literal(0.0f));
auto zero_int = mm->add_literal(migraphx::literal(0));
auto scale = mm->add_literal(migraphx::literal(0.25f));

mask = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::int32_type}}), mask);
mask = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), mask);
zero_int = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", mask->get_shape().lens()}}),
zero_int);
auto eq = mm->add_instruction(migraphx::make_op("equal"), mask, zero_int);
eq = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::bool_type}}), eq);

auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), x, c1);

eq =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s3.lens()}}), eq);
ten = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s3.lens()}}),
ten);
zero = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s3.lens()}}),
zero);
auto where = mm->add_instruction(migraphx::make_op("where"), eq, ten, zero);
scale = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s3.lens()}}),
scale);
auto add = mm->add_instruction(migraphx::make_op("add"), gemm1, where);
auto mul = mm->add_instruction(migraphx::make_op("mul"), add, scale);

auto rmax = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {3}}}), mul);
rmax = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", mul->get_shape().lens()}}), rmax);
auto sub = mm->add_instruction(migraphx::make_op("sub"), mul, rmax);
auto exp = mm->add_instruction(migraphx::make_op("exp"), sub);
auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {3}}}), exp);
rsum = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", mul->get_shape().lens()}}), rsum);
auto div = mm->add_instruction(migraphx::make_op("div"), exp, rsum);
auto gemm2 = mm->add_instruction(migraphx::make_op("dot"), div, c2);
mm->add_return({gemm2, zero, eq, scale});
}
run_pass(p1);

migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto mask = mm->add_parameter("mask", s_mask);
auto x = mm->add_parameter("x", s1);

std::vector<float> c1_vec(s1_elements, 0.125);
std::vector<float> c2_vec(s1_elements, 10);
auto c1 = mm->add_literal(migraphx::literal(s2, c1_vec));
auto c2 = mm->add_literal(migraphx::literal(s1, c2_vec));
auto ten = mm->add_literal(migraphx::literal(10.0f));
auto zero = mm->add_literal(migraphx::literal(0.0f));
auto zero_int = mm->add_literal(migraphx::literal(0));
auto scale = mm->add_literal(migraphx::literal(0.25f));

mask = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::int32_type}}), mask);
mask = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), mask);
zero_int = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", mask->get_shape().lens()}}),
zero_int);
auto eq = mm->add_instruction(migraphx::make_op("equal"), mask, zero_int);
eq = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::bool_type}}), eq);

eq =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s3.lens()}}), eq);
ten = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s3.lens()}}),
ten);
zero = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s3.lens()}}),
zero);
auto where = mm->add_instruction(migraphx::make_op("where"), eq, ten, zero);
scale = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s3.lens()}}),
scale);

auto group = add_group(
p2, "attn0", "attention", {x, c1, where, scale, c2}, [=](auto* gm, const auto& inputs) {
auto gemm1 = gm->add_instruction(migraphx::make_op("dot"), inputs[0], inputs[1]);
auto add = gm->add_instruction(migraphx::make_op("add"), gemm1, inputs[2]);
auto mul = gm->add_instruction(migraphx::make_op("mul"), add, inputs[3]);
auto rmax =
gm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {3}}}), mul);
rmax = gm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", s3.lens()}}), rmax);
auto sub = gm->add_instruction(migraphx::make_op("sub"), mul, rmax);
auto exp = gm->add_instruction(migraphx::make_op("exp"), sub);
auto rsum =
gm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {3}}}), exp);
rsum = gm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", s3.lens()}}), rsum);
auto div = gm->add_instruction(migraphx::make_op("div"), exp, rsum);

return gm->add_instruction(migraphx::make_op("dot"), div, inputs[4]);
});
mm->add_return({group, zero, eq, scale});
}
EXPECT(p1 == p2);
}

int main(int argc, const char* argv[])
{
test::run(argc, argv);
Expand Down
Loading