Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-introduce "XLA_USE_32BIT_LONG" flag #8589

Merged
merged 3 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions test/neuron/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,7 @@ function run_xla_op_tests1 {
run_test "$CDIR/dynamo/test_graph_input_matcher.py"
run_test "$CDIR/dynamo/test_dynamo_config.py"
run_save_tensor_ir run_test "$CDIR/dynamo/test_dynamo_graph_dump.py"
#run_test "$CDIR/test_data_type.py"
run_use_bf16 "$CDIR/test_data_type.py"
run_downcast_bf16 "$CDIR/test_data_type.py"
run_test "$CDIR/test_data_type.py"
#run_test "$CDIR/test_fp8.py"
run_xla_ir_debug "$CDIR/test_env_var_mapper.py"
run_xla_hlo_debug "$CDIR/test_env_var_mapper.py"
Expand Down
4 changes: 2 additions & 2 deletions test/neuron/test_neuron_data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ def test_datatypes(self):
(torch.double, "f32", torch.floor_divide),
(torch.int16, "s32", torch.add),
(torch.int32, "s32", torch.add),
(torch.int64, "s32", torch.add),
(torch.int64, "s64", torch.add),
(torch.uint16, "u32", torch.add),
(torch.uint32, "u32", torch.add),
(torch.uint64, "u32", torch.add)]
(torch.uint64, "u64", torch.add)]

