Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: Tensor on device meta is not on the expected device xla:0! #8568

Open
chaowenguo opened this issue Jan 14, 2025 · 0 comments
Open

Comments

@chaowenguo
Copy link

chaowenguo commented Jan 14, 2025

❓ Questions and Help

import torch, torch_xla, diffusers, imageio, easy_dwpose, builtins, PIL.Image, numpy

def process(index):
    openpose = easy_dwpose.DWposeDetector()
    with imageio.get_reader(f'pose{index}.mp4') as reader, imageio.get_writer(f'out{index}.mp4', fps=reader.get_meta_data().get('fps')) as writer:
        conditioning_frames = [PIL.Image.fromarray(reader.get_data(_)).resize((512, 960)) for _ in builtins.range(reader.count_frames())]
        controlnet = diffusers.ControlNetModel.from_pretrained('chaowenguo/control_v11p_sd15_openpose', torch_dtype=torch.bfloat16, variant='fp16', use_safetensors=True)
        motion_adapter = diffusers.MotionAdapter.from_pretrained('chaowenguo/AnimateLCM', torch_dtype=torch.bfloat16, variant='fp16', use_safetensors=True)
        vae = diffusers.AutoencoderKL.from_single_file('https://huggingface.co/chaowenguo/pal/blob/main/vae-ft-mse-840000-ema-pruned.safetensors', torch_dtype=torch.bfloat16, use_safetensors=True)
        pipeline = diffusers.AnimateDiffControlNetPipeline.from_single_file('https://huggingface.co/chaowenguo/pal/blob/main/chilloutMix-Ni.safetensors', config='chaowenguo/stable-diffusion-v1-5', safety_checker=None, controlnet=controlnet, use_safetensors=True, torch_dtype=torch.bfloat16, motion_adapter=motion_adapter, vae=vae).to(torch_xla.core.xla_model.xla_device())
        pipeline.scheduler = diffusers.LCMScheduler.from_config(pipeline.scheduler.config, beta_schedule='linear')
        pipeline.load_lora_weights('chaowenguo/AnimateLCM', weight_name='AnimateLCM_sd15_t2v_lora.safetensors', adapter_name='lcm-lora')
        pipeline.enable_vae_slicing()
        pipeline.enable_vae_tiling()
        pipeline.enable_sequential_cpu_offload(device=torch_xla.core.xla_model.xla_device())
        pipeline.enable_free_init()
        pipeline.enable_free_noise(4, 2)
        pipeline.enable_free_noise_split_inference(4, 2)
        pipeline.unet.enable_forward_chunking(2)
        for _ in pipeline(prompt='A gorgeous smiling slim young cleavage robust boob bare-armed japanese girl, befautiful face, hands with five fingers, light background, best quality, extremely detailed, HD, ultra-realistic, 8K, HQ, masterpiece, trending on artstation, art, smooth', num_frames=builtins.len(conditioning_frames), conditioning_frames=conditioning_frames, generator=torch.manual_seed(index), num_inference_steps=20, negative_prompt='monochrome, dark background, longbody, lowres, bad anatomy, bad hands, fused fingers, missing fingers, too many fingers, extra digit, fewer difits, cropped, worst quality, low quality, deformed body, bloated, ugly, unrealistic, extra hands and arms').frames[0]: writer.append_data(numpy.asarray(_))

if __name__ == '__main__': torch_xla.launch(process)

pipeline.enable_sequential_cpu_offload(device=torch_xla.core.xla_model.xla_device())
cause:

File "/home/ubuntu/.local/lib/python3.10/site-packages/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py", line 1070, in __call__
    noise_pred = self.unet(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py", line 170, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/diffusers/models/unets/unet_motion_model.py", line 2179, in forward
    sample, res_samples = downsample_block(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/diffusers/models/unets/unet_motion_model.py", line 546, in forward
    hidden_states = motion_module(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/diffusers/models/unets/unet_motion_model.py", line 194, in forward
    hidden_states = block(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/diffusers/pipelines/free_noise_utils.py", line 131, in forward
    intermediate_tensor_or_tensor_tuple = self.module(*args, **inputs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/diffusers/models/attention.py", line 1115, in forward
    norm_hidden_states = self.norm1(hidden_states_chunk)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/normalization.py", line 217, in forward
    return F.layer_norm(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/functional.py", line 2900, in layer_norm
    return torch.layer_norm(
  File "/usr/local/lib/python3.10/dist-packages/torch/_decomp/__init__.py", line 88, in _fn
    return f(*args, **kwargs, out=None if is_none else out_kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_prims_common/wrappers.py", line 273, in _fn
    result = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_refs/__init__.py", line 3249, in native_layer_norm
    out = out * weight + bias
  File "/usr/local/lib/python3.10/dist-packages/torch/_prims_common/wrappers.py", line 273, in _fn
    result = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_prims_common/wrappers.py", line 141, in _fn
    result = fn(**bound.arguments)
  File "/usr/local/lib/python3.10/dist-packages/torch/_refs/__init__.py", line 1050, in _ref
    output = prim(a, b)
  File "/usr/local/lib/python3.10/dist-packages/torch/_refs/__init__.py", line 1657, in mul
    return prims.mul(a, b)
File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 716, in __call__
    return self._op(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_library/fake_impl.py", line 93, in meta_kernel
    return fake_impl_holder.kernel(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_library/utils.py", line 20, in __call__
    return self.func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/library.py", line 1151, in inner
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_library/custom_ops.py", line 614, in fake_impl
    return self._abstract_fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_prims/__init__.py", line 402, in _prim_elementwise_meta
    utils.check_same_device(*args_, allow_cpu_scalar_tensors=True)
  File "/usr/local/lib/python3.10/dist-packages/torch/_prims_common/__init__.py", line 742, in check_same_device
    raise RuntimeError(msg)
RuntimeError: Tensor on device meta is not on the expected device xla:0!

I want to ask how to use pipeline.enable_sequential_cpu_offload(device=torch_xla.core.xla_model.xla_device()) correctly?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant