Skip to content

Commit

Permalink
Ensure Consistent Use of List Type for block_type Shape Argument in s…
Browse files Browse the repository at this point in the history
…emantic.py and interpreter.py (#5128)

## Description:

This pull request addresses a type compatibility issue in the block_type
constructor across semantic.py and interpreter.py, specifically where
the shape parameter, expected as a List, was being provided as a Tuple
in certain cases.

This PR is intended to fix issues identified in #4860.

### Summary of Changes:

1. semantic.py - Updates in histogram Function and device_assert:

- histogram Function: The shape parameter was previously passed as a
Tuple when calling the block_type constructor. This has now been updated
to a List to align with the constructor's requirements and ensure type
consistency.
- device_assert Function: The device_assert function’s cond_ty and cond
assignments have been updated similarly. The block_type constructor and
create_splat call were previously provided a Tuple; this has been
updated to a List in both cases.

2. interpreter.py - Adjustment in ReduceOps.sum Method:

- In ReduceOps.sum, the to_tensor function was called with a
numpy.ndarray whose shape attribute is a Tuple. This shape is now
converted to a List before passing it to block_type.

### New Test Addition:

Added a new test to cover scenarios where a histogram operation is
followed by a broadcasting operation.

Co-authored-by: maxim.m66 <[email protected]>
Co-authored-by: Maxim <[email protected]>
Co-authored-by: peterbell10 <[email protected]>
  • Loading branch information
4 people authored Nov 13, 2024
1 parent a4a490b commit 385671e
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 2 deletions.
3 changes: 3 additions & 0 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2489,6 +2489,9 @@ def histogram_kernel(x_ptr, z_ptr, M: tl.constexpr, N: tl.constexpr):
offset2 = tl.arange(0, N)
x = tl.load(x_ptr + offset1)
z = tl.histogram(x, N)
bias = tl.full([M, N], 1, dtype=tl.int32)
# check that histogram produces object compatible with broadcasting
biased = z + bias
tl.store(z_ptr + offset2, z)

torch.manual_seed(17)
Expand Down
2 changes: 2 additions & 0 deletions python/triton/language/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,8 @@ def __init__(self, element_ty: dtype, shape: List):
# Note that block_type's shape is a list of int
# while tensor's shape is a list of constexpr.

assert (isinstance(shape, list))

# shape can be empty ([]) when an input is a 0D tensor.
self.shape = _unwrap_shape(shape)
if not self.shape:
Expand Down
2 changes: 1 addition & 1 deletion python/triton/language/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1679,7 +1679,7 @@ def associative_scan(inputs: Sequence[tl.tensor], axis: int, region_builder_fn,
def histogram(input: tl.tensor, num_bins: int, builder: ir.builder) -> tl.tensor:
assert len(input.shape) == 1, "histogram only supports 1D input"
assert input.dtype.is_int(), "histogram only supports integer input"
return tl.tensor(builder.create_histogram(input.handle, num_bins), tl.block_type(tl.int32, (num_bins, )))
return tl.tensor(builder.create_histogram(input.handle, num_bins), tl.block_type(tl.int32, [num_bins]))


##
Expand Down
2 changes: 1 addition & 1 deletion python/triton/runtime/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,7 +728,7 @@ def check_tensor(self, input):

def to_tensor(self, ret, dtype):
if hasattr(ret, "shape") and ret.shape:
ret_type = tl.block_type(dtype, ret.shape)
ret_type = tl.block_type(dtype, list(ret.shape))
else:
ret = np.array([ret]).astype(_get_np_dtype(dtype))
ret_type = dtype
Expand Down

0 comments on commit 385671e

Please sign in to comment.