[checkpoint] Improve error message when use_reentrant=True is used with .grad() (#125155)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125155
Approved by: https://github.com/albanD
diff --git a/test/test_autograd.py b/test/test_autograd.py
index 80af13d..79880a1 100644
--- a/test/test_autograd.py
+++ b/test/test_autograd.py
@@ -6916,7 +6916,7 @@
a = torch.randn(2, 2, requires_grad=True)
with self.assertRaisesRegex(
- Exception, "Checkpointing is not compatible with .grad()"
+ Exception, "torch.utils.checkpoint is incompatible"
):
b = checkpoint(torch.exp, a, use_reentrant=True).sum()
torch.autograd.grad(b, (a,))
diff --git a/test/test_utils.py b/test/test_utils.py
index 5dd946f..b151b51 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -147,7 +147,7 @@
chunks = 2
modules = list(model.children())
out = checkpoint_sequential(modules, chunks, input_var, use_reentrant=True)
- with self.assertRaisesRegex(RuntimeError, "Checkpointing is not compatible"):
+ with self.assertRaisesRegex(RuntimeError, "torch.utils.checkpoint is incompatible"):
torch.autograd.grad(
outputs=[out], grad_outputs=[torch.ones(1, 5)], inputs=[input_var], create_graph=True
)
diff --git a/torch/utils/checkpoint.py b/torch/utils/checkpoint.py
index 7c74b4c..5cdfc55 100644
--- a/torch/utils/checkpoint.py
+++ b/torch/utils/checkpoint.py
@@ -258,9 +258,10 @@
def backward(ctx, *args):
if not torch.autograd._is_checkpoint_valid():
raise RuntimeError(
- "Checkpointing is not compatible with .grad() or when an `inputs` parameter"
- " is passed to .backward(). Please use .backward() and do not pass its `inputs`"
- " argument."
+ "When use_reentrant=True, torch.utils.checkpoint is incompatible"
+ " with .grad() or passing an `inputs` parameter to .backward()."
+ " To resolve this error, you can either set use_reentrant=False,"
+ " or call .backward() without passing the `inputs` argument."
)
# Copy the list to avoid modifying original list.
inputs = list(ctx.inputs)