Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bug fix of multi batch size and support onnx of yolo-s model #296

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions yolo_world/models/detectors/yolo_world.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ def extract_feat(
if batch_data_samples is None:
texts = self.texts
txt_feats = self.text_feats
batch_size=batch_inputs.shape[0]
texts = texts * batch_size
txt_feats = txt_feats.repeat(batch_size, 1, 1)
elif isinstance(batch_data_samples,
dict) and 'texts' in batch_data_samples:
texts = batch_data_samples['texts']
Expand Down
78 changes: 59 additions & 19 deletions yolo_world/models/layers/yolo_bricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,42 @@
from mmyolo.registry import MODELS
from mmyolo.models.layers import CSPLayerWithTwoConv

#AdaptiveAvgPool2dCustom and AdaptiveMaxPool2dCustom are compatible when exporting onnx format models
# reference: https://github.com/pytorch/pytorch/issues/42653#issuecomment-1168816422
class AdaptiveAvgPool2dCustom(nn.Module):
def __init__(self, output_size):
super(AdaptiveAvgPool2dCustom, self).__init__()
self.output_size = torch.tensor(output_size)

def forward(self, x: torch.Tensor):
# Calculate the stride size required to achieve the desired output size
stride_size = torch.floor(torch.tensor(x.shape[-2:]) / self.output_size).to(torch.int32)

# Calculate the kernel size based on the stride size and desired output size
kernel_size = torch.tensor(x.shape[-2:]) - (self.output_size - 1) * stride_size

# Create a AvgPool2d layer with the calculated kernel and stride sizes
avg = nn.AvgPool2d(kernel_size.tolist(), stride=stride_size.tolist())

x = avg(x)
return x
class AdaptiveMaxPool2dCustom(nn.Module):
def __init__(self, output_size):
super(AdaptiveMaxPool2dCustom, self).__init__()
self.output_size = torch.tensor(output_size)

def forward(self, x: torch.Tensor):
# Calculate the stride size required to achieve the desired output size
stride_size = torch.floor(torch.tensor(x.shape[-2:]) / self.output_size).to(torch.int32)

# Calculate the kernel size based on the stride size and desired output size
kernel_size = torch.tensor(x.shape[-2:]) - (self.output_size - 1) * stride_size

# Create a MaxPool2d layer with the calculated kernel and stride sizes
max_pool = nn.MaxPool2d(kernel_size.tolist(), stride=stride_size.tolist())

x = max_pool(x)
return x

@MODELS.register_module()
class MaxSigmoidAttnBlock(BaseModule):
Expand All @@ -31,7 +67,7 @@ def __init__(self,
momentum=0.03,
eps=0.001),
init_cfg: OptMultiConfig = None,
use_einsum: bool = True) -> None:
export_onnx: bool = True) -> None:
super().__init__(init_cfg=init_cfg)
conv = DepthwiseSeparableConvModule if use_depthwise else ConvModule

Expand All @@ -40,7 +76,7 @@ def __init__(self,
'out_channels and embed_channels should be divisible by num_heads.'
self.num_heads = num_heads
self.head_channels = out_channels // num_heads
self.use_einsum = use_einsum
self.export_onnx = export_onnx

self.embed_conv = ConvModule(
in_channels,
Expand Down Expand Up @@ -73,8 +109,7 @@ def forward(self, x: Tensor, guide: Tensor) -> Tensor:
guide = guide.reshape(B, -1, self.num_heads, self.head_channels)
embed = self.embed_conv(x) if self.embed_conv is not None else x
embed = embed.reshape(B, self.num_heads, self.head_channels, H, W)

if self.use_einsum:
if self.export_onnx == False:
attn_weight = torch.einsum('bmchw,bnmc->bmhwn', embed, guide)
else:
batch, m, channel, height, width = embed.shape
Expand Down Expand Up @@ -116,7 +151,7 @@ def __init__(self,
momentum=0.03,
eps=0.001),
init_cfg: OptMultiConfig = None,
use_einsum: bool = True) -> None:
export_onnx: bool = True) -> None:
super().__init__(init_cfg=init_cfg)
conv = DepthwiseSeparableConvModule if use_depthwise else ConvModule

Expand All @@ -125,7 +160,7 @@ def __init__(self,
'out_channels and embed_channels should be divisible by num_heads.'
self.num_heads = num_heads
self.head_channels = out_channels // num_heads
self.use_einsum = use_einsum
self.export_onnx = export_onnx

self.embed_conv = ConvModule(
in_channels,
Expand Down Expand Up @@ -272,7 +307,7 @@ def __init__(
norm_cfg: ConfigType = dict(type='BN', momentum=0.03, eps=0.001),
act_cfg: ConfigType = dict(type='SiLU', inplace=True),
init_cfg: OptMultiConfig = None,
use_einsum: bool = True) -> None:
export_onnx: bool = True) -> None:
super().__init__(in_channels=in_channels,
out_channels=out_channels,
expand_ratio=expand_ratio,
Expand All @@ -298,7 +333,7 @@ def __init__(
with_scale=with_scale,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
use_einsum=use_einsum)
export_onnx=export_onnx)

def forward(self, x: Tensor, guide: Tensor) -> Tensor:
"""Forward process."""
Expand Down Expand Up @@ -328,7 +363,7 @@ def __init__(
norm_cfg: ConfigType = dict(type='BN', momentum=0.03, eps=0.001),
act_cfg: ConfigType = dict(type='SiLU', inplace=True),
init_cfg: OptMultiConfig = None,
use_einsum: bool = True) -> None:
export_onnx: bool = True) -> None:
super().__init__(in_channels=in_channels,
out_channels=out_channels,
expand_ratio=expand_ratio,
Expand Down Expand Up @@ -412,7 +447,7 @@ def __init__(
with_scale=with_scale,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
use_einsum=use_einsum)
export_onnx=export_onnx)

def forward(self, x: Tensor, guide: Tensor) -> Tensor:
"""Forward process."""
Expand All @@ -434,7 +469,7 @@ def __init__(self,
num_feats: int = 3,
num_heads: int = 8,
pool_size: int = 3,
use_einsum: bool = True):
export_onnx: bool = True):
super().__init__()

self.text_channels = text_channels
Expand All @@ -443,7 +478,7 @@ def __init__(self,
self.num_feats = num_feats
self.head_channels = embed_channels // num_heads
self.pool_size = pool_size
self.use_einsum = use_einsum
self.export_onnx = export_onnx
if with_scale:
self.scale = nn.Parameter(torch.tensor([0.]), requires_grad=True)
else:
Expand All @@ -459,11 +494,16 @@ def __init__(self,
self.value = nn.Sequential(nn.LayerNorm(embed_channels),
Linear(embed_channels, embed_channels))
self.proj = Linear(embed_channels, text_channels)

self.image_pools = nn.ModuleList([
nn.AdaptiveMaxPool2d((pool_size, pool_size))
for _ in range(num_feats)
])
if self.export_onnx == False:
self.image_pools = nn.ModuleList([
nn.AdaptiveMaxPool2d((pool_size, pool_size))
for _ in range(num_feats)
])
else:
self.image_pools = nn.ModuleList([
AdaptiveMaxPool2dCustom((pool_size, pool_size))
for _ in range(num_feats)
])

def forward(self, text_features, image_features):
B = image_features[0].shape[0]
Expand All @@ -483,7 +523,7 @@ def forward(self, text_features, image_features):
q = q.reshape(B, -1, self.num_heads, self.head_channels)
k = k.reshape(B, -1, self.num_heads, self.head_channels)
v = v.reshape(B, -1, self.num_heads, self.head_channels)
if self.use_einsum:
if self.export_onnx == False:
attn_weight = torch.einsum('bnmc,bkmc->bmnk', q, k)
else:
q = q.permute(0, 2, 1, 3)
Expand All @@ -492,7 +532,7 @@ def forward(self, text_features, image_features):

attn_weight = attn_weight / (self.head_channels**0.5)
attn_weight = F.softmax(attn_weight, dim=-1)
if self.use_einsum:
if self.export_onnx == False:
x = torch.einsum('bmnk,bkmc->bnmc', attn_weight, v)
else:
v = v.permute(0, 2, 1, 3)
Expand Down