-
Notifications
You must be signed in to change notification settings - Fork 129
Open
Labels
status:awaiting user responseWhen awaiting user responseWhen awaiting user responsetype:supportFor use-related issuesFor use-related issues
Description
I have a pytorch model that I would like to convert to tflite model with float16 weights and activations.
I am actually not sure if this is a bug report or feature request. There is more or less no documentation on this matter. So I don't know if I am doing something wrong, there is bug or it is simply not there 🤔 .
I have been testing it on dummy model:
Model code
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.model = nn.Sequential(
nn.Conv2d(3, 1, kernel_size=8, stride=8, padding=2),
nn.ReLU(),
nn.Flatten(),
nn.Linear(32 * 32, 10),
)
def forward(self, x):
return self.model(x)I tried to set the model in pytorch to float16 first i.e.
edge_model = ai_edge_torch.convert(
model.half(), tuple([x.half() for x in sample_inputs])
)which fails catastrophically with
loc("main.SimpleCNN/torch.nn.modules.container.Sequential_model/torch.nn.modules.conv.Conv2d_0;"("conv2d"("SimpleCNN/main.py":17:0))): error: failed to legalize operation 'tfl.transpose' that was explicitly marked illegal
Then I have tried to persuade the TensorFlow Lite Quantization to do the thing
converter_flags["optimizations"] = [tf.lite.Optimize.DEFAULT]
converter_flags["target_spec.supported_types"] = [tf.float16]
edge_model = ai_edge_torch.convert(
model, sample_inputs, _ai_edge_converter_flags=converter_flags
)This does not crashes but the converted model does not have float16 weights it has mix of float32 and int8 weight types.
Metadata
Metadata
Assignees
Labels
status:awaiting user responseWhen awaiting user responseWhen awaiting user responsetype:supportFor use-related issuesFor use-related issues