Fix higher-order tape gradients of cond and case

Piggybacks on the tf.function tape interface logic: higher-order non-tape gradients work, and tf.function takes non-tape gradients of its contents.

The same fix applies to While, but the test I have in mind needs another fix before it's viable. Starting small here since cond is easier.

PiperOrigin-RevId: 337924627
Change-Id: Ife7e05a2c0818f6310c4cad19ec4fc46c8382000
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index ab32c83..9d9cf0b 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -1418,13 +1418,6 @@
             num_output_tangents)
 
 
-# Represents the output of TFE_Py_TapeSetPossibleGradientTypes. Real enums are
-# unfortunately too slow to use here.
-_POSSIBLE_GRADIENT_TYPES_NONE = 0
-_POSSIBLE_GRADIENT_TYPES_FIRST_ORDER = 1
-_POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER = 2
-
-
 class _ForwardBackwardCall(object):
   """Holds the state of a function call between execution and recording."""
 
@@ -1918,9 +1911,8 @@
                          "on invocation of %s, the %d-th input (%s) was not a "
                          "Tensor." % (self._func_graph.name, i, str(arg)))
     args = tensor_inputs + captured_inputs
-    possible_gradient_type = (
-        pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes(args))
-    if (possible_gradient_type == _POSSIBLE_GRADIENT_TYPES_NONE
+    possible_gradient_type = gradients_util.PossibleTapeGradientTypes(args)
+    if (possible_gradient_type == gradients_util.POSSIBLE_GRADIENT_TYPES_NONE
         and executing_eagerly):
       # No tape is watching; skip to running the function.
       return self._build_call_outputs(self._inference_function.call(
@@ -2080,7 +2072,7 @@
     Args:
       args: A flat list of Tensors with all of the inputs to the forward
         function (including user-specified and captured inputs).
-      possible_gradient_type: One of _POSSIBLE_GRADIENT_TYPES_*.
+      possible_gradient_type: One of gradients_util.POSSIBLE_GRADIENT_TYPES_*.
       executing_eagerly: Boolean, the value of context.executing_eagerly().
 
     Returns:
@@ -2098,7 +2090,8 @@
     # Allows re-use of forward and backward function pairs depending on the
     # tapes and forward accumulators watching its inputs.
     cache_key = (need_gradients_for_jvps, input_tangents.indices)
-    if possible_gradient_type == _POSSIBLE_GRADIENT_TYPES_FIRST_ORDER:
+    if (possible_gradient_type
+        == gradients_util.POSSIBLE_GRADIENT_TYPES_FIRST_ORDER):
       if input_tangents.indices or executing_eagerly:
         # There is a single non-persistent tape active, so the user can only
         # request first-order gradients from a tape. We can spend less time
@@ -2129,7 +2122,8 @@
         return _ForwardBackwardCall(
             self._delayed_rewrite_functions, args, input_tangents.tangents,
             tape_watching=True)
-    elif possible_gradient_type == _POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER:
+    elif (possible_gradient_type
+          == gradients_util.POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER):
       # Either there's a persistent tape watching, or there are multiple nested
       # tapes. Either way, the user may request higher-order gradients. We'll
       # spend a bit more time and make sure higher-order gradients are correct.
@@ -2144,7 +2138,7 @@
         self._higher_order_tape_functions[cache_key] = functions
       return _ForwardBackwardCall(functions, args, input_tangents.tangents,
                                   tape_watching=True)
-    # else possible_gradient_type == _POSSIBLE_GRADIENT_TYPES_NONE, meaning no
+    # else possible_gradient_type == POSSIBLE_GRADIENT_TYPES_NONE, meaning no
     # tape is recording.
     return _ForwardBackwardCall(
         self._delayed_rewrite_functions, args, input_tangents.tangents,
diff --git a/tensorflow/python/kernel_tests/cond_v2_test.py b/tensorflow/python/kernel_tests/cond_v2_test.py
index 70d7b25..fb60dc2 100644
--- a/tensorflow/python/kernel_tests/cond_v2_test.py
+++ b/tensorflow/python/kernel_tests/cond_v2_test.py
@@ -960,6 +960,42 @@
 
       self.assertAllEqual(fn_with_cond(), 12.0)
 
+  def _CheckIteratedCosGradients(self, func):
+
+    def _grad(f):
+      def _grad_function(primal):
+        with backprop.GradientTape() as tape:
+          tape.watch(primal)
+          primal_out = f(primal)
+        return tape.gradient(primal_out, primal)
+      return _grad_function
+
+    f = func
+    one = constant_op.constant(1.)
+    for expected in [math_ops.cos,
+                     lambda x: -math_ops.sin(x),
+                     lambda x: -math_ops.cos(x),
+                     math_ops.sin,
+                     math_ops.cos]:
+      self.assertAllClose(expected(one), def_function.function(f)(one))
+      f = _grad(f)
+
+  def testIteratedGradientsCond(self):
+    def _func(x):
+      return cond_v2.cond_v2(
+          constant_op.constant(True),
+          lambda: math_ops.cos(array_ops.identity(x)),
+          lambda: math_ops.sin(array_ops.identity(x)))
+    self._CheckIteratedCosGradients(_func)
+
+  def testIteratedGradientsCase(self):
+    def _func(x):
+      return cond_v2.indexed_case(
+          constant_op.constant(1),
+          [lambda: math_ops.sin(array_ops.identity(x)),
+           lambda: math_ops.cos(array_ops.identity(x))])
+    self._CheckIteratedCosGradients(_func)
+
   def testLowering(self):
     with ops.Graph().as_default() as g:
       with self.session(graph=g) as sess:
diff --git a/tensorflow/python/ops/cond_v2.py b/tensorflow/python/ops/cond_v2.py
index 5bdd249..75130fc 100644
--- a/tensorflow/python/ops/cond_v2.py
+++ b/tensorflow/python/ops/cond_v2.py
@@ -26,6 +26,7 @@
 import collections
 
 from tensorflow.python.eager import backprop_util
+from tensorflow.python.eager import function
 from tensorflow.python.framework import auto_control_deps
 from tensorflow.python.framework import auto_control_deps_utils as acd
 from tensorflow.python.framework import constant_op
@@ -192,6 +193,37 @@
   return [None] + outputs
 
 
+def _run_as_function_for_tape_gradients(make_op, cond_inputs):
+  """Fix higher-order tape gradients by wrapping `make_op` in a function."""
+  # GradientTapes created inside a function currently don't work well with
+  # un-wrapped control flow ops in that same function. Wrapping in an extra
+  # layer of intermediate function means we run extra logic in the function
+  # gradient code to record the correct intermediates on the tape.
+  #
+  # The function attribute inputs to cond/case ops are not hashable, so we pass
+  # everything as a capture to bypass defun's caching.
+  if (gradients_util.PossibleTapeGradientTypes(cond_inputs)
+      == gradients_util.POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER
+      # We only need one function between the tape and the cond; if we've
+      # already wrapped once, we stop wrapping to avoid infinite recursion.
+      and not (ops.get_default_graph().building_function
+               and "cond_gradient_wrapper" in ops.get_default_graph().name)):
+
+    op = None
+    def _run_make_and_extract_op():
+      # Post-processing happens on the cond op, not the function call op.
+      nonlocal op
+      tensors = make_op()
+      op, tensors = _get_op_and_outputs(tensors)  # pylint: disable=unused-variable
+      return tensors
+
+    return op, function.defun_with_attributes(
+        _run_make_and_extract_op,
+        attributes=dict(func_name="cond_gradient_wrapper"))()
+  else:
+    return _get_op_and_outputs(make_op())
+
+
 def _build_cond(pred,
                 true_graph,
                 false_graph,
@@ -268,16 +300,17 @@
     else:
       op_fn = gen_functional_ops.stateless_if
 
-    tensors = op_fn(
-        pred,
-        cond_inputs, [t.dtype for t in true_graph.outputs],
-        util.create_new_tf_function(true_graph),
-        util.create_new_tf_function(false_graph),
-        output_shapes=_get_output_shapes(true_graph.outputs,
-                                         false_graph.outputs),
-        name=name)
+    def make_op():
+      return op_fn(
+          pred,
+          cond_inputs, [t.dtype for t in true_graph.outputs],
+          util.create_new_tf_function(true_graph),
+          util.create_new_tf_function(false_graph),
+          output_shapes=_get_output_shapes(true_graph.outputs,
+                                           false_graph.outputs),
+          name=name)
+    if_op, tensors = _run_as_function_for_tape_gradients(make_op, cond_inputs)
 
-  if_op, tensors = _get_op_and_outputs(tensors)
   # `if_op` is None if this is a `StatelessIf` op with no outputs.
   if if_op is not None:
     if_op._true_graph = true_graph
@@ -1156,14 +1189,16 @@
   # Create the Case op.
   with ops.control_dependencies(
       sum((list(bg.control_captures) for bg in branch_graphs), [])):
-    tensors = op_fn(
-        branch_index,
-        case_inputs, [t.dtype for t in branch_graphs[0].outputs],
-        [util.create_new_tf_function(g) for g in branch_graphs],
-        output_shapes=_get_output_shapes(*[g.outputs for g in branch_graphs]),
-        name=name)
 
-  case_op, tensors = _get_op_and_outputs(tensors)
+    def _make_op():
+      return op_fn(
+          branch_index,
+          case_inputs, [t.dtype for t in branch_graphs[0].outputs],
+          [util.create_new_tf_function(g) for g in branch_graphs],
+          output_shapes=_get_output_shapes(*[g.outputs for g in branch_graphs]),
+          name=name)
+    case_op, tensors = _run_as_function_for_tape_gradients(
+        _make_op, case_inputs)
 
   if case_op is not None:
     util.maybe_set_lowering_attr(case_op, lower_using_switch_merge)
diff --git a/tensorflow/python/ops/gradients_util.py b/tensorflow/python/ops/gradients_util.py
index 4d4df0f..c356e82 100644
--- a/tensorflow/python/ops/gradients_util.py
+++ b/tensorflow/python/ops/gradients_util.py
@@ -24,6 +24,7 @@
 from six.moves import xrange, zip  # pylint: disable=redefined-builtin
 
 from tensorflow.core.framework import attr_value_pb2
+from tensorflow.python import pywrap_tfe
 from tensorflow.python.eager import backprop
 from tensorflow.python.eager import backprop_util
 from tensorflow.python.eager import context
@@ -1007,3 +1008,15 @@
       # out_grads[i] is [], thus its aggregation is simply None.
       out_grads[i] = None
   return out_grads
+
+
+# Represents the output of TFE_Py_TapeSetPossibleGradientTypes. Real enums are
+# unfortunately too slow to use here.
+POSSIBLE_GRADIENT_TYPES_NONE = 0
+POSSIBLE_GRADIENT_TYPES_FIRST_ORDER = 1
+POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER = 2
+
+
+def PossibleTapeGradientTypes(tensors):
+  """Determines whether and how `args` may require tape gradients."""
+  return pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes(tensors)