Fix higher-order tape gradients of cond and case
Piggybacks on the tf.function tape interface logic: higher-order non-tape gradients work, and tf.function takes non-tape gradients of its contents.
The same fix applies to While, but the test I have in mind needs another fix before it's viable. Starting small here since cond is easier.
PiperOrigin-RevId: 337924627
Change-Id: Ife7e05a2c0818f6310c4cad19ec4fc46c8382000
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index ab32c83..9d9cf0b 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -1418,13 +1418,6 @@
num_output_tangents)
-# Represents the output of TFE_Py_TapeSetPossibleGradientTypes. Real enums are
-# unfortunately too slow to use here.
-_POSSIBLE_GRADIENT_TYPES_NONE = 0
-_POSSIBLE_GRADIENT_TYPES_FIRST_ORDER = 1
-_POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER = 2
-
-
class _ForwardBackwardCall(object):
"""Holds the state of a function call between execution and recording."""
@@ -1918,9 +1911,8 @@
"on invocation of %s, the %d-th input (%s) was not a "
"Tensor." % (self._func_graph.name, i, str(arg)))
args = tensor_inputs + captured_inputs
- possible_gradient_type = (
- pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes(args))
- if (possible_gradient_type == _POSSIBLE_GRADIENT_TYPES_NONE
+ possible_gradient_type = gradients_util.PossibleTapeGradientTypes(args)
+ if (possible_gradient_type == gradients_util.POSSIBLE_GRADIENT_TYPES_NONE
and executing_eagerly):
# No tape is watching; skip to running the function.
return self._build_call_outputs(self._inference_function.call(
@@ -2080,7 +2072,7 @@
Args:
args: A flat list of Tensors with all of the inputs to the forward
function (including user-specified and captured inputs).
- possible_gradient_type: One of _POSSIBLE_GRADIENT_TYPES_*.
+ possible_gradient_type: One of gradients_util.POSSIBLE_GRADIENT_TYPES_*.
executing_eagerly: Boolean, the value of context.executing_eagerly().
Returns:
@@ -2098,7 +2090,8 @@
# Allows re-use of forward and backward function pairs depending on the
# tapes and forward accumulators watching its inputs.
cache_key = (need_gradients_for_jvps, input_tangents.indices)
- if possible_gradient_type == _POSSIBLE_GRADIENT_TYPES_FIRST_ORDER:
+ if (possible_gradient_type
+ == gradients_util.POSSIBLE_GRADIENT_TYPES_FIRST_ORDER):
if input_tangents.indices or executing_eagerly:
# There is a single non-persistent tape active, so the user can only
# request first-order gradients from a tape. We can spend less time
@@ -2129,7 +2122,8 @@
return _ForwardBackwardCall(
self._delayed_rewrite_functions, args, input_tangents.tangents,
tape_watching=True)
- elif possible_gradient_type == _POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER:
+ elif (possible_gradient_type
+ == gradients_util.POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER):
# Either there's a persistent tape watching, or there are multiple nested
# tapes. Either way, the user may request higher-order gradients. We'll
# spend a bit more time and make sure higher-order gradients are correct.
@@ -2144,7 +2138,7 @@
self._higher_order_tape_functions[cache_key] = functions
return _ForwardBackwardCall(functions, args, input_tangents.tangents,
tape_watching=True)
- # else possible_gradient_type == _POSSIBLE_GRADIENT_TYPES_NONE, meaning no
+ # else possible_gradient_type == POSSIBLE_GRADIENT_TYPES_NONE, meaning no
# tape is recording.
return _ForwardBackwardCall(
self._delayed_rewrite_functions, args, input_tangents.tangents,
diff --git a/tensorflow/python/kernel_tests/cond_v2_test.py b/tensorflow/python/kernel_tests/cond_v2_test.py
index 70d7b25..fb60dc2 100644
--- a/tensorflow/python/kernel_tests/cond_v2_test.py
+++ b/tensorflow/python/kernel_tests/cond_v2_test.py
@@ -960,6 +960,42 @@
self.assertAllEqual(fn_with_cond(), 12.0)
+ def _CheckIteratedCosGradients(self, func):
+
+ def _grad(f):
+ def _grad_function(primal):
+ with backprop.GradientTape() as tape:
+ tape.watch(primal)
+ primal_out = f(primal)
+ return tape.gradient(primal_out, primal)
+ return _grad_function
+
+ f = func
+ one = constant_op.constant(1.)
+ for expected in [math_ops.cos,
+ lambda x: -math_ops.sin(x),
+ lambda x: -math_ops.cos(x),
+ math_ops.sin,
+ math_ops.cos]:
+ self.assertAllClose(expected(one), def_function.function(f)(one))
+ f = _grad(f)
+
+ def testIteratedGradientsCond(self):
+ def _func(x):
+ return cond_v2.cond_v2(
+ constant_op.constant(True),
+ lambda: math_ops.cos(array_ops.identity(x)),
+ lambda: math_ops.sin(array_ops.identity(x)))
+ self._CheckIteratedCosGradients(_func)
+
+ def testIteratedGradientsCase(self):
+ def _func(x):
+ return cond_v2.indexed_case(
+ constant_op.constant(1),
+ [lambda: math_ops.sin(array_ops.identity(x)),
+ lambda: math_ops.cos(array_ops.identity(x))])
+ self._CheckIteratedCosGradients(_func)
+
def testLowering(self):
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
diff --git a/tensorflow/python/ops/cond_v2.py b/tensorflow/python/ops/cond_v2.py
index 5bdd249..75130fc 100644
--- a/tensorflow/python/ops/cond_v2.py
+++ b/tensorflow/python/ops/cond_v2.py
@@ -26,6 +26,7 @@
import collections
from tensorflow.python.eager import backprop_util
+from tensorflow.python.eager import function
from tensorflow.python.framework import auto_control_deps
from tensorflow.python.framework import auto_control_deps_utils as acd
from tensorflow.python.framework import constant_op
@@ -192,6 +193,37 @@
return [None] + outputs
+def _run_as_function_for_tape_gradients(make_op, cond_inputs):
+ """Fix higher-order tape gradients by wrapping `make_op` in a function."""
+ # GradientTapes created inside a function currently don't work well with
+ # un-wrapped control flow ops in that same function. Wrapping in an extra
+ # layer of intermediate function means we run extra logic in the function
+ # gradient code to record the correct intermediates on the tape.
+ #
+ # The function attribute inputs to cond/case ops are not hashable, so we pass
+ # everything as a capture to bypass defun's caching.
+ if (gradients_util.PossibleTapeGradientTypes(cond_inputs)
+ == gradients_util.POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER
+ # We only need one function between the tape and the cond; if we've
+ # already wrapped once, we stop wrapping to avoid infinite recursion.
+ and not (ops.get_default_graph().building_function
+ and "cond_gradient_wrapper" in ops.get_default_graph().name)):
+
+ op = None
+ def _run_make_and_extract_op():
+ # Post-processing happens on the cond op, not the function call op.
+ nonlocal op
+ tensors = make_op()
+ op, tensors = _get_op_and_outputs(tensors) # pylint: disable=unused-variable
+ return tensors
+
+ return op, function.defun_with_attributes(
+ _run_make_and_extract_op,
+ attributes=dict(func_name="cond_gradient_wrapper"))()
+ else:
+ return _get_op_and_outputs(make_op())
+
+
def _build_cond(pred,
true_graph,
false_graph,
@@ -268,16 +300,17 @@
else:
op_fn = gen_functional_ops.stateless_if
- tensors = op_fn(
- pred,
- cond_inputs, [t.dtype for t in true_graph.outputs],
- util.create_new_tf_function(true_graph),
- util.create_new_tf_function(false_graph),
- output_shapes=_get_output_shapes(true_graph.outputs,
- false_graph.outputs),
- name=name)
+ def make_op():
+ return op_fn(
+ pred,
+ cond_inputs, [t.dtype for t in true_graph.outputs],
+ util.create_new_tf_function(true_graph),
+ util.create_new_tf_function(false_graph),
+ output_shapes=_get_output_shapes(true_graph.outputs,
+ false_graph.outputs),
+ name=name)
+ if_op, tensors = _run_as_function_for_tape_gradients(make_op, cond_inputs)
- if_op, tensors = _get_op_and_outputs(tensors)
# `if_op` is None if this is a `StatelessIf` op with no outputs.
if if_op is not None:
if_op._true_graph = true_graph
@@ -1156,14 +1189,16 @@
# Create the Case op.
with ops.control_dependencies(
sum((list(bg.control_captures) for bg in branch_graphs), [])):
- tensors = op_fn(
- branch_index,
- case_inputs, [t.dtype for t in branch_graphs[0].outputs],
- [util.create_new_tf_function(g) for g in branch_graphs],
- output_shapes=_get_output_shapes(*[g.outputs for g in branch_graphs]),
- name=name)
- case_op, tensors = _get_op_and_outputs(tensors)
+ def _make_op():
+ return op_fn(
+ branch_index,
+ case_inputs, [t.dtype for t in branch_graphs[0].outputs],
+ [util.create_new_tf_function(g) for g in branch_graphs],
+ output_shapes=_get_output_shapes(*[g.outputs for g in branch_graphs]),
+ name=name)
+ case_op, tensors = _run_as_function_for_tape_gradients(
+ _make_op, case_inputs)
if case_op is not None:
util.maybe_set_lowering_attr(case_op, lower_using_switch_merge)
diff --git a/tensorflow/python/ops/gradients_util.py b/tensorflow/python/ops/gradients_util.py
index 4d4df0f..c356e82 100644
--- a/tensorflow/python/ops/gradients_util.py
+++ b/tensorflow/python/ops/gradients_util.py
@@ -24,6 +24,7 @@
from six.moves import xrange, zip # pylint: disable=redefined-builtin
from tensorflow.core.framework import attr_value_pb2
+from tensorflow.python import pywrap_tfe
from tensorflow.python.eager import backprop
from tensorflow.python.eager import backprop_util
from tensorflow.python.eager import context
@@ -1007,3 +1008,15 @@
# out_grads[i] is [], thus its aggregation is simply None.
out_grads[i] = None
return out_grads
+
+
+# Represents the output of TFE_Py_TapeSetPossibleGradientTypes. Real enums are
+# unfortunately too slow to use here.
+POSSIBLE_GRADIENT_TYPES_NONE = 0
+POSSIBLE_GRADIENT_TYPES_FIRST_ORDER = 1
+POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER = 2
+
+
+def PossibleTapeGradientTypes(tensors):
+ """Determines whether and how `args` may require tape gradients."""
+ return pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes(tensors)