Add warning for legacy autograd function (#22922)
Summary:
When working on https://github.com/pytorch/pytorch/pull/22762, we discovered that we haven't actually deprecated legacy autograd function. This PR puts up the deprecation warning for 1.2, with the goal to remove legacy function support completely in the near future.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/22922
Differential Revision: D16363916
Pulled By: yf225
fbshipit-source-id: 4b554010a3d1f87a3fa45cc1aa29d019c8f1033c
diff --git a/test/test_autograd.py b/test/test_autograd.py
index 74ec2f2..01ee7f7 100644
--- a/test/test_autograd.py
+++ b/test/test_autograd.py
@@ -167,6 +167,26 @@
MyFunction()(y).sum().backward()
self.assertEqual(v.grad.data, torch.zeros(shape))
+ def test_legacy_function_deprecation_warning(self):
+ with warnings.catch_warnings(record=True) as w:
+ # Ensure warnings are being shown
+ warnings.simplefilter("always")
+
+ # Trigger Warning
+ class MyFunction(Function):
+ def forward(self, x):
+ return x
+
+ def backward(self, grad_output):
+ return grad_output
+
+ MyFunction()(torch.randn(3, 4))
+
+ # Check warning occurs
+ self.assertIn(
+ 'Legacy autograd function with non-static forward method is deprecated',
+ str(w[0]))
+
def test_invalid_gradients(self):
class MyFunction(Function):
@staticmethod
diff --git a/torch/csrc/autograd/python_function.cpp b/torch/csrc/autograd/python_function.cpp
index e21ab3f..8abfd46 100644
--- a/torch/csrc/autograd/python_function.cpp
+++ b/torch/csrc/autograd/python_function.cpp
@@ -631,6 +631,10 @@
std::vector<c10::IValue>(),
autograd::Function::peek_at_next_sequence_nr());
+ TORCH_WARN("Legacy autograd function with non-static forward method is deprecated and will be removed in 1.3. ",
+ "Please use new-style autograd function with static forward method. ",
+ "(Example: https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function)");
+
auto info_pair = unpack_input<true>(_inputs);
auto& unpacked_input = info_pair.first;
auto& input_info = info_pair.second;