blob: b72dac1dca8ef5ec47f5c2114d82335e8f14e05b [file] [log] [blame]
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import gc
import weakref
from absl.testing import parameterized
import numpy as np
from tensorflow.python import pywrap_tfe
from tensorflow.python.distribute import mirrored_strategy
from tensorflow.python.eager import backprop
from tensorflow.python.eager import def_function
from tensorflow.python.eager import forwardprop
from tensorflow.python.eager import forwardprop_util
from tensorflow.python.eager import tape as tape_lib
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.module import module
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import custom_gradient
from tensorflow.python.ops import gradient_checker_v2
from tensorflow.python.ops import map_fn
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_impl
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops.parallel_for import control_flow_ops
from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients
from tensorflow.python.platform import test
from tensorflow.python.util import nest
_X11_35_DERIVATIVES = [
1.1 ** 3.5,
3.5 * 1.1 ** 2.5,
3.5 * 2.5 * 1.1 ** 1.5,
3.5 * 2.5 * 1.5 * 1.1 ** 0.5]
# TODO(allenl): Move this somewhere useful once forward gradients are stable.
def _jvp(f, primals, tangents):
"""Compute the jacobian of `f` at `primals` multiplied by `tangents`."""
with forwardprop.ForwardAccumulator(primals, tangents) as acc:
primals_out = f(*primals)
return primals_out, acc.jvp(
primals_out, unconnected_gradients=UnconnectedGradients.ZERO)
def _jacfwd(f, primals):
"""Compute the jacobian of `f` at `primals` using forward-mode autodiff."""
jac_flat = []
flat_primals = nest.flatten(primals)
tangent_mask = [array_ops.zeros_like(primal) for primal in flat_primals]
for primal_index, primal in enumerate(flat_primals):
primal_vector = array_ops.reshape(primal, [-1])
primal_vector_length = array_ops.size(primal_vector)
jac_columns = []
for element_index in math_ops.range(primal_vector_length):
mask = array_ops.one_hot(element_index, primal_vector_length)
tangent_mask[primal_index] = array_ops.reshape(mask,
array_ops.shape(primal))
jac_columns.append(
nest.map_structure(
functools.partial(array_ops.reshape, shape=[-1]),
_jvp(f, primals,
nest.pack_sequence_as(primals, tangent_mask))[1]))
jac_flat.append(array_ops.stack(jac_columns, axis=1))
tangent_mask[primal_index] = array_ops.zeros_like(primal)
return nest.pack_sequence_as(primals, jac_flat)
def _jvp_batch(f, primal, tangents):
tf_function = def_function.function(f)
return control_flow_ops.vectorized_map(
functools.partial(_jvp, tf_function, primal), tangents)
def _jvp_batch_matmul(f, primals, tangent_batch):
"""Compute the jacobian of `f` at `primals` multiplied by `tangents`."""
jac_fwd = _jacfwd(f, primals)
def jac_mul(tangent):
flat_tangent = array_ops.reshape(tangent, shape=[-1])
tangent_vector = array_ops.expand_dims(flat_tangent, 1)
jvp_vector = math_ops.matmul(jac_fwd, tangent_vector)
return array_ops.reshape(jvp_vector, tangent.shape)
return control_flow_ops.vectorized_map(jac_mul, tangent_batch)
def _grad(f, argnums=0):
"""Return a function which computes the gradient of `f`."""
def _f(*params):
with backprop.GradientTape() as tape:
tape.watch(params)
primals_out = f(*params)
return tape.gradient(
primals_out,
params[argnums],
unconnected_gradients=UnconnectedGradients.ZERO)
return _f
def _gradfwd(f, argnums=0, f_out_dtypes=dtypes.float32):
"""Return a function which computes the gradient of `f` in forward mode."""
def _f(*params):
def _single_jvp(param_mask):
with forwardprop.ForwardAccumulator(primals=[params[argnums]],
tangents=param_mask) as acc:
primals_out = f(*params)
return acc.jvp(primals_out)
# Building up a function to run with pfor takes a bit too long since we're
# only running it a handful of times.
return _vectorize_parameters(_single_jvp, [params[argnums]],
use_pfor=False, dtype=f_out_dtypes)
return _f
def _hvp(f, primals, tangents):
"""Compute a forward-over-back Hessian-vector product."""
with forwardprop.ForwardAccumulator(primals, tangents) as acc:
with backprop.GradientTape() as tape:
tape.watch(primals)
f_out = f(*primals)
f_out.shape.assert_is_compatible_with([])
return acc.jvp(tape.gradient(f_out, primals))
def _vectorize_parameters(f, params, use_pfor, dtype):
"""Loop over `params`, providing a one-hot mask to `f` for each."""
parameter_sizes = [array_ops.size(param) for param in params]
total_size = math_ops.add_n(parameter_sizes)
def _wrapper(index):
full_onehot = array_ops.one_hot(index, total_size)
split_onehot = array_ops.split(full_onehot, parameter_sizes)
tangents = [array_ops.reshape(v, array_ops.shape(param))
for param, v in zip(params, split_onehot)]
return f(tangents)
if use_pfor:
return control_flow_ops.vectorized_map(_wrapper, math_ops.range(total_size))
return map_fn.map_fn(_wrapper, math_ops.range(total_size), dtype)
def _forward_over_back_hessian(f, params, use_pfor, dtype=None):
"""Computes the full Hessian matrix for the scalar-valued f(*params).
Args:
f: A function taking `params` and returning a scalar.
params: A possibly nested structure of tensors.
use_pfor: If true, uses `tf.vectorized_map` calls instead of looping.
dtype: Required if `use_pfor=False`. A possibly nested structure of dtypes
(e.g. `tf.float32`) matching the structure of `f`'s returns.
Returns:
A possibly nested structure of matrix slices corresponding to `params`. Each
slice has shape [P, p_s] where `p_s` is the number of parameters (`tf.size`)
in the corresponding element of `params` and `P` is the total number of
parameters (`sum_s(p_s)`). The full matrix can be obtained by concatenating
along the second axis.
"""
return _vectorize_parameters(
functools.partial(_hvp, f, params),
params, use_pfor=use_pfor, dtype=dtype)
def _test_gradients(testcase,
f,
primals,
order,
delta=1e-3,
rtol=1e-2,
atol=1e-6):
"""Tests forward/backward jacobians of `f`'s [0, `order`)-order gradients."""
if order < 1:
raise ValueError(
"`order` should be a positive integer, got '{}'.".format(order))
if order > 1:
_test_gradients(
testcase=testcase,
f=_grad(f),
primals=primals,
order=order - 1,
delta=delta,
rtol=rtol,
atol=atol)
sym_jac_back, num_jac = gradient_checker_v2.compute_gradient(
f, primals, delta=delta)
testcase.assertAllClose(num_jac, sym_jac_back, rtol=rtol, atol=atol)
sym_jac_fwd = _jacfwd(f, primals)
testcase.assertAllClose(num_jac, sym_jac_fwd, rtol=rtol, atol=atol)
# And the symbolic computations should be much closer.
testcase.assertAllClose(sym_jac_back, sym_jac_fwd)
class ForwardpropTest(test.TestCase, parameterized.TestCase):
def testJVPFunction(self):
add_outputs = (constant_op.constant(4.),)
vp, = forwardprop._jvp_dispatch(
op_name="Add",
attr_tuple=(),
inputs=(constant_op.constant(1.), constant_op.constant(3.)),
outputs=add_outputs,
tangents=(
constant_op.constant(1.),
constant_op.constant(5.),
))
self.assertAllClose(1. + 5., self.evaluate(vp))
mul_outputs = (constant_op.constant([20.]),)
vp, = forwardprop._jvp_dispatch(
op_name="Mul",
attr_tuple=(),
inputs=(constant_op.constant([4.]), constant_op.constant([5.])),
outputs=mul_outputs,
tangents=(
constant_op.constant([2.]),
constant_op.constant([3.]),
))
self.assertAllClose([2. * 5. + 3. * 4.], self.evaluate(vp))
def testJVPFunctionWithBatchOfTangents(self):
add_outputs = (constant_op.constant(4.),)
jvp_flat = forwardprop._jvp_dispatch(
op_name="Add",
attr_tuple=(),
inputs=(constant_op.constant(1.), constant_op.constant(3.)),
outputs=add_outputs,
tangents=(
constant_op.constant([1., 2., 3.]),
constant_op.constant([4., 5., 6.]),
),
use_batch=True)
# Using evaluate and asserting with just a list works too
# but the output is more explicit this way
self.assertAllClose([constant_op.constant([1. + 4., 2. + 5., 3. + 6.])],
jvp_flat)
mul_outputs = (constant_op.constant([20.]),)
jvp_flat = forwardprop._jvp_dispatch(
op_name="Mul",
attr_tuple=(),
inputs=(constant_op.constant([4.]), constant_op.constant([5.])),
outputs=mul_outputs,
tangents=(
constant_op.constant([[1.], [0.], [1.]]),
constant_op.constant([[0.], [1.], [1.]]),
),
use_batch=True)
self.assertAllClose([constant_op.constant([[5.], [4.], [5. + 4.]])],
jvp_flat)
def testJVPFunctionRaisesError(self):
sum_outputs = (constant_op.constant(6.),)
with self.assertRaisesRegex(ValueError, r".*was expected to be of shape*"):
forwardprop._jvp_dispatch(
op_name="Add",
attr_tuple=(),
inputs=(constant_op.constant(2.), constant_op.constant(4.)),
outputs=sum_outputs,
tangents=(constant_op.constant([1., 2.]),
constant_op.constant([[1.], [2.]])),
use_batch=True)
def testNonDifferentiableOpWithInputTangent(self):
x = constant_op.constant(1.)
with forwardprop.ForwardAccumulator(x, 2.) as acc1:
with forwardprop.ForwardAccumulator(x, 2.) as acc2:
y = array_ops.zeros_like(x)
self.assertIsNone(acc1.jvp(y))
self.assertIsNone(acc2.jvp(y))
def testRunFunctionsEagerly(self):
try:
original_setting = def_function.functions_run_eagerly()
def_function.run_functions_eagerly(True)
x = constant_op.constant(1.)
with forwardprop.ForwardAccumulator(x, 2.) as acc:
y = x * 3.
self.assertAllClose(6., acc.jvp(y))
finally:
def_function.run_functions_eagerly(original_setting)
def testJVPFunctionUsedByAccumulatorForOps(self):
previous_fn = forwardprop._jvp_dispatch
try:
x = constant_op.constant(1.)
with forwardprop.ForwardAccumulator(x, 2.) as acc:
y = x + x
pywrap_tfe.TFE_Py_RegisterJVPFunction(
lambda *args, **kwargs: [constant_op.constant(-15.)])
z = x + x
self.assertAllClose(4., acc.jvp(y))
self.assertAllClose(-15., acc.jvp(z))
finally:
pywrap_tfe.TFE_Py_RegisterJVPFunction(previous_fn)
@test_util.assert_no_new_pyobjects_executing_eagerly
def testFunctionCacheLimited(self):
# Every time this test is executed, it will create a slightly larger Tensor
# and push it through Add's gradient. Since we check for new pyobjects after
# the warmup, retracing each time without cleaning up old traces fails the
# test. It works because of experimental_relax_shapes.
for _ in range(forwardprop._TRACE_COUNT_LIMIT):
execution_count = getattr(self, "_execution_count", 0)
self._execution_count = execution_count + 1
x = array_ops.zeros([execution_count])
with forwardprop.ForwardAccumulator(
x, array_ops.ones_like(x)) as acc:
y = x + x
self.assertAllClose(2. * array_ops.ones_like(x), acc.jvp(y))
@test_util.assert_no_new_pyobjects_executing_eagerly
def testMultipleWatchesAdd(self):
x = constant_op.constant(-2.)
with self.assertRaisesRegex(ValueError, "multiple times"):
with forwardprop.ForwardAccumulator(
[x, x], [1., 2.]):
pass
with forwardprop.ForwardAccumulator(
[x], [3.]) as acc:
self.assertAllClose(3., acc.jvp(x))
acc._watch(x, constant_op.constant(10.))
self.assertAllClose(13., acc.jvp(x))
acc._watch(x, constant_op.constant(11.))
self.assertAllClose(24., acc.jvp(x))
y = constant_op.constant(3.) * x
self.assertAllClose(24., acc.jvp(x))
self.assertAllClose(24. * 3., acc.jvp(y))
@test_util.assert_no_new_pyobjects_executing_eagerly
def testReenter(self):
x = constant_op.constant(-2.)
with forwardprop.ForwardAccumulator(x, 1.5) as acc:
self.assertAllClose(1.5, acc.jvp(x))
y = 4. * x
self.assertAllClose(6., acc.jvp(y))
with self.assertRaisesRegex(ValueError, "already recording"):
with acc:
pass
z = 4. * x
self.assertIsNone(acc.jvp(z))
with acc:
yy = y * y
self.assertAllClose(6. * -8. * 2., acc.jvp(yy))
@test_util.assert_no_new_pyobjects_executing_eagerly
def testDeadTensorsJVPCleared(self):
x = array_ops.ones([100])
x_weak = weakref.ref(x)
grad_tensor = constant_op.constant(array_ops.zeros([100]))
grad_tensor_weak = weakref.ref(grad_tensor)
with forwardprop.ForwardAccumulator(x, grad_tensor) as acc:
derived_tensor = constant_op.constant(2.) * x
del grad_tensor
self.assertAllClose(array_ops.zeros([100]), acc.jvp(x))
del x
self.assertIsNone(x_weak())
self.assertIsNone(grad_tensor_weak())
derived_tensor_weak = weakref.ref(derived_tensor)
derived_tensor_grad = acc.jvp(derived_tensor)
derived_tensor_grad_weak = weakref.ref(derived_tensor_grad)
del derived_tensor
del derived_tensor_grad
self.assertIsNone(derived_tensor_weak())
self.assertIsNone(derived_tensor_grad_weak())
@test_util.assert_no_new_pyobjects_executing_eagerly
def testJVPManual(self):
primal, tangent = _jvp(math_ops.sin, (constant_op.constant(0.1),),
(constant_op.constant(0.2),))
self.assertAllClose(math_ops.sin(0.1), primal)
self.assertAllClose(math_ops.cos(0.1) * 0.2, tangent)
@test_util.assert_no_new_pyobjects_executing_eagerly
def testNumericHigherOrder(self):
def f(x):
pointwise = math_ops.sin(x) * math_ops.tan(x)
return math_ops.reduce_prod(
pointwise + math_ops.reduce_sum(pointwise), axis=1)
_test_gradients(
self, f, [constant_op.constant([[2.0, 3.0], [1.0, 4.0]])], order=3)
@test_util.assert_no_new_pyobjects_executing_eagerly
def testCustomGradient(self):
@custom_gradient.custom_gradient
def f(x):
def grad(dy):
return dy * math_ops.cos(x)
return np.sin(x.numpy()), grad
_test_gradients(self, f, [constant_op.constant([1., 2.])], order=3)
# TODO(allenl): investigate why assert_no_new_pyobjects_executing_eagerly
# fails around this test?
def testExceptionCustomGradientRecomputeGradForward(self):
@custom_gradient.recompute_grad
def f(x):
return math_ops.reduce_prod(math_ops.tanh(x)**2)
with self.assertRaisesRegex(NotImplementedError,
"recompute_grad tried to transpose"):
primals = [constant_op.constant([1.])]
sym_jac_fwd = _jacfwd(f, primals)
def testExceptionInCustomGradientNotSwallowed(self):
@custom_gradient.custom_gradient
def f(unused_x):
def grad(unused_dy):
raise ValueError("test_error_string")
return 1., grad
c = constant_op.constant(1.)
d = constant_op.constant(2.)
with forwardprop.ForwardAccumulator(c, d):
with self.assertRaisesRegex(ValueError, "test_error_string"):
f(c)
@parameterized.named_parameters(
[("EluM5", -0.5, nn_ops.elu),
("EluP5", [0.5], nn_ops.elu),
("SwishP5", 0.5, nn_impl.swish),
("SwishM5", [-0.5], nn_impl.swish)])
def testElementwiseNNOps(self, value, op_fn):
_test_gradients(self, op_fn, [constant_op.constant(value)], order=3)
def testFusedBatchNormGradsInference(self):
if test.is_built_with_rocm():
# This test was addeded recently and has been failing on the ROCm
# platform, since it was added.
# TODO(rocm): do root cause analysis of test failure and fix it.
self.skipTest("Test fails on ROCm platform, needs further analysis")
x_shape = [4, 10, 10, 2]
increment = 3. / math_ops.reduce_prod(
constant_op.constant(x_shape, dtype=dtypes.float32))
x = array_ops.reshape(math_ops.range(-2., 1., increment), x_shape)
scale = constant_op.constant([1., 1.1])
offset = constant_op.constant([-0.5, -0.6])
mean = constant_op.constant([-1.3, 1.4])
variance = constant_op.constant([0.7, 0.9])
epsilon = 0.001
def _bn_fused(x_arg, scale_arg, offset_arg):
return nn_impl.fused_batch_norm(x_arg, scale_arg, offset_arg,
mean, variance,
epsilon=epsilon, is_training=False)[0]
_test_gradients(self, _bn_fused, [x, scale, offset],
order=2, atol=1e-2)
def testPushPopAccumulatorState(self):
# Note that this example is somewhat contrived. push_forwardprop_state is
# probably only useful in practice for building functions that compute jvps
# alongside their usual outputs.
c = constant_op.constant(1.)
d = constant_op.constant(2.)
with forwardprop.ForwardAccumulator(c, d) as acc:
@custom_gradient.custom_gradient
def f(x):
y = math_ops.sin(x.numpy())
def grad(dy):
with forwardprop_util.push_forwardprop_state():
x_copy = constant_op.constant(x.numpy())
acc._watch(x_copy, dy)
y_copy = math_ops.sin(x_copy)
return dy * acc.jvp(y_copy)
return y, grad
output = f(c)
self.assertAllClose(d * math_ops.cos(c), acc.jvp(output))
@parameterized.named_parameters(
[("Order{}".format(order), order, expected)
for order, expected in enumerate(_X11_35_DERIVATIVES)])
@test_util.assert_no_new_pyobjects_executing_eagerly
def testHigherOrderPureForward(self, order, expected):
def _forwardgrad(f):
def _compute_forwardgrad(primal):
tangent = constant_op.constant(1.)
with forwardprop.ForwardAccumulator(primal, tangent) as acc:
primal_out = f(primal)
return acc.jvp(primal_out)
return _compute_forwardgrad
def _forward(x):
return x ** 3.5
f = _forward
primal = constant_op.constant(1.1)
for _ in range(order):
f = _forwardgrad(f)
self.assertAllClose(expected, f(primal))
@parameterized.named_parameters(
[("Function", def_function.function),
("NoFunction", lambda f: f)])
def testGradPureForward(self, decorator):
@decorator
def f(x):
return x ** 3.5
primal = constant_op.constant(1.1)
with forwardprop.ForwardAccumulator(
primal, constant_op.constant(1.)) as outer_acc:
with forwardprop.ForwardAccumulator(
primal, constant_op.constant(1.)) as acc:
primal_out = f(primal)
inner_jvp = acc.jvp(primal_out)
outer_jvp = outer_acc.jvp(inner_jvp)
self.assertAllClose(1.1 ** 3.5, primal_out)
self.assertAllClose(3.5 * 1.1 ** 2.5, inner_jvp)
self.assertAllClose(3.5 * 2.5 * 1.1 ** 1.5, outer_jvp)
self.assertIsNone(acc.jvp(outer_acc.jvp(primal_out)))
@test_util.assert_no_new_pyobjects_executing_eagerly
def testJVPPacking(self):
two = constant_op.constant(2.)
primal_in = constant_op.constant(1.)
inner_jvp = constant_op.constant(3.)
with forwardprop.ForwardAccumulator(
[primal_in, inner_jvp],
[constant_op.constant(2.), constant_op.constant(4.)]) as outer_acc:
with forwardprop.ForwardAccumulator(
primal_in, inner_jvp) as inner_acc:
packed_input_indices, packed_input_tangents = (
forwardprop_util.pack_tangents([primal_in]))
self.assertAllClose([3., 2., 4.], packed_input_tangents)
expected_indices = (
# inner_acc watches primal_in
((0, 1),),
# outer_acc watches primal_in and inner_jvp
((0, 2),
(1, 3)))
self.assertAllEqual(expected_indices, packed_input_indices)
primal_out = primal_in * two
self.assertAllClose(6., inner_acc.jvp(primal_out))
self.assertAllClose(4., outer_acc.jvp(primal_out))
self.assertAllClose(8., outer_acc.jvp(inner_acc.jvp(primal_out)))
packed_output_indices, packed_output_tangents = (
forwardprop_util.pack_tangents([primal_out]))
self.assertAllClose([6., 4., 8.], packed_output_tangents)
self.assertAllEqual(expected_indices, packed_output_indices)
def testFunctionGradInFunctionPureForward(self):
@def_function.function
def take_gradients():
@def_function.function
def f(x):
return x ** 3.5
primal = constant_op.constant(1.1)
with forwardprop.ForwardAccumulator(
primal, constant_op.constant(1.)) as outer_acc:
with forwardprop.ForwardAccumulator(
primal, constant_op.constant(1.)) as acc:
primal_out = f(primal)
inner_jvp = acc.jvp(primal_out)
outer_jvp = outer_acc.jvp(inner_jvp)
self.assertIsNone(acc.jvp(outer_acc.jvp(primal_out)))
return primal_out, inner_jvp, outer_jvp
primal_out, inner_jvp, outer_jvp = take_gradients()
self.assertAllClose(1.1 ** 3.5, primal_out)
self.assertAllClose(3.5 * 1.1 ** 2.5, inner_jvp)
self.assertAllClose(3.5 * 2.5 * 1.1 ** 1.5, outer_jvp)
def testFunctionGrad(self):
@def_function.function
def f(x):
return math_ops.reduce_prod(math_ops.tanh(x)**2)
_test_gradients(
self,
f,
[constant_op.constant([1., 2.])],
order=3)
def testReusingJVP(self):
m1 = random_ops.random_uniform((256, 2096))
m2 = array_ops.identity(m1)
tangent1 = random_ops.random_uniform((256, 2096))
tangent2 = random_ops.random_uniform((256, 2096))
matmul = def_function.function(math_ops.matmul)
with forwardprop.ForwardAccumulator(
primals=[m1, m2], tangents=[tangent1, tangent2]) as acc:
result1 = matmul(m1, m1, transpose_b=True)
result2 = matmul(m2, m2, transpose_b=True)
def _expected(mat, tangent):
return (math_ops.matmul(tangent, mat, transpose_b=True)
+ math_ops.matmul(mat, tangent, transpose_b=True))
self.assertAllClose(result1, result2)
self.assertAllClose(_expected(m1, tangent1), acc.jvp(result1))
self.assertAllClose(_expected(m2, tangent2), acc.jvp(result2))
@test_util.assert_no_new_pyobjects_executing_eagerly
def testHVPMemory(self):
def fun(x):
return math_ops.reduce_prod(math_ops.tanh(x)**2)
primals = constant_op.constant([1., 2., 3.])
tangents = constant_op.constant([3., 4., 5.])
_hvp(fun, (primals,), (tangents,))
@test_util.assert_no_new_pyobjects_executing_eagerly
def testHVPCorrectness(self):
def fun(x):
return math_ops.reduce_prod(math_ops.tanh(x)**2)
primals = constant_op.constant([1., 2., 3.])
tangents = constant_op.constant([3., 4., 5.])
forwardback_hvp_eager, = _hvp(fun, (primals,), (tangents,))
forwardback_hvp_function, = def_function.function(_hvp)(fun, (primals,),
(tangents,))
with backprop.GradientTape(persistent=True) as g:
g.watch(primals)
with backprop.GradientTape() as gg:
gg.watch(primals)
out = fun(primals)
grad = array_ops.unstack(gg.gradient(out, primals))
hessian = []
for i in range(3):
hessian.append(g.gradient(grad[i], primals))
hessian = array_ops.stack(hessian, axis=0)
backback_hvp = math_ops.tensordot(hessian, tangents, axes=1)
self.assertAllClose(backback_hvp, forwardback_hvp_eager)
self.assertAllClose(backback_hvp, forwardback_hvp_function)
@test_util.assert_no_new_pyobjects_executing_eagerly
def testShouldRecordAndStopRecord(self):
c = constant_op.constant(1.)
c_tangent = constant_op.constant(2.)
with forwardprop.ForwardAccumulator(c, c_tangent) as acc:
with backprop.GradientTape() as tape:
self.assertFalse(tape_lib.should_record_backprop([c]))
self.assertEqual(1,
pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes([c]))
tape.watch(c)
self.assertEqual(2,
pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes([c]))
self.assertTrue(tape_lib.should_record_backprop([c]))
with tape_lib.stop_recording():
self.assertEqual(0,
pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes([c]))
self.assertFalse(tape_lib.should_record_backprop([c]))
d = c * 2.
self.assertEqual(2,
pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes([c]))
self.assertTrue(tape_lib.should_record_backprop([c]))
self.assertFalse(tape_lib.should_record_backprop([d]))
self.assertIsNone(acc.jvp(d))
self.assertIsNone(tape.gradient(d, c))
@test_util.assert_no_new_pyobjects_executing_eagerly
def testRecordingSelectively(self):
c = constant_op.constant(1.)
c_tangent = constant_op.constant(2.)
with forwardprop.ForwardAccumulator(c, c_tangent) as acc:
with backprop.GradientTape(persistent=True) as tape:
tape.watch(c)
with tape_lib.stop_recording():
two = constant_op.constant(2.)
d = c * two
three = constant_op.constant(3.)
e = c * three
self.assertIsNone(acc.jvp(d))
self.assertIsNone(acc.jvp(e))
self.assertIsNone(tape.gradient(d, c))
self.assertIsNone(tape.gradient(e, c))
tape_lib.record_operation_forwardprop_only(
"CustomForwardMul", [d], [c, two],
lambda dd: (two * dd, c * dd), None)
tape_lib.record_operation_backprop_only(
"CustomBackwardMul", [e], [c, three],
lambda de: (three * de, c * de))
self.assertAllClose(4., acc.jvp(d))
self.assertIsNone(acc.jvp(e))
self.assertIsNone(tape.gradient(d, c))
self.assertAllClose(3., tape.gradient(e, c))
@test_util.assert_no_new_pyobjects_executing_eagerly
def testOpWithNoTrainableOutputs(self):
v = variables.Variable(1.)
with forwardprop.ForwardAccumulator(v, 11.):
v.assign_sub(0.5)
self.assertAllClose(0.5, self.evaluate(v))
# TODO(b/141025187): Add a no_new_pyobjects decorator.
def testVariableReadInFunction(self):
v = variables.Variable(1.)
with forwardprop.ForwardAccumulator(v, 11.) as acc:
@def_function.function
def f():
return v.read_value(), 2. * v.read_value()
result = f()
self.assertAllClose((1.0, 2.), result)
self.assertAllClose((11., 22.), acc.jvp(result))
@parameterized.named_parameters(
[("ForwardPropFirst", True),
("TapeFirst", False)])
def testForwardOverBackwardMemoryEfficiency(self, forward_prop_first):
# Watching depends depends on nesting, not creation order
c = constant_op.constant(1.)
if forward_prop_first:
forward_accumulator = forwardprop.ForwardAccumulator(c, .1)
gradient_tape = backprop.GradientTape()
else:
gradient_tape = backprop.GradientTape()
forward_accumulator = forwardprop.ForwardAccumulator(c, .1)
try:
gc.disable()
with gradient_tape as tape:
# Adding and removing the tape multiple times in different nesting
# patterns does not affect watch ordering.
pass
with forward_accumulator as acc:
with gradient_tape as tape:
tape.watch(c)
d = math_ops.cos(c)
self.assertFalse(tape_lib.should_record_backprop((acc.jvp(d),)))
e = math_ops.cos(acc.jvp(d))
math_ops.cos(e)
weak_e = weakref.ref(e)
del e
self.assertIsNone(weak_e())
self.assertIsNone(tape.gradient(acc.jvp(d), c))
finally:
gc.enable()
@parameterized.named_parameters(
[("ForwardPropFirst", True),
("TapeFirst", False)])
def testBackwardOverForward(self, forward_prop_first):
c = constant_op.constant(1.)
# Watching depends depends on nesting, not creation order
if forward_prop_first:
forward_accumulator = forwardprop.ForwardAccumulator(c, .1)
gradient_tape = backprop.GradientTape()
else:
gradient_tape = backprop.GradientTape()
forward_accumulator = forwardprop.ForwardAccumulator(c, .1)
with gradient_tape as tape:
with forward_accumulator as acc:
tape.watch(c)
d = math_ops.cos(c)
self.assertTrue(tape_lib.should_record_backprop((acc.jvp(d),)))
self.assertAllClose(-.1 * math_ops.cos(1.),
tape.gradient(acc.jvp(d), c))
@test_util.assert_no_new_pyobjects_executing_eagerly
def testRecordingWithJVPIndices(self):
c = constant_op.constant(1.)
with forwardprop.ForwardAccumulator(c, 10.) as acc:
packed_input_tangents = forwardprop_util.pack_tangents([c]).tangents
self.assertAllClose([10.], packed_input_tangents)
d = constant_op.constant(2.)
d_tangent = constant_op.constant(3.)
tape_lib.record_operation_forwardprop_only(
"FunctionWithInlineJVPs",
[d] + [d_tangent],
[c] + packed_input_tangents,
None, (((0, 1),),))
self.assertAllClose(3., acc.jvp(d))
@test_util.assert_no_new_pyobjects_executing_eagerly
def testSpecialForwardFunctionUsed(self):
c = constant_op.constant(1.)
d = constant_op.constant(2.)
e = constant_op.constant(3.)
with forwardprop.ForwardAccumulator(c, 10.) as acc:
tape_lib.record_operation(
"ForwardIsSpecial",
[d], [c],
None, lambda jvp: [-2. * jvp])
self.assertAllClose(-20., acc.jvp(d))
tape_lib.record_operation(
"ForwardIsSpecial2",
[], [],
None, lambda: [])
tape_lib.record_operation(
"ForwardIsSpecial3",
[e], [d],
None, lambda x: [x])
self.assertAllClose(-20., acc.jvp(e))
@test_util.assert_no_new_pyobjects_executing_eagerly
def testVariableWatched(self):
v = variables.Variable([1., 2., 3.])
with forwardprop.ForwardAccumulator(
v, constant_op.constant([.1, -.2, .3])) as acc:
self.assertAllClose([.1, -.2, .3], acc.jvp(v))
x = v * 2.
self.assertAllClose([.2, -.4, .6], acc.jvp(x))
x2 = v + .1
self.assertAllClose([.1, -.2, .3], acc.jvp(x2))
def testUnconnectedGradients(self):
x = constant_op.constant(-1.)
with forwardprop.ForwardAccumulator(x, 0.1) as acc:
self.assertAllClose(0.1, acc.jvp(x, unconnected_gradients="zero"))
self.assertAllClose(0.1, acc.jvp(x, unconnected_gradients="none"))
y = constant_op.constant(-2.)
self.assertAllClose(0.0, acc.jvp(y, unconnected_gradients="zero"))
self.assertIsNone(acc.jvp(y, unconnected_gradients="none"))
# TODO(kkb): One weakref instance is created with warmup_iters=2,
# investigate.
@test_util.assert_no_new_pyobjects_executing_eagerly(warmup_iters=3)
def testVariableWatchedFunction(self):
class _Model(module.Module):
def __init__(self):
self._v = None
@def_function.function
def compute_jvps(self):
if self._v is None:
self._v = variables.Variable([1., 2., 3.])
with forwardprop.ForwardAccumulator(
self._v, constant_op.constant([.1, -.2, .3])) as acc:
x = self._v * 2.
x2 = self._v + .1
return acc.jvp((self._v, x, x2))
model = _Model()
v_jvp, x_jvp, x2_jvp = model.compute_jvps()
self.assertAllClose([.1, -.2, .3], v_jvp)
self.assertAllClose([.2, -.4, .6], x_jvp)
self.assertAllClose([.1, -.2, .3], x2_jvp)
# NOTE: assert_no_new_pyobjects_executing_eagerly fails flakily on this
# test... could be something wrong with the test decorator, or some sort of
# nondeterministic caching.
def testMirroredVariableWatched(self):
def _replicated(input_tangent):
with forwardprop.ForwardAccumulator(v, input_tangent) as acc:
self.assertAllClose([.1, -.2, .3], acc.jvp(v))
x = v * 2.
self.assertAllClose([.2, -.4, .6], acc.jvp(x))
x2 = v + .1
self.assertAllClose([.1, -.2, .3], acc.jvp(x2))
strategy = mirrored_strategy.MirroredStrategy()
with strategy.scope():
v = variables.Variable([1., 2., 3.])
strategy.run(_replicated, args=(constant_op.constant([.1, -.2, .3]),))
# TODO(b/141025187): Add a no_new_pyobjects decorator.
def testArgumentUnused(self):
v = constant_op.constant(1.)
with forwardprop.ForwardAccumulator(v, 11.) as acc:
@def_function.function
def _f(x):
del x
return constant_op.constant(1.)
result = _f(v)
self.assertAllClose(1.0, result)
self.assertIsNone(acc.jvp(result))
@def_function.function
def _has_loop(iters, y):
ret = 0.
for i in math_ops.range(iters):
ret += y * math_ops.cast(i, dtypes.float32)
return ret
@def_function.function
def _has_cond(k, y):
if k > 1:
ret = 3. * y
else:
ret = 0.
return ret
@def_function.function
def _fprop_while(iters, y):
with forwardprop.ForwardAccumulator(y, 1.) as acc:
ret = 0.
for i in math_ops.range(iters):
ret += y * math_ops.cast(i, dtypes.float32)
return acc.jvp(ret)
@def_function.function
def _fprop_cond(k, y):
with forwardprop.ForwardAccumulator(y, 1.) as acc:
if k > 1:
ret = 3. * y
else:
ret = 0.
return acc.jvp(ret)
class ControlFlowTests(test.TestCase):
@test_util.assert_no_new_pyobjects_executing_eagerly
def testOfFunctionWhile(self):
y = constant_op.constant(1.)
with forwardprop.ForwardAccumulator(y, 1.) as acc:
self.assertAllClose(
10., acc.jvp(_has_loop(constant_op.constant(5), y)))
@test_util.assert_no_new_pyobjects_executing_eagerly
def testOfFunctionCond(self):
y = constant_op.constant(1.)
with forwardprop.ForwardAccumulator(y, 1.) as acc:
self.assertAllClose(
3., acc.jvp(_has_cond(constant_op.constant(5), y)))
self.assertAllClose(
0., acc.jvp(_has_cond(constant_op.constant(0), y)))
@test_util.assert_no_new_pyobjects_executing_eagerly
def testInFunctionWhile(self):
self.assertAllClose(
10., _fprop_while(constant_op.constant(5), constant_op.constant(1.)))
@test_util.assert_no_new_pyobjects_executing_eagerly
def testInFunctionCond(self):
self.assertAllClose(
3., _fprop_cond(constant_op.constant(5), constant_op.constant(1.)))
self.assertAllClose(
0., _fprop_cond(constant_op.constant(0), constant_op.constant(1.)))
class HessianTests(test.TestCase, parameterized.TestCase):
def testHessian1D(self):
# Note: stolen from ops/gradients_test.py
m = 4
rng = np.random.RandomState([1, 2, 3])
mat_value = rng.randn(m, m).astype("float32")
x_value = rng.randn(m).astype("float32")
hess_value = mat_value + mat_value.T
mat = variables.Variable(mat_value)
def _f(x):
return math_ops.reduce_sum(x[:, None] * mat * x[None, :])
hessian_eager, = _forward_over_back_hessian(
_f, [constant_op.constant(x_value)],
use_pfor=False, dtype=[dtypes.float32])
self.assertAllClose(hess_value, hessian_eager)
hessian_function, = def_function.function(_forward_over_back_hessian)(
_f, [constant_op.constant(x_value)],
use_pfor=False, dtype=[dtypes.float32])
self.assertAllClose(hess_value, hessian_function)
hessian_pfor, = def_function.function(_forward_over_back_hessian)(
_f, [constant_op.constant(x_value)],
use_pfor=True, dtype=[dtypes.float32])
self.assertAllClose(hess_value, hessian_pfor)
class BatchTests(test.TestCase, parameterized.TestCase):
@parameterized.parameters([(math_ops.sin, (2, 3), 5),
(math_ops.sin, (2, 3, 4), 10)])
def testJVPBatchCorrectness(self, f, primal_shape, batch_size):
primals = [random_ops.random_uniform(primal_shape)]
tangent_batch = [random_ops.random_uniform([batch_size, *primal_shape])]
self.assertAllClose(
_jvp_batch(f, primals, tangent_batch)[1],
_jvp_batch_matmul(f, primals, *tangent_batch))
def testBatchCorrectness(self):
x = constant_op.constant(2.0)
y = constant_op.constant(5.0)
tangents = (
constant_op.constant([1., 0., 1.]),
constant_op.constant([0., 1., 1.]),
)
with forwardprop.ForwardAccumulator((x, y), tangents, True) as acc:
z = x * y
self.assertAllClose(
acc.jvp(z),
constant_op.constant([5.0, 2.0, 7.0]
))
if __name__ == "__main__":
# TODO(allenl): Also test with 1.x-style graph mode.
ops.enable_eager_execution()
test.main()