Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 3 additions & 27 deletions examples/07_gemm_all_scatter/matmul_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,37 +6,15 @@

# from streamk_kernel import streamk_gemm
from gemm_all_scatter import persistent_gemm_all_scatter
from examples.common.utils import is_triton_interpret_set
from examples.common.matmul_helpers import MatmulDebugMixin
import iris

gemm_kernel = persistent_gemm_all_scatter


class matmul(torch.autograd.Function):
_debug = False
_registers = None
_spills = None

class matmul(MatmulDebugMixin, torch.autograd.Function):
_num_xcds = iris.hip.get_num_xcc()

@staticmethod
def set_debug(debug: bool):
matmul._debug = debug

@staticmethod
def get_matmul_registers():
if matmul._debug:
return matmul._registers
else:
raise RuntimeError("Debug mode is not enabled. Call set_debug(True) first.")

@staticmethod
def get_matmul_spills():
if matmul._debug:
return matmul._spills
else:
raise RuntimeError("Debug mode is not enabled. Call set_debug(True) first.")

@staticmethod
def _call(
a: torch.Tensor,
Expand Down Expand Up @@ -119,9 +97,7 @@ def _call(
mm_end_timestamp_ptr=mm_end_timestamp,
)

if matmul._debug and not is_triton_interpret_set():
matmul._registers = kk.n_regs
matmul._spills = kk.n_spills
matmul._track_debug_info(kk)

return c

Expand Down
16 changes: 3 additions & 13 deletions examples/08_gemm_all_reduce_atomics/matmul_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,12 @@
# from streamk_kernel_atomic import streamk_gemm
from gemm_all_reduce_atomics import persistent_gemm_all_reduce

from examples.common.utils import is_triton_interpret_set
from examples.common.matmul_helpers import MatmulDebugMixin
import iris

gemm_kernel = persistent_gemm_all_reduce


class matmul(torch.autograd.Function):
_debug = True

@staticmethod
def set_debug(debug: bool):
matmul._debug = debug
matmul.streamk_registers = 0
matmul.streamk_spills = 0
class matmul(MatmulDebugMixin, torch.autograd.Function):

@staticmethod
def _call(
Expand Down Expand Up @@ -109,9 +101,7 @@ def _call(
mm_end_timestamp_ptr=mm_end_timestamp,
)

if matmul._debug and not is_triton_interpret_set():
matmul.streamk_registers = kk.n_regs
matmul.streamk_spills = kk.n_spills
matmul._track_debug_info(kk)

return c

Expand Down
19 changes: 3 additions & 16 deletions examples/09_gemm_one_shot_all_reduce/matmul_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,15 @@
# from streamk_kernel_atomic import streamk_gemm
from gemm_one_shot_all_reduce import persistent_gemm_all_reduce

from examples.common.utils import is_triton_interpret_set
from examples.common.matmul_helpers import MatmulDebugMixin
import iris

gemm_kernel = persistent_gemm_all_reduce


class matmul(torch.autograd.Function):
_debug = True
class matmul(MatmulDebugMixin, torch.autograd.Function):

_num_xcds = iris.hip.get_num_xcc()

@staticmethod
def set_debug(debug: bool):
matmul._debug = debug
matmul.streamk_registers = 0
matmul.streamk_spills = 0

@staticmethod
def _call(
a: torch.Tensor,
Expand Down Expand Up @@ -150,12 +142,7 @@ def _call(
mm_end_timestamp_ptr=mm_end_timestamp,
)

if matmul._debug and not is_triton_interpret_set():
matmul.streamk_registers = kk.n_regs
matmul.streamk_spills = kk.n_spills
print(f"{kk.n_regs} registers used, {kk.n_spills} spills")
# print(kk.asm['ttgir'])
# print(kk.asm['amdgcn'])
matmul._track_debug_info(kk)

return c

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,37 +8,15 @@
from gemm_all_scatter_wg_specialization import (
persistent_gemm_all_scatter_wg_specialization,
)
from examples.common.utils import is_triton_interpret_set
from examples.common.matmul_helpers import MatmulDebugMixin
import iris

gemm_kernel = persistent_gemm_all_scatter_wg_specialization


class matmul(torch.autograd.Function):
_debug = False
_registers = None
_spills = None
class matmul(MatmulDebugMixin, torch.autograd.Function):

_num_xcds = iris.hip.get_num_xcc()

@staticmethod
def set_debug(debug: bool):
matmul._debug = debug

@staticmethod
def get_matmul_registers():
if matmul._debug:
return matmul._registers
else:
raise RuntimeError("Debug mode is not enabled. Call set_debug(True) first.")

@staticmethod
def get_matmul_spills():
if matmul._debug:
return matmul._spills
else:
raise RuntimeError("Debug mode is not enabled. Call set_debug(True) first.")

@staticmethod
def _call(
a: torch.Tensor,
Expand Down Expand Up @@ -125,9 +103,7 @@ def _call(
mm_end_timestamp_ptr=mm_end_timestamp,
)

if matmul._debug and not is_triton_interpret_set():
matmul._registers = kk.n_regs
matmul._spills = kk.n_spills
matmul._track_debug_info(kk)

return c

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,37 +6,15 @@

# from streamk_kernel import streamk_gemm
from gemm_all_scatter_producer_consumer import persistent_gemm
from examples.common.utils import is_triton_interpret_set
from examples.common.matmul_helpers import MatmulDebugMixin
import iris

gemm_kernel = persistent_gemm


class matmul(torch.autograd.Function):
_debug = False
_registers = None
_spills = None
class matmul(MatmulDebugMixin, torch.autograd.Function):

_num_xcds = iris.hip.get_num_xcc()

@staticmethod
def set_debug(debug: bool):
matmul._debug = debug

@staticmethod
def get_matmul_registers():
if matmul._debug:
return matmul._registers
else:
raise RuntimeError("Debug mode is not enabled. Call set_debug(True) first.")

@staticmethod
def get_matmul_spills():
if matmul._debug:
return matmul._spills
else:
raise RuntimeError("Debug mode is not enabled. Call set_debug(True) first.")

@staticmethod
def _call(
a: torch.Tensor,
Expand Down Expand Up @@ -118,9 +96,7 @@ def _call(
mm_end_timestamp_ptr=mm_end_timestamp,
)

if matmul._debug and not is_triton_interpret_set():
matmul._registers = kk.n_regs
matmul._spills = kk.n_spills
matmul._track_debug_info(kk)

return c

Expand Down
30 changes: 3 additions & 27 deletions examples/12_gemm_all_scatter_bulk_synchronous/matmul_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,36 +6,14 @@

# from streamk_kernel import streamk_gemm
from gemm_all_scatter_bulk_synchronous import persistent_gemm
from examples.common.utils import is_triton_interpret_set
from examples.common.matmul_helpers import MatmulDebugMixin
import iris

gemm_kernel = persistent_gemm


class matmul(torch.autograd.Function):
_debug = False
_registers = None
_spills = None
class matmul(MatmulDebugMixin, torch.autograd.Function):
_num_xcds = iris.hip.get_num_xcc()

@staticmethod
def set_debug(debug: bool):
matmul._debug = debug

@staticmethod
def get_matmul_registers():
if matmul._debug:
return matmul._registers
else:
raise RuntimeError("Debug mode is not enabled. Call set_debug(True) first.")

@staticmethod
def get_matmul_spills():
if matmul._debug:
return matmul._spills
else:
raise RuntimeError("Debug mode is not enabled. Call set_debug(True) first.")

@staticmethod
def _call(
a: torch.Tensor,
Expand Down Expand Up @@ -114,9 +92,7 @@ def _call(
mm_end_timestamp_ptr=mm_end_timestamp,
)

if matmul._debug and not is_triton_interpret_set():
matmul._registers = kk.n_regs
matmul._spills = kk.n_spills
matmul._track_debug_info(kk)

return c

Expand Down
30 changes: 3 additions & 27 deletions examples/15_gemm_all_reduce_ring_based/matmul_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,37 +11,15 @@
# from streamk_kernel_atomic import streamk_gemm
from gemm_all_reduce_ring_based import persistent_gemm

from examples.common.utils import is_triton_interpret_set
from examples.common.matmul_helpers import MatmulDebugMixin
import iris

gemm_kernel = persistent_gemm


class matmul(torch.autograd.Function):
_debug = True
_registers = None
_spills = None
class matmul(MatmulDebugMixin, torch.autograd.Function):

_num_xcds = iris.hip.get_num_xcc()

@staticmethod
def set_debug(debug: bool):
matmul._debug = debug

@staticmethod
def get_matmul_registers():
if matmul._debug:
return matmul._registers
else:
raise RuntimeError("Debug mode is not enabled. Call set_debug(True) first.")

@staticmethod
def get_matmul_spills():
if matmul._debug:
return matmul._spills
else:
raise RuntimeError("Debug mode is not enabled. Call set_debug(True) first.")

@staticmethod
def _call(
a: torch.Tensor,
Expand Down Expand Up @@ -123,9 +101,7 @@ def _call(
mm_end_timestamp_ptr=mm_end_timestamp,
)

# if matmul._debug and not is_triton_interpret_set():
matmul._registers = kk.n_regs
matmul._spills = kk.n_spills
matmul._track_debug_info(kk)

return ring_buffer

Expand Down
30 changes: 3 additions & 27 deletions examples/20_gemm_all_scatter_independent/matmul_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,36 +6,14 @@

# from streamk_kernel import streamk_gemm
from gemm_all_scatter_bulk_synchronous import persistent_gemm
from examples.common.utils import is_triton_interpret_set
from examples.common.matmul_helpers import MatmulDebugMixin
import iris

gemm_kernel = persistent_gemm


class matmul(torch.autograd.Function):
_debug = False
_registers = None
_spills = None
class matmul(MatmulDebugMixin, torch.autograd.Function):
_num_xcds = iris.hip.get_num_xcc()

@staticmethod
def set_debug(debug: bool):
matmul._debug = debug

@staticmethod
def get_matmul_registers():
if matmul._debug:
return matmul._registers
else:
raise RuntimeError("Debug mode is not enabled. Call set_debug(True) first.")

@staticmethod
def get_matmul_spills():
if matmul._debug:
return matmul._spills
else:
raise RuntimeError("Debug mode is not enabled. Call set_debug(True) first.")

@staticmethod
def _call(
a: torch.Tensor,
Expand Down Expand Up @@ -114,9 +92,7 @@ def _call(
mm_end_timestamp_ptr=mm_end_timestamp,
)

if matmul._debug and not is_triton_interpret_set():
matmul._registers = kk.n_regs
matmul._spills = kk.n_spills
matmul._track_debug_info(kk)

return c

Expand Down
Loading