diff --git a/test/neuron/run_tests.sh b/test/neuron/run_tests.sh index c1bb656b8b37..1f5025239b32 100755 --- a/test/neuron/run_tests.sh +++ b/test/neuron/run_tests.sh @@ -172,7 +172,8 @@ function run_xla_op_tests1 { function run_xla_op_tests2 { run_test "$CDIR/pjrt/test_dtypes.py" #run_test "$CDIR/test_while_loop.py" - run_test "$CDIR/test_scan.py" + run_test "$CDIR/scan/test_scan.py" + run_xla_hlo_debug "$CDIR/scan/test_scan_debug.py" run_test "$CDIR/test_autocast.py" run_test "$CDIR/test_grad_checkpoint.py" run_test "$CDIR/test_grad_checkpoint.py" "$@" --test_autocast @@ -319,4 +320,4 @@ if [ "$LOGFILE" != "" ]; then run_tests 2>&1 | tee $LOGFILE else run_tests -fi \ No newline at end of file +fi diff --git a/test/run_tests.sh b/test/run_tests.sh index bf72d9d075b6..ba87a3ce3653 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -192,6 +192,7 @@ function run_xla_op_tests2 { run_test "$CDIR/scan/test_scan.py" run_test "$CDIR/scan/test_scan_spmd.py" run_test "$CDIR/scan/test_scan_layers.py" + run_xla_hlo_debug run_test "$CDIR/scan/test_scan_debug.py" run_test "$CDIR/test_autocast.py" run_test "$CDIR/eager/test_eager.py" run_test "$CDIR/eager/test_eager_with_xla_compile.py" diff --git a/test/scan/test_scan_debug.py b/test/scan/test_scan_debug.py new file mode 100644 index 000000000000..d800a36998df --- /dev/null +++ b/test/scan/test_scan_debug.py @@ -0,0 +1,69 @@ +import sys +import os +import unittest + +import torch +import torch_xla +import torch_xla.debug.metrics as met +from torch_xla.experimental.scan import scan + +parent_folder = os.path.dirname(os.path.dirname(__file__)) +sys.path.append(parent_folder) +from test_utils import XlaTestCase # type:ignore + + +class ScanDebugTest(XlaTestCase): + + def test_scan_no_recompile_with_debug_annotations(self): + """ + When someone adds debugging annotations to the HLO via env vars, the + HLO graph of the combine function captured by scan would have additional metadata + such as line numbers and scopes. Still, that should not cause the final IR + graph hash to change. This is subtle because the IR of the `scan` operation will + reference the HLO computation within. + """ + assert os.environ["XLA_HLO_DEBUG"] == "1" + met.clear_all() + + def fn(carry, x): + carry = carry + x + y = x + 42 + return carry, y + + # fn2 should trace to the same graph despite having different line numbers + def fn2(carry, x): + carry = carry + x + y = x + 42 + return carry, y + + init = torch.tensor([0.0, 0.0], + requires_grad=True, + device=torch_xla.device()) + xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], + requires_grad=True, + device=torch_xla.device()) + + # Run some graph involving a scan operation two times. + for i in range(2): + init.grad = None + xs.grad = None + carry, ys = scan(fn, init, xs) + (carry.sum() + ys.sum()).backward() + torch_xla.sync() + + # Use a differently named but semantically the same combine function. + # This should still trace to identical HLO and hence reuse the cache. + init.grad = None + xs.grad = None + carry, ys = scan(fn2, init, xs) + (carry.sum() + ys.sum()).backward() + torch_xla.sync() + + # Should only compile once and cache the last two times. + self.assertEqual(int(met.counter_value("UncachedCompile")), 1) + self.assertEqual(int(met.counter_value("CachedCompile")), 2) + + +if __name__ == '__main__': + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/tpu/run_tests.sh b/test/tpu/run_tests.sh index d33d50cd9580..76959214300a 100755 --- a/test/tpu/run_tests.sh +++ b/test/tpu/run_tests.sh @@ -33,6 +33,7 @@ python3 "$TEST_CDIR/test_while_loop.py" python3 "$TEST_CDIR/scan/test_scan.py" python3 "$TEST_CDIR/scan/test_scan_spmd.py" python3 "$TEST_CDIR/scan/test_scan_layers.py" +run_xla_hlo_debug python3 "$TEST_CDIR/scan/test_scan_debug.py" python3 "$TEST_CDIR/test_pallas.py" -v python3 "$TEST_CDIR/test_pallas_spmd.py" python3 "$TEST_CDIR/test_tpu_paged_attention_kernel.py"