Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FXImporter] Find a better way to get a buffer representation of a torch.Tensor (support bfloat and complex dtypes when faced with lift_fresh_copy) #3653

Open
123epsilon opened this issue Aug 20, 2024 · 0 comments

Comments

@123epsilon
Copy link
Contributor

This issue is being moved from SHARK-Turbine to torch-mlir as the fx importer lives in torch-mlir now.

Currently, we address issues stemming from lift_fresh_copy by creating a tensor literal op (nod-ai/SHARK-ModelDev#37), but this is problematic because in order to do so we need a buffer representation of a torch.Tensor. Unfortunately, torch.Tensor does not implement the python array interface fully which precludes us from directly grabbing the representation of the tensor in memory, rather we are forced to use an indirect route through numpy to get a python buffer that can be parsed by MLIR into a tensor literal. This has the unfortunate side effect that we can not support bfloat and complex<> datatypes with this operation because 1) numpy has no bfloat datatype and hence no representation for such a buffer and 2) numpy's buffer format for complex<> datatypes is incompatible with the buffer format that MLIR's DenseElementsAttr expects.

The best solution would be to have a first-class mechanism for getting a memoryview of a torch.Tensor by implementing the python array interface fully for this class. This is an issue tracking this shortcoming in pytorch: pytorch/pytorch#54138

Tracking the implementation of this interface: pytorch/pytorch#58743
Actually the immediately relevant interface is the python buffer interface: pytorch/pytorch#19143

The above issue refers to the following source in torch-mlir:

# torch.qint8: None, # no equivalent np datatype

# We need a raw buffer of data in order to create an ElementsAttr for the invocation of torch.vtensor.literal,

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant