From 2bb5ba5128774ade59a8dfe10d6a4ba9387a5cda Mon Sep 17 00:00:00 2001 From: psky1111 <1034487479@qq.com> Date: Wed, 22 Nov 2023 16:16:36 +0800 Subject: [PATCH] compatiable with static graph --- .../legendary_models/swin_transformer.py | 54 ++++++++++--------- 1 file changed, 30 insertions(+), 24 deletions(-) diff --git a/ppcls/arch/backbone/legendary_models/swin_transformer.py b/ppcls/arch/backbone/legendary_models/swin_transformer.py index 4534df6f43..133bb856f1 100644 --- a/ppcls/arch/backbone/legendary_models/swin_transformer.py +++ b/ppcls/arch/backbone/legendary_models/swin_transformer.py @@ -147,22 +147,27 @@ def pading_for_not_divisible(pixel_values, function="split"): if isinstance(patch_size, int): patch_size = (patch_size, patch_size) - if height % patch_size[0] == 0 and width % patch_size[1] == 0: - return pixel_values, (0, 0, 0, 0, 0, 0, 0, 0) if function == "split": pading_width = patch_size[1] - width % patch_size[1] pading_height = patch_size[0] - height % patch_size[0] elif function == "merge": pading_width = width % 2 pading_height = height % 2 + data_format = "NCHW" if format == "BCHW": - pad_index = (0, 0, 0, 0, 0, pading_height, 0, pading_width) + zero_pad = paddle.to_tensor([0],dtype=paddle.int32) + pad_index_h = paddle.concat([zero_pad,pading_height]) + pad_index_w = paddle.concat([zero_pad,pading_width]) + pad_index = paddle.concat([pad_index_h,pad_index_w]) elif format == "BHWC": - pad_index = (0, 0, 0, pading_height, 0, pading_width, 0, 0) + data_format = "NHWC" + zero_pad = paddle.to_tensor([0],dtype=paddle.int32) + pad_index_h = paddle.concat([zero_pad,pading_height]) + pad_index_w = paddle.concat([zero_pad,pading_width]) + pad_index = paddle.concat([pad_index_h,pad_index_w]) else: assert ("vaild format") - - return F.pad(pixel_values, pad_index), pad_index + return F.pad(pixel_values, pad_index,data_format=data_format), pad_index def window_partition(x, window_size): @@ -442,19 +447,19 @@ def get_attn_mask(self, height, width, dtype): return attn_mask def forward(self, x, input_dimensions): - H, W = input_dimensions - B, L, C = x.shape + B, H, W = input_dimensions + _, L, C = paddle.shape(x) + print(x) shortcut = x x = self.norm1(x) - x = x.reshape([B, H, W, C]) - + x = paddle.reshape(x,[B, H, W, C]) x, pad_values = pading_for_not_divisible(x, H, W, self.window_size, "BHWC") - _, height_pad, width_pad, _ = x.shape + _, height_pad, width_pad, _ = paddle.shape(x) - padding_state = pad_values[3] > 0 or pad_values[ - 5] > 0 # change variable name + padding_state = pad_values[1] > 0 or pad_values[ + 3] > 0 # change variable name # cyclic shift if self.shift_size > 0: shifted_x = RollWrapper.roll( @@ -493,7 +498,7 @@ def forward(self, x, input_dimensions): if padding_state: x = x[:, :H, :W, :] - x = x.reshape([B, H * W, C]) + x = paddle.reshape(x,[B, H * W, C]) # FFN x = shortcut + self.drop_path(x) @@ -541,9 +546,10 @@ def forward(self, x, input_dimensions): """ x: B, H*W, C """ - H, W = input_dimensions - B, L, C = x.shape - x = x.reshape((B, H, W, C)) + B, H, W = input_dimensions + B_, L, C = paddle.shape(x)[0], paddle.shape(x)[1], paddle.shape(x)[2] + x = paddle.reshape(x,[B, H, W, C]) + print(x.shape) x, _ = pading_for_not_divisible(x, H, W, 2, "BHWC", function="merge") x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C @@ -555,7 +561,7 @@ def forward(self, x, input_dimensions): # x = x.reshape([B, H // 2, 2, W // 2, 2, C]) # x = x.transpose((0, 1, 3, 4, 2, 5)) - x = x.reshape([B, -1, 4 * C]) # B H/2*W/2 4*C + x = paddle.reshape(x,[B_, -1, self.dim*4]) # B H/2*W/2 4*C x = self.norm(x) x = self.reduction(x) @@ -567,7 +573,7 @@ def extra_repr(self): self.dim) def flops(self): - H, W = self.input_resolution + B, H, W = self.input_resolution flops = H * W * self.dim flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim return flops @@ -641,13 +647,13 @@ def __init__(self, self.downsample = None def forward(self, x, input_dimensions): - H, W = input_dimensions + B, H, W = input_dimensions for blk in self.blocks: x = blk(x, input_dimensions) if self.downsample is not None: H, W = (H + 1) // 2, (W + 1) // 2 x = self.downsample(x, input_dimensions) - return x, (H, W) + return x, (B, H, W) def extra_repr(self): return "dim={}, input_resolution={}, depth={}".format( @@ -701,11 +707,11 @@ def __init__(self, self.norm = None def forward(self, x): - B, C, H, W = x.shape + B, C, H, W = paddle.shape(x) x, _ = pading_for_not_divisible(x, H, W, self.patch_size, "BCHW") x = self.proj(x) - _, _, height, width = x.shape - output_dimensions = (height, width) + B_, _, height, width = paddle.shape(x) + output_dimensions = (B_, height, width) x = x.flatten(2).transpose([0, 2, 1]) # B Ph*Pw C if self.norm is not None: x = self.norm(x)