[compiled autograd] match eager behavior for ctx.saved_variables (#134286)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134286
Approved by: https://github.com/jansel
ghstack dependencies: #134186, #134200, #134205
diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py
index 09c8e12..6a657b0 100644
--- a/test/inductor/test_compiled_autograd.py
+++ b/test/inductor/test_compiled_autograd.py
@@ -2476,6 +2476,13 @@
"test_accumulate_grad_posthooks_can_observe_tensor_prehook", # allclose
"test_save_tensor_hook_version_counter_not_shared", # assertEqual
"test_post_accumulate_grad_hook_returns_not_None", # throws
+ "test_custom_function_cycle", # assertEqual
+ "test_mark_non_differentiable_mixed", # assertTrue
+ "test_materialize_grads", # assertEqual
+ "test_return_leaf", # assertEqual
+ "test_save_none_for_backward", # assertIsNone
+ "test_saved_variables_deprecated", # warnings.warn
+ "test_autograd_node_isinstance", # assertIsInstance
}
test_contexts = {
@@ -2496,6 +2503,7 @@
# Running these tests succeed, but somehow cause other tests to fail
"test_saved_tensor_hooks_extra_exit_during_bw_no_crash",
"test_saved_tensor_hooks_extra_enter_during_bw_no_leak",
+ "test_callback_propagates_errors_from_device_thread", # fullgraph for queue_callback, but graph break for RuntimeError
}
known_failing_tests = {
@@ -2543,10 +2551,34 @@
"test_prehook_ordering", # retains_grad_hooks
"test_retain_grad", # retains_grad_hooks
"test_saved_variable_packing_unpacking_saved_original_with_hooks", # create_graph
+ "test_select_sum", # create_graph, also needs graph breaks
+ "test_will_engine_execute_node", # retains_grad_hooks
+ "test_backward_to_node", # retains_grad_hooks NYI
+ "test_anomaly_detect_nan", # anomaly mode
+ "test_custom_autograd_no_early_free", # create_graph
+ "test_custom_function_error", # vjp
+ "test_custom_function_save_for_forward", # vjp
+ "test_deep_reentrant", # hangs with graph breaks
+ "test_dont_materialize_grads", # undefined grad
+ "test_grad_mode_restored_reentrant", # hangs with graph breaks
+ "test_no_grad_copy", # setting static member in lifted backward
+ "test_no_grad_copy_sparse", # setting static member in lifted backward
+ "test_reentrant_priority", # hangs with graph breaks
+ "test_reentrant_with_callbacks_both_depths", # hangs with graph breaks
+ "test_reentrant_with_callbacks_depth_0", # probably hangs with graph breaks
+ "test_reentrant_with_callbacks_depth_1", # probably hangs with graph breaks
+ "test_save_output_nr", # output_nr grad passed as None
# Category: Dynamo
"test_accumulate_grad_tensor_reference", # Out of bounds: frame_state_entry.stride[i] is None
"test_custom_function_exception", # torch.no_grad(), torch._dynamo.exc.Unsupported: missing: WITH_EXCEPT_START
"test_to_sparse_backward", # Out of bounds: frame_state_entry.stride[i] is None
+ "test_autograd_simple_views_python", # gradient is None
+ "test_function_returns_undefined_tensor", # gradient is None
+ "test_naughty_autograd_function_stashing_ctx", # bytecode issue
+ "test_unrelated_inputs", # gradient batching rule not implemented for aten::sym_size.int
+ "test_custom_function_non_tensor_inputs_outputs", # gradient batching rule not implemented for aten::sym_size.int
+ "test_return_duplicate", # gradient batching rule not implemented for aten::sym_size.int
+ "test_return_duplicate_inplace", # gradient batching rule not implemented for aten::sym_size.int
# Category: Inductor
"test_input_buffer_accum", # does not support sparse_grad=True: https://github.com/pytorch/pytorch/issues/120267
"test_graph_save_on_cpu", # does not support pin_memory: https://github.com/pytorch/pytorch/issues/134173
@@ -2554,43 +2586,10 @@
"test_saving_variable_to_disk", # torch.save should no-op and be recorded in the graph
"test_wrapped_number_saved_variable_hooks", # Proxy tensor should carryover is_wrapped_number_ of its original
"test_grad_batched_grad", # torch._subclasses.fake_tensor.UnsupportedFakeTensorException: meta converter nyi
+ # Category: Divergence from eager
+ "test_invalid_gradients", # can't give autograd error due to inaccurate output metadata of lifted backward
+ "test_autograd_node_isinstance", # backward ctx is a fake cls and not directly a Node instance
# Uncategorized
- "test_select_sum", # torch.autograd.gradcheck.GradcheckError: While computing batched gradients
- "test_unrelated_inputs", # torch.autograd.gradcheck.GradcheckError: While computing batched gradients
- "test_will_engine_execute_node", # retains_grad_hooks NYI
- "test_backward_to_node", # retains_grad_hooks NYI
- "test_anomaly_detect_nan", # torch._dynamo.exc.TorchRuntimeError: Failed running call_function aten.add.Tensor(
- "test_autograd_multiple_views_python", # torch._dynamo.exc.Unsupported: call_function args: TensorVariable(
- "test_autograd_node_isinstance", # torch._dynamo.exc.Unsupported: 'inline in skipfiles: TestCase.assertIsInstance
- "test_autograd_simple_views_python", # torch._dynamo.exc.TorchRuntimeError: Failed running call_function
- "test_callback_propagates_errors_from_device_thread", # AssertionError: "blah" does not match "call_method
- "test_custom_autograd_no_early_free", # torch.autograd.gradcheck.GradcheckError: While computing batched gradients
- "test_custom_function_cycle", # torch._dynamo.exc.Unsupported: call_function UserDefinedClassVariable() [] {}
- "test_custom_function_error", # AssertionError: "must implement either the backward" does not match "call_function
- "test_custom_function_non_tensor_inputs_outputs", # torch._dynamo.exc.Unsupported: call_function
- "test_custom_function_save_for_forward", # torch._dynamo.exc.Unsupported: call_function
- "test_custom_function_setup_context_multi_input", # torch._dynamo.exc.Unsupported: call_function args
- "test_custom_function_setup_context_multi_output", # torch._dynamo.exc.Unsupported: call_function args
- "test_deep_reentrant", # torch._dynamo.exc.InternalTorchDynamoError: '<' not supported between instances of
- "test_dont_materialize_grads", # torch._dynamo.exc.Unsupported: 'inline in skipfiles: TestCase.assertIsNone
- "test_function_returns_undefined_tensor", # torch._dynamo.exc.TorchRuntimeError: Failed running call_function
- "test_grad_mode_restored_reentrant", # torch._dynamo.exc.Unsupported: 'inline in skipfiles: TestCase.assertTrue
- "test_invalid_gradients", # AssertionError: "expected shape" does not match "The size of tensor a (5) must match
- "test_mark_non_differentiable_mixed", # torch._dynamo.exc.Unsupported: 'inline in skipfiles: TestCase.assertTrue
- "test_materialize_grads", # torch._dynamo.exc.Unsupported: call_function UserDefinedClassVariable() [] {}
- "test_naughty_autograd_function_stashing_ctx", # torch._dynamo.exc.TorchRuntimeError: Failed running call_function
- "test_no_grad_copy", # torch._dynamo.exc.Unsupported: call_function args: TensorVariable() SkipFunctionVariable()
- "test_no_grad_copy_sparse", # torch._dynamo.exc.Unsupported: Tensor.data_ptr
- "test_reentrant_priority", # torch._dynamo.exc.InternalTorchDynamoError: '<' not supported between instances of
- "test_reentrant_with_callbacks_both_depths", # torch._dynamo.exc.Unsupported: call_method UserDefinedObjectVariable
- "test_reentrant_with_callbacks_depth_0", # torch._dynamo.exc.Unsupported: call_method UserDefinedObjectVariable
- "test_reentrant_with_callbacks_depth_1", # torch._dynamo.exc.Unsupported: Tensor.requires_grad_
- "test_return_duplicate", # torch.autograd.gradcheck.GradcheckError: While computing batched gradients
- "test_return_duplicate_inplace", # torch.autograd.gradcheck.GradcheckError: While computing batched gradients
- "test_return_leaf", # torch._dynamo.exc.Unsupported: call_function UserDefinedClassVariable() [] {}
- "test_save_none_for_backward", # AssertionError:
- "test_save_output_nr", # torch._dynamo.exc.Unsupported: call_function UserDefinedClassVariable() [] {}
- "test_saved_variables_deprecated", # torch._dynamo.exc.Unsupported: UNPACK_SEQUENCE SkipFunctionVariable()
"test_set_materialize_non_diff_grads", # torch._dynamo.exc.Unsupported: 'inline in skipfiles: TestCase.assertIsNone
"test_setup_context_when_forward_has_default_args", # torch._dynamo.exc.Unsupported: call_function args
"test_simple_reentrant", # torch._dynamo.exc.Unsupported: call_method SkipFunctionVariable() sum [] {}
diff --git a/torch/_dynamo/external_utils.py b/torch/_dynamo/external_utils.py
index 7ecea5d..91663b4 100644
--- a/torch/_dynamo/external_utils.py
+++ b/torch/_dynamo/external_utils.py
@@ -2,6 +2,7 @@
# This module contains functions that *will be allowed* by dynamo
import functools
+import warnings
from typing import List
import torch
@@ -81,6 +82,13 @@
self.saved_tensors = saved_tensors
def __getattr__(self, name):
+ if name == "saved_variables":
+ warnings.warn(
+ "'saved_variables' is deprecated; use 'saved_tensors'",
+ DeprecationWarning,
+ )
+ return self.saved_tensors
+
# route any attribute that isn't defined on this obj
return getattr(self.real, name)