From 3d5186c05e45a6f903d42f23d470393bb0bdb1ce Mon Sep 17 00:00:00 2001 From: Hongtao Yu Date: Sun, 10 Nov 2024 15:58:16 -0800 Subject: [PATCH 1/2] [Frontend] Print config when autotune fails --- python/triton/runtime/autotuner.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/python/triton/runtime/autotuner.py b/python/triton/runtime/autotuner.py index b967f136a966..043c500fc62f 100644 --- a/python/triton/runtime/autotuner.py +++ b/python/triton/runtime/autotuner.py @@ -42,6 +42,7 @@ def __init__( self.keys = key self.cache = {} self.arg_names = arg_names + self.debug = os.environ.get("TRITON_DEBUG", "0") == "1" # Reset to zero or restore values self.reset_to_zero = [] @@ -131,6 +132,9 @@ def _post_hook(kwargs, exception): def _bench(self, *args, config, **meta): from ..compiler.errors import CompileTimeAssertionFailure + if self.debug: + print(f"Autotuning kernel {self.base_fn.__name__} with config {config}") + # check for conflicts, i.e. meta-parameters both provided # as kwargs and by the autotuner conflicts = meta.keys() & config.kwargs.keys() @@ -161,7 +165,9 @@ def kernel_call(): try: return self.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8)) - except (OutOfResources, CompileTimeAssertionFailure, PTXASError): + except (OutOfResources, CompileTimeAssertionFailure, PTXASError) as e: + if self.debug: + print(f"Autotuning failed with {e}") return [float("inf"), float("inf"), float("inf")] def run(self, *args, **kwargs): From 116dd9fae73634fcfea4cff9c1b863f59a06cfc5 Mon Sep 17 00:00:00 2001 From: Hongtao Yu Date: Wed, 13 Nov 2024 09:14:06 -0800 Subject: [PATCH 2/2] print under TRITON_PRINT_AUTOTUNING=1 --- python/triton/runtime/autotuner.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/triton/runtime/autotuner.py b/python/triton/runtime/autotuner.py index 043c500fc62f..573d9d41913d 100644 --- a/python/triton/runtime/autotuner.py +++ b/python/triton/runtime/autotuner.py @@ -42,7 +42,6 @@ def __init__( self.keys = key self.cache = {} self.arg_names = arg_names - self.debug = os.environ.get("TRITON_DEBUG", "0") == "1" # Reset to zero or restore values self.reset_to_zero = [] @@ -132,7 +131,8 @@ def _post_hook(kwargs, exception): def _bench(self, *args, config, **meta): from ..compiler.errors import CompileTimeAssertionFailure - if self.debug: + verbose = os.environ.get("TRITON_PRINT_AUTOTUNING", None) == "1" + if verbose: print(f"Autotuning kernel {self.base_fn.__name__} with config {config}") # check for conflicts, i.e. meta-parameters both provided @@ -166,7 +166,7 @@ def kernel_call(): try: return self.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8)) except (OutOfResources, CompileTimeAssertionFailure, PTXASError) as e: - if self.debug: + if verbose: print(f"Autotuning failed with {e}") return [float("inf"), float("inf"), float("inf")]