Skip to content

Commit 385671e

Browse files
simonidaamaxim-m66peterbell10
authored
Ensure Consistent Use of List Type for block_type Shape Argument in semantic.py and interpreter.py (#5128)
## Description: This pull request addresses a type compatibility issue in the block_type constructor across semantic.py and interpreter.py, specifically where the shape parameter, expected as a List, was being provided as a Tuple in certain cases. This PR is intended to fix issues identified in #4860. ### Summary of Changes: 1. semantic.py - Updates in histogram Function and device_assert: - histogram Function: The shape parameter was previously passed as a Tuple when calling the block_type constructor. This has now been updated to a List to align with the constructor's requirements and ensure type consistency. - device_assert Function: The device_assert function’s cond_ty and cond assignments have been updated similarly. The block_type constructor and create_splat call were previously provided a Tuple; this has been updated to a List in both cases. 2. interpreter.py - Adjustment in ReduceOps.sum Method: - In ReduceOps.sum, the to_tensor function was called with a numpy.ndarray whose shape attribute is a Tuple. This shape is now converted to a List before passing it to block_type. ### New Test Addition: Added a new test to cover scenarios where a histogram operation is followed by a broadcasting operation. Co-authored-by: maxim.m66 <[email protected]> Co-authored-by: Maxim <[email protected]> Co-authored-by: peterbell10 <[email protected]>
1 parent a4a490b commit 385671e

File tree

4 files changed

+7
-2
lines changed

4 files changed

+7
-2
lines changed

python/test/unit/language/test_core.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2489,6 +2489,9 @@ def histogram_kernel(x_ptr, z_ptr, M: tl.constexpr, N: tl.constexpr):
24892489
offset2 = tl.arange(0, N)
24902490
x = tl.load(x_ptr + offset1)
24912491
z = tl.histogram(x, N)
2492+
bias = tl.full([M, N], 1, dtype=tl.int32)
2493+
# check that histogram produces object compatible with broadcasting
2494+
biased = z + bias
24922495
tl.store(z_ptr + offset2, z)
24932496

24942497
torch.manual_seed(17)

python/triton/language/core.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,8 @@ def __init__(self, element_ty: dtype, shape: List):
609609
# Note that block_type's shape is a list of int
610610
# while tensor's shape is a list of constexpr.
611611

612+
assert (isinstance(shape, list))
613+
612614
# shape can be empty ([]) when an input is a 0D tensor.
613615
self.shape = _unwrap_shape(shape)
614616
if not self.shape:

python/triton/language/semantic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1679,7 +1679,7 @@ def associative_scan(inputs: Sequence[tl.tensor], axis: int, region_builder_fn,
16791679
def histogram(input: tl.tensor, num_bins: int, builder: ir.builder) -> tl.tensor:
16801680
assert len(input.shape) == 1, "histogram only supports 1D input"
16811681
assert input.dtype.is_int(), "histogram only supports integer input"
1682-
return tl.tensor(builder.create_histogram(input.handle, num_bins), tl.block_type(tl.int32, (num_bins, )))
1682+
return tl.tensor(builder.create_histogram(input.handle, num_bins), tl.block_type(tl.int32, [num_bins]))
16831683

16841684

16851685
##

python/triton/runtime/interpreter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -728,7 +728,7 @@ def check_tensor(self, input):
728728

729729
def to_tensor(self, ret, dtype):
730730
if hasattr(ret, "shape") and ret.shape:
731-
ret_type = tl.block_type(dtype, ret.shape)
731+
ret_type = tl.block_type(dtype, list(ret.shape))
732732
else:
733733
ret = np.array([ret]).astype(_get_np_dtype(dtype))
734734
ret_type = dtype

0 commit comments

Comments
 (0)