From 6811c47c10fa1ec3cc5fded53fbb120260b7b2bb Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 7 Oct 2024 15:49:25 +0200 Subject: [PATCH 01/14] igemm layout Signed-off-by: Ivan Butygin --- tests/kernel/wave/wave_e2e_test.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/kernel/wave/wave_e2e_test.py b/tests/kernel/wave/wave_e2e_test.py index a3effd46..40f2d536 100644 --- a/tests/kernel/wave/wave_e2e_test.py +++ b/tests/kernel/wave/wave_e2e_test.py @@ -749,6 +749,20 @@ def test_igemm_conv(n, h, w, c, hf, wf, nf, stride, mem_space, layout, request): K = HF * WF * C M = SZ_OUT * N + if layout == "nchw_fchw": + x_type = tkl.Memory[N, C, H, W, ADDRESS_SPACE, tkl.f16] + we_type = tkl.Memory[NF, C, HF, WF, ADDRESS_SPACE, tkl.f16] + out_type = tkl.Memory[N, NF, H_OUT, W_OUT, GLOBAL_ADDRESS_SPACE, tkl.f32] + elif layout == "nhwc_hwcf": + x_type = tkl.Memory[N, H, W, C, ADDRESS_SPACE, tkl.f16] + we_type = tkl.Memory[HF, WF, C, NF, ADDRESS_SPACE, tkl.f16] + out_type = tkl.Memory[N, H_OUT, W_OUT, HF, GLOBAL_ADDRESS_SPACE, tkl.f32] + x = torch.permute(x, (0, 2, 3, 1)).clone(torch.contiguous_format) + we = torch.permute(we, (2, 3, 1, 0)).clone(torch.contiguous_format) + out_ref = torch.permute(out_ref, (0, 2, 3, 1)).clone(torch.contiguous_format) + else: + raise ValueError(f"Invalid layout: {layout}") + i = tkw.IndexMapping.iterator(0) j = tkw.IndexMapping.iterator(1) From ee211ca1c24216a6833a606021cc70d8284927f1 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 7 Oct 2024 15:51:45 +0200 Subject: [PATCH 02/14] fix Signed-off-by: Ivan Butygin --- tests/kernel/wave/wave_e2e_test.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/tests/kernel/wave/wave_e2e_test.py b/tests/kernel/wave/wave_e2e_test.py index 40f2d536..a3effd46 100644 --- a/tests/kernel/wave/wave_e2e_test.py +++ b/tests/kernel/wave/wave_e2e_test.py @@ -749,20 +749,6 @@ def test_igemm_conv(n, h, w, c, hf, wf, nf, stride, mem_space, layout, request): K = HF * WF * C M = SZ_OUT * N - if layout == "nchw_fchw": - x_type = tkl.Memory[N, C, H, W, ADDRESS_SPACE, tkl.f16] - we_type = tkl.Memory[NF, C, HF, WF, ADDRESS_SPACE, tkl.f16] - out_type = tkl.Memory[N, NF, H_OUT, W_OUT, GLOBAL_ADDRESS_SPACE, tkl.f32] - elif layout == "nhwc_hwcf": - x_type = tkl.Memory[N, H, W, C, ADDRESS_SPACE, tkl.f16] - we_type = tkl.Memory[HF, WF, C, NF, ADDRESS_SPACE, tkl.f16] - out_type = tkl.Memory[N, H_OUT, W_OUT, HF, GLOBAL_ADDRESS_SPACE, tkl.f32] - x = torch.permute(x, (0, 2, 3, 1)).clone(torch.contiguous_format) - we = torch.permute(we, (2, 3, 1, 0)).clone(torch.contiguous_format) - out_ref = torch.permute(out_ref, (0, 2, 3, 1)).clone(torch.contiguous_format) - else: - raise ValueError(f"Invalid layout: {layout}") - i = tkw.IndexMapping.iterator(0) j = tkw.IndexMapping.iterator(1) From 0696b6ef9abd43655faac759843dab43a736ce3e Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 7 Oct 2024 21:42:06 +0200 Subject: [PATCH 03/14] unmasked Signed-off-by: Ivan Butygin --- iree/turbine/kernel/wave/codegen.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/iree/turbine/kernel/wave/codegen.py b/iree/turbine/kernel/wave/codegen.py index bc6e54ed..65eda072 100644 --- a/iree/turbine/kernel/wave/codegen.py +++ b/iree/turbine/kernel/wave/codegen.py @@ -667,7 +667,10 @@ def handle_read(emitter: WaveEmitter, node: fx.Node): mask = _build_mask( emitter, index, cast_py_literal(emitter, elements_per_thread) ) - if mask is None: + if ( + mask is None + or get_custom(node).memory_type.address_space == SHARED_ADDRESS_SPACE + ): result = vector_d.load(vector_type, kb_src, start_indices) else: zero = get_constant_attr(0, element_type) @@ -730,7 +733,10 @@ def handle_write(emitter: WaveEmitter, node: fx.Node): mask = _build_mask( emitter, index, cast_py_literal(emitter, elements_per_thread) ) - if mask is None: + if ( + mask is None + or get_custom(node).memory_type.address_space == SHARED_ADDRESS_SPACE + ): vector_d.store(insert_vector, kb_dest, start_indices) else: vector_d.maskedstore(kb_dest, start_indices, mask, insert_vector) From 4df4023535c0b8f6690726fb9fb250df2286fcad Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 7 Oct 2024 21:54:20 +0200 Subject: [PATCH 04/14] hack offset Signed-off-by: Ivan Butygin --- iree/turbine/kernel/wave/codegen.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/iree/turbine/kernel/wave/codegen.py b/iree/turbine/kernel/wave/codegen.py index 65eda072..ce13d13e 100644 --- a/iree/turbine/kernel/wave/codegen.py +++ b/iree/turbine/kernel/wave/codegen.py @@ -594,8 +594,8 @@ def _construct_gather_scatter_indices( for i in range(elements_per_thread): # Update most-minor dim, i.e. in case of identity mapping it will # be equivalent to just vector.load - subs = [(sym, idx) for sym, idx in zip(iters.keys(), start_indices_orig)] - subs[-1] = (subs[-1][0], start_indices_orig[-1] + i) + subs = [(sym, 0) for sym in iters.keys()] + subs[-1] = (subs[-1][0], i) indices = [i.subs(subs) for i in index_mapping] # First, we build indices as if resulting gather/scatter `start_indices` @@ -605,9 +605,10 @@ def _construct_gather_scatter_indices( # simple cases like transpose, the resulting expression should fold into # simple constant while more complex expressions may requires actual # arith ops on dynamic values. - offset = _compute_offset(indices, strides) - start_indices_offset - offset = subs_idxc(offset) + offset = _compute_offset(indices, strides) + offset = sympy.simplify(subs_idxc(offset)) + print(offset) if offset.is_number: # If resulted offset sympy expr is convertible to int constant it # will be directly encoded into `arith.constant`. From 0393c897648157078715b05049503c516e6a7c3d Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 7 Oct 2024 21:59:00 +0200 Subject: [PATCH 05/14] Revert "hack offset" This reverts commit 4b8cd0c2aa9b920be4d93d14520ad6642d0e8ee6. Signed-off-by: Ivan Butygin --- iree/turbine/kernel/wave/codegen.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/iree/turbine/kernel/wave/codegen.py b/iree/turbine/kernel/wave/codegen.py index ce13d13e..65eda072 100644 --- a/iree/turbine/kernel/wave/codegen.py +++ b/iree/turbine/kernel/wave/codegen.py @@ -594,8 +594,8 @@ def _construct_gather_scatter_indices( for i in range(elements_per_thread): # Update most-minor dim, i.e. in case of identity mapping it will # be equivalent to just vector.load - subs = [(sym, 0) for sym in iters.keys()] - subs[-1] = (subs[-1][0], i) + subs = [(sym, idx) for sym, idx in zip(iters.keys(), start_indices_orig)] + subs[-1] = (subs[-1][0], start_indices_orig[-1] + i) indices = [i.subs(subs) for i in index_mapping] # First, we build indices as if resulting gather/scatter `start_indices` @@ -605,10 +605,9 @@ def _construct_gather_scatter_indices( # simple cases like transpose, the resulting expression should fold into # simple constant while more complex expressions may requires actual # arith ops on dynamic values. - offset = _compute_offset(indices, strides) - offset = sympy.simplify(subs_idxc(offset)) + offset = _compute_offset(indices, strides) - start_indices_offset + offset = subs_idxc(offset) - print(offset) if offset.is_number: # If resulted offset sympy expr is convertible to int constant it # will be directly encoded into `arith.constant`. From 6dff35d0e98f5c351264960c820b093f0ab60982 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 7 Oct 2024 22:43:45 +0200 Subject: [PATCH 06/14] gather indices Signed-off-by: Ivan Butygin --- iree/turbine/kernel/wave/codegen.py | 33 +++++++++++++++++------------ 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/iree/turbine/kernel/wave/codegen.py b/iree/turbine/kernel/wave/codegen.py index 65eda072..786ba5d7 100644 --- a/iree/turbine/kernel/wave/codegen.py +++ b/iree/turbine/kernel/wave/codegen.py @@ -588,8 +588,8 @@ def _construct_gather_scatter_indices( start_indices = _get_start_indices(result_index) start_indices_orig = _get_start_indices(index) - dynamic_offsets = [] + need_dynamic_offsets = False start_indices_offset = _compute_offset(start_indices, strides) for i in range(elements_per_thread): # Update most-minor dim, i.e. in case of identity mapping it will @@ -615,22 +615,29 @@ def _construct_gather_scatter_indices( # arith ops and then `vector.insertelement` them into offsets vec. offset = int(offset) else: - dyn_offset = gen_sympy_index(add_emitter_subs(emitter), offset) - dynamic_offsets.append((i, dyn_offset)) - offset = 0 + need_dynamic_offsets = True + break offsets.append(IntegerAttr.get(IndexType.get(), offset)) - start_indices = _build_start_indices(emitter, result_index) offsets_vec_type = VectorType.get([elements_per_thread], IndexType.get()) - - offsets_vec = arith_d.ConstantOp( - offsets_vec_type, DenseElementsAttr.get(offsets, offsets_vec_type) - ) - - for i, off in dynamic_offsets: - pos = arith_d.ConstantOp(IndexType.get(), i) - offsets_vec = vector_d.insertelement(off, offsets_vec, position=pos) + if need_dynamic_offsets: + result_index = {key: 0 for key in symbolc_shape} + start_indices = _build_start_indices(emitter, result_index) + subs = [(sym, idx) for sym, idx in zip(iters.keys(), start_indices_orig)] + subs[-1] = ( + subs[-1][0], + start_indices_orig[-1] + idxc.iota(elements_per_thread), + ) + indices = [i.subs(subs) for i in index_mapping] + offsets_vec = gen_sympy_index( + add_emitter_subs(emitter), _compute_offset(indices, strides) + ) + else: + start_indices = _build_start_indices(emitter, result_index) + offsets_vec = arith_d.ConstantOp( + offsets_vec_type, DenseElementsAttr.get(offsets, offsets_vec_type) + ) mask = _build_mask(emitter, index, elements_per_thread) if mask is None: From 374b06dbf6253b0934ed0c0d9587d33844a57b96 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 7 Oct 2024 23:14:00 +0200 Subject: [PATCH 07/14] tiling Signed-off-by: Ivan Butygin --- tests/kernel/wave/wave_e2e_test.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/kernel/wave/wave_e2e_test.py b/tests/kernel/wave/wave_e2e_test.py index a3effd46..aab052db 100644 --- a/tests/kernel/wave/wave_e2e_test.py +++ b/tests/kernel/wave/wave_e2e_test.py @@ -781,7 +781,7 @@ def test_igemm_conv(n, h, w, c, hf, wf, nf, stride, mem_space, layout, request): # Workgroup tile sizes BLOCK_M = tkl.sym.BLOCK_M BLOCK_N = tkl.sym.BLOCK_N - BLOCK_K = 16 + BLOCK_K = tkl.sym.BLOCK_K # Address space (for GPU, shared(1) or global(0)) ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE # Other hyperparameters @@ -805,8 +805,8 @@ def test_igemm_conv(n, h, w, c, hf, wf, nf, stride, mem_space, layout, request): constraints: list[tkw.Constraint] = [] constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] constraints += [tkw.WorkgroupConstraint(NF, BLOCK_N, 1)] - constraints += [tkw.WaveConstraint(M, BLOCK_M)] - constraints += [tkw.WaveConstraint(NF, BLOCK_N)] + constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)] + constraints += [tkw.WaveConstraint(NF, BLOCK_N / 2)] constraints += [tkw.TilingConstraint(K, BLOCK_K)] constraints += [ @@ -866,8 +866,9 @@ def repeat(acc: tkl.Register[M, NF, tkl.f32]) -> tkl.Register[M, NF, tkl.f32]: NF: nf, WF: wf, HF: hf, - BLOCK_M: 16, - BLOCK_N: 16, + BLOCK_M: 64, + BLOCK_N: 64, + BLOCK_K: 32, ELEMS_PER_THREAD: 4, ADDRESS_SPACE: mem_space, }, From ee9b9541a619bfc9e19734ee884580cbecd02159 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 7 Oct 2024 23:24:22 +0200 Subject: [PATCH 08/14] fix sympy false Signed-off-by: Ivan Butygin --- iree/turbine/kernel/wave/codegen.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/iree/turbine/kernel/wave/codegen.py b/iree/turbine/kernel/wave/codegen.py index 786ba5d7..f6ad5cda 100644 --- a/iree/turbine/kernel/wave/codegen.py +++ b/iree/turbine/kernel/wave/codegen.py @@ -401,6 +401,12 @@ def _get_const(val): _enforce_non_rational(lhs, term) res = arith_d.andi(*_broadcast(lhs, rhs)) stack.append(res) + case sympy.logic.boolalg.BooleanFalse(): + res = arith_d.constant(IntegerType.get_signless(1), 0) + stack.append(res) + case sympy.logic.boolalg.BooleanTrue(): + res = arith_d.constant(IntegerType.get_signless(1), 1) + stack.append(res) case sympy.UnevaluatedExpr(): continue case _: From 2c007c1e269325ced78b03bbdf78caccf374e56f Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 7 Oct 2024 23:46:58 +0200 Subject: [PATCH 09/14] schedule Signed-off-by: Ivan Butygin --- tests/kernel/wave/wave_e2e_test.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/kernel/wave/wave_e2e_test.py b/tests/kernel/wave/wave_e2e_test.py index aab052db..9e1c90e5 100644 --- a/tests/kernel/wave/wave_e2e_test.py +++ b/tests/kernel/wave/wave_e2e_test.py @@ -871,11 +871,20 @@ def repeat(acc: tkl.Register[M, NF, tkl.f32]) -> tkl.Register[M, NF, tkl.f32]: BLOCK_K: 32, ELEMS_PER_THREAD: 4, ADDRESS_SPACE: mem_space, + READ_SHARED_DELAY: 1, + WRITE_SHARED_DELAY: 1, + READ_GLOBAL_DELAY: 2, + WRITE_GLOBAL_DELAY: 2, + MMA_DELAY: 1, + SHARED_MEMORY_UNITS: 4, + GLOBAL_MEMORY_UNITS: 4, + MMA_UNITS: 4, }, canonicalize=True, run=True, run_bench=run_bench, run_config=config, + schedule=True, ): out = torch.zeros_like(out_ref) conv(x, we, out) From 2222a9a356cfcfc4d4afac01a198de98d493d4a0 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Tue, 8 Oct 2024 02:19:40 +0200 Subject: [PATCH 10/14] block sizes Signed-off-by: Ivan Butygin --- tests/kernel/wave/wave_e2e_test.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/kernel/wave/wave_e2e_test.py b/tests/kernel/wave/wave_e2e_test.py index 9e1c90e5..3d9b8aef 100644 --- a/tests/kernel/wave/wave_e2e_test.py +++ b/tests/kernel/wave/wave_e2e_test.py @@ -803,8 +803,8 @@ def test_igemm_conv(n, h, w, c, hf, wf, nf, stride, mem_space, layout, request): # Expose user-constraints constraints: list[tkw.Constraint] = [] - constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] - constraints += [tkw.WorkgroupConstraint(NF, BLOCK_N, 1)] + constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 1)] + constraints += [tkw.WorkgroupConstraint(NF, BLOCK_N, 0)] constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)] constraints += [tkw.WaveConstraint(NF, BLOCK_N / 2)] constraints += [tkw.TilingConstraint(K, BLOCK_K)] @@ -867,7 +867,7 @@ def repeat(acc: tkl.Register[M, NF, tkl.f32]) -> tkl.Register[M, NF, tkl.f32]: WF: wf, HF: hf, BLOCK_M: 64, - BLOCK_N: 64, + BLOCK_N: 128, BLOCK_K: 32, ELEMS_PER_THREAD: 4, ADDRESS_SPACE: mem_space, From 8b5242b33bd2a45e81fe7ecbbdd1b125cf3ed4a6 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Tue, 8 Oct 2024 02:21:12 +0200 Subject: [PATCH 11/14] rearrange acess pattern Signed-off-by: Ivan Butygin --- tests/kernel/wave/wave_e2e_test.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/kernel/wave/wave_e2e_test.py b/tests/kernel/wave/wave_e2e_test.py index 3d9b8aef..472f2747 100644 --- a/tests/kernel/wave/wave_e2e_test.py +++ b/tests/kernel/wave/wave_e2e_test.py @@ -757,14 +757,14 @@ def test_igemm_conv(n, h, w, c, hf, wf, nf, stride, mem_space, layout, request): inputs={ N: i // SZ_OUT, C: j // (HF * WF), - H: (i % SZ_OUT) % W_OUT * stride + (j % (HF * WF)) % WF, - W: (i % SZ_OUT) // W_OUT * stride + (j % (HF * WF)) // WF, + W: (i % SZ_OUT) % W_OUT * stride + (j % (HF * WF)) % WF, + H: (i % SZ_OUT) // W_OUT * stride + (j % (HF * WF)) // WF, }, outputs={M: i, K: j}, ) w_mapping = tkw.IndexMapping( num_iterators=2, - inputs={NF: i % NF, C: j // (HF * WF), HF: j % WF, WF: (j % (HF * WF)) // WF}, + inputs={NF: i % NF, C: j // (HF * WF), WF: j % WF, HF: (j % (HF * WF)) // WF}, outputs={NF: i, K: j}, ) out_mapping = tkw.IndexMapping( @@ -773,8 +773,8 @@ def test_igemm_conv(n, h, w, c, hf, wf, nf, stride, mem_space, layout, request): outputs={ N: i // SZ_OUT, NF: j, - H_OUT: (i % SZ_OUT) % W_OUT, - W_OUT: (i % SZ_OUT) // W_OUT, + W_OUT: (i % SZ_OUT) % W_OUT, + H_OUT: (i % SZ_OUT) // W_OUT, }, ) From 6f32d2807f353be889a01e887df9a412a5869de9 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Tue, 8 Oct 2024 03:10:50 +0200 Subject: [PATCH 12/14] waves_per_block Signed-off-by: Ivan Butygin --- tests/kernel/wave/wave_e2e_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/kernel/wave/wave_e2e_test.py b/tests/kernel/wave/wave_e2e_test.py index 472f2747..3221096e 100644 --- a/tests/kernel/wave/wave_e2e_test.py +++ b/tests/kernel/wave/wave_e2e_test.py @@ -812,7 +812,7 @@ def test_igemm_conv(n, h, w, c, hf, wf, nf, stride, mem_space, layout, request): constraints += [ tkw.HardwareConstraint( threads_per_wave=64, - waves_per_block=(1, 1, 1), + waves_per_block=(2, 2, 1), ) ] From 639aa70be8209c4d3b6ccc3410069b76fbbff8ef Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Tue, 8 Oct 2024 17:42:17 +0200 Subject: [PATCH 13/14] ratios Signed-off-by: Ivan Butygin --- tests/kernel/wave/wave_e2e_test.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/kernel/wave/wave_e2e_test.py b/tests/kernel/wave/wave_e2e_test.py index 3221096e..21f89374 100644 --- a/tests/kernel/wave/wave_e2e_test.py +++ b/tests/kernel/wave/wave_e2e_test.py @@ -801,18 +801,21 @@ def test_igemm_conv(n, h, w, c, hf, wf, nf, stride, mem_space, layout, request): else: raise ValueError(f"Invalid layout: {layout}") + ratio_m = 2 + ratio_n = 2 + # Expose user-constraints constraints: list[tkw.Constraint] = [] constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 1)] constraints += [tkw.WorkgroupConstraint(NF, BLOCK_N, 0)] - constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)] - constraints += [tkw.WaveConstraint(NF, BLOCK_N / 2)] + constraints += [tkw.WaveConstraint(M, BLOCK_M / ratio_m)] + constraints += [tkw.WaveConstraint(NF, BLOCK_N / ratio_n)] constraints += [tkw.TilingConstraint(K, BLOCK_K)] constraints += [ tkw.HardwareConstraint( threads_per_wave=64, - waves_per_block=(2, 2, 1), + waves_per_block=(ratio_n, ratio_m, 1), ) ] From a6804db8a55ee797f905c5c5024faac535aa2f26 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Tue, 8 Oct 2024 23:09:09 +0200 Subject: [PATCH 14/14] disable schedule Signed-off-by: Ivan Butygin --- tests/kernel/wave/wave_e2e_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/kernel/wave/wave_e2e_test.py b/tests/kernel/wave/wave_e2e_test.py index 21f89374..b7da8b9d 100644 --- a/tests/kernel/wave/wave_e2e_test.py +++ b/tests/kernel/wave/wave_e2e_test.py @@ -887,7 +887,7 @@ def repeat(acc: tkl.Register[M, NF, tkl.f32]) -> tkl.Register[M, NF, tkl.f32]: run=True, run_bench=run_bench, run_config=config, - schedule=True, + schedule=False, ): out = torch.zeros_like(out_ref) conv(x, we, out)