diff --git a/python/taichi/lang/kernel_impl.py b/python/taichi/lang/kernel_impl.py index 5260208dcfbe3..abd40834f88fd 100644 --- a/python/taichi/lang/kernel_impl.py +++ b/python/taichi/lang/kernel_impl.py @@ -749,7 +749,9 @@ def call_back(): v.grad = torch.zeros_like(v) tmp = v - if str(v.device).startswith("cuda") and taichi_arch != _ti_core.Arch.cuda: + if (str(v.device) != "cpu") and not ( + str(v.device).startswith("cuda") and taichi_arch == _ti_core.Arch.cuda + ): # Getting a torch CUDA tensor on Taichi non-cuda arch: # We just replace it with a CPU tensor and by the end of kernel execution we'll use the # callback to copy the values back to the original CUDA tensor.