Description
Hello,
I've been trying to perform something that feels similar to style transfer from reference image while preserving the structure of input image. I created following code to run it via cli for a machine that isn't compatible with the gradio implementation you have used. But I get garbage output image.
def main(args):
# Initialize model
print("Loading model...")
model = create_model('configs/pair_diff.yaml').cpu()
model.load_state_dict(load_state_dict('checkpoints/model_e91.ckpt', location='cuda'))
model = model.cuda()
sampler = DDIMSamplerSpaCFG(model)
# Initialize ImageComp with edit_app mode
comp = ImageComp(edit_operation='edit_app')
# Load and prepare images
input_img, ref_img = load_images(args.input_image, args.reference_image)
assert input_img.shape == ref_img.shape, f"Input and reference images must match — got {input_img.shape} vs {ref_img.shape}"
input_img = comp.init_input_canvas(input_img)
ref_img = comp.init_ref_canvas(ref_img)
# Set input and reference image into comp
comp.input_image = input_img
comp.ref_image = ref_img
comp.init_segmentation_model(mask_model='Oneformer', segment_model='Oneformer')
comp.annotator = comp.segment_model
# Running segmentation for input
comp.input_segmask = comp.annotator.segment(input_img)
print(f"Input segmentation mask unique classes: {torch.unique(comp.input_segmask)}")
# Running segmentation for reference
comp.ref_segmask = comp.annotator.segment(ref_img)
print(f"Reference segmentation mask unique classes: {torch.unique(comp.ref_segmask)}")
# Select object mask by class index
input_mask = comp.select_input_object(idx=0)
ref_mask = comp.select_ref_object(idx=0)
whole_ref = True
# Run the editing process
print("Running editing process...")
output = comp.process(input_mask=input_mask, ref_mask=ref_mask, prompt=args.prompt,
a_prompt=args.a_prompt, n_prompt=args.n_prompt, num_samples=args.num_samples,
ddim_steps=args.ddim_steps, guess_mode=args.guess_mode, strength=args.strength,
scale_s=args.scale_s, scale_f=args.scale_f, scale_t=args.scale_t, seed=args.seed,
eta=args.eta, dilation_fac=args.dilation_fac, masking=not args.no_masking,
whole_ref=whole_ref, inter=args.inter, free_form_obj_var=args.free_form_obj_var,
ignore_structure=args.ignore_structure
)
# Save the output
comp.save_result(input_mask=None, prompt=None, a_prompt=None, n_prompt=None,
ddim_steps=args.ddim_steps, scale_s=args.scale_s, scale_f=args.scale_f, scale_t=args.scale_t,
seed=args.seed, dilation_fac=args.dilation_fac, free_form_obj_var=args.free_form_obj_var,
ignore_structure=args.ignore_structure
)
I also made changes in ImageComp class's methods as below, before which there were numerous issues related to None objects and dimension mismatches.
def _edit_app(self, whole_ref=False):
# Ensure input_mask is available
if self.input_mask is None:
print("[INFO] input_mask is None, generating default full-image mask.")
if self.input_image.ndim == 3 and self.input_image.shape[2] == 3: # HWC
H, W = self.input_image.shape[:2]
else: # CHW format
_, H, W = self.input_image.shape
self.input_mask = torch.ones((H, W), dtype=torch.float32).cuda()
........ <Same content as in original code>
def _edit(self, input_mask, ref_mask, dilation_fac=1, whole_ref=False, inter=1):
...... <Same content as in original code>
if isinstance(mean_feat_inpt_conc, list):
appearance_conc = []
for i in range(len(mean_feat_inpt_conc)):
feat_inpt = mean_feat_inpt_conc[i]
feat_ref = mean_feat_ref_conc[i]
region_count = feat_inpt.shape[1]
target_idx = ma + 1 if (ma + 1) < region_count else 0
feat_inpt[:, target_idx] = (1 - inter) * feat_inpt[:, target_idx] + inter * feat_ref[:, 0]
splatted_feat_conc = torch.einsum('nmc, nmhw->nchw', feat_inpt, one_hot_inpt_conc)
splatted_feat_conc = torch.nn.functional.normalize(splatted_feat_conc)
splatted_feat_conc = torch.nn.functional.interpolate(splatted_feat_conc, (self.H // 8, self.W // 8))
appearance_conc.append(splatted_feat_conc)
appearance_conc = torch.cat(appearance_conc, dim=1)
else:
region_count = mean_feat_inpt_conc.shape[1]
target_idx = ma + 1 if (ma + 1) < region_count else 0
mean_feat_inpt_conc[:, target_idx] = (1 - inter) * mean_feat_inpt_conc[:, target_idx] + inter * mean_feat_ref_conc[:, 0]
splatted_feat_conc = ......
......... <Same content as in original code>
if isinstance(mean_feat_inpt_ca, list):
appearance_ca = []
for i in range(len(mean_feat_inpt_ca)):
feat_inpt = mean_feat_inpt_ca[i]
feat_ref = mean_feat_ref_ca[i]
region_count = feat_inpt.shape[1]
target_idx = ma + 1 if (ma + 1) < region_count else 0
feat_inpt[:, target_idx] = (1 - inter) * feat_inpt[:, target_idx] + inter * feat_ref[:, 0]
splatted_feat_ca = torch.einsum('nmc, nmhw->nchw', feat_inpt, one_hot_inpt_ca)
splatted_feat_ca = torch.nn.functional.normalize(splatted_feat_ca)
splatted_feat_ca = torch.nn.functional.interpolate(splatted_feat_ca, (self.H // 8, self.W // 8))
appearance_ca.append(splatted_feat_ca)
appearance_ca = torch.cat(appearance_ca, dim=1)
else:
region_count = mean_feat_inpt_ca.shape[1]
target_idx = ma + 1 if (ma + 1) < region_count else 0
mean_feat_inpt_ca[:, target_idx] = (1 - inter) * mean_feat_inpt_ca[:, target_idx] + inter * mean_feat_ref_ca[:, 0]
splatted_feat_ca = .....
........... <Same content as in original code>
I'm trying to understand why I still get garbage images. I also used different values for interpolation parameter (inter in above code) but still no useful output. Please help.