[REFACTOR] Separate dtype and shape in SymbolicExpr#322
Conversation
SymbolicExpr.dtype previously bundled scalar type and shape into tl.block_type, causing bugs when subclasses (like CastSymbolicExpr) set dtype without preserving shape. This mirrors Triton's own interpreter.py design where TensorHandle stores scalar-only dtype. - Add _decompose_dtype() helper to split block_type into (scalar, shape) - Add self.shape field to base SymbolicExpr; dtype now always stores scalar types (never block_type) - Add block_dtype property to reconstruct full block_type when needed - Update all 20+ subclasses to maintain separate dtype/shape - Fix AddPtrSymbolicExpr._to_z3_impl: remove .scalar unwrap - Fix _collect_tensor_base: skip block pointer consts - Add TensorPointerSymbolicExpr._resolve_block_shape for block ptr loads - Update patch.py to use block_dtype for tl.core.tensor() reconstruction - Add test_dtype_is_always_scalar invariant test
Sanitizer Performance Benchmark
Iterations: 1 warmup + 40 measured |
Inline the dtype decomposition logic at each usage site and replace the public block_dtype property with a private _full_dtype() method.
Jokeren
left a comment
There was a problem hiding this comment.
Why not using triton's dtype? Can you provide me a concrete example?
The construction order of a tl.core.tensor is like:
(See triton_viz/core/patch.py:179-188) So we cannot simply create a reference to the tensor's dtype and shape, but we need to create and maintain one first in our |
This should be mentioned in the commit message |
Summary
dtypeandshapeinSymbolicExpr:self.dtypenow always stores scalar types (nevertl.block_type), andself.shapeis an explicittuple[int, ...]field. This mirrors Triton's owninterpreter.pydesign whereTensorHandlestores scalar-only dtype..to()/cast_implwould overwritedtypewith a scalar type, silently dropping shape information and causing downstream crashes (e.g. intl.sumaftertl.load(block_ptr).to(fp32)).block_dtypeproperty to reconstruct fulltl.block_typeon demand (used bypatch.pyandconcretize()).Changes
symbolic_engine.py_decompose_dtype()helper; addself.shapefield andblock_dtypeproperty to base class; update all 20+ subclasses; add_resolve_block_shape()for block pointer loads; fixAddPtrSymbolicExpr._to_z3_impl.scalarunwrappatch.pyblock_dtypeto reconstruct full dtype fortl.core.tensor()sanitizer.py_collect_tensor_baseto skip block pointer consts (checknot node.shape)test_sanitizer.py(unit)test_dtype_is_always_scalarinvariant testtest_sanitizer.py(e2e)test_reduce_symbolic_core_dtype,test_reduce_symbolic_nonetype,test_expand_dims_scalar_attrfromfix/cast-block-type-preserveTest plan
pytest tests/unit/test_sanitizer.py— 57 passedpytest tests/end_to_end/test_sanitizer.py— 24 passedpytest tests/— 210 passed, 4 skippedSupersedes
Closes #313 (
fix/cast-block-type-preserve) — this refactor addresses the root cause rather than patching the symptom.