From 385671ed9b3fad53d3efbe5971d933c039c541ff Mon Sep 17 00:00:00 2001 From: simonidaa <126631512+simonidaa@users.noreply.github.com> Date: Wed, 13 Nov 2024 02:53:33 +0100 Subject: [PATCH] Ensure Consistent Use of List Type for block_type Shape Argument in semantic.py and interpreter.py (#5128) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Description: This pull request addresses a type compatibility issue in the block_type constructor across semantic.py and interpreter.py, specifically where the shape parameter, expected as a List, was being provided as a Tuple in certain cases. This PR is intended to fix issues identified in #4860. ### Summary of Changes: 1. semantic.py - Updates in histogram Function and device_assert: - histogram Function: The shape parameter was previously passed as a Tuple when calling the block_type constructor. This has now been updated to a List to align with the constructor's requirements and ensure type consistency. - device_assert Function: The device_assert function’s cond_ty and cond assignments have been updated similarly. The block_type constructor and create_splat call were previously provided a Tuple; this has been updated to a List in both cases. 2. interpreter.py - Adjustment in ReduceOps.sum Method: - In ReduceOps.sum, the to_tensor function was called with a numpy.ndarray whose shape attribute is a Tuple. This shape is now converted to a List before passing it to block_type. ### New Test Addition: Added a new test to cover scenarios where a histogram operation is followed by a broadcasting operation. Co-authored-by: maxim.m66 Co-authored-by: Maxim <134897529+maxim-m66@users.noreply.github.com> Co-authored-by: peterbell10 --- python/test/unit/language/test_core.py | 3 +++ python/triton/language/core.py | 2 ++ python/triton/language/semantic.py | 2 +- python/triton/runtime/interpreter.py | 2 +- 4 files changed, 7 insertions(+), 2 deletions(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 8e516f22eb60..b12441f3a08a 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -2489,6 +2489,9 @@ def histogram_kernel(x_ptr, z_ptr, M: tl.constexpr, N: tl.constexpr): offset2 = tl.arange(0, N) x = tl.load(x_ptr + offset1) z = tl.histogram(x, N) + bias = tl.full([M, N], 1, dtype=tl.int32) + # check that histogram produces object compatible with broadcasting + biased = z + bias tl.store(z_ptr + offset2, z) torch.manual_seed(17) diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 27af6c4ebb4b..d3b5269b4461 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -609,6 +609,8 @@ def __init__(self, element_ty: dtype, shape: List): # Note that block_type's shape is a list of int # while tensor's shape is a list of constexpr. + assert (isinstance(shape, list)) + # shape can be empty ([]) when an input is a 0D tensor. self.shape = _unwrap_shape(shape) if not self.shape: diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 3ab72bc87a94..61f6d3948d97 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1679,7 +1679,7 @@ def associative_scan(inputs: Sequence[tl.tensor], axis: int, region_builder_fn, def histogram(input: tl.tensor, num_bins: int, builder: ir.builder) -> tl.tensor: assert len(input.shape) == 1, "histogram only supports 1D input" assert input.dtype.is_int(), "histogram only supports integer input" - return tl.tensor(builder.create_histogram(input.handle, num_bins), tl.block_type(tl.int32, (num_bins, ))) + return tl.tensor(builder.create_histogram(input.handle, num_bins), tl.block_type(tl.int32, [num_bins])) ## diff --git a/python/triton/runtime/interpreter.py b/python/triton/runtime/interpreter.py index 8e279b318ca2..7c53697429ac 100644 --- a/python/triton/runtime/interpreter.py +++ b/python/triton/runtime/interpreter.py @@ -728,7 +728,7 @@ def check_tensor(self, input): def to_tensor(self, ret, dtype): if hasattr(ret, "shape") and ret.shape: - ret_type = tl.block_type(dtype, ret.shape) + ret_type = tl.block_type(dtype, list(ret.shape)) else: ret = np.array([ret]).astype(_get_np_dtype(dtype)) ret_type = dtype