Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 9 additions & 10 deletions test/export/test_export_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

# following are failing fake export on cuda device
fake_export_failures = {
xfail("geqrf"),
# xfail("geqrf"),
xfail("histogram"),
xfail("masked.amax"),
xfail("masked.amin"),
Expand All @@ -50,17 +50,17 @@
xfail("masked.std"),
xfail("masked.sum"),
xfail("masked.var"),
xfail("nn.functional.grid_sample"),
xfail("to_sparse"),
# xfail("nn.functional.grid_sample"),
# xfail("to_sparse"),
# cannot xfail as it is passing for cpu-only build
skip("nn.functional.conv2d"),
skip("nn.functional.scaled_dot_product_attention"),
# following are failing due to OptionalDeviceGuard
xfail("__getitem__"),
xfail("nn.functional.batch_norm"),
xfail("nn.functional.instance_norm"),
xfail("nn.functional.multi_margin_loss"),
xfail("nonzero"),
# xfail("__getitem__"),
# xfail("nn.functional.batch_norm"),
# xfail("nn.functional.instance_norm"),
# xfail("nn.functional.multi_margin_loss"),
# xfail("nonzero"),
}

fake_decomposition_failures = {
Expand All @@ -78,8 +78,7 @@ def _test_export_helper(self, dtype, op):

mode = FakeTensorMode(allow_non_fake_inputs=True)
converter = mode.fake_tensor_converter
# intentionally avoid cuda:0 to flush out some bugs
target_device = "cuda:1"
target_device = "cuda:0"

def to_fake_device(x):
x = converter.from_real_tensor(mode, x)
Expand Down
12 changes: 7 additions & 5 deletions test/inductor/test_benchmark_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,19 +220,21 @@ def relu(x):

x = torch.randn(int(16e6), device="cuda:1")

orig_benchmark_fused_nodes = TritonScheduling.benchmark_fused_nodes
orig_benchmark_codegened_module = (
TritonScheduling.benchmark_codegened_module
)

def mock_benchmark_fused_nodes(*args, **kwargs):
def benchmark_codegened_module(*args, **kwargs):
nonlocal hit_count
hit_count += 1
ms, path = orig_benchmark_fused_nodes(*args, **kwargs)
ms, path = orig_benchmark_codegened_module(*args, **kwargs)
self.assertTrue(ms > 0)
return ms, path

with unittest.mock.patch.object(
TritonScheduling,
"benchmark_fused_nodes",
mock_benchmark_fused_nodes,
"benchmark_codegened_module",
benchmark_codegened_module,
):
relu(x)
self.assertTrue(hit_count > 0)
Expand Down
4 changes: 3 additions & 1 deletion test/inductor/test_torchinductor_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1329,8 +1329,10 @@ def _get_tolerances(dtype):
# Triton
if has_triton():
adjusted_kwargs.update(
copy_to_gpu=False, reference_in_float=False
copy_to_gpu=False,
)
if device_type == GPU_TYPE:
adjusted_kwargs["reference_in_float"] = False

# skip checking gradient on CPU for now
if device_type == GPU_TYPE:
Expand Down
4 changes: 2 additions & 2 deletions test/profiler/test_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2107,7 +2107,7 @@ def test_dynamic_toggle(self):

self.assertTrue(any("aten" in e.name for e in p.events()))

self.assertTrue(any(device in e.name for e in p.events()))
self.assertTrue(any(device in e.name.lower() for e in p.events()))

self.assertTrue(any("kernel" in e.name.lower() for e in p.events()))

Expand Down Expand Up @@ -2232,7 +2232,7 @@ def check_correlations(event, disable_external_correlation):
if "cat" in event and event["cat"] in cuda_external_id_events:
if disable_external_correlation:
self.assertTrue("External id" not in event["args"])
elif event["name"] != "cudaDeviceSynchronize":
elif event["name"] not in ["cudaDeviceSynchronize", "hipDeviceSynchronize"]:
self.assertTrue("External id" in event["args"])
self.assertTrue(event["args"]["External id"] > 0)

Expand Down
5 changes: 4 additions & 1 deletion test/test_flop_counter.py
Original file line number Diff line number Diff line change
Expand Up @@ -853,7 +853,10 @@ def get_flops(model):
"FP8 is only supported on H100+, SM 8.9 and MI300+ devices",
)
def test_scaled_mm(self):
dtype = torch.float8_e4m3fnuz if torch.version.hip else torch.float8_e4m3fn
if (torch.version.hip and 'gfx942' in torch.cuda.get_device_properties(0).gcnArchName):
dtype = torch.float8_e4m3fnuz
else:
dtype = torch.float8_e4m3fn
with FlopCounterMode() as mode:
torch._scaled_mm(
torch.randn((3 * 16, 5 * 16), device="cuda").to(dtype),
Expand Down
6 changes: 3 additions & 3 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3956,7 +3956,7 @@ def test_rnn_check_device(self):

# input and weights are not at the same device
with self.assertRaisesRegex(RuntimeError,
"Input and parameter tensors are not at the same device"):
"Expected all tensors to be on the same device"):
model(input.to('cuda:0'))
with self.assertRaisesRegex(RuntimeError,
"Input and parameter tensors are not at the same device"):
Expand All @@ -3970,7 +3970,7 @@ def test_rnn_check_device(self):
else:
model(input, (hidden.to('cuda:0')))
with self.assertRaisesRegex(RuntimeError,
r"Input and hidden tensors are not at the same device"):
r"Expected all tensors to be on the same device"):
if mode == 'LSTM':
model_cuda(input.to('cuda:0'), (hidden, hidden))
else:
Expand All @@ -3979,7 +3979,7 @@ def test_rnn_check_device(self):
# hidden tensors are not at the same CUDA device
if mode == 'LSTM':
with self.assertRaisesRegex(RuntimeError,
"Input and hidden tensors are not at the same device"):
"Expected all tensors to be on the same device"):
model(input.to('cuda:0'), (hidden.to('cuda:0'), hidden.to('cuda:1')))

