Replace invoking self.value if there is a user defined init, avoiding arbitrary code execution (#117818)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117818
Approved by: https://github.com/ezyang
diff --git a/test/dynamo/test_ctx_manager.py b/test/dynamo/test_ctx_manager.py
index d82524a..adb3e92 100644
--- a/test/dynamo/test_ctx_manager.py
+++ b/test/dynamo/test_ctx_manager.py
@@ -744,11 +744,23 @@
x = torch.relu(x)
return x - 1
+ x = torch.rand(2, 3)
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch.compile(backend=cnts, fullgraph=False)(fn)
+
with torch.no_grad():
- torch._dynamo.testing.standard_test(self, fn=fn, nargs=1, expected_ops=6)
+ ref = fn(x)
+ res = opt_fn(x)
+ self.assertTrue(same(ref, res))
+ self.assertEqual(cnts.frame_count, 2)
+ self.assertEqual(cnts.op_count, 2)
with torch.enable_grad():
- torch._dynamo.testing.standard_test(self, fn=fn, nargs=1, expected_ops=6)
+ ref = fn(x)
+ res = opt_fn(x)
+ self.assertTrue(same(ref, res))
+ self.assertEqual(cnts.frame_count, 4)
+ self.assertEqual(cnts.op_count, 4)
def test_nested_generic_context_manager(self):
def fn(x):
@@ -763,11 +775,23 @@
x = torch.relu(x)
return x - 1
+ x = torch.rand(2, 3)
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch.compile(backend=cnts, fullgraph=False)(fn)
+
with torch.no_grad():
- torch._dynamo.testing.standard_test(self, fn=fn, nargs=1, expected_ops=9)
+ ref = fn(x)
+ res = opt_fn(x)
+ self.assertTrue(same(ref, res))
+ self.assertEqual(cnts.frame_count, 4)
+ self.assertEqual(cnts.op_count, 4)
with torch.enable_grad():
- torch._dynamo.testing.standard_test(self, fn=fn, nargs=1, expected_ops=9)
+ ref = fn(x)
+ res = opt_fn(x)
+ self.assertTrue(same(ref, res))
+ self.assertEqual(cnts.frame_count, 6)
+ self.assertEqual(cnts.op_count, 6)
def test_generic_context_manager_with_graph_break(self):
def fn(x):
diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py
index 9ae10bf..976bceb 100644
--- a/test/dynamo/test_repros.py
+++ b/test/dynamo/test_repros.py
@@ -4000,6 +4000,73 @@
# frame_count should stay at 1.
self.assertEqual(cnt.frame_count, 1)
+ def test_user_ctor_ctx_manager(self):
+ class UserCtxManager:
+ def __enter__(self):
+ return 1
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ pass
+
+ def fn(x, y):
+ ucm = UserCtxManager()
+ return x * x
+
+ cnt = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnt, nopython=True)(fn)
+ x = torch.rand([2, 2])
+ opt_fn(x, x)
+ self.assertEqual(cnt.frame_count, 1)
+
+ def test_user_ctor_ctx_manager_custom_init(self):
+ class UserCtxManager:
+ def __init__(self, x):
+ x[0] = 10
+
+ def __enter__(self):
+ return 1
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ pass
+
+ def fn(x, y):
+ ucm = UserCtxManager(y)
+ return x * y[0]
+
+ cnt = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnt, nopython=True)(fn)
+ x = torch.rand([2, 2])
+ self.assertEqual(opt_fn(x, [5]), fn(x, [5]))
+ self.assertEqual(cnt.frame_count, 1)
+
+ def test_user_ctor_ctx_manager_custom_init_graph_break(self):
+ counter = [0]
+
+ class UserCtxManager:
+ def __init__(self, k):
+ k[0] += 1
+
+ def __enter__(self):
+ return 1
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ pass
+
+ def fn(x, counter):
+ x = x * x
+ ucm = UserCtxManager(counter)
+ return x * x
+
+ cnt = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnt)(fn)
+ x = torch.rand([2, 2])
+ self.assertEqual(opt_fn(x, counter), fn(x, counter))
+ self.assertEqual(counter[0], 2)
+ for i in range(0, 10):
+ opt_fn(x, counter)
+ self.assertEqual(counter[0], 12)
+ self.assertEqual(cnt.frame_count, torch._dynamo.utils.ifdynstaticdefault(3, 2))
+
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
diff --git a/test/functorch/test_eager_transforms.py b/test/functorch/test_eager_transforms.py
index e6528c4..625826c 100644
--- a/test/functorch/test_eager_transforms.py
+++ b/test/functorch/test_eager_transforms.py
@@ -2669,7 +2669,6 @@
_, y = jvp(lambda x: jvp(f, (x,), (t,))[1], (x,), (t,))
self.assertEqual(y, 2)
- @xfailIfTorchDynamo
def test_disable_fwd_grad_mixed(self, device):
def f(x):
with fwAD._set_fwd_grad_enabled(False):
diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py
index b9dd9d2..910f527 100644
--- a/torch/_dynamo/variables/dicts.py
+++ b/torch/_dynamo/variables/dicts.py
@@ -127,8 +127,14 @@
def __eq__(self, other: "ConstDictVariable._HashableTracker") -> bool:
Hashable = ConstDictVariable._HashableTracker
- assert isinstance(other, Hashable), type(other)
- return Hashable._eq_impl(self.underlying_value, other.underlying_value)
+ assert isinstance(other, Hashable) or ConstantVariable.is_literal(
+ other
+ ), type(other)
+ if isinstance(other, Hashable):
+ return Hashable._eq_impl(self.underlying_value, other.underlying_value)
+
+ # constant
+ return Hashable._eq_impl(self.underlying_value, other)
def __init__(
self, items: Dict[VariableTracker, VariableTracker], user_cls=dict, **kwargs
diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py
index b1bf60e..617f02f 100644
--- a/torch/_dynamo/variables/user_defined.py
+++ b/torch/_dynamo/variables/user_defined.py
@@ -285,9 +285,14 @@
)
elif (
issubclass(type(self.value), type)
- and hasattr(self.value, "__enter__")
- and hasattr(self.value, "__exit__")
+ and hasattr(
+ self.value, "__enter__"
+ ) # TODO(voz): These can invoke user code!
+ and hasattr(
+ self.value, "__exit__"
+ ) # TODO(voz): These can invoke user code!
and check_constant_args(args, kwargs)
+ and self.value.__init__ == object.__init__
and len(kwargs) == 0 # TODO(ybliang): support kwargs
):
unwrapped_args = [x.as_python_constant() for x in args]
@@ -295,6 +300,7 @@
unwrapped_args,
cm_obj=self.value(*unwrapped_args),
)
+
elif is_namedtuple_cls(self.value):
fields = namedtuple_fields(self.value)
field_defaults = self.value._field_defaults
diff --git a/torch/testing/_internal/dynamo_test_failures.py b/torch/testing/_internal/dynamo_test_failures.py
index e36e388..0947765 100644
--- a/torch/testing/_internal/dynamo_test_failures.py
+++ b/torch/testing/_internal/dynamo_test_failures.py
@@ -1878,9 +1878,7 @@
"TestShapeOpsCPU.test_flip_cpu_bfloat16", # test_shape_ops
"TestShapeOpsCPU.test_clamp_cpu_float32", # test_shape_ops
"TestSubclassSerialization.test_tensor_subclass_deepcopy", # test_serialization
- "TestOldSerialization.test_save_different_dtype_unallocated", # test_serialization
"TestSubclassSerialization.test_tensor_subclass_getstate_overwrite", # test_serialization
- "TestSerialization.test_save_different_dtype_unallocated", # test_serialization
"TestSubclassSerialization.test_tensor_subclass_wrapper_serialization", # test_serialization
"TestScatterGatherCPU.test_scatter_reduce_sum_cpu_float32", # test_scatter_gather_ops
"TestScatterGatherCPU.test_scatter_reduce_mean_cpu_int16", # test_scatter_gather_ops
@@ -1950,7 +1948,6 @@
"TestNNDeviceTypeCPU.test_upsamplingTrilinear3d_align_corners_False_memory_format1_cpu", # test_nn
"TestNNDeviceTypeCPU.test_batchnorm_grad_cpu", # test_nn
"TestNN.test_interpolate", # test_nn
- "TestNN.test_register_state_dict_pre_hook", # test_nn
"TestNNDeviceTypeCPU.test_upsamplingTrilinear3d_align_corners_True_memory_format0_cpu", # test_nn
"TestNNDeviceTypeCPU.test_upsamplingTrilinear3d_align_corners_True_memory_format1_cpu", # test_nn
"TestNN.test_fb_fc_packed", # test_nn
@@ -1958,7 +1955,6 @@
"TestNNDeviceTypeCPU.test_invalid_reduction_strings_cpu", # test_nn
"TestNNDeviceTypeCPU.test_nll_loss_total_weight_is_zero_cpu", # test_nn
"TestNNDeviceTypeCPU.test_nll_loss_empty_tensor_reduction_mean_cpu", # test_nn
- "TestNN.test_register_state_dict_pre_hook_lazy_module", # test_nn
"TestNN.test_ParameterDict_replication", # test_nn
"TestNN.test_Sequential_iadd", # test_nn
"TestNN.test_upsamplingLinear1d", # test_nn
@@ -2030,33 +2026,17 @@
"PackedSequenceTest.test_total_length", # nn/test_packed_sequence
"TestModuleHooks.test_forward_pre_hooks_named_tuple_True", # nn/test_module_hooks
"TestModuleHooks.test_full_backward_pre_hooks_named_tuple_True", # nn/test_module_hooks
- "TestModuleHookNN.test_hook_submodule_registration", # nn/test_module_hooks
"TestModuleHooks.test_forward_hooks_named_tuple_False", # nn/test_module_hooks
"TestModuleHooks.test_full_backward_hooks_named_tuple_False", # nn/test_module_hooks
"TestModuleHooks.test_forward_hooks_named_tuple_True", # nn/test_module_hooks
- "TestStateDictHooks.test_pickled_hook", # nn/test_module_hooks
"TestModuleHookNN.test_hook_inplace", # nn/test_module_hooks
- "TestModuleGlobalHooks.test_module_backward_global_hook_writeable", # nn/test_module_hooks
- "TestModuleHookNN.test_hook_buffer_registration", # nn/test_module_hooks
"TestModuleHooks.test_full_backward_hooks_named_tuple_True", # nn/test_module_hooks
- "TestModuleHookNN.test_hook_no_requires_grad", # nn/test_module_hooks
- "TestModuleHookNN.test_hook_backward_writeable", # nn/test_module_hooks
"TestModuleHooks.test_forward_pre_hooks_named_tuple_False", # nn/test_module_hooks
- "TestModuleHookNN.test_hook_parameter_registration", # nn/test_module_hooks
"TestModuleHooks.test_full_backward_pre_hooks_named_tuple_False", # nn/test_module_hooks
- "TestModuleHookNN.test_hook_cpp", # nn/test_module_hooks
- "TestStateDictHooks.test_load_state_dict_pre_hook", # nn/test_module_hooks
- "TestModuleHookNN.test_hook_invalid_outputs", # nn/test_module_hooks
- "TestModuleHookNN.test_backward_hooks_interaction", # nn/test_module_hooks
- "TestModuleHookNN.test_hooks", # nn/test_module_hooks
- "TestModuleHookNN.test_hook_last_arg_requires_grad", # nn/test_module_hooks
- "TestModuleGlobalHooks.test_module_global_hook_invalid_outputs", # nn/test_module_hooks
- "TestLazyModules.test_lazy_module_parameter", # nn/test_lazy_modules
"TestLazyModules.test_lazy_batchnorm2d_state", # nn/test_lazy_modules
"TestLazyModules.test_lazy_conv3d", # nn/test_lazy_modules
"TestLazyModules.test_lazy_conv_transposed1d", # nn/test_lazy_modules
"TestLazyModules.test_lazy_conv2d", # nn/test_lazy_modules
- "TestLazyModules.test_optimizer_pass", # nn/test_lazy_modules
"TestLazyModules.test_lazy_instancenorm3d_state", # nn/test_lazy_modules
"TestLazyModules.test_lazy_batchnorm3d_state", # nn/test_lazy_modules
"TestLazyModules.test_lazy_conv_transpose1d_pickle", # nn/test_lazy_modules
@@ -2072,13 +2052,10 @@
"TestLazyModules.test_lazy_batchnorm3d", # nn/test_lazy_modules
"TestLazyModules.test_lazy_conv2d_pickle", # nn/test_lazy_modules
"TestLazyModules.test_lazy_conv1d_pickle", # nn/test_lazy_modules
- "TestLazyModules.test_lazy_module_jit_buffer", # nn/test_lazy_modules
"TestLazyModules.test_lazy_conv1d", # nn/test_lazy_modules
"TestLazyModules.test_linear", # nn/test_lazy_modules
- "TestLazyModules.test_materialize_dtype", # nn/test_lazy_modules
"TestLazyModules.test_lazy_module_buffer", # nn/test_lazy_modules
"TestLazyModules.test_lazy_batchnorm1d_state", # nn/test_lazy_modules
- "TestLazyModules.test_lazy_module_jit_param", # nn/test_lazy_modules
"TestLazyModules.test_lazy_batchnorm_with_dict_input", # nn/test_lazy_modules
"TestLazyModules.test_lazy_conv_transpose2d", # nn/test_lazy_modules
"TestLazyModules.test_lazy_conv_transpose2d_pickle", # nn/test_lazy_modules
@@ -2937,7 +2914,6 @@
"TestRecordFunction.test_record_function", # profiler/test_profiler
"TestTorchTidyProfiler.test_optimizer_parameters_adam", # profiler/test_profiler
"TestTorchTidyProfiler.test_tensor_properties", # profiler/test_profiler
- "TestProfiler.test_record_function_fast", # profiler/test_profiler
"TestProfiler.test_profiler_fwd_bwd_link", # profiler/test_profiler
"TestProfiler.test_concrete_inputs_profiling", # profiler/test_profiler
"TestTorchTidyProfiler.test_tensorimpl_invalidation_scalar_args", # profiler/test_profiler
@@ -3178,7 +3154,6 @@
"TestAutogradForwardMode.test_forward_level_cleanup", # test_autograd
"TestAutograd.test_gradcheck_check_forward_or_backward_only", # test_autograd
"TestAutogradDeviceTypeCPU.test_inplace_on_view_modify_base_cpu", # test_autograd
- "TestAutograd.test_full_backward_hook_double_backward", # test_autograd
"TestAutograd.test_gradcheck_forward_ad_batched_grad", # test_autograd
"TestAutograd.test_custom_function_non_tensor_inputs_outputs", # test_autograd
"TestNestedCheckpoint.test_nested_checkpoint_non_tensor_inputs_and_outputs_early_stop_True", # test_autograd
@@ -3234,9 +3209,7 @@
"TestAutogradInferenceMode.test_inference_mode_inf_tensor_in_inf_mode_functional_op", # test_autograd
"TestAutogradInferenceMode.test_inference_mode_inf_tensor_in_normal_mode_functional_op", # test_autograd
"TestAutogradInferenceMode.test_inference_mode_inf_tensor_in_inf_mode_inplace_op", # test_autograd
- "TestMultithreadAutograd.test_set_multithreading_enabled_as_context_manager_and_function", # test_autograd
"TestAutogradDeviceTypeCPU.test_scatter_index_reduce_prod_gradgrad_error_cpu", # test_autograd
- "TestAutograd.test_current_graph_task_execution_order", # test_autograd
"TestAutograd.test_nested_anomaly_detect_nan", # test_autograd
"TestAutograd.test_nested_anomaly_printstack_cleanup", # test_autograd
"TestAutograd.test_post_accumulate_grad_hook_gets_cleaned_up", # test_autograd
@@ -3687,7 +3660,6 @@
"TestCustomOp.test_impl_meta", # test_custom_ops
"TestCustomOp.test_impl_invalid_devices", # test_custom_ops
"TestCustomOp.test_new_data_dependent_symint", # test_custom_ops
- "TestCustomOpTestingCPU.test_missing_abstract_impl_cpu", # test_custom_ops
"TestCustomOp.test_define_with_tags_list", # test_custom_ops
"TestCustomOp.test_backward_tensorlist_input_requires_list_grads", # test_custom_ops
"TestCustomOp.test_not_implemented_error", # test_custom_ops
@@ -3701,7 +3673,6 @@
"TestCustomOp.test_impl_device_function", # test_custom_ops
"TestCustomOp.test_builtin_torchscript_ops", # test_custom_ops
"TestCustomOpTestingCPU.test_missing_functionalization_cpu", # test_custom_ops
- "TestCustomOpTestingCPU.test_incorrect_schema_view_cpu", # test_custom_ops
"TestCustomOp.test_define_with_tags_tuple", # test_custom_ops
"TestCustomOp.test_builtin_aten_ops_are_pt2_compliant", # test_custom_ops
"TestCustomOp.test_save_for_backward_inputs_are_namedtuple", # test_custom_ops
@@ -3710,7 +3681,6 @@
"TestCustomOp.test_backward_dict_invalid_keys", # test_custom_ops
"TestCustomOp.test_backward_tensorlist_input_requires_list_grads_with_same_numel", # test_custom_ops
"TestCustomOp.test_duplicate_impl", # test_custom_ops
- "TestCustomOpTestingCPU.test_incorrect_abstract_impl_cpu", # test_custom_ops
"TestCustomOp.test_backward_output_differentiability_numel", # test_custom_ops
"TestCustomOp.test_backward_dict_requires_keys_for_input_tensors", # test_custom_ops
"TestCustomOp.test_legacy_define", # test_custom_ops