-
Notifications
You must be signed in to change notification settings - Fork 31
Description
Hi there,
First of all thanks for putting out the code, great work! I have noticed one a thing in the code.
In transformer_2d code where update forward method to add prompt/width/height information
AutoStudio/model/transformer_2d.py
Lines 213 to 227 in 9c9820f
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| encoder_hidden_states: Optional[torch.Tensor] = None, | |
| timestep: Optional[torch.LongTensor] = None, | |
| class_labels: Optional[torch.LongTensor] = None, | |
| cross_attention_kwargs: Dict[str, Any] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| encoder_attention_mask: Optional[torch.Tensor] = None, | |
| return_dict: bool = True, | |
| prompt_book_info: list = None, | |
| layout_mask=None, | |
| height=None, | |
| width=None, | |
| ): |
which we then refer in the following attention block
AutoStudio/model/transformer_2d.py
Lines 319 to 331 in 9c9820f
| hidden_states, cross_attn_prob = block( | |
| hidden_states, | |
| attention_mask=attention_mask, | |
| encoder_hidden_states=encoder_hidden_states, | |
| encoder_attention_mask=encoder_attention_mask, | |
| timestep=timestep, | |
| cross_attention_kwargs=cross_attention_kwargs, | |
| class_labels=class_labels, | |
| prompt_book_info=prompt_book_info, | |
| layout_mask=layout_mask, | |
| height=height, | |
| width=width, | |
| ) |
However, before that line, this following code gets executed:
AutoStudio/model/transformer_2d.py
Lines 285 to 302 in 9c9820f
| if self.is_input_continuous: | |
| batch, _, height, width = hidden_states.shape | |
| residual = hidden_states | |
| hidden_states = self.norm(hidden_states) | |
| if not self.use_linear_projection: | |
| hidden_states = self.proj_in(hidden_states, lora_scale) | |
| inner_dim = hidden_states.shape[1] | |
| hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) | |
| else: | |
| inner_dim = hidden_states.shape[1] | |
| hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) | |
| hidden_states = self.proj_in(hidden_states, scale=lora_scale) | |
| elif self.is_input_vectorized: | |
| hidden_states = self.latent_image_embedding(hidden_states) | |
| elif self.is_input_patches: | |
| hidden_states = self.pos_embed(hidden_states) |
Since we pass continuous input to this code, the height and width parameters are updated. I was wondering if this was a bug? If width and height meant to be inferred from the tensor dimensions, perhaps we don't need to pass it here?