Skip to content

Commit

Permalink
[lang] Fix MPS / other torch backend from_torch (#8298)
Browse files Browse the repository at this point in the history
Issue: #6861

### Brief Summary

<!--
copilot:summary
-->
### <samp>🤖 Generated by Copilot at fd38fea</samp>

Fix a bug that prevents using torch tensors on non-CPU devices other
than CUDA in Taichi kernels. Update the condition in `kernel_impl.py` to
copy tensors to the correct device before and after kernel execution.

### Walkthrough

<!--
copilot:walkthrough
-->
### <samp>🤖 Generated by Copilot at fd38fea</samp>

* Fix a bug that prevents using torch tensors on non-CPU devices other
than CUDA
([link](https://github.com/taichi-dev/taichi/pull/8298/files?diff=unified&w=0#diff-a157043b38542c8145447ff342fda65fe4d54fb777fe514daa70007e83e20dc1L752-R752))

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
bobcao3 and pre-commit-ci[bot] authored Jul 26, 2023
1 parent daeb013 commit cd620d6
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion python/taichi/lang/kernel_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,9 @@ def call_back():
v.grad = torch.zeros_like(v)

tmp = v
if str(v.device).startswith("cuda") and taichi_arch != _ti_core.Arch.cuda:
if (str(v.device) != "cpu") and not (
str(v.device).startswith("cuda") and taichi_arch == _ti_core.Arch.cuda
):
# Getting a torch CUDA tensor on Taichi non-cuda arch:
# We just replace it with a CPU tensor and by the end of kernel execution we'll use the
# callback to copy the values back to the original CUDA tensor.
Expand Down

0 comments on commit cd620d6

Please sign in to comment.