Skip to content

Garbage image output when using whole_ref=True #17

Open
@GShef

Description

@GShef

Hello,

I've been trying to perform something that feels similar to style transfer from reference image while preserving the structure of input image. I created following code to run it via cli for a machine that isn't compatible with the gradio implementation you have used. But I get garbage output image.

def main(args):
# Initialize model
print("Loading model...")
model = create_model('configs/pair_diff.yaml').cpu()
model.load_state_dict(load_state_dict('checkpoints/model_e91.ckpt', location='cuda'))
model = model.cuda()
sampler = DDIMSamplerSpaCFG(model)

# Initialize ImageComp with edit_app mode
comp = ImageComp(edit_operation='edit_app')

# Load and prepare images
input_img, ref_img = load_images(args.input_image, args.reference_image)
assert input_img.shape == ref_img.shape, f"Input and reference images must match — got {input_img.shape} vs {ref_img.shape}"
input_img = comp.init_input_canvas(input_img)
ref_img = comp.init_ref_canvas(ref_img)

# Set input and reference image into comp
comp.input_image = input_img
comp.ref_image = ref_img

comp.init_segmentation_model(mask_model='Oneformer', segment_model='Oneformer')
comp.annotator = comp.segment_model

# Running segmentation for input
comp.input_segmask = comp.annotator.segment(input_img)
print(f"Input segmentation mask unique classes: {torch.unique(comp.input_segmask)}")
# Running segmentation for reference
comp.ref_segmask = comp.annotator.segment(ref_img)
print(f"Reference segmentation mask unique classes: {torch.unique(comp.ref_segmask)}")
# Select object mask by class index
input_mask = comp.select_input_object(idx=0)
ref_mask = comp.select_ref_object(idx=0)
whole_ref = True

# Run the editing process
print("Running editing process...")
output = comp.process(input_mask=input_mask, ref_mask=ref_mask, prompt=args.prompt,
    a_prompt=args.a_prompt, n_prompt=args.n_prompt, num_samples=args.num_samples,
    ddim_steps=args.ddim_steps, guess_mode=args.guess_mode, strength=args.strength,
    scale_s=args.scale_s, scale_f=args.scale_f, scale_t=args.scale_t, seed=args.seed,
    eta=args.eta, dilation_fac=args.dilation_fac, masking=not args.no_masking,
    whole_ref=whole_ref, inter=args.inter, free_form_obj_var=args.free_form_obj_var,
    ignore_structure=args.ignore_structure
)
# Save the output
comp.save_result(input_mask=None, prompt=None,  a_prompt=None, n_prompt=None,
    ddim_steps=args.ddim_steps, scale_s=args.scale_s, scale_f=args.scale_f, scale_t=args.scale_t,
    seed=args.seed, dilation_fac=args.dilation_fac, free_form_obj_var=args.free_form_obj_var,
    ignore_structure=args.ignore_structure
)

I also made changes in ImageComp class's methods as below, before which there were numerous issues related to None objects and dimension mismatches.

def _edit_app(self, whole_ref=False):
    # Ensure input_mask is available 
    if self.input_mask is None:
        print("[INFO] input_mask is None, generating default full-image mask.")
        if self.input_image.ndim == 3 and self.input_image.shape[2] == 3:  # HWC
            H, W = self.input_image.shape[:2]
        else:  # CHW format
            _, H, W = self.input_image.shape
        self.input_mask = torch.ones((H, W), dtype=torch.float32).cuda()
    ........ <Same content as in original code>

def _edit(self, input_mask, ref_mask, dilation_fac=1, whole_ref=False, inter=1):
    ...... <Same content as in original code>
    if isinstance(mean_feat_inpt_conc, list):
        appearance_conc = []
        for i in range(len(mean_feat_inpt_conc)):
            feat_inpt = mean_feat_inpt_conc[i]
            feat_ref = mean_feat_ref_conc[i]

            region_count = feat_inpt.shape[1]
            target_idx = ma + 1 if (ma + 1) < region_count else 0

            feat_inpt[:, target_idx] = (1 - inter) * feat_inpt[:, target_idx] + inter * feat_ref[:, 0]
            splatted_feat_conc = torch.einsum('nmc, nmhw->nchw', feat_inpt, one_hot_inpt_conc)
            splatted_feat_conc = torch.nn.functional.normalize(splatted_feat_conc)
            splatted_feat_conc = torch.nn.functional.interpolate(splatted_feat_conc, (self.H // 8, self.W // 8))
            appearance_conc.append(splatted_feat_conc)
        appearance_conc = torch.cat(appearance_conc, dim=1)
    else:
        region_count = mean_feat_inpt_conc.shape[1]
        target_idx = ma + 1 if (ma + 1) < region_count else 0
        mean_feat_inpt_conc[:, target_idx] = (1 - inter) * mean_feat_inpt_conc[:, target_idx] + inter * mean_feat_ref_conc[:, 0]
        splatted_feat_conc = ......
    ......... <Same content as in original code>

    if isinstance(mean_feat_inpt_ca, list):
        appearance_ca = []
        for i in range(len(mean_feat_inpt_ca)):
            feat_inpt = mean_feat_inpt_ca[i]
            feat_ref = mean_feat_ref_ca[i]

            region_count = feat_inpt.shape[1]
            target_idx = ma + 1 if (ma + 1) < region_count else 0

            feat_inpt[:, target_idx] = (1 - inter) * feat_inpt[:, target_idx] + inter * feat_ref[:, 0]
            splatted_feat_ca = torch.einsum('nmc, nmhw->nchw', feat_inpt, one_hot_inpt_ca)
            splatted_feat_ca = torch.nn.functional.normalize(splatted_feat_ca)
            splatted_feat_ca = torch.nn.functional.interpolate(splatted_feat_ca, (self.H // 8, self.W // 8))
            appearance_ca.append(splatted_feat_ca)
        appearance_ca = torch.cat(appearance_ca, dim=1)
    else:
        region_count = mean_feat_inpt_ca.shape[1]
        target_idx = ma + 1 if (ma + 1) < region_count else 0
        mean_feat_inpt_ca[:, target_idx] = (1 - inter) * mean_feat_inpt_ca[:, target_idx] + inter * mean_feat_ref_ca[:, 0]
        splatted_feat_ca = .....
    ........... <Same content as in original code>

I'm trying to understand why I still get garbage images. I also used different values for interpolation parameter (inter in above code) but still no useful output. Please help.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions