Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -211,27 +211,44 @@ util.func private @pingpong_large_bf16(%lhs_base: !bf16_in_ty, %rhs_base: !bf16_
// Epilogue
%lhs_vec_0 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<8x1x1x4xbf16>
%rhs_vec_0 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<4x1x1x4xbf16>
%lhs_vec_1 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c1, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<8x1x1x4xbf16>
%rhs_vec_1 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c1, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<4x1x1x4xbf16>

gpu.barrier
rocdl.s.setprio 1 { iree_gpu.swap_mfma = 1 }
rocdl.sched.barrier 0

%dot0 = iree_codegen.inner_tiled ins(%lhs_vec_0, %rhs_vec_0) outs(%3) {
indexing_maps = #contraction_accesses,
iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_BF16, col_major = true>
} : vector<8x1x1x4xbf16>, vector<4x1x1x4xbf16> into vector<8x4x1x4xf32>
%lhs_vec_1 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c1, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<8x1x1x4xbf16>
%rhs_vec_1 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c1, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<4x1x1x4xbf16>
%dot1 = iree_codegen.inner_tiled ins(%lhs_vec_1, %rhs_vec_1) outs(%dot0) {
indexing_maps = #contraction_accesses,
iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_BF16, col_major = true>
} : vector<8x1x1x4xbf16>, vector<4x1x1x4xbf16> into vector<8x4x1x4xf32>

rocdl.s.setprio 0
gpu.barrier
rocdl.sched.barrier 0

%lhs_vec_2 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<8x1x1x4xbf16>
%rhs_vec_2 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<4x1x1x4xbf16>
%lhs_vec_3 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c3, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<8x1x1x4xbf16>
%rhs_vec_3 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c3, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<4x1x1x4xbf16>

scf.if %cmp1 {
rocdl.s.setprio 1 { iree_gpu.swap_mfma = 1 }
gpu.barrier
rocdl.sched.barrier 0
}

%dot2 = iree_codegen.inner_tiled ins(%lhs_vec_2, %rhs_vec_2) outs(%dot1) {
indexing_maps = #contraction_accesses,
iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_BF16, col_major = true>
} : vector<8x1x1x4xbf16>, vector<4x1x1x4xbf16> into vector<8x4x1x4xf32>
%lhs_vec_3 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c3, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<8x1x1x4xbf16>
%rhs_vec_3 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c3, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<4x1x1x4xbf16>
%dot3 = iree_codegen.inner_tiled ins(%lhs_vec_3, %rhs_vec_3) outs(%dot2) {
indexing_maps = #contraction_accesses,
iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
Expand Down Expand Up @@ -402,27 +419,30 @@ util.func private @pingpong_medium_bf16_expanded(%lhs_base: !mexp_in_ty_bf16, %r

scf.yield %dot2 : vector<4x4x1x4xf32>
}
scf.if %cmp1 {
rocdl.s.barrier
}

// Epilogue
%lhs_vec_0 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !mshared_exp_bf16, vector<4x1x2x4xbf16>
%rhs_vec_0 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp_bf16, vector<4x1x2x4xbf16>
%lhs_vec_0_t = vector.transpose %lhs_vec_0, [0, 2, 1, 3] : vector<4x1x2x4xbf16> to vector<4x2x1x4xbf16>
%rhs_vec_0_t = vector.transpose %rhs_vec_0, [0, 2, 1, 3] : vector<4x1x2x4xbf16> to vector<4x2x1x4xbf16>

%lhs_vec_2 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !mshared_exp_bf16, vector<4x1x2x4xbf16>
%rhs_vec_2 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp_bf16, vector<4x1x2x4xbf16>
%lhs_vec_2_t = vector.transpose %lhs_vec_2, [0, 2, 1, 3] : vector<4x1x2x4xbf16> to vector<4x2x1x4xbf16>
%rhs_vec_2_t = vector.transpose %rhs_vec_2, [0, 2, 1, 3] : vector<4x1x2x4xbf16> to vector<4x2x1x4xbf16>

rocdl.s.setprio 1 { iree_gpu.swap_mfma = 1 }
scf.if %cmp1 {
gpu.barrier
}
rocdl.sched.barrier 0

%dot0 = iree_codegen.inner_tiled ins(%lhs_vec_0_t, %rhs_vec_0_t) outs(%3) {
indexing_maps = #contraction_accesses,
iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_BF16, col_major = true>
} : vector<4x2x1x4xbf16>, vector<4x2x1x4xbf16> into vector<4x4x1x4xf32>

%lhs_vec_2 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !mshared_exp_bf16, vector<4x1x2x4xbf16>
%rhs_vec_2 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp_bf16, vector<4x1x2x4xbf16>
%lhs_vec_2_t = vector.transpose %lhs_vec_2, [0, 2, 1, 3] : vector<4x1x2x4xbf16> to vector<4x2x1x4xbf16>
%rhs_vec_2_t = vector.transpose %rhs_vec_2, [0, 2, 1, 3] : vector<4x1x2x4xbf16> to vector<4x2x1x4xbf16>

%dot2 = iree_codegen.inner_tiled ins(%lhs_vec_2_t, %rhs_vec_2_t) outs(%dot0) {
indexing_maps = #contraction_accesses,
iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
Expand Down Expand Up @@ -616,34 +636,47 @@ util.func private @pingpong_large_bf16_expanded(%lhs_base: !bf16_exp_in_ty, %rhs

scf.yield %dot3 : vector<8x4x1x4xf32>
}
scf.if %cmp1 {
rocdl.s.barrier
}

// Epilogue
%lhs_vec_0 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<8x1x1x4xbf16>
%rhs_vec_0 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<4x1x1x4xbf16>
%lhs_vec_1 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c1, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<8x1x1x4xbf16>
%rhs_vec_1 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c1, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<4x1x1x4xbf16>

rocdl.s.setprio 1 { iree_gpu.swap_mfma = 1 }
gpu.barrier
rocdl.sched.barrier 0

%dot0 = iree_codegen.inner_tiled ins(%lhs_vec_0, %rhs_vec_0) outs(%3) {
indexing_maps = #contraction_accesses,
iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_BF16, col_major = true>
} : vector<8x1x1x4xbf16>, vector<4x1x1x4xbf16> into vector<8x4x1x4xf32>
%lhs_vec_1 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c1, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<8x1x1x4xbf16>
%rhs_vec_1 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c1, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<4x1x1x4xbf16>
%dot1 = iree_codegen.inner_tiled ins(%lhs_vec_1, %rhs_vec_1) outs(%dot0) {
indexing_maps = #contraction_accesses,
iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_BF16, col_major = true>
} : vector<8x1x1x4xbf16>, vector<4x1x1x4xbf16> into vector<8x4x1x4xf32>

rocdl.s.setprio 0
gpu.barrier
rocdl.sched.barrier 0

%lhs_vec_2 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<8x1x1x4xbf16>
%rhs_vec_2 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<4x1x1x4xbf16>
%lhs_vec_3 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c3, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<8x1x1x4xbf16>
%rhs_vec_3 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c3, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<4x1x1x4xbf16>

scf.if %cmp1 {
rocdl.s.setprio 1 { iree_gpu.swap_mfma = 1 }
gpu.barrier
rocdl.sched.barrier 0
}

%dot2 = iree_codegen.inner_tiled ins(%lhs_vec_2, %rhs_vec_2) outs(%dot1) {
indexing_maps = #contraction_accesses,
iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_BF16, col_major = true>
} : vector<8x1x1x4xbf16>, vector<4x1x1x4xbf16> into vector<8x4x1x4xf32>
%lhs_vec_3 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c3, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<8x1x1x4xbf16>
%rhs_vec_3 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c3, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<4x1x1x4xbf16>
%dot3 = iree_codegen.inner_tiled ins(%lhs_vec_3, %rhs_vec_3) outs(%dot2) {
indexing_maps = #contraction_accesses,
iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
Expand Down
Loading
Loading