patching shape query in TensorProxy to avoid prims#1696
patching shape query in TensorProxy to avoid prims#1696mruberry merged 4 commits intobackward_transform_dependency_fixfrom
Conversation
|
This is really interesting. fyi @t-vi and @IvanYashchuk As a note for the future, when we're transforming traces, maybe we shouldn't generally be operating within the trace context, to avoid accidentally putting operations into the trace. I guess if we did that then calling I also wonder if, in the future, we should make these queries more explicit. For example, instead of telling people to query I don't think we have to pursue either of these ideas in this PR, of course. |

We should avoid calling
TensorProxy.shapedirectly during trace transformations, but useTensorProxy._shapeinstead.The difference is that
shapewould leave aprims.shapesymbol in the trace, if under a tracing context. Which is the case during grad transform.Without this change, the trace will be corrupted with
prims.shapecalling on non-existing TensorProxy.