Skip to content

Commit 1b72cf0

Browse files
mlazospytorchmergebot
authored andcommitted
Add hasattr for tensor variable (pytorch#131008)
Pull Request resolved: pytorch#131008 Approved by: https://github.com/anijain2305 ghstack dependencies: pytorch#131007
1 parent 1f961ad commit 1b72cf0

File tree

5 files changed

+35
-0
lines changed

5 files changed

+35
-0
lines changed

test/dynamo/test_misc.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1364,6 +1364,20 @@ def fn(x):
13641364
r2 = opt_fn(i)
13651365
self.assertEqual(r1, r2)
13661366

1367+
def test_tensor_hasattr(self):
1368+
@torch.compile(fullgraph=True)
1369+
def fn(x):
1370+
if hasattr(x, "test"):
1371+
return x + 2
1372+
else:
1373+
return x + 1
1374+
1375+
self.assertEqual(torch.ones(2, 2) + 1, fn(torch.ones(2, 2)))
1376+
1377+
inp = torch.ones(2, 2)
1378+
inp.test = None
1379+
self.assertEqual(torch.ones(2, 2) + 2, fn(inp))
1380+
13671381
def test_shape_unpack(self):
13681382
def fn(x):
13691383
a, b = x.size()

test/dynamo_expected_failures/TestTorchDeviceTypeCPU.test_broadcast_fn_copy_cpu

Whitespace-only changes.

test/dynamo_expected_failures/TestTorchDeviceTypeCPU.test_broadcast_fn_map2_cpu

Whitespace-only changes.

test/dynamo_expected_failures/TestTorchDeviceTypeCPU.test_broadcast_fn_map_cpu

Whitespace-only changes.

torch/_dynamo/variables/tensor.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,27 @@ def method_attr__version(self, tx):
333333
tx, [self], {}
334334
)
335335

336+
def call_hasattr(self, tx, name):
337+
from . import GetAttrVariable
338+
from .builtin import BuiltinVariable
339+
340+
try:
341+
var = BuiltinVariable(getattr).call_function(
342+
tx, [self, ConstantVariable(name)], {}
343+
)
344+
# in the event that TensorVariable returns NotImplemented
345+
# BuiltinVariable.call_getattr returns GetAttrVariable
346+
ret_val = not isinstance(var, GetAttrVariable)
347+
except AttributeError:
348+
ret_val = False
349+
350+
if self.source:
351+
install_guard(
352+
AttrSource(self.source, name).make_guard(GuardBuilder.HASATTR)
353+
)
354+
355+
return ConstantVariable(ret_val)
356+
336357
def var_getattr(self, tx, name):
337358
from . import UserDefinedClassVariable
338359

0 commit comments

Comments
 (0)