Skip to content

Commit c9c9b95

Browse files
committed
Fix errors
1 parent 06d4b28 commit c9c9b95

File tree

3 files changed

+7
-7
lines changed

3 files changed

+7
-7
lines changed

test/pytorch_test_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -559,7 +559,7 @@ def _alt_lookup(d, keys, defval):
559559
def instantiate_test(cls, name, test, *, generic_cls):
560560
test_name = name + '_' + cls.device_type
561561
class_name = cls.__name__
562-
real_device_type = xm.xla_device_hw(str(torch.device('xla')))
562+
real_device_type = xm.xla_device_hw(str(torch_xla.device()))
563563
assert real_device_type in DISABLED_TORCH_TESTS, 'Unsupported device type:' + real_device_type
564564
disabled_torch_tests = DISABLED_TORCH_TESTS[real_device_type]
565565

@@ -632,7 +632,7 @@ def get_primary_device(cls):
632632
@classmethod
633633
def setUpClass(cls):
634634
# Sets the primary test device to the xla_device (CPU or TPU)
635-
cls.primary_device = str(torch.device('xla'))
635+
cls.primary_device = str(torch_xla.device())
636636
torch_xla._XLAC._xla_set_mat_mul_precision('highest')
637637

638638
def setUp(self):

test/test_operations.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -440,12 +440,11 @@ def test_nonzero_cast(self):
440440
class TestOptimizationBarrier(test_utils.XlaTestCase):
441441

442442
def test_optimization_barrier_correctness(self):
443-
device = torch.device('xla')
444443
# only test optimization_barrier on TPU
445-
if xm.xla_device_hw(device) != 'TPU':
444+
if xm.xla_device_hw(torch_xla.device()) != 'TPU':
446445
return
447-
x = torch.randn(5, 5, device=device)
448-
y = torch.randn(5, 5, device=device)
446+
x = torch.randn(5, 5, device='xla')
447+
y = torch.randn(5, 5, device='xla')
449448
z = x + y
450449
xm.optimization_barrier_([x, y])
451450
self.assertEqual(z, x + y)

torch_xla/runtime.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,8 @@ def local_ordinal() -> int:
156156
Local ordinal is in range [0, local_device_count)."""
157157
local_rank = xu.getenv_as(xenv.PJRT_LOCAL_PROCESS_RANK, int, 0)
158158
devices_per_process = addressable_device_count()
159-
return local_rank * devices_per_process + torch.device('xla').index
159+
return local_rank * devices_per_process + torch.device(
160+
torch_xla._XLAC._xla_get_default_device()).index
160161

161162

162163
def process_index() -> int:

0 commit comments

Comments
 (0)