Skip to content

Commit 5771cd3

Browse files
authored
Deprecate devkind field from xla_model.xla_device (#9284)
1 parent a6f2b27 commit 5771cd3

File tree

8 files changed

+15
-32
lines changed

8 files changed

+15
-32
lines changed

benchmarks/check_xla_device.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def use_torch_xla2():
1717

1818
if not use_torch_xla2():
1919
import torch_xla.core.xla_model as xm
20-
devlist = xm.get_xla_supported_devices(devkind=devkind)
20+
devlist = xm.get_xla_supported_devices()
2121
else:
2222
# torch_xla2 needs jax to detect the device
2323
os.environ["JAX_PLATFORMS"] = devkind.lower(

scripts/bench_tensor_io.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111

1212
def run_benchmark(args, pos_args):
13-
devices = xm.get_xla_supported_devices(max_devices=args.max_devices)
13+
devices = xm.get_xla_supported_devices(args.max_devices)
1414
shape = [int(x) for x in args.shape.split(',')]
1515

1616
send_list = []

test/ds/test_dynamic_shape_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def forward(self, x):
4444

4545

4646
@unittest.skipIf(
47-
not xm.get_xla_supported_devices("TPU"),
47+
xm.xla_device_hw(torch_xla.device()) != 'TPU',
4848
f"The tests fail on CPU. See https://github.com/pytorch/xla/issues/4298 for more detail."
4949
)
5050
class TestDynamicShapeModels(unittest.TestCase):

test/pjrt/test_dynamic_plugin_tpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def setUpClass(cls):
2020
@staticmethod
2121
def _assert_tpus_exist(index=0):
2222
del index
23-
assert len(xm.get_xla_supported_devices('TPU')) > 0
23+
assert xm.xla_device_hw(torch_xla.device()) == 'TPU'
2424

2525
def test_single_process(self):
2626
with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor:

test/pjrt/test_runtime_tpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def test_xla_devices_single_process_one_chip_one_device_spawn(self):
129129

130130
def test_default_xla_devices(self):
131131
with concurrent.futures.ProcessPoolExecutor(max_workers=1) as e:
132-
f = e.submit(xm.get_xla_supported_devices, 'TPU')
132+
f = e.submit(xm.get_xla_supported_devices)
133133
devices = [torch.device(d) for d in f.result()]
134134

135135
self.assertListEqual(

test/spmd/test_xla_spmd_python_api_interaction.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@ def setUpClass(cls):
2222
super().setUpClass()
2323

2424
def test_get_xla_supported_devices(self):
25-
device_type = os.environ['PJRT_DEVICE']
26-
devices = xm.get_xla_supported_devices(device_type)
25+
devices = xm.get_xla_supported_devices()
2726
self.assertEqual(len(devices), 1)
2827

2928
def test_world_size(self):

test/test_autocast.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,8 @@ def compare(first, second):
348348
self.assertFalse(self.is_autocast_enabled())
349349

350350

351-
@unittest.skipIf(not xm.get_xla_supported_devices("TPU"), f"TPU autocast test.")
351+
@unittest.skipIf(
352+
xm.xla_device_hw(torch_xla.device()) != 'TPU', f"TPU autocast test.")
352353
class TestAutocastTPU(TestAutocastBase):
353354

354355
@classmethod
@@ -404,7 +405,7 @@ class TestOtherOps(unittest.TestCase):
404405

405406
# On TPU, the input of batch norm is casted into fp32, see torch_xla/csrc/autocast_mode.cpp
406407
@unittest.skipIf(
407-
not xm.get_xla_supported_devices("TPU"),
408+
xm.xla_device_hw(torch_xla.device()) != 'TPU',
408409
"the behavior of batch_norm autocast on TPU is different from others")
409410
def test_batch_norm_tpu(self):
410411
device = torch_xla.device()

torch_xla/core/xla_model.py

Lines changed: 6 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -66,38 +66,21 @@ def is_xla_tensor(tensor: torch.Tensor) -> bool:
6666
return tensor.device.type == 'xla'
6767

6868

69-
def get_xla_supported_devices(devkind: Optional[str] = None,
70-
max_devices: Optional[int] = None) -> List[str]:
69+
def get_xla_supported_devices(max_devices: Optional[int] = None) -> List[str]:
7170
"""Returns a list of supported devices of a given kind.
7271
7372
Args:
74-
devkind (string..., optional): If specified, a device type such as `TPU`,
75-
`CUDA`, `CPU`, or name of custom PJRT device.
7673
max_devices (int, optional): The maximum number of devices to be returned of
7774
that kind.
7875
7976
Returns:
8077
The list of device strings such as ['xla:0', 'xla:1', ...]
8178
"""
82-
# TODO(wcromar): Remove `devkind` after 2.3 release cut. We no longer support
83-
# multiple device types.
84-
if not devkind:
85-
devices = torch_xla._XLAC._xla_get_devices()
86-
return [
87-
f'xla:{i}'
88-
for i, _ in enumerate(devices[:max_devices] if max_devices else devices)
89-
]
90-
else:
91-
warnings.warn("`devkind` argument is deprecated and will be removed in a "
92-
"future release.")
93-
94-
xla_devices = _DEVICES.value
95-
kind_devices = []
96-
for i, device in enumerate(xla_devices):
97-
if re.match(devkind + r':\d+$', device):
98-
kind_devices.append('xla:{}'.format(i))
99-
if kind_devices:
100-
return kind_devices[:max_devices] if max_devices else kind_devices
79+
devices = torch_xla._XLAC._xla_get_devices()
80+
return [
81+
f'xla:{i}'
82+
for i, _ in enumerate(devices[:max_devices] if max_devices else devices)
83+
]
10184

10285

10386
def get_local_ordinal() -> int:

0 commit comments

Comments
 (0)