-
Notifications
You must be signed in to change notification settings - Fork 443
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Onnx op topk #2305
base: main
Are you sure you want to change the base?
Onnx op topk #2305
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2305 +/- ##
==========================================
- Coverage 85.79% 85.43% -0.36%
==========================================
Files 754 766 +12
Lines 95189 98065 +2876
==========================================
+ Hits 81671 83786 +2115
- Misses 13518 14279 +761 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the addition! 🙏
Overall, implementation looks good! I have a few comments below.
Anytime had fun working on it. Also I finished making the changes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the delay of the follow-up review 😅 been busy the past couple of days.
The changes look good to me! But we can remove the tensor API changes and ONNX op config for topk largest (see my comments)
let largest = match node.attrs.get("largest") { | ||
Some(val) => val.clone().into_i64(), | ||
_ => 1, | ||
}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you're not checking for "k" as the second input of the node (for opsets 10, 11) and just adding support for opset 1, then we don't need to check for the "largest" attribute here. It's only present in the later version 11 of the op.
So we can remove this from the config and node.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure thing will remove and resubmit tonight
@@ -726,20 +726,53 @@ where | |||
} | |||
|
|||
/// Returns the `k` largest elements of the given input tensor along a given dimension. | |||
pub fn topk(self, k: usize, dim: usize) -> Tensor<B, D, K> { | |||
pub fn topk(self, k: usize, dim: usize, largest: Option<usize>) -> Tensor<B, D, K> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we're only adding support for opset 1, the "largest" option is not necessary. We can remove the changes to the tensor API.
In any case, it would have been preferable to add it as another method, similar to sort
& sort_descending
. Something like topk
and topk_smallest
. But not required for this PR 🙂
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good will remove the tensor api changes as well
@@ -107,6 +107,7 @@ fn main() { | |||
.input("tests/sum/sum_int.onnx") | |||
.input("tests/tanh/tanh.onnx") | |||
.input("tests/tile/tile.onnx") | |||
.input("tests/top_k/top_k.onnx") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks the CI caught something I missed! This file doesn't exist anymore with your changes :)
This PR has been marked as stale because it has not been updated for over a month |
Pull Request Template
Checklist
run-checks all
script has been executed.test_rotary_encoding_forward
failed but passed when I ran it via the uiRelated Issues/PRs
#1714
Provide links to relevant issues and dependent PRs.
#1714
Changes
fn tanh_should_not_have_numerical_bugs_on_macos()
only run on macosIOEntry::Node
in theinput_names_map
. This was important as before it was never incremented with the number of outputs so you would always have _1 as the output name suffix even if there were x >= 2 outputs.Summarize the problem being addressed and your solution.
Testing
ran the below:
cargo test
./run-checks.sh all
Describe how these changes have been tested.
instructions listed here and here