Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions __init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def clear_modules():
from .operators.dream_texture import DreamTexture, kill_generator
from .operators.upscale import upscale_options
from .property_groups.dream_prompt import DreamPrompt
from .property_groups.object_prompt import ObjectPrompt
from .preferences import StableDiffusionPreferences
from .ui.presets import register_default_presets

Expand Down Expand Up @@ -93,6 +94,9 @@ def get_selection_preview(self):

bpy.types.Scene.dream_textures_project_prompt = PointerProperty(type=DreamPrompt)
bpy.types.Scene.dream_textures_project_framebuffer_arguments = EnumProperty(name="Inputs", items=framebuffer_arguments)
bpy.types.Scene.dream_textures_project_use_object_prompts = BoolProperty(name="Use Per-Object Prompts", default=False, description="Specify a separate prompt for each object in the 'Object Properties' panel")

bpy.types.Object.dream_textures_prompt = PointerProperty(type=ObjectPrompt)

for cls in CLASSES:
bpy.utils.register_class(cls)
Expand Down
5 changes: 4 additions & 1 deletion classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from .operators.upscale import Upscale
from .operators.project import ProjectDreamTexture, dream_texture_projection_panels
from .property_groups.dream_prompt import DreamPrompt
from .ui.panels import dream_texture, history, upscaling, render_properties
from .property_groups.object_prompt import ObjectPrompt
from .ui.panels import dream_texture, history, upscaling, render_properties, object_prompt
from .preferences import OpenHuggingFace, OpenContributors, StableDiffusionPreferences, OpenDreamStudio, ImportWeights, Model, DeleteSelectedWeights, ModelSearch, InstallModel, PREFERENCES_UL_ModelList

from .ui.presets import DREAM_PT_AdvancedPresets, DREAM_MT_AdvancedPresets, AddAdvancedPreset, RestoreDefaultPresets
Expand Down Expand Up @@ -37,6 +38,7 @@
*upscaling.upscaling_panels(),
*history.history_panels(),
*dream_texture_projection_panels(),
object_prompt.ObjectPromptPanel,

dream_texture.OpenClipSegDownload,
dream_texture.OpenClipSegWeightsDirectory,
Expand All @@ -49,6 +51,7 @@
DeleteSelectedWeights,
Model,
DreamPrompt,
ObjectPrompt,
InstallDependencies,
OpenHuggingFace,
ImportWeights,
Expand Down
159 changes: 155 additions & 4 deletions generator_process/actions/depth_to_image.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Union, Generator, Callable, List, Optional
from typing import Union, Generator, Callable, List, Optional, Dict, Tuple
import os
from contextlib import nullcontext
from numpy.typing import NDArray
import numpy as np
import random
import math
from .prompt_to_image import Pipeline, Scheduler, Optimizations, StepPreviewMode, approximate_decoded_latents, ImageGenerationResult

