Skip to content

BUG in Transformer2D #48

@lordsoffallen

Description

@lordsoffallen

Hi there,

First of all thanks for putting out the code, great work! I have noticed one a thing in the code.

In transformer_2d code where update forward method to add prompt/width/height information

def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
timestep: Optional[torch.LongTensor] = None,
class_labels: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
return_dict: bool = True,
prompt_book_info: list = None,
layout_mask=None,
height=None,
width=None,
):

which we then refer in the following attention block

hidden_states, cross_attn_prob = block(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
timestep=timestep,
cross_attention_kwargs=cross_attention_kwargs,
class_labels=class_labels,
prompt_book_info=prompt_book_info,
layout_mask=layout_mask,
height=height,
width=width,
)

However, before that line, this following code gets executed:

if self.is_input_continuous:
batch, _, height, width = hidden_states.shape
residual = hidden_states
hidden_states = self.norm(hidden_states)
if not self.use_linear_projection:
hidden_states = self.proj_in(hidden_states, lora_scale)
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
else:
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
hidden_states = self.proj_in(hidden_states, scale=lora_scale)
elif self.is_input_vectorized:
hidden_states = self.latent_image_embedding(hidden_states)
elif self.is_input_patches:
hidden_states = self.pos_embed(hidden_states)

Since we pass continuous input to this code, the height and width parameters are updated. I was wondering if this was a bug? If width and height meant to be inferred from the tensor dimensions, perhaps we don't need to pass it here?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions