warn about orphan StopGradient output
Summary: Quite common confusion is how to use StopGradient, and typical bug is to forget to specify input=output. This adds a sanity check to gradient builder that checks if some StopGradient outputs are orphaned.
Reviewed By: dzhulgakov
Differential Revision: D5458341
fbshipit-source-id: 056fef4f0ee53eb10e66e9be0ecb55b55f9cc3d7
diff --git a/caffe2/python/core.py b/caffe2/python/core.py
index 3b9ad2e..f410b7e 100644
--- a/caffe2/python/core.py
+++ b/caffe2/python/core.py
@@ -425,6 +425,18 @@
for op in operators:
self.Play(op)
+ self.SanityCheck(operators)
+
+ def SanityCheck(self, operators):
+ # Validate StopGradient usage by checking that StopGradient's output
+ # is actually passed forward
+ for op in operators:
+ if op.type == 'StopGradient':
+ if op.output[0] not in self.input_usages:
+ raise Exception("""StopGradient's output '{}' is orphan.
+You typically want to specify same input and output for
+StopGradient. Op:\n\n{}""".format(op.output[0], str(op)))
+
def Play(self, op):
""""Adds an op to the current IR, and update the internal states to
reflect the blobs and versions after the execution of the op.
diff --git a/caffe2/python/core_gradients_test.py b/caffe2/python/core_gradients_test.py
index 1811832..e891d8e 100644
--- a/caffe2/python/core_gradients_test.py
+++ b/caffe2/python/core_gradients_test.py
@@ -493,6 +493,21 @@
operators, {'out': 'out_grad'})
self.assertEqual(gradients, desired_grad_operators)
+ @unittest.expectedFailure
+ def testStopGradientOrphan(self):
+ operators = [
+ CreateOperator('Direct', 'in', 'hidden'),
+ CreateOperator('StopGradient', 'hidden', 'auto_blobx'),
+ CreateOperator('Direct', 'hidden', 'out'),
+ ]
+ try:
+ # This should complain about incorrect use of StopGradient
+ gradients, _ = GradientRegistry.GetBackwardPass(
+ operators, {'out': 'out_grad'})
+ except Exception as e:
+ print(e)
+ raise e
+
def testStopGradientInplace(self):
operators = [
CreateOperator('Direct', 'in', 'hidden'),