def depth_to_image(
Expand All @@ -18,6 +19,8 @@ def depth_to_image(

depth: NDArray | str,
image: NDArray | str | None,
segmentation_map: NDArray | str | None,
segmentation_prompts: Dict[int, str],
strength: float,
prompt: str,
steps: int,
Expand Down Expand Up @@ -48,6 +51,14 @@ def prepare_depth(depth):
depth = depth[None, None]
depth = torch.from_numpy(depth)
return depth

def _img_importance_flatten(img: torch.Tensor, ratio: int) -> torch.Tensor:
return torch.nn.functional.interpolate(
img.unsqueeze(0).unsqueeze(1),
scale_factor=1 / ratio,
mode="bilinear",
align_corners=True,
).squeeze()

class GeneratorPipeline(diffusers.StableDiffusionInpaintPipeline):
def prepare_depth_latents(
Expand Down Expand Up @@ -106,12 +117,68 @@ def get_timesteps(self, num_inference_steps, strength, device):

return timesteps, num_inference_steps - t_start

#region Segmentation
def _segmentation_encode_prompts(
self,
segmentation_map: PIL.Image.Image,
segmentation_prompts: Dict[int, str]
):
tokenized_segments = []

for segment_color, segment_prompt in segmentation_prompts.items():
f = 1.5
v_input = self.tokenizer(
segment_prompt,
max_length=self.tokenizer.model_max_length,
truncation=True,
)
v_as_tokens = v_input["input_ids"][1:-1]

img_where_color = (np.array(segmentation_map) == segment_color)

img_where_color = torch.tensor(img_where_color, dtype=torch.float32) * f

tokenized_segments.append((v_as_tokens, img_where_color))

if len(tokenized_segments) == 0:
tokenized_segments.append(([-1], torch.zeros(segmentation_map.size[:2], dtype=torch.float32)))
return tokenized_segments

def _segmentation_tokens_weight(
self,
encoded_prompts,
tokenized,
ratio
):
token_lis = tokenized["input_ids"][0].tolist()
w, h = encoded_prompts[0][1].shape

w_r, h_r = w // ratio, h // ratio

weights = torch.zeros((w_r * h_r, len(token_lis)), dtype=torch.float32)

for v_as_tokens, img_where_color in encoded_prompts:
is_in = 0

for i in range(len(token_lis)):
if token_lis[i : i + len(v_as_tokens)] == v_as_tokens:
is_in = 1
weights[:, i : i + len(v_as_tokens)] += _img_importance_flatten(img_where_color, ratio).reshape(-1, 1).repeat(1, len(v_as_tokens))

if not is_in == 1:
print(f"Warning ratio {ratio} : tokens {v_as_tokens} not found in text")

return weights
#endregion

@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
depth_image: Union[torch.FloatTensor, PIL.Image.Image],
image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None,
segmentation_map: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None,
segmentation_prompts: Optional[Dict[int, str]] = None,
strength: float = 0.8,
height: Optional[int] = None,
width: Optional[int] = None,
Expand All @@ -128,7 +195,8 @@ def __call__(
callback_steps: Optional[int] = 1,
**kwargs,
):

_configure_paint_with_words_attention(self.unet)

# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
Expand All @@ -149,6 +217,17 @@ def __call__(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
)

# 3b. Segmentation
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
segmented_embeddings = self._segmentation_encode_prompts(segmentation_map, segmentation_prompts)
cross_attention_weights = torch.cat([self._segmentation_tokens_weight(segmented_embeddings, text_inputs, 8 * (2 ** ratio)) for ratio in range(4)]).to(device)

# 4. Prepare the depth image
depth = prepare_depth(depth_image)

Expand Down Expand Up @@ -226,8 +305,14 @@ def __call__(
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
latent_model_input = torch.cat([latent_model_input, depth], dim=1)

# NOTE: Segmentation passed through encoder_hidden_states
if segmentation_map is not None:
encoder_hidden_states = [text_embeddings, cross_attention_weights, self.scheduler.sigmas[i]]
else:
encoder_hidden_states = text_embeddings

# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=encoder_hidden_states).sample

# perform guidance
if do_classifier_free_guidance:
Expand Down Expand Up @@ -333,10 +418,14 @@ def __call__(
(torch.autocast(device) if optimizations.can_use("amp", device) else nullcontext()):
depth_image = PIL.ImageOps.flip(PIL.Image.fromarray(np.uint8(depth * 255), 'L')).resize(rounded_size)
init_image = None if image is None else (PIL.Image.open(image) if isinstance(image, str) else PIL.Image.fromarray(image.astype(np.uint8))).convert('RGB').resize(rounded_size)
if segmentation_map is not None:
segmentation_map = PIL.ImageOps.flip(PIL.Image.fromarray(np.uint8(segmentation_map * 255), 'L')).resize(rounded_size)
yield from pipe(
prompt=prompt,
depth_image=depth_image,
image=init_image,
segmentation_map=segmentation_map,
segmentation_prompts=segmentation_prompts,
strength=strength,
width=rounded_size[0],
height=rounded_size[1],
Expand All @@ -357,4 +446,66 @@ def __call__(
import stability_sdk
raise NotImplementedError()
case _:
raise Exception(f"Unsupported pipeline {pipeline}.")
raise Exception(f"Unsupported pipeline {pipeline}.")

def _configure_paint_with_words_attention(unet):
import torch

def forward(
self,
hidden_states,
context=None,
mask=None
):
"""Paint with words attention based on https://github.com/cloneofsimo/paint-with-words-sd

`context` is expected to be in the following format:
```python
[text_embeddings, [weight_64, weight_256, weight_1024, weight_4096], sigma]
```
"""

if context is not None:
context_tensor = context[0]
else:
context_tensor = hidden_states

query = self.to_q(hidden_states)

key = self.to_k(context_tensor)
value = self.to_v(context_tensor)

query = self.reshape_heads_to_batch_dim(query)
key = self.reshape_heads_to_batch_dim(key)
value = self.reshape_heads_to_batch_dim(value)

attention_scores = torch.matmul(query, key.transpose(-1, -2))

attention_size_of_img = attention_scores.shape[-2]
attention_weight_index = int(math.log2(attention_size_of_img // 64) // 2)
if context is not None:
w = context[1][attention_weight_index]
sigma = context[2]

cross_attention_weight = 0.1 * w * math.log(sigma + 1) * attention_scores.max()
else:
cross_attention_weight = 0.0

attention_scores = (attention_scores + cross_attention_weight) * self.scale

attention_probs = attention_scores.softmax(dim=-1)

hidden_states = torch.matmul(attention_probs, value)

hidden_states = self.reshape_batch_dim_to_heads(hidden_states)

# linear proj
hidden_states = self.to_out[0](hidden_states)
# dropout
hidden_states = self.to_out[1](hidden_states)

return hidden_states

for _module in unet.modules():
if _module.__class__.__name__ == "CrossAttention":
_module.__class__.__call__ = forward
Loading