Add argmax/argmin as symbolic reduce ops with Z3 reasoning#316
Add argmax/argmin as symbolic reduce ops with Z3 reasoning#316
Conversation
Register argmax and argmin in REDUCE_OPS, _SUPPORTED_OPS, and _Z3_BUILDERS. The Z3 implementation uses an If-chain that tracks the best value and index, avoiding the concretize() fallback which would fail for nodes that only implement _to_z3_impl (e.g. where). GPC: support-tl-max-1-argmax-argmin-z3
f9ee11c to
2c8c0c7
Compare
Sanitizer Performance Benchmark
Iterations: 1 warmup + 40 measured |
|
@codex review. |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 9e7ad4954a
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| for i in range(1, len(arr)): | ||
| is_better = arr[i] > best_val | ||
| best_idx = If(is_better, IntVal(i), best_idx) | ||
| best_val = If(is_better, arr[i], best_val) | ||
| return best_idx, constraints |
There was a problem hiding this comment.
Respect axis in argmax/argmin reduction logic
This loop always scans the flattened arr and returns a single global index, but ReduceSymbolicExpr accepts axis/keepdims and computes shaped outputs from them. For any multi-dimensional input with axis set (the normal argmax/argmin case), the result is incorrect: it returns a flat index over all elements instead of per-axis indices, so symbolic checks can validate the wrong element and wrong index range.
Useful? React with 👍 / 👎.
Code reviewFound 1 issue:
The docstring for triton-viz/tests/unit/test_sanitizer.py Lines 99 to 106 in 2c8c0c7 🤖 Generated with Claude Code - If this code review was useful, please react with 👍. Otherwise, react with 👎. |
The docstring incorrectly referenced a non-existent "old concretize() fallback". Before this branch, argmax/argmin were not registered ops at all, so no such fallback ever existed on main.
fixed. |
|
|
||
|
|
||
| @pytest.mark.parametrize("op,np_op", [("argmax", np.argmax), ("argmin", np.argmin)]) | ||
| def test_reduce_argmax_argmin_z3_through_where(op: str, np_op): |
There was a problem hiding this comment.
Do you have any end to end examples?
| return IntVal(0), constraints | ||
| best_val = arr[0] | ||
| best_idx: Z3Expr = IntVal(0) | ||
| for i in range(1, len(arr)): |
There was a problem hiding this comment.
This is confusing to me. If arr is a SymbolicExpr, what's gonna happen?
Resolve conflict in ReduceSymbolicExpr.__init__: keep argmax/argmin dtype logic from this branch and add self.shape from main.
Summary
argmaxandargmininREDUCE_OPS,_SUPPORTED_OPS, and_Z3_BUILDERSconcretize()fallback_to_z3_impl(e.g.where)Test plan
pytest tests/unit/test_sanitizer.py::test_reduce_argmax_argmin_z3_through_where -xvs[FEAT] Add argmax/argmin as symbolic reduce ops with Z3 reasoning
Register argmax and argmin in REDUCE_OPS, _SUPPORTED_OPS, and
_Z3_BUILDERS. The Z3 implementation uses an If-chain that tracks
the best value and index, avoiding the concretize() fallback which
would fail for nodes that only implement _to_z3_impl (e.g. where).
PR chain