for dtype, op_xla_dtype, op in test_cases:
with self.subTest(dtype=dtype, op_xla_dtype=op_xla_dtype, op=op):
Expand Down
2 changes: 0 additions & 2 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,6 @@ function run_xla_op_tests1 {
run_test "$CDIR/dynamo/test_dynamo_config.py"
run_save_tensor_ir run_test "$CDIR/dynamo/test_dynamo_graph_dump.py"
run_test "$CDIR/test_data_type.py"
run_use_bf16 "$CDIR/test_data_type.py"
run_downcast_bf16 "$CDIR/test_data_type.py"
run_test "$CDIR/test_fp8.py"
run_xla_ir_debug run_test "$CDIR/test_env_var_mapper.py"
run_xla_hlo_debug run_test "$CDIR/test_env_var_mapper.py"
Expand Down
95 changes: 52 additions & 43 deletions test/test_data_type.py
Original file line number Diff line number Diff line change
@@ -1,73 +1,82 @@
import os
import sys
import unittest

import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.utils.utils as xu
import unittest


def check_env_flag(name, default=''):
return os.getenv(name, default).upper() in ['ON', '1', 'YES', 'TRUE', 'Y']
class XlaDataTypeTest(unittest.TestCase):

def setUp(cls):
cls.original_env = {
'XLA_USE_BF16': os.environ.get('XLA_USE_BF16'),
'XLA_DOWNCAST_BF16': os.environ.get('XLA_DOWNCAST_BF16'),
'XLA_USE_32BIT_LONG': os.environ.get('XLA_USE_32BIT_LONG')
}

class XlaDataTypeTest(unittest.TestCase):
def tearDown(self):
for key, value in self.original_env.items():
if value is None:
os.environ.pop(key, None)
else:
os.environ[key] = value

def test_datatype_f32(self):
t1 = torch.tensor([2.0, 3.0], dtype=torch.float, device=xm.xla_device())
t2 = torch.tensor([2.0, 3.0], dtype=torch.float, device=xm.xla_device())
t3 = torch.div(t1, t2, rounding_mode='floor')
assert t3.dtype == torch.float
def _set_env(self, **kwargs):
for key, value in kwargs.items():
os.environ[key] = value

hlo_text = torch_xla._XLAC._get_xla_tensors_text([t3])
device_data_hlo = hlo_text.split('\n')[1]
assert 'xla::device_data' in device_data_hlo, device_data_hlo
if check_env_flag('XLA_USE_BF16') or check_env_flag('XLA_DOWNCAST_BF16'):
assert 'bf16' in device_data_hlo, device_data_hlo
elif check_env_flag('XLA_USE_FP16') or check_env_flag('XLA_DOWNCAST_FP16'):
assert 'f16' in device_data_hlo, device_data_hlo
else:
assert 'f32' in device_data_hlo, device_data_hlo

def test_datatype_f64(self):
t1 = torch.tensor([2.0, 3.0], dtype=torch.double, device=xm.xla_device())
t2 = torch.tensor([2.0, 3.0], dtype=torch.double, device=xm.xla_device())
t3 = torch.div(t1, t2, rounding_mode='floor')
assert t3.dtype == torch.double
def _test_datatype(self, dtype, expected_type, op):
t1 = torch.tensor([2, 3], dtype=dtype, device=xm.xla_device())
t2 = torch.tensor([2, 3], dtype=dtype, device=xm.xla_device())
t3 = op(t1, t2)
self.assertEqual(t3.dtype, dtype)

hlo_text = torch_xla._XLAC._get_xla_tensors_text([t3])
device_data_hlo = hlo_text.split('\n')[1]
assert 'xla::device_data' in device_data_hlo, device_data_hlo
if check_env_flag('XLA_USE_BF16'):
assert 'bf16' in device_data_hlo, device_data_hlo
elif check_env_flag('XLA_USE_FP16'):
assert 'f16' in device_data_hlo, device_data_hlo
elif check_env_flag('XLA_DOWNCAST_BF16') or check_env_flag(
'XLA_DOWNCAST_FP16'):
assert 'f32' in device_data_hlo, device_data_hlo
else:
assert 'f64' in device_data_hlo, device_data_hlo
device_data_hlo = hlo_text.split('\n')[2]
self.assertIn('xla::device_data', device_data_hlo)
self.assertIn(expected_type, device_data_hlo)

def test_datatype_use_bf16(self):
self._set_env(XLA_USE_BF16='1')
self._test_datatype(torch.double, 'bf16', torch.floor_divide)
self._test_datatype(torch.float, 'bf16', torch.floor_divide)

def test_datatype_downcast_bf16(self):
self._set_env(XLA_DOWNCAST_BF16='1')
self._test_datatype(torch.double, 'bf16', torch.floor_divide)
self._test_datatype(torch.float, 'bf16', torch.floor_divide)

def test_datatype_use_32bit_long(self):
self._set_env(XLA_USE_32BIT_LONG='1')
self._test_datatype(torch.int64, 's32', torch.add)
self._test_datatype(torch.uint64, 'u32', torch.add)

def test_module_to_dtype(self):
device = torch_xla.device()
linear = torch.nn.Linear(
5, 10, dtype=torch.float32).to(device).to(torch.bfloat16)
input = torch.randn(
10,
5,
).to(device).to(torch.bfloat16)
input = torch.randn(10, 5).to(device).to(torch.bfloat16)
xm.mark_step()
res = linear(input)

hlo_text = torch_xla._XLAC._get_xla_tensors_text([res])
res_hlo = hlo_text.split('\n')[-3]
assert 'bf16' in res_hlo, res_hlo
self.assertIn('bf16', res_hlo)

linear_weight_hlo = torch_xla._XLAC._get_xla_tensors_text([linear.weight
]).split('\n')[-3]
assert 'bf16' in linear_weight_hlo, linear_weight_hlo
self.assertIn('bf16', linear_weight_hlo)


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
suite = unittest.TestSuite()
suite.addTest(XlaDataTypeTest("test_datatype_use_bf16"))
suite.addTest(XlaDataTypeTest("test_datatype_downcast_bf16"))
suite.addTest(XlaDataTypeTest("test_datatype_use_32bit_long"))
suite.addTest(XlaDataTypeTest("test_module_to_dtype"))
runner = unittest.TextTestRunner(failfast=True)
result = runner.run(suite)
sys.exit(0 if result.wasSuccessful() else 1)
1 change: 1 addition & 0 deletions test/tpu/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ python3 "$TEST_CDIR/torch_distributed/test_torch_distributed_multi_all_reduce_xl
python3 "$TEST_CDIR/torch_distributed/test_torch_distributed_reduce_scatter_xla_backend.py"
python3 "$TEST_CDIR/quantized_ops/test_dot_general.py"
run_xla_ir_hlo_debug python3 "$TEST_CDIR/test_user_computation_debug_cache.py"
python3 "$TEST_CDIR/test_data_type.py"

# run examples, each test should takes <2 minutes
python3 "$TEST_CDIR/../examples/data_parallel/train_resnet_spmd_data_parallel.py"
Expand Down
4 changes: 1 addition & 3 deletions torch_xla/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,7 @@ def _setup_tpu_vm_library_path() -> bool:


def _check_deprecated_env_var():
deprecated_env_vars = [
'XLA_USE_FP16', 'XLA_DOWNCAST_FP16', 'XLA_USE_32BIT_LONG'
]
deprecated_env_vars = ['XLA_USE_FP16', 'XLA_DOWNCAST_FP16']
for env_var in deprecated_env_vars:
if os.environ.get(env_var):
warnings.warn(f"The environment variable '{env_var}' is deprecated "
Expand Down
22 changes: 18 additions & 4 deletions torch_xla/csrc/dtype.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,17 @@ bool ShouldDowncastToBF16() {
return downcast_bf16;
}

bool ShouldUse32BitLong() {
bool use_32bit_long =
runtime::sys_util::GetEnvBool("XLA_USE_32BIT_LONG", false);
if (use_32bit_long) {
std::cout
<< "XLA_USE_32BIT_LONG will be deprecated after the 2.6 release\n";
TF_LOG(INFO) << "Using 32bit integers for kLong values";
}
return use_32bit_long;
}

bool UseBF16() {
static bool use_bf16 = ShouldUseBF16();
return use_bf16;
Expand All @@ -40,6 +51,11 @@ bool DowncastBF16() {
return downcast_bf16;
}

bool Use32BitLong() {
static bool use_32bit_long = ShouldUse32BitLong();
return use_32bit_long;
}

} // namespace

at::ScalarType TorchTypeFromXlaType(xla::PrimitiveType xla_type) {
Expand Down Expand Up @@ -143,11 +159,9 @@ xla::PrimitiveType MaybeDowncastToXlaDeviceType(
return CheckNeuronDevice(hw_type) ? xla::PrimitiveType::S32
: xla::PrimitiveType::S16;
case xla::PrimitiveType::S64:
return CheckNeuronDevice(hw_type) ? xla::PrimitiveType::S32
: xla::PrimitiveType::S64;
return Use32BitLong() ? xla::PrimitiveType::S32 : xla::PrimitiveType::S64;
case xla::PrimitiveType::U64:
return CheckNeuronDevice(hw_type) ? xla::PrimitiveType::U32
: xla::PrimitiveType::U64;
return Use32BitLong() ? xla::PrimitiveType::U32 : xla::PrimitiveType::U64;
case xla::PrimitiveType::C128:
return xla::PrimitiveType::C128;
default:
Expand Down
Loading