@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
Expand Down
3 changes: 2 additions & 1 deletion test/test_sparse_csr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2527,7 +2527,8 @@ def run_test(c, a, b):
self.assertEqual(actual.shape, c.shape)

for m, n, k in itertools.product([0, 5], repeat=3):
c = torch.empty(m, n, dtype=dtype, device=device, layout=torch.sparse_csr)
with torch.sparse.check_sparse_tensor_invariants(enable=False):
c = torch.empty(m, n, dtype=dtype, device=device, layout=torch.sparse_csr)
a = make_tensor((m, k), dtype=dtype, device=device)
b = make_tensor((k, n), dtype=dtype, device=device)
run_test(c, a, b)
Expand Down
11 changes: 6 additions & 5 deletions test/test_tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,14 @@ def tearDown(self):
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir)

def assertProto(self, str_to_compare):
def assertProto(self, actual_proto):
if expecttest.ACCEPT:
write_proto(str_to_compare, self)
write_proto(actual_proto, self)
return True
expected = read_expected_content(self)
str_to_compare = str(str_to_compare)
self.assertEqual(remove_whitespace(str_to_compare), remove_whitespace(expected))
expected_str = read_expected_content(self)
expected_proto = Summary()
text_format.Parse(expected_str, expected_proto)
self.assertEqual(actual_proto, expected_proto)

def assertImageProto(self, actual_proto):
if expecttest.ACCEPT:
Expand Down
7 changes: 7 additions & 0 deletions test/test_type_hints.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,18 @@ def get_all_examples():
"import io",
"import itertools",
"",
"from typing import Any, ClassVar, Generic, List, Tuple, Union",
"from typing_extensions import Literal, get_origin, TypeAlias",
"T: TypeAlias = object",
"",
"import numpy",
"",
"import torch",
"import torch.nn.functional as F",
"",
"from typing_extensions import ParamSpec as _ParamSpec",
"ParamSpec = _ParamSpec",
"",
# for requires_grad_ example
# NB: We are parsing this file as Python 2, so we must use
# Python 2 type annotation syntax
Expand Down
10 changes: 8 additions & 2 deletions torch/_inductor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1080,8 +1080,14 @@ def to_symbol(
integer=replaced.is_integer, # type: ignore[attr-defined]
nonnegative=replaced.is_nonnegative, # type: ignore[attr-defined]
)
else:
return replacement

if isinstance(replacement, bool):
return sympy.true if replacement else sympy.false
if isinstance(replacement, int):
return sympy.Integer(replacement)
if isinstance(replacement, float):
return sympy.Float(replacement)
return replacement

# xreplace is faster than subs, but is way more picky
return sympy.sympify(expr).xreplace(
Expand Down
2 changes: 1 addition & 1 deletion torch/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1529,7 +1529,7 @@ def dim_order(
... ) # It can be mapped to contiguous format
(0, 1, 2, 3)
>>> try:
... torch.empty((1, 2, 3, 4)).dim_order(ambiguity_check="ILLEGAL")
... torch.empty((1, 2, 3, 4)).dim_order(ambiguity_check="ILLEGAL") # type: ignore[arg-type]
... except TypeError as e:
... print(e)
The ambiguity_check argument must be a bool or a list of memory formats.
Expand Down
2 changes: 1 addition & 1 deletion torch/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -1377,7 +1377,7 @@ def load(
# Load all tensors onto GPU 1
>>> torch.load(
... "tensors.pt",
... map_location=lambda storage, loc: storage.cuda(1),
... map_location=lambda storage, loc: storage.cuda(1), # type: ignore[attr-defined]
... weights_only=True,
... ) # type: ignore[attr-defined]
# Map tensors from GPU 1 to GPU 0
Expand Down