Make caling type on user defined class UserError (#98366)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98366
Approved by: https://github.com/anijain2305
diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py
index 67b5dfd..909c7d2 100644
--- a/test/dynamo/test_export.py
+++ b/test/dynamo/test_export.py
@@ -2322,6 +2322,31 @@
# this should be captured as static, as export won't generate any symbols.
self.assertEqual(gm(torch.ones(2, 4)), torch.ones(2, 4).sin())
+ def test_access_class_method_from_user_class(self):
+ class A:
+ @classmethod
+ def func(cls):
+ return torch.Tensor([4, 5])
+
+ def f(x):
+ a = A()
+ return x.sum() + type(a).func().sum()
+
+ with self.assertRaisesRegex(torch._dynamo.exc.UserError, "Can't call type()"):
+ gm, _ = torch._dynamo.export(
+ f, torch.ones(6, 4), aten_graph=True, tracing_mode="symbolic"
+ )
+
+ def f_correct(x):
+ a = A()
+ return x.sum() + a.__class__.func().sum()
+
+ gm, _ = torch._dynamo.export(
+ f_correct, torch.ones(6, 4), aten_graph=True, tracing_mode="symbolic"
+ )
+
+ self.assertEqual(f_correct(torch.ones(6, 4)), gm(torch.ones(6, 4)))
+
common_utils.instantiate_parametrized_tests(ExportTests)
diff --git a/torch/_dynamo/exc.py b/torch/_dynamo/exc.py
index c35e1ce..bafb581 100644
--- a/torch/_dynamo/exc.py
+++ b/torch/_dynamo/exc.py
@@ -80,6 +80,7 @@
class UserErrorType(Enum):
DYNAMIC_CONTROL_FLOW = auto()
+ ANTI_PATTERN = auto()
class UserError(Unsupported):
diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py
index a10b11f..ea84921 100644
--- a/torch/_dynamo/variables/builtin.py
+++ b/torch/_dynamo/variables/builtin.py
@@ -12,7 +12,7 @@
from .. import config, variables
from ..allowed_functions import is_allowed
-from ..exc import unimplemented, Unsupported
+from ..exc import unimplemented, Unsupported, UserError, UserErrorType
from ..guards import GuardBuilder
from ..replay_record import DummyModule
from ..source import AttrSource, is_constant_source, SuperSource, TypeSource
@@ -1075,7 +1075,11 @@
self, obj
)
- unimplemented(f"type({obj})")
+ raise UserError(
+ UserErrorType.ANTI_PATTERN,
+ "Can't call type() on generated custom object. "
+ "Please use __class__ instead",
+ )
def call_reversed(self, tx, obj: VariableTracker):
if obj.has_unpack_var_sequence(tx):