warn about orphan StopGradient output

Summary: Quite common confusion is how to use StopGradient, and typical bug is to forget to specify input=output. This adds a sanity check to gradient builder that checks if some StopGradient outputs are orphaned.

Reviewed By: dzhulgakov

Differential Revision: D5458341

fbshipit-source-id: 056fef4f0ee53eb10e66e9be0ecb55b55f9cc3d7
diff --git a/caffe2/python/core.py b/caffe2/python/core.py
index 3b9ad2e..f410b7e 100644
--- a/caffe2/python/core.py
+++ b/caffe2/python/core.py
@@ -425,6 +425,18 @@
         for op in operators:
             self.Play(op)
 
+        self.SanityCheck(operators)
+
+    def SanityCheck(self, operators):
+        # Validate StopGradient usage by checking that StopGradient's output
+        # is actually passed forward
+        for op in operators:
+            if op.type == 'StopGradient':
+                if op.output[0] not in self.input_usages:
+                    raise Exception("""StopGradient's output '{}' is orphan.
+You typically want to specify same input and output for
+StopGradient. Op:\n\n{}""".format(op.output[0], str(op)))
+
     def Play(self, op):
         """"Adds an op to the current IR, and update the internal states to
         reflect the blobs and versions after the execution of the op.
diff --git a/caffe2/python/core_gradients_test.py b/caffe2/python/core_gradients_test.py
index 1811832..e891d8e 100644
--- a/caffe2/python/core_gradients_test.py
+++ b/caffe2/python/core_gradients_test.py
@@ -493,6 +493,21 @@
             operators, {'out': 'out_grad'})
         self.assertEqual(gradients, desired_grad_operators)
 
+    @unittest.expectedFailure
+    def testStopGradientOrphan(self):
+        operators = [
+            CreateOperator('Direct', 'in', 'hidden'),
+            CreateOperator('StopGradient', 'hidden', 'auto_blobx'),
+            CreateOperator('Direct', 'hidden', 'out'),
+        ]
+        try:
+            # This should complain about incorrect use of StopGradient
+            gradients, _ = GradientRegistry.GetBackwardPass(
+                operators, {'out': 'out_grad'})
+        except Exception as e:
+            print(e)
+            raise e
+
     def testStopGradientInplace(self):
         operators = [
             CreateOperator('Direct', 'in', 'hidden'),