You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
shape check logic in broadcast currently isn't translated properly to prologue trace. In the example below, the logic checks should be plumbed to prologue to ensure we reject cache hit when it's not appropriate to do so.
I think the root cause is the limited support in constraints we have today. Right now propagate_constraints only insert contraints on NumberProxy fed through prologue trace to compute trace, and TensorProxy.shape isn't explicit that.
To Reproduce
On a script like this
jfoo = thunder.jit(foo, cache="symbolic values")
a = torch.randn(2, 2, device="cuda")
b = torch.randn(1, device="cuda")
out = jfoo(a, b)
print(out.shape)
print("====================")
print("--- run1 ---")
print("\n\tprologue:\n", thunder.last_prologue_traces(jfoo)[0])
print("\n\tcompute:\n", thunder.last_traces(jfoo)[0])
a = torch.randn(2, 1, device="cuda")
b = torch.randn(8, device="cuda")
out = jfoo(a, b)
print(out.shape)
print("cache_hit: ", thunder.cache_hits(jfoo))
e.g. 0
def foo(a, b):
if tuple(a.shape) == tuple((2, 2)):
return a
else:
return b
It shows the comparison of each shape element, but that check is not propagated to prologue check, so even though the condition changes in the second iteration, we still hit the cache and return the wrong result.
Similarly, there are lots of checks like that in broadcast handling in thunder, so a simple binary add with the same input from the example would also show a cache hit, even though the broadcast in the second iteration changes.
e.g. 1
def foo(a, b):
return a + b
The text was updated successfully, but these errors were encountered:
🐛 Bug
shape check logic in broadcast currently isn't translated properly to prologue trace. In the example below, the logic checks should be plumbed to prologue to ensure we reject cache hit when it's not appropriate to do so.
I think the root cause is the limited support in constraints we have today. Right now propagate_constraints only insert contraints on NumberProxy fed through prologue trace to compute trace, and TensorProxy.shape isn't explicit that.
To Reproduce
On a script like this
e.g. 0
We have a trace like this:
It shows the comparison of each shape element, but that check is not propagated to prologue check, so even though the condition changes in the second iteration, we still hit the cache and return the wrong result.
Similarly, there are lots of checks like that in broadcast handling in thunder, so a simple binary add with the same input from the example would also show a cache hit, even though the broadcast in the second iteration changes.
e.g. 1
The text was updated successfully, but these errors were encountered: