blob: cf49aa2fde9d32a10af0918b31b22f482064e264 [file] [log] [blame]
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Basic tests for gradients."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import tape
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import custom_gradient
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
# Importing nn_grad for the registration functions.
from tensorflow.python.ops import nn_grad # pylint: disable=unused-import
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import variables
@custom_gradient.custom_gradient
def two_outputs(a, b):
mm = math_ops.matmul(a, b)
r = math_ops.reduce_sum(mm)
def grad(dmm, dr):
return [
math_ops.matmul(dmm, b, transpose_b=True) +
math_ops.matmul(array_ops.ones_like(b * dr), b, transpose_b=True),
math_ops.matmul(a, dmm, transpose_b=True) +
math_ops.matmul(a, array_ops.ones_like(a) * dr, transpose_b=True)
]
return [mm, r], grad
@custom_gradient.custom_gradient
def gradient_is_constant(x):
result = x * x
def grad(dr):
return [dr]
return result, grad
class TapeTest(test.TestCase):
def testMultiOutput(self):
def fn(x, y):
c = x + y
# Multiple outputs from split.
d, f = array_ops.split(c, 2)
return d + f
a = constant_op.constant([[1., 0.], [0., 1.]])
b = constant_op.constant([[1., 2.], [3., 4.]])
da, db = backprop.gradients_function(fn, [0, 1])(a, b)
with context.graph_mode(), self.cached_session():
tf_a = constant_op.constant([[1, 0], [0, 1]], dtype=dtypes.float32)
tf_b = constant_op.constant([[1, 2], [3, 4]], dtype=dtypes.float32)
tf_c = tf_a + tf_b
tf_d, tf_f = array_ops.split(tf_c, 2, axis=1)
tf_e = tf_d + tf_f
tf_da, tf_db = gradients_impl.gradients(tf_e, [tf_a, tf_b])
self.assertAllEqual(da, self.evaluate(tf_da))
self.assertAllEqual(db, self.evaluate(tf_db))
def testBasicFunctional(self):
def forward(a, b):
mm = math_ops.matmul(a, b)
return math_ops.reduce_sum(mm)
aa = constant_op.constant([[1., 0.], [0., 1.]])
bb = constant_op.constant([[1., 2.], [3., 4.]])
da, = backprop.gradients_function(forward, ['a'])(aa, bb)
self.assertAllEqual(da,
math_ops.matmul(
array_ops.ones_like(aa),
array_ops.transpose(bb)).numpy())
def testBasicFunctionalPositionalArg(self):
def forward(a, b):
mm = math_ops.matmul(a, b)
return math_ops.reduce_sum(mm)
aa = constant_op.constant([[1., 0.], [0., 1.]])
bb = constant_op.constant([[1., 2.], [3., 4.]])
da, = backprop.gradients_function(forward, [0])(aa, bb)
self.assertAllEqual(da,
math_ops.matmul(
array_ops.ones_like(aa),
array_ops.transpose(bb)).numpy())
def testBasicFunctionalWithValue(self):
def forward(a, b):
mm = math_ops.matmul(a, b)
return math_ops.reduce_sum(mm)
aa = constant_op.constant([[1., 0.], [0., 1.]])
bb = constant_op.constant([[1., 2.], [3., 4.]])
val, (da,) = backprop.val_and_grad_function(forward, ['a'])(aa, bb)
self.assertAllEqual(da,
math_ops.matmul(
array_ops.ones_like(aa),
array_ops.transpose(bb)))
self.assertAllEqual(val, forward(aa, bb))
def testTwoOutputs(self):
def fn(x, y):
mm, r = two_outputs(x, y)
return r + math_ops.reduce_sum(mm)
a = constant_op.constant([[1., 0.], [0., 1.]])
b = constant_op.constant([[1., 2.], [3., 4.]])
da, db = backprop.gradients_function(fn, [0, 1])(a, b)
with context.graph_mode(), self.cached_session():
tf_a = constant_op.constant([[1, 0], [0, 1]], dtype=dtypes.float32)
tf_b = constant_op.constant([[1, 2], [3, 4]], dtype=dtypes.float32)
tf_mm = math_ops.matmul(tf_a, tf_b)
tf_rr = 2 * math_ops.reduce_sum(tf_mm)
tf_da, tf_db = gradients_impl.gradients(tf_rr, [tf_a, tf_b])
self.assertAllEqual(da, self.evaluate(tf_da))
self.assertAllEqual(db, self.evaluate(tf_db))
def testGcTwoOutputs(self):
def fn(x, y):
return nn_ops.sparse_softmax_cross_entropy_with_logits(logits=x,
labels=y)[0]
labels = constant_op.constant([0])
logits = constant_op.constant([[0.0]])
grad, = backprop.gradients_function(fn, [0])(logits, labels)
self.assertAllEqual(grad, [[0.0]])
def testTfTensor(self):
def fn(x):
return x
t = constant_op.constant(1.0)
g, = backprop.gradients_function(fn, [0])(t)
self.assertAllEqual(g, 1.0)
class VariableWatcherTest(test.TestCase):
def testBasic(self):
var1 = variables.Variable(0.0)
var2 = variables.Variable(1.0)
with tape.VariableWatcher() as variable_watcher:
var1.assign_add(1.0)
var2.assign_add(2.0)
self.assertAllEqual(variable_watcher.watched_variables(), (var1, var2))
def testNonTrainableVariables(self):
var1 = variables.Variable(0.0)
var2 = variables.Variable(1.0, trainable=False)
with tape.VariableWatcher() as variable_watcher:
var1.assign_add(1.0)
var2.assign_add(2.0)
self.assertAllEqual(variable_watcher.watched_variables(), (var1,))
def testMultipleScopes(self):
var1 = variables.Variable(0.0)
var2 = variables.Variable(1.0)
with tape.VariableWatcher() as variable_watcher1:
var1.assign_add(1.0)
with tape.VariableWatcher() as variable_watcher2:
var2.assign_add(2.0)
# variable_watcher1 should see both vars and variable_watcher2 only sees
# var2
self.assertAllEqual(variable_watcher1.watched_variables(), (var1, var2))
self.assertAllEqual(variable_watcher2.watched_variables(), (var2,))
def testCreateVariables(self):
with tape.VariableWatcher() as variable_watcher:
var1 = variables.Variable(0.0)
var2 = variables.Variable(1.0)
var1.assign_add(1.0)
var2.assign_add(2.0)
self.assertAllEqual(variable_watcher.watched_variables(), (var1, var2))
if __name__ == '__main__':
test.main()