Skip to content

[REFACTOR] Separate dtype and shape in SymbolicExpr#322

Merged
mark14wu merged 3 commits intomainfrom
refactor/separate-dtype-shape
Mar 9, 2026
Merged

[REFACTOR] Separate dtype and shape in SymbolicExpr#322
mark14wu merged 3 commits intomainfrom
refactor/separate-dtype-shape

Conversation

@mark14wu
Copy link
Collaborator

@mark14wu mark14wu commented Mar 8, 2026

Summary

  • Separates dtype and shape in SymbolicExpr: self.dtype now always stores scalar types (never tl.block_type), and self.shape is an explicit tuple[int, ...] field. This mirrors Triton's own interpreter.py design where TensorHandle stores scalar-only dtype.
  • Fixes the CastSymbolicExpr bug where .to() / cast_impl would overwrite dtype with a scalar type, silently dropping shape information and causing downstream crashes (e.g. in tl.sum after tl.load(block_ptr).to(fp32)).
  • Adds block_dtype property to reconstruct full tl.block_type on demand (used by patch.py and concretize()).

Changes

File What changed
symbolic_engine.py Add _decompose_dtype() helper; add self.shape field and block_dtype property to base class; update all 20+ subclasses; add _resolve_block_shape() for block pointer loads; fix AddPtrSymbolicExpr._to_z3_impl .scalar unwrap
patch.py Use block_dtype to reconstruct full dtype for tl.core.tensor()
sanitizer.py Fix _collect_tensor_base to skip block pointer consts (check not node.shape)
test_sanitizer.py (unit) Update assertions for scalar dtype; add test_dtype_is_always_scalar invariant test
test_sanitizer.py (e2e) Add test_reduce_symbolic_core_dtype, test_reduce_symbolic_nonetype, test_expand_dims_scalar_attr from fix/cast-block-type-preserve

Test plan

  • pytest tests/unit/test_sanitizer.py — 57 passed
  • pytest tests/end_to_end/test_sanitizer.py — 24 passed
  • pytest tests/ — 210 passed, 4 skipped

Supersedes

Closes #313 (fix/cast-block-type-preserve) — this refactor addresses the root cause rather than patching the symptom.

SymbolicExpr.dtype previously bundled scalar type and shape into
tl.block_type, causing bugs when subclasses (like CastSymbolicExpr)
set dtype without preserving shape. This mirrors Triton's own
interpreter.py design where TensorHandle stores scalar-only dtype.

- Add _decompose_dtype() helper to split block_type into (scalar, shape)
- Add self.shape field to base SymbolicExpr; dtype now always stores
  scalar types (never block_type)
- Add block_dtype property to reconstruct full block_type when needed
- Update all 20+ subclasses to maintain separate dtype/shape
- Fix AddPtrSymbolicExpr._to_z3_impl: remove .scalar unwrap
- Fix _collect_tensor_base: skip block pointer consts
- Add TensorPointerSymbolicExpr._resolve_block_shape for block ptr loads
- Update patch.py to use block_dtype for tl.core.tensor() reconstruction
- Add test_dtype_is_always_scalar invariant test
@github-actions
Copy link

github-actions bot commented Mar 8, 2026

Sanitizer Performance Benchmark

Benchmark main (min) PR (min) Change
simple_load_store 0.005s 0.005s +0.6%
gemm 0.023s 0.023s -0.1%
gemm_oob 0.024s 0.024s +0.4%
indirect_load 0.076s 0.076s -0.2%
nested_loop 0.024s 0.025s +1.5%
block_pointer_loop_advance 0.007s 0.007s -0.0%
liger_jsd 0.149s 0.150s +0.2%
flaggems_layernorm 2.921s 2.795s -4.3%
Total 3.230s 3.105s -3.9%

Iterations: 1 warmup + 40 measured

mark14wu added 2 commits March 9, 2026 00:13
Inline the dtype decomposition logic at each usage site and replace
the public block_dtype property with a private _full_dtype() method.
Copy link
Member

@Jokeren Jokeren left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not using triton's dtype? Can you provide me a concrete example?

@mark14wu
Copy link
Collaborator Author

mark14wu commented Mar 9, 2026

Why not using triton's dtype? Can you provide me a concrete example?

The construction order of a tl.core.tensor is like:

  1. Create the SymbolicExpr first (at this point, no corresponding Triton tensor exists yet)
  2. Derive ret_dtype from the SymbolicExpr's dtype and shape
  3. Construct the Triton tensor with tl.core.tensor(symbolic_ret, ret_dtype)

(See triton_viz/core/patch.py:179-188)

  symbolic_ret = self.callbacks.op_overrider(
      args[0].handle, *args[1:], **kwargs
  )                                                    
  _shape = getattr(symbolic_ret, "shape", ())
  _dtype = getattr(symbolic_ret, "dtype", None)
  if _shape and _dtype:
      ret_dtype = tl.block_type(_dtype, list(_shape))
  else:
      ret_dtype = _dtype or args[0].dtype          
  ret = tl.core.tensor(symbolic_ret, ret_dtype) 

So we cannot simply create a reference to the tensor's dtype and shape, but we need to create and maintain one first in our SymbolicExpr.

@Jokeren
Copy link
Member

Jokeren commented Mar 9, 2026

Why not using triton's dtype? Can you provide me a concrete example?

The construction order of a tl.core.tensor is like:

  1. Create the SymbolicExpr first (at this point, no corresponding Triton tensor exists yet)
  2. Derive ret_dtype from the SymbolicExpr's dtype and shape
  3. Construct the Triton tensor with tl.core.tensor(symbolic_ret, ret_dtype)

(See triton_viz/core/patch.py:179-188)

  symbolic_ret = self.callbacks.op_overrider(
      args[0].handle, *args[1:], **kwargs
  )                                                    
  _shape = getattr(symbolic_ret, "shape", ())
  _dtype = getattr(symbolic_ret, "dtype", None)
  if _shape and _dtype:
      ret_dtype = tl.block_type(_dtype, list(_shape))
  else:
      ret_dtype = _dtype or args[0].dtype          
  ret = tl.core.tensor(symbolic_ret, ret_dtype) 

So we cannot simply create a reference to the tensor's dtype and shape, but we need to create and maintain one first in our SymbolicExpr.

This should be mentioned in the commit message

@mark14wu mark14wu merged commit fe1f620 into main Mar 9, 2026
4 checks passed
@mark14wu mark14wu deleted the refactor/separate-dtype-shape branch March 9, 2026 22:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants