Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
is_accelerate_version,
is_hpu_available,
is_torch_npu_available,
is_mlu_available,
is_torch_version,
is_transformers_version,
logging,
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
is_google_colab,
is_hf_hub_version,
is_hpu_available,
is_mlu_available,
is_inflect_available,
is_invisible_watermark_available,
is_k_diffusion_available,
Expand Down
5 changes: 5 additions & 0 deletions src/diffusers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[b

_torch_xla_available, _torch_xla_version = _is_package_available("torch_xla")
_torch_npu_available, _torch_npu_version = _is_package_available("torch_npu")
_torch_mlu_available, _torch_mlu_version = _is_package_available("torch_mlu")
_transformers_available, _transformers_version = _is_package_available("transformers")
_hf_hub_available, _hf_hub_version = _is_package_available("huggingface_hub")
_kernels_available, _kernels_version = _is_package_available("kernels")
Expand Down Expand Up @@ -243,6 +244,10 @@ def is_torch_npu_available():
return _torch_npu_available


def is_mlu_available():
return _torch_mlu_available


def is_flax_available():
return _flax_available

Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,8 @@ def get_device():
return "xpu"
elif torch.backends.mps.is_available():
return "mps"
elif torch.mlu.is_available():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
elif torch.mlu.is_available():
elif is_mlu_available:

also need to import this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

return "mlu"
else:
return "cpu"

Expand Down
Loading