Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TKW] Igemm opt #203

Draft
wants to merge 14 commits into
base: main
Choose a base branch
from
49 changes: 34 additions & 15 deletions iree/turbine/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,12 @@ def _get_const(val):
_enforce_non_rational(lhs, term)
res = arith_d.andi(*_broadcast(lhs, rhs))
stack.append(res)
case sympy.logic.boolalg.BooleanFalse():
res = arith_d.constant(IntegerType.get_signless(1), 0)
stack.append(res)
case sympy.logic.boolalg.BooleanTrue():
res = arith_d.constant(IntegerType.get_signless(1), 1)
stack.append(res)
case sympy.UnevaluatedExpr():
continue
case _:
Expand Down Expand Up @@ -588,8 +594,8 @@ def _construct_gather_scatter_indices(

start_indices = _get_start_indices(result_index)
start_indices_orig = _get_start_indices(index)
dynamic_offsets = []

need_dynamic_offsets = False
start_indices_offset = _compute_offset(start_indices, strides)
for i in range(elements_per_thread):
# Update most-minor dim, i.e. in case of identity mapping it will
Expand All @@ -615,22 +621,29 @@ def _construct_gather_scatter_indices(
# arith ops and then `vector.insertelement` them into offsets vec.
offset = int(offset)
else:
dyn_offset = gen_sympy_index(add_emitter_subs(emitter), offset)
dynamic_offsets.append((i, dyn_offset))
offset = 0
need_dynamic_offsets = True
break

offsets.append(IntegerAttr.get(IndexType.get(), offset))

start_indices = _build_start_indices(emitter, result_index)
offsets_vec_type = VectorType.get([elements_per_thread], IndexType.get())

offsets_vec = arith_d.ConstantOp(
offsets_vec_type, DenseElementsAttr.get(offsets, offsets_vec_type)
)

for i, off in dynamic_offsets:
pos = arith_d.ConstantOp(IndexType.get(), i)
offsets_vec = vector_d.insertelement(off, offsets_vec, position=pos)
if need_dynamic_offsets:
result_index = {key: 0 for key in symbolc_shape}
start_indices = _build_start_indices(emitter, result_index)
subs = [(sym, idx) for sym, idx in zip(iters.keys(), start_indices_orig)]
subs[-1] = (
subs[-1][0],
start_indices_orig[-1] + idxc.iota(elements_per_thread),
)
indices = [i.subs(subs) for i in index_mapping]
offsets_vec = gen_sympy_index(
add_emitter_subs(emitter), _compute_offset(indices, strides)
)
else:
start_indices = _build_start_indices(emitter, result_index)
offsets_vec = arith_d.ConstantOp(
offsets_vec_type, DenseElementsAttr.get(offsets, offsets_vec_type)
)

mask = _build_mask(emitter, index, elements_per_thread)
if mask is None:
Expand Down Expand Up @@ -667,7 +680,10 @@ def handle_read(emitter: WaveEmitter, node: fx.Node):
mask = _build_mask(
emitter, index, cast_py_literal(emitter, elements_per_thread)
)
if mask is None:
if (
mask is None
or get_custom(node).memory_type.address_space == SHARED_ADDRESS_SPACE
):
result = vector_d.load(vector_type, kb_src, start_indices)
else:
zero = get_constant_attr(0, element_type)
Expand Down Expand Up @@ -730,7 +746,10 @@ def handle_write(emitter: WaveEmitter, node: fx.Node):
mask = _build_mask(
emitter, index, cast_py_literal(emitter, elements_per_thread)
)
if mask is None:
if (
mask is None
or get_custom(node).memory_type.address_space == SHARED_ADDRESS_SPACE
):
vector_d.store(insert_vector, kb_dest, start_indices)
else:
vector_d.maskedstore(kb_dest, start_indices, mask, insert_vector)
Expand Down
39 changes: 26 additions & 13 deletions tests/kernel/wave/wave_e2e_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,14 +757,14 @@ def test_igemm_conv(n, h, w, c, hf, wf, nf, stride, mem_space, layout, request):
inputs={
N: i // SZ_OUT,
C: j // (HF * WF),
H: (i % SZ_OUT) % W_OUT * stride + (j % (HF * WF)) % WF,
W: (i % SZ_OUT) // W_OUT * stride + (j % (HF * WF)) // WF,
W: (i % SZ_OUT) % W_OUT * stride + (j % (HF * WF)) % WF,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was this a bug? how were the tests passing if we had H and W switched?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This changes how we unroll conv window (rows vs colums first), it should produce semantically equivalent result as long as it synchronized between all mapping. But it also affects memory access pattern and this way it's slightly faster.

H: (i % SZ_OUT) // W_OUT * stride + (j % (HF * WF)) // WF,
},
outputs={M: i, K: j},
)
w_mapping = tkw.IndexMapping(
num_iterators=2,
inputs={NF: i % NF, C: j // (HF * WF), HF: j % WF, WF: (j % (HF * WF)) // WF},
inputs={NF: i % NF, C: j // (HF * WF), WF: j % WF, HF: (j % (HF * WF)) // WF},
outputs={NF: i, K: j},
)
out_mapping = tkw.IndexMapping(
Expand All @@ -773,15 +773,15 @@ def test_igemm_conv(n, h, w, c, hf, wf, nf, stride, mem_space, layout, request):
outputs={
N: i // SZ_OUT,
NF: j,
H_OUT: (i % SZ_OUT) % W_OUT,
W_OUT: (i % SZ_OUT) // W_OUT,
W_OUT: (i % SZ_OUT) % W_OUT,
H_OUT: (i % SZ_OUT) // W_OUT,
},
)

# Workgroup tile sizes
BLOCK_M = tkl.sym.BLOCK_M
BLOCK_N = tkl.sym.BLOCK_N
BLOCK_K = 16
BLOCK_K = tkl.sym.BLOCK_K
# Address space (for GPU, shared(1) or global(0))
ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE
# Other hyperparameters
Expand All @@ -801,18 +801,21 @@ def test_igemm_conv(n, h, w, c, hf, wf, nf, stride, mem_space, layout, request):
else:
raise ValueError(f"Invalid layout: {layout}")

ratio_m = 2
ratio_n = 2

# Expose user-constraints
constraints: list[tkw.Constraint] = []
constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
constraints += [tkw.WorkgroupConstraint(NF, BLOCK_N, 1)]
constraints += [tkw.WaveConstraint(M, BLOCK_M)]
constraints += [tkw.WaveConstraint(NF, BLOCK_N)]
constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 1)]
constraints += [tkw.WorkgroupConstraint(NF, BLOCK_N, 0)]
constraints += [tkw.WaveConstraint(M, BLOCK_M / ratio_m)]
constraints += [tkw.WaveConstraint(NF, BLOCK_N / ratio_n)]
constraints += [tkw.TilingConstraint(K, BLOCK_K)]

constraints += [
tkw.HardwareConstraint(
threads_per_wave=64,
waves_per_block=(1, 1, 1),
waves_per_block=(ratio_n, ratio_m, 1),
)
]

Expand Down Expand Up @@ -866,15 +869,25 @@ def repeat(acc: tkl.Register[M, NF, tkl.f32]) -> tkl.Register[M, NF, tkl.f32]:
NF: nf,
WF: wf,
HF: hf,
BLOCK_M: 16,
BLOCK_N: 16,
BLOCK_M: 64,
BLOCK_N: 128,
BLOCK_K: 32,
ELEMS_PER_THREAD: 4,
ADDRESS_SPACE: mem_space,
READ_SHARED_DELAY: 1,
WRITE_SHARED_DELAY: 1,
READ_GLOBAL_DELAY: 2,
WRITE_GLOBAL_DELAY: 2,
MMA_DELAY: 1,
SHARED_MEMORY_UNITS: 4,
GLOBAL_MEMORY_UNITS: 4,
MMA_UNITS: 4,
},
canonicalize=True,
run=True,
run_bench=run_bench,
run_config=config,
schedule=False,
):
out = torch.zeros_like(out_ref)
conv(x, we, out)
Expand Down
Loading