[checkpoint] Improve error message when use_reentrant=True is used with .grad() (#125155)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125155
Approved by: https://github.com/albanD
diff --git a/test/test_autograd.py b/test/test_autograd.py
index 80af13d..79880a1 100644
--- a/test/test_autograd.py
+++ b/test/test_autograd.py
@@ -6916,7 +6916,7 @@
         a = torch.randn(2, 2, requires_grad=True)
 
         with self.assertRaisesRegex(
-            Exception, "Checkpointing is not compatible with .grad()"
+            Exception, "torch.utils.checkpoint is incompatible"
         ):
             b = checkpoint(torch.exp, a, use_reentrant=True).sum()
             torch.autograd.grad(b, (a,))
diff --git a/test/test_utils.py b/test/test_utils.py
index 5dd946f..b151b51 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -147,7 +147,7 @@
         chunks = 2
         modules = list(model.children())
         out = checkpoint_sequential(modules, chunks, input_var, use_reentrant=True)
-        with self.assertRaisesRegex(RuntimeError, "Checkpointing is not compatible"):
+        with self.assertRaisesRegex(RuntimeError, "torch.utils.checkpoint is incompatible"):
             torch.autograd.grad(
                 outputs=[out], grad_outputs=[torch.ones(1, 5)], inputs=[input_var], create_graph=True
             )
diff --git a/torch/utils/checkpoint.py b/torch/utils/checkpoint.py
index 7c74b4c..5cdfc55 100644
--- a/torch/utils/checkpoint.py
+++ b/torch/utils/checkpoint.py
@@ -258,9 +258,10 @@
     def backward(ctx, *args):
         if not torch.autograd._is_checkpoint_valid():
             raise RuntimeError(
-                "Checkpointing is not compatible with .grad() or when an `inputs` parameter"
-                " is passed to .backward(). Please use .backward() and do not pass its `inputs`"
-                " argument."
+                "When use_reentrant=True, torch.utils.checkpoint is incompatible"
+                " with .grad() or passing an `inputs` parameter to .backward()."
+                " To resolve this error, you can either set use_reentrant=False,"
+                " or call .backward() without passing the `inputs` argument."
             )
         # Copy the list to avoid modifying original list.
         inputs = list(ctx.inputs)