-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Ensure Consistent Use of List Type for block_type Shape Argument in semantic.py and interpreter.py #5128
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
peterbell10
reviewed
Nov 12, 2024
simonidaa
commented
Nov 12, 2024
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, running pre-commit produced this formatting
peterbell10
approved these changes
Nov 13, 2024
Luosuu
pushed a commit
to Luosuu/triton
that referenced
this pull request
Nov 13, 2024
…emantic.py and interpreter.py (triton-lang#5128) ## Description: This pull request addresses a type compatibility issue in the block_type constructor across semantic.py and interpreter.py, specifically where the shape parameter, expected as a List, was being provided as a Tuple in certain cases. This PR is intended to fix issues identified in triton-lang#4860. ### Summary of Changes: 1. semantic.py - Updates in histogram Function and device_assert: - histogram Function: The shape parameter was previously passed as a Tuple when calling the block_type constructor. This has now been updated to a List to align with the constructor's requirements and ensure type consistency. - device_assert Function: The device_assert function’s cond_ty and cond assignments have been updated similarly. The block_type constructor and create_splat call were previously provided a Tuple; this has been updated to a List in both cases. 2. interpreter.py - Adjustment in ReduceOps.sum Method: - In ReduceOps.sum, the to_tensor function was called with a numpy.ndarray whose shape attribute is a Tuple. This shape is now converted to a List before passing it to block_type. ### New Test Addition: Added a new test to cover scenarios where a histogram operation is followed by a broadcasting operation. Co-authored-by: maxim.m66 <[email protected]> Co-authored-by: Maxim <[email protected]> Co-authored-by: peterbell10 <[email protected]>
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description:
This pull request addresses a type compatibility issue in the block_type constructor across semantic.py and interpreter.py, specifically where the shape parameter, expected as a List, was being provided as a Tuple in certain cases.
This PR is intended to fix issues identified in #4860.
Summary of Changes:
New Test Addition:
Added a new test to cover scenarios where a histogram operation is followed by a broadcasting operation.
New contributor declaration
I am not making a trivial change, such as fixing a typo in a comment.
I have written a PR description following these
rules.
I have run
pre-commit run --from-ref origin/main --to-ref HEAD
.Select one of the following.
/test
forlit
tests/unittest
for C++ tests/python/test
for end-to-end testsFILL THIS IN
.Select one of the following.
lit
tests.lit
tests I have added follow these best practices,including the "tests should be minimal" section. (Usually running Python code
and using the instructions it generates is not minimal.)