You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm trying to write a wrapper class as torch.nn.Module, i need to have some custom reworks of the model's output in the forward step and then export the model to ONNX.
Here the wrapper class
import torch
import torch.nn as nn
from fastsam import FastSAM, FastSAMPrompt
import cv2
import matplotlib.pyplot as plt
import onnxruntime as ort
import numpy as np
class FastSAMWrapper(nn.Module):
def __init__(self):
super(FastSAMWrapper, self).__init__()
self.model = FastSAM("FastSAM-s.pt")
device = torch.device('cpu')
self.model.to(device)
def forward(self, img_tt):
# Assuming x is the input image tensor
# FastSAM may expect an image, so you may need to preprocess `x` as required by FastSAM.
# Run the model's forward pass
with torch.no_grad():
image_np = img_tt.squeeze(0).permute(1, 2, 0).numpy()
everything_results = self.model(image_np, device='cpu', retina_masks=True, imgsz=1024, conf=0.4, iou=0.9,)
prompt_process = FastSAMPrompt(image_np, everything_results, device='cpu')
ann = prompt_process.everything_prompt()
# HERE i'm going to do some post-processing to ann
if not isinstance(ann, torch.Tensor):
ann = torch.tensor(ann)
return ann
if i run the model as follow
my_model = FastSAMWrapper()
image_np = cv2.imread('../images/IMG_3619.jpeg')
image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
image_np = cv2.resize(image_np, (1024, 1024))
image_tensor = torch.from_numpy(image_np).float() # Ensure the tensor is of floating point type
reshaped_tensor = image_tensor.permute(2, 0, 1) # Change to (3, 1024, 1024)
reshaped_tensor = reshaped_tensor.unsqueeze(0) # Add batch dimension
results = my_model(reshaped_tensor)
everything works fine and the output is correct.
But when i export the model with the following code
dummy_input = torch.randn((1, 3, 1024, 1024), device='cpu')
torch.onnx.export(
my_model, # The PyTorch model
dummy_input, # The input tensor
"fastsam_model.onnx", # The output ONNX file name
export_params=True, # Store the trained parameter weights
opset_version=11, # ONNX opset version (11 is widely compatible, but you can try 12 or higher if needed)
do_constant_folding=True, # Whether to apply constant folding for optimization
input_names=['input'],
output_names=['masks_data'],
dynamic_axes={'input': {0: 'image_tensor'}, 'output': {0: 'masks'}},
verbose=True
)
the model inputs returned by the following code is empty
I've also tried exporting replacing dummy_input with reshaped_tensor but nothing change.
Follow what i see when i load the model into netron
I need to do this because the model will be used in a Unity app and we want to reduce the number of operations (the postprocessing on ann/model output) in the app
The text was updated successfully, but these errors were encountered:
I'm trying to write a wrapper class as torch.nn.Module, i need to have some custom reworks of the model's output in the forward step and then export the model to ONNX.
Here the wrapper class
if i run the model as follow
everything works fine and the output is correct.
But when i export the model with the following code
the model inputs returned by the following code is empty
I've also tried exporting replacing
dummy_input
withreshaped_tensor
but nothing change.Follow what i see when i load the model into netron
I need to do this because the model will be used in a Unity app and we want to reduce the number of operations (the postprocessing on ann/model output) in the app
The text was updated successfully, but these errors were encountered: