Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Export FastSAM to ONNX using a wrapper class #254

Open
mattia-z opened this issue Oct 30, 2024 · 0 comments
Open

Export FastSAM to ONNX using a wrapper class #254

mattia-z opened this issue Oct 30, 2024 · 0 comments

Comments

@mattia-z
Copy link

I'm trying to write a wrapper class as torch.nn.Module, i need to have some custom reworks of the model's output in the forward step and then export the model to ONNX.
Here the wrapper class

import torch
import torch.nn as nn
from fastsam import FastSAM, FastSAMPrompt
import cv2
import matplotlib.pyplot as plt
import onnxruntime as ort
import numpy as np

class FastSAMWrapper(nn.Module):
    def __init__(self):
        super(FastSAMWrapper, self).__init__()
        self.model = FastSAM("FastSAM-s.pt")
        device = torch.device('cpu')
        self.model.to(device)

    def forward(self, img_tt):
        # Assuming x is the input image tensor
        # FastSAM may expect an image, so you may need to preprocess `x` as required by FastSAM.
        # Run the model's forward pass
        with torch.no_grad():
            image_np = img_tt.squeeze(0).permute(1, 2, 0).numpy()
            everything_results = self.model(image_np, device='cpu', retina_masks=True, imgsz=1024, conf=0.4, iou=0.9,)
            prompt_process = FastSAMPrompt(image_np, everything_results, device='cpu')
            ann = prompt_process.everything_prompt()
            # HERE i'm going to do some post-processing to ann
            if not isinstance(ann, torch.Tensor):
                ann = torch.tensor(ann)
            return ann

if i run the model as follow

 my_model = FastSAMWrapper()

image_np = cv2.imread('../images/IMG_3619.jpeg')
image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
image_np = cv2.resize(image_np, (1024, 1024))
image_tensor = torch.from_numpy(image_np).float()  # Ensure the tensor is of floating point type
reshaped_tensor = image_tensor.permute(2, 0, 1)  # Change to (3, 1024, 1024)
reshaped_tensor = reshaped_tensor.unsqueeze(0)  # Add batch dimension

results = my_model(reshaped_tensor)

everything works fine and the output is correct.

But when i export the model with the following code

dummy_input = torch.randn((1, 3, 1024, 1024), device='cpu')

torch.onnx.export(
    my_model,                     # The PyTorch model
    dummy_input,               # The input tensor
    "fastsam_model.onnx",      # The output ONNX file name
    export_params=True,        # Store the trained parameter weights
    opset_version=11,          # ONNX opset version (11 is widely compatible, but you can try 12 or higher if needed)
    do_constant_folding=True,  # Whether to apply constant folding for optimization
    input_names=['input'], 
    output_names=['masks_data'], 
    dynamic_axes={'input': {0: 'image_tensor'}, 'output': {0: 'masks'}},
    verbose=True
)

the model inputs returned by the following code is empty

onnx_model = ort.InferenceSession("fastsam_model.onnx")
print("Model Inputs:", onnx_model.get_inputs())

I've also tried exporting replacing dummy_input with reshaped_tensor but nothing change.

Follow what i see when i load the model into netron
Screenshot 2024-10-30 alle 18 18 48

I need to do this because the model will be used in a Unity app and we want to reduce the number of operations (the postprocessing on ann/model output) in the app

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant