Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compatiable with static graph #3043

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 30 additions & 24 deletions ppcls/arch/backbone/legendary_models/swin_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,22 +147,27 @@ def pading_for_not_divisible(pixel_values,
function="split"):
if isinstance(patch_size, int):
patch_size = (patch_size, patch_size)
if height % patch_size[0] == 0 and width % patch_size[1] == 0:
return pixel_values, (0, 0, 0, 0, 0, 0, 0, 0)
if function == "split":
pading_width = patch_size[1] - width % patch_size[1]
pading_height = patch_size[0] - height % patch_size[0]
elif function == "merge":
pading_width = width % 2
pading_height = height % 2
data_format = "NCHW"
if format == "BCHW":
pad_index = (0, 0, 0, 0, 0, pading_height, 0, pading_width)
zero_pad = paddle.to_tensor([0],dtype=paddle.int32)
pad_index_h = paddle.concat([zero_pad,pading_height])
pad_index_w = paddle.concat([zero_pad,pading_width])
pad_index = paddle.concat([pad_index_h,pad_index_w])
elif format == "BHWC":
pad_index = (0, 0, 0, pading_height, 0, pading_width, 0, 0)
data_format = "NHWC"
zero_pad = paddle.to_tensor([0],dtype=paddle.int32)
pad_index_h = paddle.concat([zero_pad,pading_height])
pad_index_w = paddle.concat([zero_pad,pading_width])
pad_index = paddle.concat([pad_index_h,pad_index_w])
else:
assert ("vaild format")

return F.pad(pixel_values, pad_index), pad_index
return F.pad(pixel_values, pad_index,data_format=data_format), pad_index


def window_partition(x, window_size):
Expand Down Expand Up @@ -442,19 +447,19 @@ def get_attn_mask(self, height, width, dtype):
return attn_mask

def forward(self, x, input_dimensions):
H, W = input_dimensions
B, L, C = x.shape
B, H, W = input_dimensions
_, L, C = paddle.shape(x)

print(x)
shortcut = x
x = self.norm1(x)
x = x.reshape([B, H, W, C])

x = paddle.reshape(x,[B, H, W, C])
x, pad_values = pading_for_not_divisible(x, H, W, self.window_size,
"BHWC")
_, height_pad, width_pad, _ = x.shape
_, height_pad, width_pad, _ = paddle.shape(x)

padding_state = pad_values[3] > 0 or pad_values[
5] > 0 # change variable name
padding_state = pad_values[1] > 0 or pad_values[
3] > 0 # change variable name
# cyclic shift
if self.shift_size > 0:
shifted_x = RollWrapper.roll(
Expand Down Expand Up @@ -493,7 +498,7 @@ def forward(self, x, input_dimensions):

if padding_state:
x = x[:, :H, :W, :]
x = x.reshape([B, H * W, C])
x = paddle.reshape(x,[B, H * W, C])

# FFN
x = shortcut + self.drop_path(x)
Expand Down Expand Up @@ -541,9 +546,10 @@ def forward(self, x, input_dimensions):
"""
x: B, H*W, C
"""
H, W = input_dimensions
B, L, C = x.shape
x = x.reshape((B, H, W, C))
B, H, W = input_dimensions
B_, L, C = paddle.shape(x)[0], paddle.shape(x)[1], paddle.shape(x)[2]
x = paddle.reshape(x,[B, H, W, C])
print(x.shape)
x, _ = pading_for_not_divisible(x, H, W, 2, "BHWC", function="merge")

x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
Expand All @@ -555,7 +561,7 @@ def forward(self, x, input_dimensions):
# x = x.reshape([B, H // 2, 2, W // 2, 2, C])
# x = x.transpose((0, 1, 3, 4, 2, 5))

x = x.reshape([B, -1, 4 * C]) # B H/2*W/2 4*C
x = paddle.reshape(x,[B_, -1, self.dim*4]) # B H/2*W/2 4*C

x = self.norm(x)
x = self.reduction(x)
Expand All @@ -567,7 +573,7 @@ def extra_repr(self):
self.dim)

def flops(self):
H, W = self.input_resolution
B, H, W = self.input_resolution
flops = H * W * self.dim
flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
return flops
Expand Down Expand Up @@ -641,13 +647,13 @@ def __init__(self,
self.downsample = None

def forward(self, x, input_dimensions):
H, W = input_dimensions
B, H, W = input_dimensions
for blk in self.blocks:
x = blk(x, input_dimensions)
if self.downsample is not None:
H, W = (H + 1) // 2, (W + 1) // 2
x = self.downsample(x, input_dimensions)
return x, (H, W)
return x, (B, H, W)

def extra_repr(self):
return "dim={}, input_resolution={}, depth={}".format(
Expand Down Expand Up @@ -701,11 +707,11 @@ def __init__(self,
self.norm = None

def forward(self, x):
B, C, H, W = x.shape
B, C, H, W = paddle.shape(x)
x, _ = pading_for_not_divisible(x, H, W, self.patch_size, "BCHW")
x = self.proj(x)
_, _, height, width = x.shape
output_dimensions = (height, width)
B_, _, height, width = paddle.shape(x)
output_dimensions = (B_, height, width)
x = x.flatten(2).transpose([0, 2, 1]) # B Ph*Pw C
if self.norm is not None:
x = self.norm(x)
Expand Down