Skip to content

Commit 01a5692

Browse files
authored
Rope in float32 for mps or npu compatibility (#12665)
rope in float32
1 parent a9e4883 commit 01a5692

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

src/diffusers/models/transformers/transformer_prx.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,12 @@ def __init__(self, dim: int, theta: int, axes_dim: List[int]):
275275

276276
def rope(self, pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
277277
assert dim % 2 == 0
278-
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
278+
279+
is_mps = pos.device.type == "mps"
280+
is_npu = pos.device.type == "npu"
281+
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
282+
283+
scale = torch.arange(0, dim, 2, dtype=dtype, device=pos.device) / dim
279284
omega = 1.0 / (theta**scale)
280285
out = pos.unsqueeze(-1) * omega.unsqueeze(0)
281286
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)

0 commit comments

Comments
 (0)