Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

unexpected cache hit with symbolic values cache #1525

Open
jjsjann123 opened this issue Dec 6, 2024 · 0 comments
Open

unexpected cache hit with symbolic values cache #1525

jjsjann123 opened this issue Dec 6, 2024 · 0 comments

Comments

@jjsjann123
Copy link
Collaborator

jjsjann123 commented Dec 6, 2024

🐛 Bug

shape check logic in broadcast currently isn't translated properly to prologue trace. In the example below, the logic checks should be plumbed to prologue to ensure we reject cache hit when it's not appropriate to do so.

I think the root cause is the limited support in constraints we have today. Right now propagate_constraints only insert contraints on NumberProxy fed through prologue trace to compute trace, and TensorProxy.shape isn't explicit that.

To Reproduce

On a script like this

jfoo = thunder.jit(foo, cache="symbolic values")

a = torch.randn(2, 2, device="cuda")
b = torch.randn(1, device="cuda")
 
out = jfoo(a, b)
print(out.shape)

print("====================")
print("--- run1 ---")
print("\n\tprologue:\n", thunder.last_prologue_traces(jfoo)[0])
print("\n\tcompute:\n", thunder.last_traces(jfoo)[0])

a = torch.randn(2, 1, device="cuda")
b = torch.randn(8, device="cuda")

out = jfoo(a, b)
print(out.shape)
print("cache_hit: ", thunder.cache_hits(jfoo))

e.g. 0

def foo(a, b):
    if tuple(a.shape) == tuple((2, 2)):
        return a
    else:
        return b

We have a trace like this:

@torch.no_grad()
@no_autocast
def computation(a):
  # a: "cuda:0 f32[[IntegerProxy name=i0, value=2, static=CONSTRAINT.CONSTRAINABLE], [IntegerProxy name=i1, value=2, static=CONSTRAINT.CONSTRAINABLE]]"
  # /volume/thunder_dynamic/bool2.py:6:             if tuple(a.shape) == tuple((2, 2)):
  (i0, i1) = prims.shape(a)
  b3 = prims.eq(i0, 2)  # b3: "bool True"
  b4 = prims.eq(i1, 2)  # b4: "bool True"
  return a

It shows the comparison of each shape element, but that check is not propagated to prologue check, so even though the condition changes in the second iteration, we still hit the cache and return the wrong result.

Similarly, there are lots of checks like that in broadcast handling in thunder, so a simple binary add with the same input from the example would also show a cache hit, even though the broadcast in the second iteration changes.

e.g. 1

def foo(a, b):
    return a + b
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant