[dynamo] Minor compile time optimizations in torch.py (#121615)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121615
Approved by: https://github.com/oulgen
diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py
index a73de8b..fee1680 100644
--- a/test/dynamo/test_misc.py
+++ b/test/dynamo/test_misc.py
@@ -776,7 +776,7 @@
"""\
def forward(self, arg0_1: "f32[3]", arg1_1: "f32[3]", arg2_1: "f32[3]", arg3_1: "f32[3]"):
# No stacktrace found for following nodes
- foo_default = torch.ops.mylib.foo.default(None, [arg0_1, arg3_1], arg1_1, 2, arg2_1); arg0_1 = arg3_1 = arg1_1 = arg2_1 = None
+ foo_default = torch.ops.mylib.foo.default(None, [arg2_1, arg3_1], arg0_1, 2, arg1_1); arg2_1 = arg3_1 = arg0_1 = arg1_1 = None
return ()""",
)
diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py
index eb6b35c..1974eb0 100644
--- a/torch/_dynamo/utils.py
+++ b/torch/_dynamo/utils.py
@@ -1033,6 +1033,10 @@
return unspec_count > 0
+def check_unspec_or_constant_args(args, kwargs):
+ return check_constant_args(args, kwargs) or check_unspec_python_args(args, kwargs)
+
+
def check_numpy_ndarray_args(args, kwargs):
from .variables.tensor import NumpyNdarrayVariable
diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py
index dc5fc0a..a027f7f 100644
--- a/torch/_dynamo/variables/torch.py
+++ b/torch/_dynamo/variables/torch.py
@@ -23,8 +23,7 @@
from ..guards import GuardBuilder, install_guard
from ..source import SyntheticLocalSource
from ..utils import (
- check_constant_args,
- check_unspec_python_args,
+ check_unspec_or_constant_args,
guard_if_dyn,
has_torch_function,
hashable,
@@ -49,29 +48,33 @@
log = logging.getLogger(__name__)
-supported_ctx_manager_classes = {
- torch.profiler.profiler.profile,
- torch.autograd.profiler.profile,
- torch.autograd.profiler.record_function,
- torch._C.DisableTorchFunctionSubclass,
- torch._functorch.vmap.vmap_increment_nesting,
- torch._functorch.eager_transforms.grad_increment_nesting,
- torch._functorch.eager_transforms.enable_inplace_requires_grad,
- torch.amp.autocast_mode.autocast,
- torch.autograd.grad_mode.enable_grad,
- torch.autograd.grad_mode.inference_mode,
- torch.autograd.grad_mode.no_grad,
- torch.autograd.grad_mode.set_grad_enabled,
- torch.autograd.graph.disable_saved_tensors_hooks,
- torch.cpu.amp.autocast_mode.autocast,
- torch.cuda.amp.autocast_mode.autocast,
-}
+supported_ctx_manager_classes = dict.fromkeys(
+ [
+ torch.profiler.profiler.profile,
+ torch.autograd.profiler.profile,
+ torch.autograd.profiler.record_function,
+ torch._C.DisableTorchFunctionSubclass,
+ torch._functorch.vmap.vmap_increment_nesting,
+ torch._functorch.eager_transforms.grad_increment_nesting,
+ torch._functorch.eager_transforms.enable_inplace_requires_grad,
+ torch.amp.autocast_mode.autocast,
+ torch.autograd.grad_mode.enable_grad,
+ torch.autograd.grad_mode.inference_mode,
+ torch.autograd.grad_mode.no_grad,
+ torch.autograd.grad_mode.set_grad_enabled,
+ torch.autograd.graph.disable_saved_tensors_hooks,
+ torch.cpu.amp.autocast_mode.autocast,
+ torch.cuda.amp.autocast_mode.autocast,
+ ]
+)
-REWRITE_OPS_TO_TENSOR_SIZE_METHOD = [
- torch.onnx.operators.shape_as_tensor,
- torch._shape_as_tensor,
-]
+REWRITE_OPS_TO_TENSOR_SIZE_METHOD = dict.fromkeys(
+ [
+ torch.onnx.operators.shape_as_tensor,
+ torch._shape_as_tensor,
+ ]
+)
constant_fold_functions = [
torch._assert,
@@ -91,8 +94,6 @@
torch.promote_types,
torch._C._get_privateuse1_backend_name,
]
-
-
if torch.distributed.is_available():
constant_fold_functions.extend(
[
@@ -101,6 +102,8 @@
torch.distributed.get_world_size,
]
)
+# Convert to dict for O(1) access times
+constant_fold_functions = dict.fromkeys(constant_fold_functions)
tracing_state_functions = {
@@ -287,13 +290,11 @@
TensorVariable,
UserDefinedObjectVariable,
)
-
from .builder import wrap_fx_proxy, wrap_fx_proxy_cls
- constant_args = check_constant_args(args, kwargs)
- unspec_python_args = check_unspec_python_args(args, kwargs)
-
- if self.can_constant_fold_through() and (constant_args or unspec_python_args):
+ if self.can_constant_fold_through() and check_unspec_or_constant_args(
+ args, kwargs
+ ):
# constant fold
return ConstantVariable.create(
self.as_python_constant()(
@@ -328,7 +329,9 @@
return tx.inline_user_function_return(
SourcelessBuilder()(tx, polyfill.accumulate_grad), args, kwargs
)
- elif self.value == math.radians and not (constant_args or unspec_python_args):
+ elif self.value == math.radians and not check_unspec_or_constant_args(
+ args, kwargs
+ ):
# Use polyfill to convert math.radians(x) into math.pi * x / 180.0
from .builder import SourcelessBuilder