Skip to content

Add argmax/argmin as symbolic reduce ops with Z3 reasoning#316

Open
mark14wu wants to merge 4 commits intomainfrom
support-tl-max-1-argmax-argmin-z3
Open

Add argmax/argmin as symbolic reduce ops with Z3 reasoning#316
mark14wu wants to merge 4 commits intomainfrom
support-tl-max-1-argmax-argmin-z3

Conversation

@mark14wu
Copy link
Collaborator

@mark14wu mark14wu commented Mar 6, 2026

Summary

  • Register argmax and argmin in REDUCE_OPS, _SUPPORTED_OPS, and _Z3_BUILDERS
  • Z3 implementation uses an If-chain that tracks best value and index, avoiding concretize() fallback
  • Fixes failure when input flows through nodes that only have _to_z3_impl (e.g. where)

Test plan

  • pytest tests/unit/test_sanitizer.py::test_reduce_argmax_argmin_z3_through_where -xvs

[FEAT] Add argmax/argmin as symbolic reduce ops with Z3 reasoning

Register argmax and argmin in REDUCE_OPS, _SUPPORTED_OPS, and
_Z3_BUILDERS. The Z3 implementation uses an If-chain that tracks
the best value and index, avoiding the concretize() fallback which
would fail for nodes that only implement _to_z3_impl (e.g. where).

PR chain

  1. 👉 Add argmax/argmin as symbolic reduce ops with Z3 reasoning #316 👈 YOU ARE HERE
  2. Add concretize() fallback for ReduceSymbolicExpr #317
  3. Support tl.max/tl.min with return_indices=True #318

Register argmax and argmin in REDUCE_OPS, _SUPPORTED_OPS, and
_Z3_BUILDERS. The Z3 implementation uses an If-chain that tracks
the best value and index, avoiding the concretize() fallback which
would fail for nodes that only implement _to_z3_impl (e.g. where).

GPC: support-tl-max-1-argmax-argmin-z3
@github-actions
Copy link

github-actions bot commented Mar 6, 2026

Sanitizer Performance Benchmark

Benchmark main (min) PR (min) Change
simple_load_store 0.005s 0.005s +0.7%
gemm 0.020s 0.021s +1.4%
gemm_oob 0.021s 0.021s +0.3%
indirect_load 0.066s 0.067s +0.8%
nested_loop 0.022s 0.022s +0.5%
block_pointer_loop_advance 0.007s 0.007s +0.4%
liger_jsd 0.144s 0.145s +0.5%
flaggems_layernorm 2.474s N/A N/A
Total 2.760s N/A N/A

Iterations: 1 warmup + 40 measured

@mark14wu
Copy link
Collaborator Author

mark14wu commented Mar 6, 2026

@codex review.

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 9e7ad4954a

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +1192 to +1196
for i in range(1, len(arr)):
is_better = arr[i] > best_val
best_idx = If(is_better, IntVal(i), best_idx)
best_val = If(is_better, arr[i], best_val)
return best_idx, constraints

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Respect axis in argmax/argmin reduction logic

This loop always scans the flattened arr and returns a single global index, but ReduceSymbolicExpr accepts axis/keepdims and computes shaped outputs from them. For any multi-dimensional input with axis set (the normal argmax/argmin case), the result is incorrect: it returns a flat index over all elements instead of per-axis indices, so symbolic checks can validate the wrong element and wrong index range.

Useful? React with 👍 / 👎.

@mark14wu
Copy link
Collaborator Author

mark14wu commented Mar 6, 2026

Code review

Found 1 issue:

  1. Misleading test docstring references a non-existent "old concretize() fallback" (factually incorrect documentation)

The docstring for test_reduce_argmax_argmin_z3_through_where claims "the old concretize() fallback would raise NotImplementedError." Before this PR, argmax/argmin were not registered ops at all — calling SymbolicExpr.create("argmax", ...) would raise NotImplementedError("Unsupported reduce op: argmax") from ReduceSymbolicExpr.__init__, not from any concretize() path. The docstring appears to describe intermediate development commits that were squashed and never existed on main.

def test_reduce_argmax_argmin_z3_through_where(op: str, np_op):
"""argmax/argmin should use Z3 symbolic path, not concretize().
When the input flows through a node that only has _to_z3_impl (like
``where``), the old concretize() fallback would raise
NotImplementedError. The Z3 If-chain implementation avoids this by
staying on the symbolic path end-to-end.
"""

🤖 Generated with Claude Code

- If this code review was useful, please react with 👍. Otherwise, react with 👎.

The docstring incorrectly referenced a non-existent "old concretize()
fallback". Before this branch, argmax/argmin were not registered ops
at all, so no such fallback ever existed on main.
@mark14wu
Copy link
Collaborator Author

mark14wu commented Mar 6, 2026

Code review

Found 1 issue:

  1. Misleading test docstring references a non-existent "old concretize() fallback" (factually incorrect documentation)

The docstring for test_reduce_argmax_argmin_z3_through_where claims "the old concretize() fallback would raise NotImplementedError." Before this PR, argmax/argmin were not registered ops at all — calling SymbolicExpr.create("argmax", ...) would raise NotImplementedError("Unsupported reduce op: argmax") from ReduceSymbolicExpr.__init__, not from any concretize() path. The docstring appears to describe intermediate development commits that were squashed and never existed on main.

def test_reduce_argmax_argmin_z3_through_where(op: str, np_op):
"""argmax/argmin should use Z3 symbolic path, not concretize().
When the input flows through a node that only has _to_z3_impl (like
``where``), the old concretize() fallback would raise
NotImplementedError. The Z3 If-chain implementation avoids this by
staying on the symbolic path end-to-end.
"""

🤖 Generated with Claude Code

  • If this code review was useful, please react with 👍. Otherwise, react with 👎.

fixed.



@pytest.mark.parametrize("op,np_op", [("argmax", np.argmax), ("argmin", np.argmin)])
def test_reduce_argmax_argmin_z3_through_where(op: str, np_op):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have any end to end examples?

return IntVal(0), constraints
best_val = arr[0]
best_idx: Z3Expr = IntVal(0)
for i in range(1, len(arr)):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is confusing to me. If arr is a SymbolicExpr, what's gonna happen?

Resolve conflict in ReduceSymbolicExpr.__init__: keep argmax/argmin
dtype logic from this branch and add self.shape from main.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants