Skip to content

patching shape query in TensorProxy to avoid prims#1696

Merged
mruberry merged 4 commits intobackward_transform_dependency_fixfrom
prims_shape_patch
Jan 27, 2025
Merged

patching shape query in TensorProxy to avoid prims#1696
mruberry merged 4 commits intobackward_transform_dependency_fixfrom
prims_shape_patch

Conversation

@jjsjann123
Copy link
Collaborator

@jjsjann123 jjsjann123 commented Jan 25, 2025

We should avoid calling TensorProxy.shape directly during trace transformations, but use TensorProxy._shape instead.

The difference is that shape would leave a prims.shape symbol in the trace, if under a tracing context. Which is the case during grad transform.

Without this change, the trace will be corrupted with prims.shape calling on non-existing TensorProxy.

@jjsjann123 jjsjann123 marked this pull request as ready for review January 27, 2025 17:00
@mruberry
Copy link
Collaborator

This is really interesting. fyi @t-vi and @IvanYashchuk

As a note for the future, when we're transforming traces, maybe we shouldn't generally be operating within the trace context, to avoid accidentally putting operations into the trace. I guess if we did that then calling .shape on these tensors would throw an error that we're calling a tracing operation outside a trace context? If it doesn't, then maybe it should? Otherwise it would be easy for a developer to accidentally write .shape when they mean ._shape.

I also wonder if, in the future, we should make these queries more explicit. For example, instead of telling people to query ._shape to avoid appearing in a trace (which seems totally fine) we could have properties like .shape_notrace or a utility library like with functions like notrace_shape(x) to acquire metadata without appearing in a trace.

I don't think we have to pursue either of these ideas in this PR, of course.

Copy link
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mruberry mruberry merged commit 91ce25f into backward_transform_dependency_fix Jan 27, 2025
46 checks passed
@mruberry mruberry deleted the prims_shape_patch branch January 27, 2025 17:50
@jjsjann123
Copy link
Collaborator Author

err.. I should have made it clear that these three PRs are stacked. so they really should be merged in the order of #1673 #1693 #1696

@IvanYashchuk
Copy link
Collaborator

err.. I should have made it clear that these three PRs are stacked. so they really should be merged in the order of #1673 #1693 #1696

This happens 😅 I usually keep the stacked PRs in the draft mode to prevent merges

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants