Replace invoking self.value if there is a user defined init, avoiding arbitrary code execution (#117818)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117818
Approved by: https://github.com/ezyang
diff --git a/test/dynamo/test_ctx_manager.py b/test/dynamo/test_ctx_manager.py
index d82524a..adb3e92 100644
--- a/test/dynamo/test_ctx_manager.py
+++ b/test/dynamo/test_ctx_manager.py
@@ -744,11 +744,23 @@
                 x = torch.relu(x)
             return x - 1
 
+        x = torch.rand(2, 3)
+        cnts = torch._dynamo.testing.CompileCounter()
+        opt_fn = torch.compile(backend=cnts, fullgraph=False)(fn)
+
         with torch.no_grad():
-            torch._dynamo.testing.standard_test(self, fn=fn, nargs=1, expected_ops=6)
+            ref = fn(x)
+            res = opt_fn(x)
+            self.assertTrue(same(ref, res))
+            self.assertEqual(cnts.frame_count, 2)
+            self.assertEqual(cnts.op_count, 2)
 
         with torch.enable_grad():
-            torch._dynamo.testing.standard_test(self, fn=fn, nargs=1, expected_ops=6)
+            ref = fn(x)
+            res = opt_fn(x)
+            self.assertTrue(same(ref, res))
+            self.assertEqual(cnts.frame_count, 4)
+            self.assertEqual(cnts.op_count, 4)
 
     def test_nested_generic_context_manager(self):
         def fn(x):
@@ -763,11 +775,23 @@
                 x = torch.relu(x)
             return x - 1
 
+        x = torch.rand(2, 3)
+        cnts = torch._dynamo.testing.CompileCounter()
+        opt_fn = torch.compile(backend=cnts, fullgraph=False)(fn)
+
         with torch.no_grad():
-            torch._dynamo.testing.standard_test(self, fn=fn, nargs=1, expected_ops=9)
+            ref = fn(x)
+            res = opt_fn(x)
+            self.assertTrue(same(ref, res))
+            self.assertEqual(cnts.frame_count, 4)
+            self.assertEqual(cnts.op_count, 4)
 
         with torch.enable_grad():
-            torch._dynamo.testing.standard_test(self, fn=fn, nargs=1, expected_ops=9)
+            ref = fn(x)
+            res = opt_fn(x)
+            self.assertTrue(same(ref, res))
+            self.assertEqual(cnts.frame_count, 6)
+            self.assertEqual(cnts.op_count, 6)
 
     def test_generic_context_manager_with_graph_break(self):
         def fn(x):
diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py
index 9ae10bf..976bceb 100644
--- a/test/dynamo/test_repros.py
+++ b/test/dynamo/test_repros.py
@@ -4000,6 +4000,73 @@
                 # frame_count should stay at 1.
                 self.assertEqual(cnt.frame_count, 1)
 
+    def test_user_ctor_ctx_manager(self):
+        class UserCtxManager:
+            def __enter__(self):
+                return 1
+
+            def __exit__(self, exc_type, exc_val, exc_tb):
+                pass
+
+        def fn(x, y):
+            ucm = UserCtxManager()
+            return x * x
+
+        cnt = torch._dynamo.testing.CompileCounter()
+        opt_fn = torch._dynamo.optimize(cnt, nopython=True)(fn)
+        x = torch.rand([2, 2])
+        opt_fn(x, x)
+        self.assertEqual(cnt.frame_count, 1)
+
+    def test_user_ctor_ctx_manager_custom_init(self):
+        class UserCtxManager:
+            def __init__(self, x):
+                x[0] = 10
+
+            def __enter__(self):
+                return 1
+
+            def __exit__(self, exc_type, exc_val, exc_tb):
+                pass
+
+        def fn(x, y):
+            ucm = UserCtxManager(y)
+            return x * y[0]
+
+        cnt = torch._dynamo.testing.CompileCounter()
+        opt_fn = torch._dynamo.optimize(cnt, nopython=True)(fn)
+        x = torch.rand([2, 2])
+        self.assertEqual(opt_fn(x, [5]), fn(x, [5]))
+        self.assertEqual(cnt.frame_count, 1)
+
+    def test_user_ctor_ctx_manager_custom_init_graph_break(self):
+        counter = [0]
+
+        class UserCtxManager:
+            def __init__(self, k):
+                k[0] += 1
+
+            def __enter__(self):
+                return 1
+
+            def __exit__(self, exc_type, exc_val, exc_tb):
+                pass
+
+        def fn(x, counter):
+            x = x * x
+            ucm = UserCtxManager(counter)
+            return x * x
+
+        cnt = torch._dynamo.testing.CompileCounter()
+        opt_fn = torch._dynamo.optimize(cnt)(fn)
+        x = torch.rand([2, 2])
+        self.assertEqual(opt_fn(x, counter), fn(x, counter))
+        self.assertEqual(counter[0], 2)
+        for i in range(0, 10):
+            opt_fn(x, counter)
+        self.assertEqual(counter[0], 12)
+        self.assertEqual(cnt.frame_count, torch._dynamo.utils.ifdynstaticdefault(3, 2))
+
 
 if __name__ == "__main__":
     from torch._dynamo.test_case import run_tests
diff --git a/test/functorch/test_eager_transforms.py b/test/functorch/test_eager_transforms.py
index e6528c4..625826c 100644
--- a/test/functorch/test_eager_transforms.py
+++ b/test/functorch/test_eager_transforms.py
@@ -2669,7 +2669,6 @@
         _, y = jvp(lambda x: jvp(f, (x,), (t,))[1], (x,), (t,))
         self.assertEqual(y, 2)
 
-    @xfailIfTorchDynamo
     def test_disable_fwd_grad_mixed(self, device):
         def f(x):
             with fwAD._set_fwd_grad_enabled(False):
diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py
index b9dd9d2..910f527 100644
--- a/torch/_dynamo/variables/dicts.py
+++ b/torch/_dynamo/variables/dicts.py
@@ -127,8 +127,14 @@
 
         def __eq__(self, other: "ConstDictVariable._HashableTracker") -> bool:
             Hashable = ConstDictVariable._HashableTracker
-            assert isinstance(other, Hashable), type(other)
-            return Hashable._eq_impl(self.underlying_value, other.underlying_value)
+            assert isinstance(other, Hashable) or ConstantVariable.is_literal(
+                other
+            ), type(other)
+            if isinstance(other, Hashable):
+                return Hashable._eq_impl(self.underlying_value, other.underlying_value)
+
+            # constant
+            return Hashable._eq_impl(self.underlying_value, other)
 
     def __init__(
         self, items: Dict[VariableTracker, VariableTracker], user_cls=dict, **kwargs
diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py
index b1bf60e..617f02f 100644
--- a/torch/_dynamo/variables/user_defined.py
+++ b/torch/_dynamo/variables/user_defined.py
@@ -285,9 +285,14 @@
             )
         elif (
             issubclass(type(self.value), type)
-            and hasattr(self.value, "__enter__")
-            and hasattr(self.value, "__exit__")
+            and hasattr(
+                self.value, "__enter__"
+            )  # TODO(voz): These can invoke user code!
+            and hasattr(
+                self.value, "__exit__"
+            )  # TODO(voz): These can invoke user code!
             and check_constant_args(args, kwargs)
+            and self.value.__init__ == object.__init__
             and len(kwargs) == 0  # TODO(ybliang): support kwargs
         ):
             unwrapped_args = [x.as_python_constant() for x in args]
@@ -295,6 +300,7 @@
                 unwrapped_args,
                 cm_obj=self.value(*unwrapped_args),
             )
+
         elif is_namedtuple_cls(self.value):
             fields = namedtuple_fields(self.value)
             field_defaults = self.value._field_defaults
diff --git a/torch/testing/_internal/dynamo_test_failures.py b/torch/testing/_internal/dynamo_test_failures.py
index e36e388..0947765 100644
--- a/torch/testing/_internal/dynamo_test_failures.py
+++ b/torch/testing/_internal/dynamo_test_failures.py
@@ -1878,9 +1878,7 @@
     "TestShapeOpsCPU.test_flip_cpu_bfloat16",  # test_shape_ops
     "TestShapeOpsCPU.test_clamp_cpu_float32",  # test_shape_ops
     "TestSubclassSerialization.test_tensor_subclass_deepcopy",  # test_serialization
-    "TestOldSerialization.test_save_different_dtype_unallocated",  # test_serialization
     "TestSubclassSerialization.test_tensor_subclass_getstate_overwrite",  # test_serialization
-    "TestSerialization.test_save_different_dtype_unallocated",  # test_serialization
     "TestSubclassSerialization.test_tensor_subclass_wrapper_serialization",  # test_serialization
     "TestScatterGatherCPU.test_scatter_reduce_sum_cpu_float32",  # test_scatter_gather_ops
     "TestScatterGatherCPU.test_scatter_reduce_mean_cpu_int16",  # test_scatter_gather_ops
@@ -1950,7 +1948,6 @@
     "TestNNDeviceTypeCPU.test_upsamplingTrilinear3d_align_corners_False_memory_format1_cpu",  # test_nn
     "TestNNDeviceTypeCPU.test_batchnorm_grad_cpu",  # test_nn
     "TestNN.test_interpolate",  # test_nn
-    "TestNN.test_register_state_dict_pre_hook",  # test_nn
     "TestNNDeviceTypeCPU.test_upsamplingTrilinear3d_align_corners_True_memory_format0_cpu",  # test_nn
     "TestNNDeviceTypeCPU.test_upsamplingTrilinear3d_align_corners_True_memory_format1_cpu",  # test_nn
     "TestNN.test_fb_fc_packed",  # test_nn
@@ -1958,7 +1955,6 @@
     "TestNNDeviceTypeCPU.test_invalid_reduction_strings_cpu",  # test_nn
     "TestNNDeviceTypeCPU.test_nll_loss_total_weight_is_zero_cpu",  # test_nn
     "TestNNDeviceTypeCPU.test_nll_loss_empty_tensor_reduction_mean_cpu",  # test_nn
-    "TestNN.test_register_state_dict_pre_hook_lazy_module",  # test_nn
     "TestNN.test_ParameterDict_replication",  # test_nn
     "TestNN.test_Sequential_iadd",  # test_nn
     "TestNN.test_upsamplingLinear1d",  # test_nn
@@ -2030,33 +2026,17 @@
     "PackedSequenceTest.test_total_length",  # nn/test_packed_sequence
     "TestModuleHooks.test_forward_pre_hooks_named_tuple_True",  # nn/test_module_hooks
     "TestModuleHooks.test_full_backward_pre_hooks_named_tuple_True",  # nn/test_module_hooks
-    "TestModuleHookNN.test_hook_submodule_registration",  # nn/test_module_hooks
     "TestModuleHooks.test_forward_hooks_named_tuple_False",  # nn/test_module_hooks
     "TestModuleHooks.test_full_backward_hooks_named_tuple_False",  # nn/test_module_hooks
     "TestModuleHooks.test_forward_hooks_named_tuple_True",  # nn/test_module_hooks
-    "TestStateDictHooks.test_pickled_hook",  # nn/test_module_hooks
     "TestModuleHookNN.test_hook_inplace",  # nn/test_module_hooks
-    "TestModuleGlobalHooks.test_module_backward_global_hook_writeable",  # nn/test_module_hooks
-    "TestModuleHookNN.test_hook_buffer_registration",  # nn/test_module_hooks
     "TestModuleHooks.test_full_backward_hooks_named_tuple_True",  # nn/test_module_hooks
-    "TestModuleHookNN.test_hook_no_requires_grad",  # nn/test_module_hooks
-    "TestModuleHookNN.test_hook_backward_writeable",  # nn/test_module_hooks
     "TestModuleHooks.test_forward_pre_hooks_named_tuple_False",  # nn/test_module_hooks
-    "TestModuleHookNN.test_hook_parameter_registration",  # nn/test_module_hooks
     "TestModuleHooks.test_full_backward_pre_hooks_named_tuple_False",  # nn/test_module_hooks
-    "TestModuleHookNN.test_hook_cpp",  # nn/test_module_hooks
-    "TestStateDictHooks.test_load_state_dict_pre_hook",  # nn/test_module_hooks
-    "TestModuleHookNN.test_hook_invalid_outputs",  # nn/test_module_hooks
-    "TestModuleHookNN.test_backward_hooks_interaction",  # nn/test_module_hooks
-    "TestModuleHookNN.test_hooks",  # nn/test_module_hooks
-    "TestModuleHookNN.test_hook_last_arg_requires_grad",  # nn/test_module_hooks
-    "TestModuleGlobalHooks.test_module_global_hook_invalid_outputs",  # nn/test_module_hooks
-    "TestLazyModules.test_lazy_module_parameter",  # nn/test_lazy_modules
     "TestLazyModules.test_lazy_batchnorm2d_state",  # nn/test_lazy_modules
     "TestLazyModules.test_lazy_conv3d",  # nn/test_lazy_modules
     "TestLazyModules.test_lazy_conv_transposed1d",  # nn/test_lazy_modules
     "TestLazyModules.test_lazy_conv2d",  # nn/test_lazy_modules
-    "TestLazyModules.test_optimizer_pass",  # nn/test_lazy_modules
     "TestLazyModules.test_lazy_instancenorm3d_state",  # nn/test_lazy_modules
     "TestLazyModules.test_lazy_batchnorm3d_state",  # nn/test_lazy_modules
     "TestLazyModules.test_lazy_conv_transpose1d_pickle",  # nn/test_lazy_modules
@@ -2072,13 +2052,10 @@
     "TestLazyModules.test_lazy_batchnorm3d",  # nn/test_lazy_modules
     "TestLazyModules.test_lazy_conv2d_pickle",  # nn/test_lazy_modules
     "TestLazyModules.test_lazy_conv1d_pickle",  # nn/test_lazy_modules
-    "TestLazyModules.test_lazy_module_jit_buffer",  # nn/test_lazy_modules
     "TestLazyModules.test_lazy_conv1d",  # nn/test_lazy_modules
     "TestLazyModules.test_linear",  # nn/test_lazy_modules
-    "TestLazyModules.test_materialize_dtype",  # nn/test_lazy_modules
     "TestLazyModules.test_lazy_module_buffer",  # nn/test_lazy_modules
     "TestLazyModules.test_lazy_batchnorm1d_state",  # nn/test_lazy_modules
-    "TestLazyModules.test_lazy_module_jit_param",  # nn/test_lazy_modules
     "TestLazyModules.test_lazy_batchnorm_with_dict_input",  # nn/test_lazy_modules
     "TestLazyModules.test_lazy_conv_transpose2d",  # nn/test_lazy_modules
     "TestLazyModules.test_lazy_conv_transpose2d_pickle",  # nn/test_lazy_modules
@@ -2937,7 +2914,6 @@
     "TestRecordFunction.test_record_function",  # profiler/test_profiler
     "TestTorchTidyProfiler.test_optimizer_parameters_adam",  # profiler/test_profiler
     "TestTorchTidyProfiler.test_tensor_properties",  # profiler/test_profiler
-    "TestProfiler.test_record_function_fast",  # profiler/test_profiler
     "TestProfiler.test_profiler_fwd_bwd_link",  # profiler/test_profiler
     "TestProfiler.test_concrete_inputs_profiling",  # profiler/test_profiler
     "TestTorchTidyProfiler.test_tensorimpl_invalidation_scalar_args",  # profiler/test_profiler
@@ -3178,7 +3154,6 @@
     "TestAutogradForwardMode.test_forward_level_cleanup",  # test_autograd
     "TestAutograd.test_gradcheck_check_forward_or_backward_only",  # test_autograd
     "TestAutogradDeviceTypeCPU.test_inplace_on_view_modify_base_cpu",  # test_autograd
-    "TestAutograd.test_full_backward_hook_double_backward",  # test_autograd
     "TestAutograd.test_gradcheck_forward_ad_batched_grad",  # test_autograd
     "TestAutograd.test_custom_function_non_tensor_inputs_outputs",  # test_autograd
     "TestNestedCheckpoint.test_nested_checkpoint_non_tensor_inputs_and_outputs_early_stop_True",  # test_autograd
@@ -3234,9 +3209,7 @@
     "TestAutogradInferenceMode.test_inference_mode_inf_tensor_in_inf_mode_functional_op",  # test_autograd
     "TestAutogradInferenceMode.test_inference_mode_inf_tensor_in_normal_mode_functional_op",  # test_autograd
     "TestAutogradInferenceMode.test_inference_mode_inf_tensor_in_inf_mode_inplace_op",  # test_autograd
-    "TestMultithreadAutograd.test_set_multithreading_enabled_as_context_manager_and_function",  # test_autograd
     "TestAutogradDeviceTypeCPU.test_scatter_index_reduce_prod_gradgrad_error_cpu",  # test_autograd
-    "TestAutograd.test_current_graph_task_execution_order",  # test_autograd
     "TestAutograd.test_nested_anomaly_detect_nan",  # test_autograd
     "TestAutograd.test_nested_anomaly_printstack_cleanup",  # test_autograd
     "TestAutograd.test_post_accumulate_grad_hook_gets_cleaned_up",  # test_autograd
@@ -3687,7 +3660,6 @@
     "TestCustomOp.test_impl_meta",  # test_custom_ops
     "TestCustomOp.test_impl_invalid_devices",  # test_custom_ops
     "TestCustomOp.test_new_data_dependent_symint",  # test_custom_ops
-    "TestCustomOpTestingCPU.test_missing_abstract_impl_cpu",  # test_custom_ops
     "TestCustomOp.test_define_with_tags_list",  # test_custom_ops
     "TestCustomOp.test_backward_tensorlist_input_requires_list_grads",  # test_custom_ops
     "TestCustomOp.test_not_implemented_error",  # test_custom_ops
@@ -3701,7 +3673,6 @@
     "TestCustomOp.test_impl_device_function",  # test_custom_ops
     "TestCustomOp.test_builtin_torchscript_ops",  # test_custom_ops
     "TestCustomOpTestingCPU.test_missing_functionalization_cpu",  # test_custom_ops
-    "TestCustomOpTestingCPU.test_incorrect_schema_view_cpu",  # test_custom_ops
     "TestCustomOp.test_define_with_tags_tuple",  # test_custom_ops
     "TestCustomOp.test_builtin_aten_ops_are_pt2_compliant",  # test_custom_ops
     "TestCustomOp.test_save_for_backward_inputs_are_namedtuple",  # test_custom_ops
@@ -3710,7 +3681,6 @@
     "TestCustomOp.test_backward_dict_invalid_keys",  # test_custom_ops
     "TestCustomOp.test_backward_tensorlist_input_requires_list_grads_with_same_numel",  # test_custom_ops
     "TestCustomOp.test_duplicate_impl",  # test_custom_ops
-    "TestCustomOpTestingCPU.test_incorrect_abstract_impl_cpu",  # test_custom_ops
     "TestCustomOp.test_backward_output_differentiability_numel",  # test_custom_ops
     "TestCustomOp.test_backward_dict_requires_keys_for_input_tensors",  # test_custom_ops
     "TestCustomOp.test_legacy_define",  # test_custom_ops