You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[FXImporter] Find a better way to get a buffer representation of a torch.Tensor (support bfloat and complex dtypes when faced with lift_fresh_copy)
#3653
Open
123epsilon opened this issue
Aug 20, 2024
· 0 comments
This issue is being moved from SHARK-Turbine to torch-mlir as the fx importer lives in torch-mlir now.
Currently, we address issues stemming from lift_fresh_copy by creating a tensor literal op (nod-ai/SHARK-ModelDev#37), but this is problematic because in order to do so we need a buffer representation of a torch.Tensor. Unfortunately, torch.Tensor does not implement the python array interface fully which precludes us from directly grabbing the representation of the tensor in memory, rather we are forced to use an indirect route through numpy to get a python buffer that can be parsed by MLIR into a tensor literal. This has the unfortunate side effect that we can not support bfloat and complex<> datatypes with this operation because 1) numpy has no bfloat datatype and hence no representation for such a buffer and 2) numpy's buffer format for complex<> datatypes is incompatible with the buffer format that MLIR's DenseElementsAttr expects.
The best solution would be to have a first-class mechanism for getting a memoryview of a torch.Tensor by implementing the python array interface fully for this class. This is an issue tracking this shortcoming in pytorch: pytorch/pytorch#54138
This issue is being moved from SHARK-Turbine to torch-mlir as the fx importer lives in torch-mlir now.
Currently, we address issues stemming from lift_fresh_copy by creating a tensor literal op (nod-ai/SHARK-ModelDev#37), but this is problematic because in order to do so we need a buffer representation of a torch.Tensor. Unfortunately, torch.Tensor does not implement the python array interface fully which precludes us from directly grabbing the representation of the tensor in memory, rather we are forced to use an indirect route through numpy to get a python buffer that can be parsed by MLIR into a tensor literal. This has the unfortunate side effect that we can not support bfloat and complex<> datatypes with this operation because 1) numpy has no bfloat datatype and hence no representation for such a buffer and 2) numpy's buffer format for complex<> datatypes is incompatible with the buffer format that MLIR's DenseElementsAttr expects.
The best solution would be to have a first-class mechanism for getting a memoryview of a torch.Tensor by implementing the python array interface fully for this class. This is an issue tracking this shortcoming in pytorch: pytorch/pytorch#54138
Tracking the implementation of this interface: pytorch/pytorch#58743
Actually the immediately relevant interface is the python buffer interface: pytorch/pytorch#19143
The above issue refers to the following source in torch-mlir:
torch-mlir/python/torch_mlir/extras/fx_importer.py
Line 195 in af67f9e
torch-mlir/python/torch_mlir/extras/fx_importer.py
Line 2110 in af67f9e
The text was updated successfully, but these errors were encountered: