[dynamo] Minor compile time optimizations in torch.py (#121615)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121615
Approved by: https://github.com/oulgen
diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py
index a73de8b..fee1680 100644
--- a/test/dynamo/test_misc.py
+++ b/test/dynamo/test_misc.py
@@ -776,7 +776,7 @@
                     """\
 def forward(self, arg0_1: "f32[3]", arg1_1: "f32[3]", arg2_1: "f32[3]", arg3_1: "f32[3]"):
         # No stacktrace found for following nodes
-        foo_default = torch.ops.mylib.foo.default(None, [arg0_1, arg3_1], arg1_1, 2, arg2_1);  arg0_1 = arg3_1 = arg1_1 = arg2_1 = None
+        foo_default = torch.ops.mylib.foo.default(None, [arg2_1, arg3_1], arg0_1, 2, arg1_1);  arg2_1 = arg3_1 = arg0_1 = arg1_1 = None
         return ()""",
                 )
 
diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py
index eb6b35c..1974eb0 100644
--- a/torch/_dynamo/utils.py
+++ b/torch/_dynamo/utils.py
@@ -1033,6 +1033,10 @@
     return unspec_count > 0
 
 
+def check_unspec_or_constant_args(args, kwargs):
+    return check_constant_args(args, kwargs) or check_unspec_python_args(args, kwargs)
+
+
 def check_numpy_ndarray_args(args, kwargs):
     from .variables.tensor import NumpyNdarrayVariable
 
diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py
index dc5fc0a..a027f7f 100644
--- a/torch/_dynamo/variables/torch.py
+++ b/torch/_dynamo/variables/torch.py
@@ -23,8 +23,7 @@
 from ..guards import GuardBuilder, install_guard
 from ..source import SyntheticLocalSource
 from ..utils import (
-    check_constant_args,
-    check_unspec_python_args,
+    check_unspec_or_constant_args,
     guard_if_dyn,
     has_torch_function,
     hashable,
@@ -49,29 +48,33 @@
 
 log = logging.getLogger(__name__)
 
-supported_ctx_manager_classes = {
-    torch.profiler.profiler.profile,
-    torch.autograd.profiler.profile,
-    torch.autograd.profiler.record_function,
-    torch._C.DisableTorchFunctionSubclass,
-    torch._functorch.vmap.vmap_increment_nesting,
-    torch._functorch.eager_transforms.grad_increment_nesting,
-    torch._functorch.eager_transforms.enable_inplace_requires_grad,
-    torch.amp.autocast_mode.autocast,
-    torch.autograd.grad_mode.enable_grad,
-    torch.autograd.grad_mode.inference_mode,
-    torch.autograd.grad_mode.no_grad,
-    torch.autograd.grad_mode.set_grad_enabled,
-    torch.autograd.graph.disable_saved_tensors_hooks,
-    torch.cpu.amp.autocast_mode.autocast,
-    torch.cuda.amp.autocast_mode.autocast,
-}
+supported_ctx_manager_classes = dict.fromkeys(
+    [
+        torch.profiler.profiler.profile,
+        torch.autograd.profiler.profile,
+        torch.autograd.profiler.record_function,
+        torch._C.DisableTorchFunctionSubclass,
+        torch._functorch.vmap.vmap_increment_nesting,
+        torch._functorch.eager_transforms.grad_increment_nesting,
+        torch._functorch.eager_transforms.enable_inplace_requires_grad,
+        torch.amp.autocast_mode.autocast,
+        torch.autograd.grad_mode.enable_grad,
+        torch.autograd.grad_mode.inference_mode,
+        torch.autograd.grad_mode.no_grad,
+        torch.autograd.grad_mode.set_grad_enabled,
+        torch.autograd.graph.disable_saved_tensors_hooks,
+        torch.cpu.amp.autocast_mode.autocast,
+        torch.cuda.amp.autocast_mode.autocast,
+    ]
+)
 
 
-REWRITE_OPS_TO_TENSOR_SIZE_METHOD = [
-    torch.onnx.operators.shape_as_tensor,
-    torch._shape_as_tensor,
-]
+REWRITE_OPS_TO_TENSOR_SIZE_METHOD = dict.fromkeys(
+    [
+        torch.onnx.operators.shape_as_tensor,
+        torch._shape_as_tensor,
+    ]
+)
 
 constant_fold_functions = [
     torch._assert,
@@ -91,8 +94,6 @@
     torch.promote_types,
     torch._C._get_privateuse1_backend_name,
 ]
-
-
 if torch.distributed.is_available():
     constant_fold_functions.extend(
         [
@@ -101,6 +102,8 @@
             torch.distributed.get_world_size,
         ]
     )
+# Convert to dict for O(1) access times
+constant_fold_functions = dict.fromkeys(constant_fold_functions)
 
 
 tracing_state_functions = {
@@ -287,13 +290,11 @@
             TensorVariable,
             UserDefinedObjectVariable,
         )
-
         from .builder import wrap_fx_proxy, wrap_fx_proxy_cls
 
-        constant_args = check_constant_args(args, kwargs)
-        unspec_python_args = check_unspec_python_args(args, kwargs)
-
-        if self.can_constant_fold_through() and (constant_args or unspec_python_args):
+        if self.can_constant_fold_through() and check_unspec_or_constant_args(
+            args, kwargs
+        ):
             # constant fold
             return ConstantVariable.create(
                 self.as_python_constant()(
@@ -328,7 +329,9 @@
             return tx.inline_user_function_return(
                 SourcelessBuilder()(tx, polyfill.accumulate_grad), args, kwargs
             )
-        elif self.value == math.radians and not (constant_args or unspec_python_args):
+        elif self.value == math.radians and not check_unspec_or_constant_args(
+            args, kwargs
+        ):
             # Use polyfill to convert math.radians(x) into math.pi * x / 180.0
             from .builder import SourcelessBuilder