blob: 31e5895fd0bd43a9666185287932d27a381dc069 [file] [log] [blame]
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Gradients for operators defined in math_ops.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python.compat import compat
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import math_ops
def _safe_shape_div(x, y):
"""Divides `x / y` assuming `x, y >= 0`, treating `0 / 0 = 0`."""
return x // math_ops.maximum(y, 1)
@ops.RegisterGradient("ArgMax")
def _ArgMaxGrad(op, grad):
del op, grad
return [None, None]
@ops.RegisterGradient("ArgMin")
def _ArgMinGrad(op, grad):
del op, grad
return [None, None]
# TODO(rmlarsen): Implement gradient.
ops.NotDifferentiable("EuclideanNorm")
_empty_tuple = ()
def _IsScalar(x):
return x._shape_tuple() is _empty_tuple # pylint: disable=protected-access
@ops.RegisterGradient("Sum")
def _SumGrad(op, grad):
"""Gradient for Sum."""
# Fast path for when reducing to a scalar and ndims is known: adds only
# Reshape and Tile ops (and possibly a Shape).
input_0_shape = op.inputs[0]._shape_tuple() # pylint: disable=protected-access
if input_0_shape is not None:
axes = tensor_util.constant_value(op.inputs[1])
if axes is not None:
rank = len(input_0_shape)
if np.array_equal(axes, np.arange(rank)): # Reduce all dims.
if context.executing_eagerly():
ctx = context.context()
new_shape = ctx.ones_rank_cache().get(rank)
if new_shape is None:
new_shape = constant_op.constant([1] * rank, dtype=dtypes.int32)
ctx.ones_rank_cache().put(rank, new_shape)
else:
new_shape = [1] * rank
grad = array_ops.reshape(grad, new_shape)
# If shape is not fully defined (but rank is), we use Shape.
if None not in input_0_shape:
input_shape = constant_op.constant(input_0_shape, dtype=dtypes.int32)
else:
input_shape = array_ops.shape(op.inputs[0])
return [array_ops.tile(grad, input_shape), None]
input_shape = array_ops.shape(op.inputs[0])
# TODO(apassos) remove this once device placement for eager ops makes more
# sense.
with ops.colocate_with(input_shape):
output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1])
tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims)
grad = array_ops.reshape(grad, output_shape_kept_dims)
return [array_ops.tile(grad, tile_scaling), None]
def _MinOrMaxGrad(op, grad):
"""Gradient for Min or Max. Amazingly it's precisely the same code."""
input_shape = array_ops.shape(op.inputs[0])
output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1])
y = op.outputs[0]
y = array_ops.reshape(y, output_shape_kept_dims)
grad = array_ops.reshape(grad, output_shape_kept_dims)
# Compute the number of selected (maximum or minimum) elements in each
# reduction dimension. If there are multiple minimum or maximum elements
# then the gradient will be divided between them.
indicators = math_ops.cast(math_ops.equal(y, op.inputs[0]), grad.dtype)
num_selected = array_ops.reshape(
math_ops.reduce_sum(indicators, op.inputs[1]), output_shape_kept_dims)
return [math_ops.divide(indicators, num_selected) * grad, None]
@ops.RegisterGradient("Max")
def _MaxGrad(op, grad):
"""Gradient for Max."""
return _MinOrMaxGrad(op, grad)
@ops.RegisterGradient("Min")
def _MinGrad(op, grad):
return _MinOrMaxGrad(op, grad)
@ops.RegisterGradient("Mean")
def _MeanGrad(op, grad):
"""Gradient for Mean."""
sum_grad = _SumGrad(op, grad)[0]
input_shape = op.inputs[0]._shape_tuple() # pylint: disable=protected-access
output_shape = op.outputs[0]._shape_tuple() # pylint: disable=protected-access
if (input_shape is not None and output_shape is not None and
None not in input_shape and None not in output_shape):
input_size = np.prod(input_shape)
output_size = np.prod(output_shape)
factor = input_size // max(output_size, 1)
factor = constant_op.constant(factor, dtype=sum_grad.dtype)
else:
input_shape = array_ops.shape(op.inputs[0])
output_shape = array_ops.shape(op.outputs[0])
factor = _safe_shape_div(
math_ops.reduce_prod(input_shape), math_ops.reduce_prod(output_shape))
return math_ops.truediv(sum_grad, math_ops.cast(factor, sum_grad.dtype)), None
@ops.RegisterGradient("Prod")
def _ProdGrad(op, grad):
"""Gradient for Prod."""
# The gradient can be expressed by dividing the product by each entry of the
# input tensor, but this approach can't deal with zeros in the input.
# Here, we avoid this problem by composing the output as a product of two
# cumprod operations.
input_shape = array_ops.shape(op.inputs[0])
# Reshape reduction indices for the case where the parameter is a scalar
reduction_indices = array_ops.reshape(op.inputs[1], [-1])
# Expand grad to full input shape
output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1])
tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims)
grad = array_ops.reshape(grad, output_shape_kept_dims)
grad = array_ops.tile(grad, tile_scaling)
# Pack all reduced dimensions into a single one, so we can perform the
# cumprod ops. If the reduction dims list is empty, it defaults to float32,
# so we need to cast here. We put all the shape-related ops on CPU to avoid
# copying back and forth, and since listdiff is CPU only.
with ops.device("/cpu:0"):
rank = array_ops.rank(op.inputs[0])
reduction_indices = (reduction_indices + rank) % rank
reduced = math_ops.cast(reduction_indices, dtypes.int32)
idx = math_ops.range(0, rank)
other, _ = array_ops.setdiff1d(idx, reduced)
perm = array_ops.concat([reduced, other], 0)
reduced_num = math_ops.reduce_prod(array_ops.gather(input_shape, reduced))
other_num = math_ops.reduce_prod(array_ops.gather(input_shape, other))
permuted = array_ops.transpose(op.inputs[0], perm)
permuted_shape = array_ops.shape(permuted)
reshaped = array_ops.reshape(permuted, (reduced_num, other_num))
# Calculate product, leaving out the current entry
left = math_ops.cumprod(reshaped, axis=0, exclusive=True)
right = math_ops.cumprod(reshaped, axis=0, exclusive=True, reverse=True)
# For complex inputs, the gradient is in the conjugate direction.
y = array_ops.reshape(
math_ops.conj(left) * math_ops.conj(right), permuted_shape)
# Invert the transpose and reshape operations.
# Make sure to set the statically known shape information through a reshape.
out = grad * array_ops.transpose(y, array_ops.invert_permutation(perm))
return array_ops.reshape(out, input_shape), None
@ops.RegisterGradient("SegmentSum")
def _SegmentSumGrad(op, grad):
"""Gradient for SegmentSum."""
return array_ops.gather(grad, op.inputs[1]), None
@ops.RegisterGradient("SegmentMean")
def _SegmentMeanGrad(op, grad):
"""Gradient for SegmentMean."""
input_rank = array_ops.rank(op.inputs[0])
ones_shape = array_ops.concat([
array_ops.shape(op.inputs[1]),
array_ops.fill(array_ops.expand_dims(input_rank - 1, 0), 1)
], 0)
ones = array_ops.fill(ones_shape, constant_op.constant(1, dtype=grad.dtype))
scaled_grad = math_ops.divide(grad, math_ops.segment_sum(ones, op.inputs[1]))
return array_ops.gather(scaled_grad, op.inputs[1]), None
@ops.RegisterGradient("SparseSegmentSum")
def _SparseSegmentSumGrad(op, grad):
"""Gradient for SparseSegmentSum."""
input_rows = array_ops.shape(op.inputs[0])[0]
return (math_ops.unsorted_segment_sum(
array_ops.gather(grad, op.inputs[2]), op.inputs[1], input_rows), None,
None)
@ops.RegisterGradient("SparseSegmentSumWithNumSegments")
def _SparseSegmentSumWithNumSegmentsGrad(op, grad):
"""Gradient for SparseSegmentSumWithNumSegments."""
input_rows = array_ops.shape(op.inputs[0])[0]
return (math_ops.unsorted_segment_sum(
array_ops.gather(grad, op.inputs[2]), op.inputs[1], input_rows), None,
None, None)
@ops.RegisterGradient("SparseSegmentMean")
def _SparseSegmentMeanGrad(op, grad):
"""Gradient for SparseSegmentMean."""
dim0 = array_ops.shape(op.inputs[0])[0]
return (math_ops.sparse_segment_mean_grad(grad, op.inputs[1], op.inputs[2],
dim0), None, None)
@ops.RegisterGradient("SparseSegmentMeanWithNumSegments")
def _SparseSegmentMeanWithNumSegmentsGrad(op, grad):
"""Gradient for SparseSegmentMeanWithNumSegments."""
dim0 = array_ops.shape(op.inputs[0])[0]
return (math_ops.sparse_segment_mean_grad(grad, op.inputs[1], op.inputs[2],
dim0), None, None, None)
@ops.RegisterGradient("SparseSegmentSqrtN")
def _SparseSegmentSqrtNGrad(op, grad):
"""Gradient for SparseSegmentSqrtN."""
dim0 = array_ops.shape(op.inputs[0])[0]
return (math_ops.sparse_segment_sqrt_n_grad(grad, op.inputs[1], op.inputs[2],
dim0), None, None)
@ops.RegisterGradient("SparseSegmentSqrtNWithNumSegments")
def _SparseSegmentSqrtNWithNumSegmentsGrad(op, grad):
"""Gradient for SparseSegmentSqrtNWithNumSegments."""
dim0 = array_ops.shape(op.inputs[0])[0]
return (math_ops.sparse_segment_sqrt_n_grad(grad, op.inputs[1], op.inputs[2],
dim0), None, None, None)
def _SegmentMinOrMaxGrad(op, grad):
""" Gradient for SegmentMin and SegmentMax. """
zeros = array_ops.zeros_like(op.inputs[0], dtype=op.inputs[0].dtype)
# Get the number of selected (minimum or maximum) elements in each segment.
gathered_outputs = array_ops.gather(op.outputs[0], op.inputs[1])
is_selected = math_ops.equal(op.inputs[0], gathered_outputs)
num_selected = math_ops.segment_sum(
math_ops.cast(is_selected, grad.dtype), op.inputs[1])
# Compute the gradient for each segment. The gradient for the ith segment is
# divided evenly among the selected elements in that segment.
weighted_grads = math_ops.divide(grad, num_selected)
gathered_grads = array_ops.gather(weighted_grads, op.inputs[1])
return array_ops.where(is_selected, gathered_grads, zeros), None
@ops.RegisterGradient("SegmentMin")
def _SegmentMinGrad(op, grad):
"""Gradient for SegmentMin."""
return _SegmentMinOrMaxGrad(op, grad)
@ops.RegisterGradient("SegmentMax")
def _SegmentMaxGrad(op, grad):
"""Gradient for SegmentMax."""
return _SegmentMinOrMaxGrad(op, grad)
def _GatherDropNegatives(params,
ids,
zero_clipped_indices=None,
is_positive=None):
""" Helper function for unsorted segment ops.
Gathers params for
positive segment ids and gathers 0 for inputs with negative segment id.
Also returns the clipped indices and a boolean mask with the same shape
as ids where a positive id is masked as true. With this, the latter two
can be passed as arguments to this function to reuse them.
"""
if zero_clipped_indices is None:
zero_clipped_indices = math_ops.maximum(ids, array_ops.zeros_like(ids))
gathered = array_ops.gather(params, zero_clipped_indices)
if is_positive is None:
is_positive = math_ops.greater_equal(ids, 0)
# tf.where(condition, x, y) requires condition to have the same shape as x
# and y.
# todo(philjd): remove this if tf.where supports broadcasting (#9284)
for _ in range(gathered.shape.ndims - is_positive.shape.ndims):
is_positive = array_ops.expand_dims(is_positive, -1)
is_positive = (
is_positive & array_ops.ones_like(gathered, dtype=dtypes.bool))
# replace gathered params of negative indices with 0
zero_slice = array_ops.zeros_like(gathered)
return (array_ops.where(is_positive, gathered, zero_slice),
zero_clipped_indices, is_positive)
def _UnsortedSegmentMinOrMaxGrad(op, grad):
""" Gradient for UnsortedSegmentMin and UnsortedSegmentMax. """
# Get the number of selected (minimum or maximum) elements in each segment.
gathered_outputs, zero_clipped_indices, is_positive = \
_GatherDropNegatives(op.outputs[0], op.inputs[1])
is_selected = math_ops.equal(op.inputs[0], gathered_outputs)
is_selected = math_ops.logical_and(is_selected, is_positive)
num_selected = math_ops.unsorted_segment_sum(
math_ops.cast(is_selected, grad.dtype), op.inputs[1], op.inputs[2])
# Compute the gradient for each segment. The gradient for the ith segment is
# divided evenly among the selected elements in that segment.
weighted_grads = math_ops.divide(grad, num_selected)
gathered_grads, _, _ = _GatherDropNegatives(weighted_grads, None,
zero_clipped_indices, is_positive)
zeros = array_ops.zeros_like(gathered_grads)
return array_ops.where(is_selected, gathered_grads, zeros), None, None
@ops.RegisterGradient("UnsortedSegmentSum")
def _UnsortedSegmentSumGrad(op, grad):
"""Gradient for UnsortedSegmentSum."""
return _GatherDropNegatives(grad, op.inputs[1])[0], None, None
@ops.RegisterGradient("UnsortedSegmentMax")
def _UnsortedSegmentMaxGrad(op, grad):
""" Gradient for UnsortedSegmentMax. """
return _UnsortedSegmentMinOrMaxGrad(op, grad)
@ops.RegisterGradient("UnsortedSegmentMin")
def _UnsortedSegmentMinGrad(op, grad):
""" Gradient for UnsortedSegmentMin. """
return _UnsortedSegmentMinOrMaxGrad(op, grad)
@ops.RegisterGradient("UnsortedSegmentProd")
def _UnsortedSegmentProdGrad(op, grad):
""" Gradient for UnsortedSegmentProd.
The gradient can be expressed for each segment by dividing the segment's
product by each element of the segment input tensor, but this approach can't
deal with zeros in the input.
Unlike reduce_prod we can't use cumsum here as individual segments may have
a different number of elements. Therefore we consider three cases:
1) A segment input contains no zeros and we can safely divide by the input
tensor.
2) A segment contains exactly one zero. Then the gradient of each input of
the segment is zero except for the 0-input, there the gradient is
the product of the remaining segment entries.
3) A segment contains at least two zeros. The gradient is zero for all
segment inputs.
"""
# Note that unsorted_segment_sum will filter out the negative indices,
# so we don't need to do a logical_and with is_positive here
is_zero = math_ops.equal(op.inputs[0], 0)
num_zeros = gen_math_ops.unsorted_segment_sum(
math_ops.cast(is_zero, dtype=dtypes.int32), op.inputs[1], op.inputs[2])
# handle case 3 and set the gradient to 0 for segments with more than one
# 0 as input
grad = array_ops.where(
math_ops.greater(num_zeros, 1), array_ops.zeros_like(grad), grad)
# replace all zeros with ones and compute the unsorted_segment_prod
non_zero_data = array_ops.where(is_zero, array_ops.ones_like(op.inputs[0]),
op.inputs[0])
non_zero_prod = gen_math_ops.unsorted_segment_prod(non_zero_data,
op.inputs[1], op.inputs[2])
# clip the indices for gather to be positive
zero_clipped_indices = math_ops.maximum(op.inputs[1],
array_ops.zeros_like(op.inputs[1]))
gathered_prod = array_ops.gather(op.outputs[0], zero_clipped_indices)
gathered_non_zero_prod = array_ops.gather(non_zero_prod, zero_clipped_indices)
prod_divided_by_el = gathered_prod / op.inputs[0] # May contain nan/inf.
# Now fetch the individual results for segments containing 0 and those that
# don't. is_zero will also fetch results for entries with negative index
# but the following gather_drop_negatives sets the corresponding entry in
# grad to 0 for these
partial_derivative = array_ops.where(is_zero, gathered_non_zero_prod,
prod_divided_by_el)
gathered_grad = _GatherDropNegatives(grad, op.inputs[1],
zero_clipped_indices)[0]
return gathered_grad * partial_derivative, None, None
@ops.RegisterGradient("Abs")
def _AbsGrad(op, grad):
x = op.inputs[0]
return grad * math_ops.sign(x)
@ops.RegisterGradient("Neg")
def _NegGrad(_, grad):
"""Returns -grad."""
return -grad
@ops.RegisterGradient("Inv")
def _InvGrad(op, grad):
"""Returns -grad * (1 / x^2)."""
y = op.outputs[0] # y = 1 / x
return gen_math_ops.reciprocal_grad(y, grad)
@ops.RegisterGradient("Reciprocal")
def _ReciprocalGrad(op, grad):
"""Returns -grad * (1 / x^2)."""
y = op.outputs[0] # y = 1 / x
return gen_math_ops.reciprocal_grad(y, grad)
@ops.RegisterGradient("InvGrad")
def _InvGradGrad(op, grad):
b = op.inputs[1]
# op.output[0]: y = -b * conj(a)^2
with ops.control_dependencies([grad]):
ca = math_ops.conj(op.inputs[0])
cg = math_ops.conj(grad)
return cg * -2.0 * b * ca, gen_math_ops.reciprocal_grad(ca, grad)
@ops.RegisterGradient("ReciprocalGrad")
def _ReciprocalGradGrad(op, grad):
b = op.inputs[1]
# op.output[0]: y = -b * conj(a)^2
with ops.control_dependencies([grad]):
ca = math_ops.conj(op.inputs[0])
cg = math_ops.conj(grad)
return cg * -2.0 * b * ca, gen_math_ops.reciprocal_grad(ca, grad)
@ops.RegisterGradient("Square")
def _SquareGrad(op, grad):
x = op.inputs[0]
# Added control dependencies to prevent 2*x from being computed too early.
with ops.control_dependencies([grad]):
x = math_ops.conj(x)
y = constant_op.constant(2.0, dtype=x.dtype)
return math_ops.multiply(grad, math_ops.multiply(x, y))
@ops.RegisterGradient("Sqrt")
def _SqrtGrad(op, grad):
y = op.outputs[0] # y = x^(1/2)
return gen_math_ops.sqrt_grad(y, grad)
@ops.RegisterGradient("SqrtGrad")
def _SqrtGradGrad(op, grad):
a = op.inputs[0]
y = op.outputs[0] # y = 0.5 * b / conj(a)
with ops.control_dependencies([grad]):
if compat.forward_compatible(2019, 9, 14):
ga = gen_math_ops.xdivy(grad, a)
return -gen_math_ops.mul_no_nan(y, math_ops.conj(ga)), 0.5 * ga
else:
ga = grad / a
return -math_ops.conj(ga) * y, 0.5 * ga
@ops.RegisterGradient("Rsqrt")
def _RsqrtGrad(op, grad):
"""Returns -0.5 * grad * conj(y)^3."""
y = op.outputs[0] # y = x^(-1/2)
return gen_math_ops.rsqrt_grad(y, grad)
@ops.RegisterGradient("RsqrtGrad")
def _RsqrtGradGrad(op, grad):
"""Returns backprop gradient for f(a,b) = -0.5 * b * conj(a)^3."""
a = op.inputs[0] # a = x^{-1/2}
b = op.inputs[1] # backprop gradient for a
with ops.control_dependencies([grad]):
ca = math_ops.conj(a)
cg = math_ops.conj(grad)
grad_a = -1.5 * cg * b * math_ops.square(ca)
grad_b = gen_math_ops.rsqrt_grad(ca, grad)
return grad_a, grad_b
@ops.RegisterGradient("Exp")
def _ExpGrad(op, grad):
"""Returns grad * exp(x)."""
y = op.outputs[0] # y = e^x
with ops.control_dependencies([grad]):
y = math_ops.conj(y)
if compat.forward_compatible(2019, 9, 14):
return math_ops.mul_no_nan(y, grad)
else:
return grad * y
@ops.RegisterGradient("Expm1")
def _Expm1Grad(op, grad):
"""Returns grad * exp(x)."""
x = op.inputs[0]
with ops.control_dependencies([grad]):
x = math_ops.conj(x)
y = math_ops.exp(x)
if compat.forward_compatible(2019, 9, 14):
return math_ops.mul_no_nan(y, grad)
else:
return grad * y
@ops.RegisterGradient("Log")
def _LogGrad(op, grad):
"""Returns grad * (1/x)."""
x = op.inputs[0]
with ops.control_dependencies([grad]):
x = math_ops.conj(x)
if compat.forward_compatible(2019, 9, 14):
return gen_math_ops.xdivy(grad, x)
else:
return grad * math_ops.reciprocal(x)
@ops.RegisterGradient("Log1p")
def _Log1pGrad(op, grad):
"""Returns grad * (1/(1 + x))."""
x = op.inputs[0]
with ops.control_dependencies([grad]):
x = math_ops.conj(x)
if compat.forward_compatible(2019, 9, 14):
return gen_math_ops.xdivy(grad, 1 + x)
else:
return grad * math_ops.reciprocal(1 + x)
@ops.RegisterGradient("Xlogy")
def _XLogyGrad(op, grad):
"""Returns gradient of xlogy(x, y) with respect to x and y."""
x = op.inputs[0]
y = op.inputs[1]
sx = array_ops.shape(x)
sy = array_ops.shape(y)
rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
with ops.control_dependencies([grad]):
not_zero_x = math_ops.cast(
math_ops.not_equal(x, math_ops.cast(0., dtype=x.dtype)), dtype=x.dtype)
partial_x = gen_math_ops.xlogy(not_zero_x, y)
partial_y = gen_math_ops.xdivy(x, y)
return (array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx),
array_ops.reshape(math_ops.reduce_sum(partial_y * grad, ry), sy))
@ops.RegisterGradient("Xdivy")
def _XDivyGrad(op, grad):
"""Returns gradient of xdivy(x, y) with respect to x and y."""
x = op.inputs[0]
y = op.inputs[1]
sx = array_ops.shape(x)
sy = array_ops.shape(y)
rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
with ops.control_dependencies([grad]):
not_zero_x = math_ops.cast(
math_ops.not_equal(x, math_ops.cast(0., dtype=x.dtype)), dtype=x.dtype)
partial_x = gen_math_ops.xdivy(not_zero_x, y)
partial_y = gen_math_ops.xdivy(math_ops.negative(x), y**2)
return (array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx),
array_ops.reshape(math_ops.reduce_sum(partial_y * grad, ry), sy))
@ops.RegisterGradient("Sinh")
def _SinhGrad(op, grad):
"""Returns grad * cosh(x)."""
x = op.inputs[0]
with ops.control_dependencies([grad]):
x = math_ops.conj(x)
return grad * math_ops.cosh(x)
@ops.RegisterGradient("Cosh")
def _CoshGrad(op, grad):
"""Returns grad * sinh(x)."""
x = op.inputs[0]
with ops.control_dependencies([grad]):
x = math_ops.conj(x)
return grad * math_ops.sinh(x)
@ops.RegisterGradient("Tanh")
def _TanhGrad(op, grad):
"""Returns grad * (1 - tanh(x) * tanh(x))."""
y = op.outputs[0] # y = tanh(x)
with ops.control_dependencies([grad]):
y = math_ops.conj(y)
return gen_math_ops.tanh_grad(y, grad)
@ops.RegisterGradient("Asinh")
def _AsinhGrad(op, grad):
"""Returns grad * 1/cosh(y)."""
y = op.outputs[0]
with ops.control_dependencies([grad]):
y = math_ops.conj(y)
return grad / math_ops.cosh(y)
@ops.RegisterGradient("Acosh")
def _AcoshGrad(op, grad):
"""Returns grad * 1/sinh(y)."""
y = op.outputs[0]
with ops.control_dependencies([grad]):
y = math_ops.conj(y)
if compat.forward_compatible(2019, 9, 14):
return math_ops.xdivy(grad, math_ops.sinh(y))
else:
return grad / math_ops.sinh(y)
@ops.RegisterGradient("Atanh")
def _AtanhGrad(op, grad):
"""Returns grad * 1/ (1 - x^2)."""
x = op.inputs[0]
with ops.control_dependencies([grad]):
x = math_ops.conj(x)
x2 = math_ops.square(x)
one = constant_op.constant(1, dtype=grad.dtype)
inv = math_ops.reciprocal(math_ops.subtract(one, x2))
return grad * inv
@ops.RegisterGradient("TanhGrad")
def _TanhGradGrad(op, grad):
with ops.control_dependencies([grad]):
a = math_ops.conj(op.inputs[0])
b = math_ops.conj(op.inputs[1])
return grad * -2.0 * b * a, gen_math_ops.tanh_grad(a, grad)
@ops.RegisterGradient("Erf")
def _ErfGrad(op, grad):
"""Returns grad * 2/sqrt(pi) * exp(-x**2)."""
x = op.inputs[0]
two_over_root_pi = constant_op.constant(2 / np.sqrt(np.pi), dtype=grad.dtype)
with ops.control_dependencies([grad]):
x = math_ops.conj(x)
return grad * two_over_root_pi * math_ops.exp(-math_ops.square(x))
@ops.RegisterGradient("Erfc")
def _ErfcGrad(op, grad):
"""Returns -grad * 2/sqrt(pi) * exp(-x**2)."""
x = op.inputs[0]
minus_two_over_root_pi = constant_op.constant(
-2 / np.sqrt(np.pi), dtype=grad.dtype)
with ops.control_dependencies([grad]):
x = math_ops.conj(x)
return grad * minus_two_over_root_pi * math_ops.exp(-math_ops.square(x))
@ops.RegisterGradient("Lgamma")
def _LgammaGrad(op, grad):
"""Returns grad * digamma(x)."""
x = op.inputs[0]
with ops.control_dependencies([grad]):
x = math_ops.conj(x)
if compat.forward_compatible(2019, 9, 14):
return math_ops.mul_no_nan(math_ops.digamma(x), grad)
else:
return grad * math_ops.digamma(x)
@ops.RegisterGradient("Digamma")
def _DigammaGrad(op, grad):
"""Compute gradient of the digamma function with respect to its argument."""
x = op.inputs[0]
with ops.control_dependencies([grad]):
x = math_ops.conj(x)
partial_x = math_ops.polygamma(array_ops.constant(1, dtype=x.dtype), x)
if compat.forward_compatible(2019, 9, 14):
return math_ops.mul_no_nan(partial_x, grad)
else:
return grad * partial_x
@ops.RegisterGradient("BesselI0e")
def _BesselI0eGrad(op, grad):
"""Compute gradient of bessel_i0e(x) with respect to its argument."""
x = op.inputs[0]
y = op.outputs[0]
with ops.control_dependencies([grad]):
partial_x = (math_ops.bessel_i1e(x) - math_ops.sign(x) * y)
if compat.forward_compatible(2019, 9, 14):
return math_ops.mul_no_nan(partial_x, grad)
else:
return grad * partial_x
@ops.RegisterGradient("BesselI1e")
def _BesselI1eGrad(op, grad):
"""Compute gradient of bessel_i1e(x) with respect to its argument."""
x = op.inputs[0]
y = op.outputs[0]
with ops.control_dependencies([grad]):
# For x = 0, the correct gradient is 0.5.
# However, the main branch gives NaN because of the division by x, so
# we impute the gradient manually.
# An alternative solution is to express the gradient via bessel_i0e and
# bessel_i2e, but the latter is not yet implemented in Eigen.
eps = np.finfo(x.dtype.as_numpy_dtype).eps
zeros = array_ops.zeros_like(x)
x_is_not_tiny = math_ops.abs(x) > eps
safe_x = array_ops.where(x_is_not_tiny, x, eps + zeros)
dy_dx = math_ops.bessel_i0e(safe_x) - y * (
math_ops.sign(safe_x) + math_ops.reciprocal(safe_x))
dy_dx = array_ops.where(x_is_not_tiny, dy_dx, 0.5 + zeros)
if compat.forward_compatible(2019, 9, 14):
return math_ops.mul_no_nan(dy_dx, grad)
else:
return grad * dy_dx
@ops.RegisterGradient("Igamma")
def _IgammaGrad(op, grad):
"""Returns gradient of igamma(a, x) with respect to a and x."""
a = op.inputs[0]
x = op.inputs[1]
sa = array_ops.shape(a)
sx = array_ops.shape(x)
ra, rx = gen_array_ops.broadcast_gradient_args(sa, sx)
with ops.control_dependencies([grad]):
partial_a = gen_math_ops.igamma_grad_a(a, x)
# Perform operations in log space before summing, because Gamma(a)
# and Gamma'(a) can grow large.
partial_x = math_ops.exp(-x + (a - 1) * math_ops.log(x) -
math_ops.lgamma(a))
if compat.forward_compatible(2019, 9, 14):
return (array_ops.reshape(
math_ops.reduce_sum(math_ops.mul_no_nan(partial_a, grad), ra), sa),
array_ops.reshape(
math_ops.reduce_sum(math_ops.mul_no_nan(partial_x, grad), rx),
sx))
else:
return (array_ops.reshape(math_ops.reduce_sum(partial_a * grad, ra), sa),
array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx))
@ops.RegisterGradient("Igammac")
def _IgammacGrad(op, grad):
"""Returns gradient of igammac(a, x) = 1 - igamma(a, x) w.r.t. a and x."""
igamma_grad_a, igamma_grad_x = _IgammaGrad(op, grad)
return (-igamma_grad_a, -igamma_grad_x)
@ops.RegisterGradient("Betainc")
def _BetaincGrad(op, grad):
"""Returns gradient of betainc(a, b, x) with respect to x."""
# TODO(ebrevdo): Perhaps add the derivative w.r.t. a, b
a, b, x = op.inputs
# two cases: x is a scalar and a/b are same-shaped tensors, or vice
# versa; so its sufficient to check against shape(a).
sa = array_ops.shape(a)
sx = array_ops.shape(x)
_, rx = gen_array_ops.broadcast_gradient_args(sa, sx)
# Perform operations in log space before summing, because terms
# can grow large.
log_beta = (
gen_math_ops.lgamma(a) + gen_math_ops.lgamma(b) -
gen_math_ops.lgamma(a + b))
partial_x = math_ops.exp((b - 1) * math_ops.log(1 - x) +
(a - 1) * math_ops.log(x) - log_beta)
# TODO(b/36815900): Mark None return values as NotImplemented
if compat.forward_compatible(2019, 9, 14):
return (
None, # da
None, # db
array_ops.reshape(
math_ops.reduce_sum(math_ops.mul_no_nan(partial_x, grad), rx), sx))
else:
return (
None, # da
None, # db
array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx))
@ops.RegisterGradient("Zeta")
def _ZetaGrad(op, grad):
"""Returns gradient of zeta(x, q) with respect to x and q."""
# TODO(tillahoffmann): Add derivative with respect to x
x = op.inputs[0]
q = op.inputs[1]
# Broadcast gradients
sx = array_ops.shape(x)
sq = array_ops.shape(q)
unused_rx, rq = gen_array_ops.broadcast_gradient_args(sx, sq)
# Evaluate gradient
with ops.control_dependencies([grad]):
x = math_ops.conj(x)
q = math_ops.conj(q)
partial_q = -x * math_ops.zeta(x + 1, q)
# TODO(b/36815900): Mark None return values as NotImplemented
if compat.forward_compatible(2019, 9, 14):
return (None,
array_ops.reshape(
math_ops.reduce_sum(math_ops.mul_no_nan(partial_q, grad), rq),
sq))
else:
return (None,
array_ops.reshape(math_ops.reduce_sum(partial_q * grad, rq), sq))
@ops.RegisterGradient("Polygamma")
def _PolygammaGrad(op, grad):
"""Returns gradient of psi(n, x) with respect to n and x."""
# TODO(tillahoffmann): Add derivative with respect to n
n = op.inputs[0]
x = op.inputs[1]
# Broadcast gradients
sn = array_ops.shape(n)
sx = array_ops.shape(x)
unused_rn, rx = gen_array_ops.broadcast_gradient_args(sn, sx)
# Evaluate gradient
with ops.control_dependencies([grad]):
n = math_ops.conj(n)
x = math_ops.conj(x)
partial_x = math_ops.polygamma(n + 1, x)
# TODO(b/36815900): Mark None return values as NotImplemented
if compat.forward_compatible(2019, 9, 14):
return (None,
array_ops.reshape(
math_ops.reduce_sum(math_ops.mul_no_nan(partial_x, grad), rx),
sx))
else:
return (None,
array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx))
@ops.RegisterGradient("Sigmoid")
def _SigmoidGrad(op, grad):
"""Returns grad * sigmoid(x) * (1 - sigmoid(x))."""
y = op.outputs[0] # y = sigmoid(x)
with ops.control_dependencies([grad]):
y = math_ops.conj(y)
return gen_math_ops.sigmoid_grad(y, grad)
@ops.RegisterGradient("SigmoidGrad")
def _SigmoidGradGrad(op, grad):
with ops.control_dependencies([grad]):
a = math_ops.conj(op.inputs[0])
b = math_ops.conj(op.inputs[1])
gb = grad * b
return gb - 2.0 * gb * a, gen_math_ops.sigmoid_grad(a, grad)
@ops.RegisterGradient("Sign")
def _SignGrad(op, _):
"""Returns 0."""
x = op.inputs[0]
return array_ops.zeros(array_ops.shape(x), dtype=x.dtype)
@ops.RegisterGradient("Sin")
def _SinGrad(op, grad):
"""Returns grad * cos(x)."""
x = op.inputs[0]
with ops.control_dependencies([grad]):
x = math_ops.conj(x)
return grad * math_ops.cos(x)
@ops.RegisterGradient("Cos")
def _CosGrad(op, grad):
"""Returns grad * -sin(x)."""
x = op.inputs[0]
with ops.control_dependencies([grad]):
x = math_ops.conj(x)
return -grad * math_ops.sin(x)
@ops.RegisterGradient("Tan")
def _TanGrad(op, grad):
"""Returns grad * 1/sec^2(x)."""
x = op.inputs[0]
with ops.control_dependencies([grad]):
x = math_ops.conj(x)
secx = math_ops.reciprocal(math_ops.cos(x))
secx2 = math_ops.square(secx)
if compat.forward_compatible(2019, 9, 14):
return math_ops.mul_no_nan(secx2, grad)
else:
return secx2 * grad
@ops.RegisterGradient("Asin")
def _AsinGrad(op, grad):
"""Returns grad * 1/sqrt(1-x^2)."""
x = op.inputs[0]
with ops.control_dependencies([grad]):
x = math_ops.conj(x)
x2 = math_ops.square(x)
one = constant_op.constant(1, dtype=grad.dtype)
den = math_ops.sqrt(math_ops.subtract(one, x2))
if compat.forward_compatible(2019, 9, 14):
return math_ops.xdivy(grad, den)
else:
inv = math_ops.reciprocal(den)
return grad * inv
@ops.RegisterGradient("Acos")
def _AcosGrad(op, grad):
"""Returns grad * -1/sqrt(1-x^2)."""
x = op.inputs[0]
with ops.control_dependencies([grad]):
x = math_ops.conj(x)
x2 = math_ops.square(x)
one = constant_op.constant(1, dtype=grad.dtype)
den = math_ops.sqrt(math_ops.subtract(one, x2))
if compat.forward_compatible(2019, 9, 14):
return -math_ops.xdivy(grad, den)
else:
inv = math_ops.reciprocal(den)
return -grad * inv
@ops.RegisterGradient("Atan")
def _AtanGrad(op, grad):
"""Returns grad * 1/ (1 + x^2)."""
x = op.inputs[0]
with ops.control_dependencies([grad]):
x = math_ops.conj(x)
x2 = math_ops.square(x)
one = constant_op.constant(1, dtype=grad.dtype)
inv = math_ops.reciprocal(math_ops.add(one, x2))
return grad * inv
@ops.RegisterGradient("Atan2")
def _Atan2Grad(op, grad):
"""Returns grad * x / (x^2 + y^2), grad * -y / (x^2 + y^2)."""
y = op.inputs[0]
x = op.inputs[1]
with ops.control_dependencies([grad]):
if compat.forward_compatible(2019, 9, 14):
grad_inv = math_ops.xdivy(grad, (math_ops.square(x) + math_ops.square(y)))
else:
grad_inv = grad / (math_ops.square(x) + math_ops.square(y))
return x * grad_inv, -y * grad_inv
@ops.RegisterGradient("AddN")
def _AddNGrad(op, grad):
"""Copies the gradient to all inputs."""
# Not broadcasting.
return [grad] * len(op.inputs)
def _ShapesFullySpecifiedAndEqual(x, y, grad):
# pylint: disable=protected-access
x_shape = x._shape_tuple()
y_shape = y._shape_tuple()
grad_shape = grad._shape_tuple()
# pylint: enable=protected-access
return (x_shape == y_shape and x_shape == grad_shape and
x_shape is not None and None not in x_shape)
@ops.RegisterGradient("Add")
@ops.RegisterGradient("AddV2")
def _AddGrad(op, grad):
"""Gradient for Add."""
y = op.inputs[1]
skip_input_indices = None
try:
skip_input_indices = op.skip_input_indices
if skip_input_indices is not None and 1 in skip_input_indices and _IsScalar(
y):
return grad, None
except AttributeError:
# No gradient skipping, so do the full gradient computation
pass
x = op.inputs[0]
if (isinstance(grad, ops.Tensor) and
_ShapesFullySpecifiedAndEqual(x, y, grad)):
return grad, grad
sx = array_ops.shape(x)
sy = array_ops.shape(y)
rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
if skip_input_indices is not None and 0 in skip_input_indices:
gx = None
else:
gx = array_ops.reshape(math_ops.reduce_sum(grad, rx), sx)
if skip_input_indices is not None and 1 in skip_input_indices:
gy = None
else:
gy = array_ops.reshape(math_ops.reduce_sum(grad, ry), sy)
return (gx, gy)
@ops.RegisterGradient("Sub")
def _SubGrad(op, grad):
"""Gradient for Sub."""
x = op.inputs[0]
y = op.inputs[1]
if (isinstance(grad, ops.Tensor) and
_ShapesFullySpecifiedAndEqual(x, y, grad)):
return grad, -grad
sx = array_ops.shape(x)
sy = array_ops.shape(y)
rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
return (array_ops.reshape(math_ops.reduce_sum(grad, rx), sx),
array_ops.reshape(-math_ops.reduce_sum(grad, ry), sy))
@ops.RegisterGradient("Mul")
def _MulGrad(op, grad):
"""The gradient of scalar multiplication."""
x = op.inputs[0]
y = op.inputs[1]
if (isinstance(grad, ops.Tensor) and
_ShapesFullySpecifiedAndEqual(x, y, grad) and
grad.dtype in (dtypes.int32, dtypes.float32)):
return gen_math_ops.mul(grad, y), gen_math_ops.mul(grad, x)
assert x.dtype.base_dtype == y.dtype.base_dtype, (x.dtype, " vs. ", y.dtype)
sx = array_ops.shape(x)
sy = array_ops.shape(y)
rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
x = math_ops.conj(x)
y = math_ops.conj(y)
return (array_ops.reshape(
math_ops.reduce_sum(gen_math_ops.mul(grad, y), rx), sx),
array_ops.reshape(
math_ops.reduce_sum(gen_math_ops.mul(x, grad), ry), sy))
@ops.RegisterGradient("MulNoNan")
def _MulNoNanGrad(op, grad):
"""The gradient of scalar multiplication with NaN-suppression."""
x = op.inputs[0]
y = op.inputs[1]
if (isinstance(grad, ops.Tensor) and
_ShapesFullySpecifiedAndEqual(x, y, grad)):
return gen_math_ops.mul_no_nan(grad, y), gen_math_ops.mul_no_nan(x, grad)
assert x.dtype.base_dtype == y.dtype.base_dtype, (x.dtype, " vs. ", y.dtype)
sx = array_ops.shape(x)
sy = array_ops.shape(y)
rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
return (array_ops.reshape(
math_ops.reduce_sum(gen_math_ops.mul_no_nan(grad, y), rx), sx),
array_ops.reshape(
math_ops.reduce_sum(gen_math_ops.mul_no_nan(x, grad), ry), sy))
@ops.RegisterGradient("Div")
def _DivGrad(op, grad):
"""The gradient for the Div operator."""
x = op.inputs[0]
y = op.inputs[1]
sx = array_ops.shape(x)
sy = array_ops.shape(y)
rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
x = math_ops.conj(x)
y = math_ops.conj(y)
if compat.forward_compatible(2019, 9, 14):
return (array_ops.reshape(
math_ops.reduce_sum(math_ops.xdivy(grad, y), rx), sx),
array_ops.reshape(
math_ops.reduce_sum(
math_ops.mul_no_nan(
math_ops.divide(math_ops.divide(-x, y), y), grad), ry),
sy))
else:
return (array_ops.reshape(
math_ops.reduce_sum(math_ops.divide(grad, y), rx), sx),
array_ops.reshape(
math_ops.reduce_sum(
grad * math_ops.divide(math_ops.divide(-x, y), y), ry), sy))
@ops.RegisterGradient("FloorDiv")
def _FloorDivGrad(_, unused_grad):
"""The gradient for the FloorDiv operator."""
return None, None
@ops.RegisterGradient("FloorMod")
def _FloorModGrad(op, grad):
"""Returns grad * (1, -floor(x/y))."""
x = math_ops.conj(op.inputs[0])
y = math_ops.conj(op.inputs[1])
sx = array_ops.shape(x)
sy = array_ops.shape(y)
rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
floor_xy = math_ops.floor_div(x, y)
gx = array_ops.reshape(math_ops.reduce_sum(grad, rx), sx)
gy = array_ops.reshape(
math_ops.reduce_sum(grad * math_ops.negative(floor_xy), ry), sy)
return gx, gy
@ops.RegisterGradient("TruncateDiv")
def _TruncateDivGrad(_, unused_grad):
return None, None
@ops.RegisterGradient("RealDiv")
def _RealDivGrad(op, grad):
"""RealDiv op gradient."""
x = op.inputs[0]
y = op.inputs[1]
sx = array_ops.shape(x)
sy = array_ops.shape(y)
rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
x = math_ops.conj(x)
y = math_ops.conj(y)
if compat.forward_compatible(2019, 9, 14):
return (array_ops.reshape(
math_ops.reduce_sum(math_ops.xdivy(grad, y), rx), sx),
array_ops.reshape(
math_ops.reduce_sum(
math_ops.mul_no_nan(
math_ops.realdiv(math_ops.realdiv(-x, y), y), grad),
ry), sy))
else:
return (array_ops.reshape(
math_ops.reduce_sum(math_ops.realdiv(grad, y), rx), sx),
array_ops.reshape(
math_ops.reduce_sum(
grad * math_ops.realdiv(math_ops.realdiv(-x, y), y), ry),
sy))
@ops.RegisterGradient("DivNoNan")
def _DivNoNanGrad(op, grad):
"""DivNoNan op gradient."""
x = op.inputs[0]
y = op.inputs[1]
sx = array_ops.shape(x)
sy = array_ops.shape(y)
rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
x = math_ops.conj(x)
y = math_ops.conj(y)
if compat.forward_compatible(2019, 9, 14):
return (array_ops.reshape(
math_ops.reduce_sum(math_ops.div_no_nan(grad, y), rx), sx),
array_ops.reshape(
math_ops.reduce_sum(
math_ops.mul_no_nan(
math_ops.div_no_nan(math_ops.div_no_nan(-x, y), y),
grad), ry), sy))
else:
return (array_ops.reshape(
math_ops.reduce_sum(math_ops.div_no_nan(grad, y), rx), sx),
array_ops.reshape(
math_ops.reduce_sum(
grad * math_ops.div_no_nan(math_ops.div_no_nan(-x, y), y),
ry), sy))
@ops.RegisterGradient("Pow")
def _PowGrad(op, grad):
"""Returns grad * (y*x^(y-1), z*log(x))."""
x = op.inputs[0]
y = op.inputs[1]
z = op.outputs[0]
sx = array_ops.shape(x)
sy = array_ops.shape(y)
rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
x = math_ops.conj(x)
y = math_ops.conj(y)
z = math_ops.conj(z)
if compat.forward_compatible(2019, 9, 14):
gx = array_ops.reshape(
math_ops.reduce_sum(
gen_math_ops.mul_no_nan(y * math_ops.pow(x, y - 1), grad), rx), sx)
else:
gx = array_ops.reshape(
math_ops.reduce_sum(grad * y * math_ops.pow(x, y - 1), rx), sx)
# Avoid false singularity at x = 0
if x.dtype.is_complex:
# real(x) < 0 is fine for the complex case
mask = math_ops.not_equal(x, 0)
else:
# There's no sensible real value to return if x < 0, so return 0
mask = x > 0
safe_x = array_ops.where(mask, x, array_ops.ones_like(x))
log_x = array_ops.where(mask, math_ops.log(safe_x), array_ops.zeros_like(x))
if compat.forward_compatible(2019, 9, 14):
gy = array_ops.reshape(
math_ops.reduce_sum(gen_math_ops.mul_no_nan(z * log_x, grad), ry), sy)
else:
gy = array_ops.reshape(math_ops.reduce_sum(grad * z * log_x, ry), sy)
return gx, gy
def _MaximumMinimumGradInputOnly(op, grad, selector_op):
x = op.inputs[0]
y = op.inputs[1]
zeros = array_ops.zeros_like(grad)
xmask = selector_op(x, y)
xgrad = array_ops.where(xmask, grad, zeros)
ygrad = None # Return None for ygrad since the config allows that.
return (xgrad, ygrad)
def _MaximumMinimumGrad(op, grad, selector_op):
"""Factor out the code for the gradient of Maximum or Minimum."""
y = op.inputs[1]
skip_input_indices = None
try:
skip_input_indices = op.skip_input_indices
if skip_input_indices is not None and 1 in skip_input_indices and _IsScalar(
y):
# When we want to get gradients for the first input only, and the second
# input tensor is a scalar, we can do a much simpler calculation
return _MaximumMinimumGradInputOnly(op, grad, selector_op)
except AttributeError:
# No gradient skipping, so do the full gradient computation
pass
x = op.inputs[0]
gdtype = grad.dtype
sx = array_ops.shape(x)
sy = array_ops.shape(y)
gradshape = array_ops.shape(grad)
zeros = array_ops.zeros(gradshape, gdtype)
xmask = selector_op(x, y)
rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
if skip_input_indices is not None and 0 in skip_input_indices:
gx = None
else:
xgrad = array_ops.where(xmask, grad, zeros)
gx = array_ops.reshape(math_ops.reduce_sum(xgrad, rx), sx)
if skip_input_indices is not None and 1 in skip_input_indices:
gy = None
else:
ygrad = array_ops.where(xmask, zeros, grad)
gy = array_ops.reshape(math_ops.reduce_sum(ygrad, ry), sy)
return (gx, gy)
@ops.RegisterGradient("Maximum")
def _MaximumGrad(op, grad):
"""Returns grad*(x > y, x <= y) with type of grad."""
return _MaximumMinimumGrad(op, grad, math_ops.greater_equal)
@ops.RegisterGradient("Minimum")
def _MinimumGrad(op, grad):
"""Returns grad*(x < y, x >= y) with type of grad."""
return _MaximumMinimumGrad(op, grad, math_ops.less_equal)
@ops.RegisterGradient("SquaredDifference")
def _SquaredDifferenceGrad(op, grad):
"""Returns the gradient for (x-y)^2."""
x = op.inputs[0]
y = op.inputs[1]
sx = array_ops.shape(x)
sy = array_ops.shape(y)
rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
with ops.control_dependencies([grad]):
# The parens ensure that if grad is IndexedSlices, it'll get multiplied by
# Tensor (not a number like 2.0) which causes it to convert to Tensor.
x_grad = math_ops.scalar_mul(2.0, grad) * (x - y)
return (array_ops.reshape(math_ops.reduce_sum(x_grad, rx), sx),
-array_ops.reshape(math_ops.reduce_sum(x_grad, ry), sy))
# Logical operations have no gradients.
ops.NotDifferentiable("Less")
ops.NotDifferentiable("LessEqual")
ops.NotDifferentiable("Greater")
ops.NotDifferentiable("GreaterEqual")
ops.NotDifferentiable("Equal")
ops.NotDifferentiable("ApproximateEqual")
ops.NotDifferentiable("NotEqual")
ops.NotDifferentiable("LogicalAnd")
ops.NotDifferentiable("LogicalOr")
ops.NotDifferentiable("LogicalNot")
@ops.RegisterGradient("Select")
def _SelectGrad(op, grad):
c = op.inputs[0]
x = op.inputs[1]
zeros = array_ops.zeros_like(x)
return (None, array_ops.where(c, grad, zeros), array_ops.where(
c, zeros, grad))
@ops.RegisterGradient("SelectV2")
def _SelectGradV2(op, grad):
c = op.inputs[0]
x = op.inputs[1]
y = op.inputs[2]
zeros = array_ops.zeros([], dtype=grad.dtype.base_dtype)
gx = array_ops.where_v2(c, grad, zeros)
x_shape = array_ops.shape(x)
output_shape = array_ops.shape(op.outputs[0])
# Reduce away broadcasted leading dims.
reduce_x, _ = gen_array_ops.broadcast_gradient_args(x_shape, output_shape)
gx = math_ops.reduce_sum(gx, keepdims=True, axis=reduce_x)
gx = array_ops.reshape(gx, x_shape)
gy = array_ops.where_v2(c, zeros, grad)
y_shape = array_ops.shape(y)
# Reduce away broadcasted leading dims.
reduce_y, _ = gen_array_ops.broadcast_gradient_args(y_shape, output_shape)
gy = math_ops.reduce_sum(gy, keepdims=True, axis=reduce_y)
gy = array_ops.reshape(gy, y_shape)
return (None, gx, gy)
def _MatMulGradAgainstFirstOnly(op, grad):
"""Gradient for MatMul, only for the first input."""
t_a = op.get_attr("transpose_a")
t_b = op.get_attr("transpose_b")
b = math_ops.conj(op.inputs[1])
if not t_a and not t_b:
grad_a = gen_math_ops.mat_mul(grad, b, transpose_b=True)
elif not t_a and t_b:
grad_a = gen_math_ops.mat_mul(grad, b)
elif t_a and not t_b:
grad_a = gen_math_ops.mat_mul(b, grad, transpose_b=True)
elif t_a and t_b:
grad_a = gen_math_ops.mat_mul(b, grad, transpose_a=True, transpose_b=True)
return grad_a, None
def _MatMulGradAgainstSecondOnly(op, grad):
"""Gradient for MatMul, only for the second input."""
t_a = op.get_attr("transpose_a")
t_b = op.get_attr("transpose_b")
a = math_ops.conj(op.inputs[0])
if not t_a and not t_b:
grad_b = gen_math_ops.mat_mul(a, grad, transpose_a=True)
elif not t_a and t_b:
grad_b = gen_math_ops.mat_mul(grad, a, transpose_a=True)
elif t_a and not t_b:
grad_b = gen_math_ops.mat_mul(a, grad)
elif t_a and t_b:
grad_b = gen_math_ops.mat_mul(grad, a, transpose_a=True, transpose_b=True)
return None, grad_b
@ops.RegisterGradient("MatMul")
def _MatMulGrad(op, grad):
"""Gradient for MatMul."""
try:
skip_input_indices = op.skip_input_indices
if skip_input_indices is not None:
if 1 in skip_input_indices:
return _MatMulGradAgainstFirstOnly(op, grad)
elif 0 in skip_input_indices:
return _MatMulGradAgainstSecondOnly(op, grad)
except AttributeError:
# No gradient skipping, so do the full gradient computation
pass
t_a = op.get_attr("transpose_a")
t_b = op.get_attr("transpose_b")
a = math_ops.conj(op.inputs[0])
b = math_ops.conj(op.inputs[1])
if not t_a and not t_b:
grad_a = gen_math_ops.mat_mul(grad, b, transpose_b=True)
grad_b = gen_math_ops.mat_mul(a, grad, transpose_a=True)
elif not t_a and t_b:
grad_a = gen_math_ops.mat_mul(grad, b)
grad_b = gen_math_ops.mat_mul(grad, a, transpose_a=True)
elif t_a and not t_b:
grad_a = gen_math_ops.mat_mul(b, grad, transpose_b=True)
grad_b = gen_math_ops.mat_mul(a, grad)
elif t_a and t_b:
grad_a = gen_math_ops.mat_mul(b, grad, transpose_a=True, transpose_b=True)
grad_b = gen_math_ops.mat_mul(grad, a, transpose_a=True, transpose_b=True)
return grad_a, grad_b
@ops.RegisterGradient("SparseMatMul")
def _SparseMatMulGrad(op, grad):
"""Gradient for SparseMatMul."""
t_a = op.get_attr("transpose_a")
t_b = op.get_attr("transpose_b")
is_sparse = {
op.inputs[0]: op.get_attr("a_is_sparse"),
op.inputs[1]: op.get_attr("b_is_sparse"),
# Use heuristic to figure out if grad might be sparse
grad: not context.executing_eagerly() and (grad.op.type == "ReluGrad")
}
def _SparseMatMul(t1, t2, out_dtype, transpose_a=False, transpose_b=False):
"""Helper function to create SparseMatMul op."""
assert t1 in is_sparse and t2 in is_sparse
t1_sparse = is_sparse[t1]
t2_sparse = is_sparse[t2]
if transpose_b:
t2 = array_ops.transpose(t2)
transpose_b = False
prod = math_ops.matmul(
t1,
t2,
transpose_a=transpose_a,
transpose_b=transpose_b,
a_is_sparse=t1_sparse,
b_is_sparse=t2_sparse)
if prod.dtype != out_dtype:
prod = math_ops.cast(prod, out_dtype)
return prod
dtype_a = op.inputs[0].dtype
dtype_b = op.inputs[1].dtype
if not t_a and not t_b:
return (_SparseMatMul(grad, op.inputs[1], dtype_a, transpose_b=True),
_SparseMatMul(op.inputs[0], grad, dtype_b, transpose_a=True))
elif not t_a and t_b:
return (_SparseMatMul(grad, op.inputs[1], dtype_a),
_SparseMatMul(grad, op.inputs[0], dtype_b, transpose_a=True))
elif t_a and not t_b:
return (_SparseMatMul(op.inputs[1], grad, dtype_a, transpose_b=True),
_SparseMatMul(op.inputs[0], grad, dtype_b))
elif t_a and t_b:
return (_SparseMatMul(
op.inputs[1], grad, dtype_a, transpose_a=True, transpose_b=True),
_SparseMatMul(
grad, op.inputs[0], dtype_b, transpose_a=True,
transpose_b=True))
@ops.RegisterGradient("Floor")
def _FloorGrad(_, unused_grad):
return [None]
@ops.RegisterGradient("Ceil")
def _CeilGrad(_, unused_grad):
return [None]
@ops.RegisterGradient("Round")
def _RoundGrad(_, unused_grad):
return [None]
@ops.RegisterGradient("Rint")
def _RintGrad(_, unused_grad):
# the gradient of Rint is zero
return [None]
@ops.RegisterGradient("BatchMatMul")
def _BatchMatMul(op, grad):
"""Returns the gradient of x and y given the gradient of x * y."""
x = op.inputs[0]
y = op.inputs[1]
adj_x = op.get_attr("adj_x")
adj_y = op.get_attr("adj_y")
if not adj_x:
if not adj_y:
grad_x = math_ops.matmul(grad, y, adjoint_a=False, adjoint_b=True)
grad_y = math_ops.matmul(x, grad, adjoint_a=True, adjoint_b=False)
else:
grad_x = math_ops.matmul(grad, y, adjoint_a=False, adjoint_b=False)
grad_y = math_ops.matmul(grad, x, adjoint_a=True, adjoint_b=False)
else:
if not adj_y:
grad_x = math_ops.matmul(y, grad, adjoint_a=False, adjoint_b=True)
grad_y = math_ops.matmul(x, grad, adjoint_a=False, adjoint_b=False)
else:
grad_x = math_ops.matmul(y, grad, adjoint_a=True, adjoint_b=True)
grad_y = math_ops.matmul(grad, x, adjoint_a=True, adjoint_b=True)
return grad_x, grad_y
@ops.RegisterGradient("BatchMatMulV2")
def _BatchMatMulV2(op, grad):
"""Returns the gradient of x and y given the gradient of x * y."""
x = op.inputs[0]
y = op.inputs[1]
adj_x = op.get_attr("adj_x")
adj_y = op.get_attr("adj_y")
if not adj_x:
if not adj_y:
grad_x = math_ops.matmul(grad, y, adjoint_a=False, adjoint_b=True)
grad_y = math_ops.matmul(x, grad, adjoint_a=True, adjoint_b=False)
else:
grad_x = math_ops.matmul(grad, y, adjoint_a=False, adjoint_b=False)
grad_y = math_ops.matmul(grad, x, adjoint_a=True, adjoint_b=False)
else:
if not adj_y:
grad_x = math_ops.matmul(y, grad, adjoint_a=False, adjoint_b=True)
grad_y = math_ops.matmul(x, grad, adjoint_a=False, adjoint_b=False)
else:
grad_x = math_ops.matmul(y, grad, adjoint_a=True, adjoint_b=True)
grad_y = math_ops.matmul(grad, x, adjoint_a=True, adjoint_b=True)
# Reduce along the broadcasted batch dimensions, if broadcasting is required.
shape_x_static = x.get_shape()
shape_y_static = y.get_shape()
if not (shape_x_static.is_fully_defined() and
shape_y_static.is_fully_defined() and
shape_x_static == shape_y_static):
sx = array_ops.shape(x)
sy = array_ops.shape(y)
rx, ry = gen_array_ops.broadcast_gradient_args(sx[:-2], sy[:-2])
grad_x = array_ops.reshape(math_ops.reduce_sum(grad_x, rx), sx)
grad_y = array_ops.reshape(math_ops.reduce_sum(grad_y, ry), sy)
return grad_x, grad_y
ops.NotDifferentiable("Range")
ops.NotDifferentiable("LinSpace")
@ops.RegisterGradient("Complex")
def _ComplexGrad(op, grad):
"""Returns the real and imaginary components of 'grad', respectively."""
x = op.inputs[0]
y = op.inputs[1]
sx = array_ops.shape(x)
sy = array_ops.shape(y)
rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
return (array_ops.reshape(math_ops.reduce_sum(math_ops.real(grad), rx), sx),
array_ops.reshape(math_ops.reduce_sum(math_ops.imag(grad), ry), sy))
@ops.RegisterGradient("Real")
def _RealGrad(_, grad):
"""Returns 'grad' as the real part and set the imaginary part 0."""
zero = constant_op.constant(0, dtype=grad.dtype)
return math_ops.complex(grad, zero)
@ops.RegisterGradient("Imag")
def _ImagGrad(_, grad):
"""Returns 'grad' as the imaginary part and set the real part 0."""
zero = constant_op.constant(0, dtype=grad.dtype)
return math_ops.complex(zero, grad)
@ops.RegisterGradient("Angle")
def _AngleGrad(op, grad):
"""Returns -grad / (Im(x) + iRe(x))"""
x = op.inputs[0]
with ops.control_dependencies([grad]):
re = math_ops.real(x)
im = math_ops.imag(x)
z = math_ops.reciprocal(math_ops.complex(im, re))
zero = constant_op.constant(0, dtype=grad.dtype)
complex_grad = math_ops.complex(grad, zero)
return -complex_grad * z
@ops.RegisterGradient("Conj")
def _ConjGrad(_, grad):
"""Returns the complex conjugate of grad."""
return math_ops.conj(grad)
@ops.RegisterGradient("ComplexAbs")
def _ComplexAbsGrad(op, grad):
"""Returns the gradient of ComplexAbs."""
return math_ops.div_no_nan(
math_ops.complex(
grad, array_ops.zeros_like(grad)) * op.inputs[0],
math_ops.complex(
op.outputs[0], array_ops.zeros_like(op.outputs[0])))
@ops.RegisterGradient("Cast")
def _CastGrad(op, grad):
t = [
dtypes.float16, dtypes.float32, dtypes.float64, dtypes.bfloat16,
dtypes.complex64, dtypes.complex128
]
src_type = op.inputs[0].dtype.base_dtype
dst_type = grad.dtype.base_dtype
if src_type in t and dst_type in t:
return math_ops.cast(grad, src_type)
else:
return None
@ops.RegisterGradient("Cross")
def _CrossGrad(op, grad):
u = op.inputs[0]
v = op.inputs[1]
return (math_ops.cross(v, grad), math_ops.cross(grad, u))
@ops.RegisterGradient("Cumsum")
def _CumsumGrad(op, grad):
axis = op.inputs[1]
exclusive = op.get_attr("exclusive")
reverse = op.get_attr("reverse")
return [
math_ops.cumsum(grad, axis, exclusive=exclusive, reverse=not reverse),
None
]
@ops.RegisterGradient("Cumprod")
def _CumprodGrad(op, grad):
x = op.inputs[0]
axis = op.inputs[1]
exclusive = op.get_attr("exclusive")
reverse = op.get_attr("reverse")
# TODO This fails when x contains 0 and should be fixed
prod = math_ops.cumprod(x, axis, exclusive=exclusive, reverse=reverse)
out = math_ops.cumsum(
prod * grad, axis, exclusive=exclusive, reverse=not reverse)
return [out / x, None]
@ops.RegisterGradient("CumulativeLogsumexp")
def _CumulativeLogsumexpGrad(op, grad):
x = op.inputs[0]
axis = op.inputs[1]
cumulative_logsumexp = op.outputs[0]
exclusive = op.get_attr("exclusive")
reverse = op.get_attr("reverse")
# Split the incoming gradient into positive and negative part
# in order to take logs. This is required for stable results.
log_grad_positive = array_ops.where_v2(
math_ops.greater(grad, 0),
math_ops.log(grad),
grad.dtype.min)
log_grad_negative = array_ops.where_v2(
math_ops.less(grad, 0),
math_ops.log(-grad),
grad.dtype.min)
output_pos = math_ops.exp(
math_ops.cumulative_logsumexp(
log_grad_positive - cumulative_logsumexp,
axis=axis, reverse=not reverse, exclusive=exclusive) + x)
output_neg = math_ops.exp(
math_ops.cumulative_logsumexp(
log_grad_negative - cumulative_logsumexp,
axis=axis, reverse=not reverse, exclusive=exclusive) + x)
return [output_pos - output_neg, None]
@ops.RegisterGradient("NextAfter")
def _NextAfterGrad(op, grad):
"""Returns gradient of nextafter(x1, x2) with respect to x1 and x2."""
x1 = op.inputs[0]
x2 = op.inputs[1]
s_x1 = array_ops.shape(x1)
s_x2 = array_ops.shape(x2)
r_x1, r_x2 = gen_array_ops.broadcast_gradient_args(s_x1, s_x2)
with ops.control_dependencies([grad]):
partial_x1 = array_ops.ones(s_x1, dtype=x1.dtype)
partial_x2 = array_ops.zeros(s_x2, dtype=x2.dtype)
return (array_ops.reshape(
math_ops.reduce_sum(partial_x1 * grad, r_x1), s_x1),
array_ops.reshape(
math_ops.reduce_sum(partial_x2 * grad, r_x2), s_x2))