| # Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Gradients for operators defined in math_ops.py.""" |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| import numpy as np |
| |
| from tensorflow.python.compat import compat |
| from tensorflow.python.eager import context |
| from tensorflow.python.framework import constant_op |
| from tensorflow.python.framework import dtypes |
| from tensorflow.python.framework import ops |
| from tensorflow.python.framework import tensor_util |
| from tensorflow.python.ops import array_ops |
| from tensorflow.python.ops import gen_array_ops |
| from tensorflow.python.ops import gen_math_ops |
| from tensorflow.python.ops import math_ops |
| |
| |
| def _safe_shape_div(x, y): |
| """Divides `x / y` assuming `x, y >= 0`, treating `0 / 0 = 0`.""" |
| return x // math_ops.maximum(y, 1) |
| |
| |
| @ops.RegisterGradient("ArgMax") |
| def _ArgMaxGrad(op, grad): |
| del op, grad |
| return [None, None] |
| |
| |
| @ops.RegisterGradient("ArgMin") |
| def _ArgMinGrad(op, grad): |
| del op, grad |
| return [None, None] |
| |
| |
| # TODO(rmlarsen): Implement gradient. |
| ops.NotDifferentiable("EuclideanNorm") |
| |
| |
| _empty_tuple = () |
| |
| |
| def _IsScalar(x): |
| return x._shape_tuple() is _empty_tuple # pylint: disable=protected-access |
| |
| |
| @ops.RegisterGradient("Sum") |
| def _SumGrad(op, grad): |
| """Gradient for Sum.""" |
| # Fast path for when reducing to a scalar and ndims is known: adds only |
| # Reshape and Tile ops (and possibly a Shape). |
| input_0_shape = op.inputs[0]._shape_tuple() # pylint: disable=protected-access |
| if input_0_shape is not None: |
| axes = tensor_util.constant_value(op.inputs[1]) |
| if axes is not None: |
| rank = len(input_0_shape) |
| if np.array_equal(axes, np.arange(rank)): # Reduce all dims. |
| if context.executing_eagerly(): |
| ctx = context.context() |
| new_shape = ctx.ones_rank_cache().get(rank) |
| if new_shape is None: |
| new_shape = constant_op.constant([1] * rank, dtype=dtypes.int32) |
| ctx.ones_rank_cache().put(rank, new_shape) |
| else: |
| new_shape = [1] * rank |
| grad = array_ops.reshape(grad, new_shape) |
| # If shape is not fully defined (but rank is), we use Shape. |
| if None not in input_0_shape: |
| input_shape = constant_op.constant(input_0_shape, dtype=dtypes.int32) |
| else: |
| input_shape = array_ops.shape(op.inputs[0]) |
| return [array_ops.tile(grad, input_shape), None] |
| |
| input_shape = array_ops.shape(op.inputs[0]) |
| # TODO(apassos) remove this once device placement for eager ops makes more |
| # sense. |
| with ops.colocate_with(input_shape): |
| output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1]) |
| tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims) |
| grad = array_ops.reshape(grad, output_shape_kept_dims) |
| return [array_ops.tile(grad, tile_scaling), None] |
| |
| |
| def _MinOrMaxGrad(op, grad): |
| """Gradient for Min or Max. Amazingly it's precisely the same code.""" |
| input_shape = array_ops.shape(op.inputs[0]) |
| output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1]) |
| y = op.outputs[0] |
| y = array_ops.reshape(y, output_shape_kept_dims) |
| grad = array_ops.reshape(grad, output_shape_kept_dims) |
| |
| # Compute the number of selected (maximum or minimum) elements in each |
| # reduction dimension. If there are multiple minimum or maximum elements |
| # then the gradient will be divided between them. |
| indicators = math_ops.cast(math_ops.equal(y, op.inputs[0]), grad.dtype) |
| num_selected = array_ops.reshape( |
| math_ops.reduce_sum(indicators, op.inputs[1]), output_shape_kept_dims) |
| |
| return [math_ops.divide(indicators, num_selected) * grad, None] |
| |
| |
| @ops.RegisterGradient("Max") |
| def _MaxGrad(op, grad): |
| """Gradient for Max.""" |
| return _MinOrMaxGrad(op, grad) |
| |
| |
| @ops.RegisterGradient("Min") |
| def _MinGrad(op, grad): |
| return _MinOrMaxGrad(op, grad) |
| |
| |
| @ops.RegisterGradient("Mean") |
| def _MeanGrad(op, grad): |
| """Gradient for Mean.""" |
| sum_grad = _SumGrad(op, grad)[0] |
| input_shape = op.inputs[0]._shape_tuple() # pylint: disable=protected-access |
| output_shape = op.outputs[0]._shape_tuple() # pylint: disable=protected-access |
| if (input_shape is not None and output_shape is not None and |
| None not in input_shape and None not in output_shape): |
| input_size = np.prod(input_shape) |
| output_size = np.prod(output_shape) |
| factor = input_size // max(output_size, 1) |
| factor = constant_op.constant(factor, dtype=sum_grad.dtype) |
| else: |
| input_shape = array_ops.shape(op.inputs[0]) |
| output_shape = array_ops.shape(op.outputs[0]) |
| factor = _safe_shape_div( |
| math_ops.reduce_prod(input_shape), math_ops.reduce_prod(output_shape)) |
| return math_ops.truediv(sum_grad, math_ops.cast(factor, sum_grad.dtype)), None |
| |
| |
| @ops.RegisterGradient("Prod") |
| def _ProdGrad(op, grad): |
| """Gradient for Prod.""" |
| # The gradient can be expressed by dividing the product by each entry of the |
| # input tensor, but this approach can't deal with zeros in the input. |
| # Here, we avoid this problem by composing the output as a product of two |
| # cumprod operations. |
| |
| input_shape = array_ops.shape(op.inputs[0]) |
| # Reshape reduction indices for the case where the parameter is a scalar |
| reduction_indices = array_ops.reshape(op.inputs[1], [-1]) |
| |
| # Expand grad to full input shape |
| output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1]) |
| tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims) |
| grad = array_ops.reshape(grad, output_shape_kept_dims) |
| grad = array_ops.tile(grad, tile_scaling) |
| |
| # Pack all reduced dimensions into a single one, so we can perform the |
| # cumprod ops. If the reduction dims list is empty, it defaults to float32, |
| # so we need to cast here. We put all the shape-related ops on CPU to avoid |
| # copying back and forth, and since listdiff is CPU only. |
| with ops.device("/cpu:0"): |
| rank = array_ops.rank(op.inputs[0]) |
| reduction_indices = (reduction_indices + rank) % rank |
| reduced = math_ops.cast(reduction_indices, dtypes.int32) |
| idx = math_ops.range(0, rank) |
| other, _ = array_ops.setdiff1d(idx, reduced) |
| perm = array_ops.concat([reduced, other], 0) |
| reduced_num = math_ops.reduce_prod(array_ops.gather(input_shape, reduced)) |
| other_num = math_ops.reduce_prod(array_ops.gather(input_shape, other)) |
| permuted = array_ops.transpose(op.inputs[0], perm) |
| permuted_shape = array_ops.shape(permuted) |
| reshaped = array_ops.reshape(permuted, (reduced_num, other_num)) |
| |
| # Calculate product, leaving out the current entry |
| left = math_ops.cumprod(reshaped, axis=0, exclusive=True) |
| right = math_ops.cumprod(reshaped, axis=0, exclusive=True, reverse=True) |
| # For complex inputs, the gradient is in the conjugate direction. |
| y = array_ops.reshape( |
| math_ops.conj(left) * math_ops.conj(right), permuted_shape) |
| |
| # Invert the transpose and reshape operations. |
| # Make sure to set the statically known shape information through a reshape. |
| out = grad * array_ops.transpose(y, array_ops.invert_permutation(perm)) |
| return array_ops.reshape(out, input_shape), None |
| |
| |
| @ops.RegisterGradient("SegmentSum") |
| def _SegmentSumGrad(op, grad): |
| """Gradient for SegmentSum.""" |
| return array_ops.gather(grad, op.inputs[1]), None |
| |
| |
| @ops.RegisterGradient("SegmentMean") |
| def _SegmentMeanGrad(op, grad): |
| """Gradient for SegmentMean.""" |
| input_rank = array_ops.rank(op.inputs[0]) |
| ones_shape = array_ops.concat([ |
| array_ops.shape(op.inputs[1]), |
| array_ops.fill(array_ops.expand_dims(input_rank - 1, 0), 1) |
| ], 0) |
| ones = array_ops.fill(ones_shape, constant_op.constant(1, dtype=grad.dtype)) |
| scaled_grad = math_ops.divide(grad, math_ops.segment_sum(ones, op.inputs[1])) |
| return array_ops.gather(scaled_grad, op.inputs[1]), None |
| |
| |
| @ops.RegisterGradient("SparseSegmentSum") |
| def _SparseSegmentSumGrad(op, grad): |
| """Gradient for SparseSegmentSum.""" |
| input_rows = array_ops.shape(op.inputs[0])[0] |
| return (math_ops.unsorted_segment_sum( |
| array_ops.gather(grad, op.inputs[2]), op.inputs[1], input_rows), None, |
| None) |
| |
| |
| @ops.RegisterGradient("SparseSegmentSumWithNumSegments") |
| def _SparseSegmentSumWithNumSegmentsGrad(op, grad): |
| """Gradient for SparseSegmentSumWithNumSegments.""" |
| input_rows = array_ops.shape(op.inputs[0])[0] |
| return (math_ops.unsorted_segment_sum( |
| array_ops.gather(grad, op.inputs[2]), op.inputs[1], input_rows), None, |
| None, None) |
| |
| |
| @ops.RegisterGradient("SparseSegmentMean") |
| def _SparseSegmentMeanGrad(op, grad): |
| """Gradient for SparseSegmentMean.""" |
| dim0 = array_ops.shape(op.inputs[0])[0] |
| return (math_ops.sparse_segment_mean_grad(grad, op.inputs[1], op.inputs[2], |
| dim0), None, None) |
| |
| |
| @ops.RegisterGradient("SparseSegmentMeanWithNumSegments") |
| def _SparseSegmentMeanWithNumSegmentsGrad(op, grad): |
| """Gradient for SparseSegmentMeanWithNumSegments.""" |
| dim0 = array_ops.shape(op.inputs[0])[0] |
| return (math_ops.sparse_segment_mean_grad(grad, op.inputs[1], op.inputs[2], |
| dim0), None, None, None) |
| |
| |
| @ops.RegisterGradient("SparseSegmentSqrtN") |
| def _SparseSegmentSqrtNGrad(op, grad): |
| """Gradient for SparseSegmentSqrtN.""" |
| dim0 = array_ops.shape(op.inputs[0])[0] |
| return (math_ops.sparse_segment_sqrt_n_grad(grad, op.inputs[1], op.inputs[2], |
| dim0), None, None) |
| |
| |
| @ops.RegisterGradient("SparseSegmentSqrtNWithNumSegments") |
| def _SparseSegmentSqrtNWithNumSegmentsGrad(op, grad): |
| """Gradient for SparseSegmentSqrtNWithNumSegments.""" |
| dim0 = array_ops.shape(op.inputs[0])[0] |
| return (math_ops.sparse_segment_sqrt_n_grad(grad, op.inputs[1], op.inputs[2], |
| dim0), None, None, None) |
| |
| |
| def _SegmentMinOrMaxGrad(op, grad): |
| """ Gradient for SegmentMin and SegmentMax. """ |
| zeros = array_ops.zeros_like(op.inputs[0], dtype=op.inputs[0].dtype) |
| # Get the number of selected (minimum or maximum) elements in each segment. |
| gathered_outputs = array_ops.gather(op.outputs[0], op.inputs[1]) |
| is_selected = math_ops.equal(op.inputs[0], gathered_outputs) |
| num_selected = math_ops.segment_sum( |
| math_ops.cast(is_selected, grad.dtype), op.inputs[1]) |
| # Compute the gradient for each segment. The gradient for the ith segment is |
| # divided evenly among the selected elements in that segment. |
| weighted_grads = math_ops.divide(grad, num_selected) |
| gathered_grads = array_ops.gather(weighted_grads, op.inputs[1]) |
| return array_ops.where(is_selected, gathered_grads, zeros), None |
| |
| |
| @ops.RegisterGradient("SegmentMin") |
| def _SegmentMinGrad(op, grad): |
| """Gradient for SegmentMin.""" |
| return _SegmentMinOrMaxGrad(op, grad) |
| |
| |
| @ops.RegisterGradient("SegmentMax") |
| def _SegmentMaxGrad(op, grad): |
| """Gradient for SegmentMax.""" |
| return _SegmentMinOrMaxGrad(op, grad) |
| |
| |
| def _GatherDropNegatives(params, |
| ids, |
| zero_clipped_indices=None, |
| is_positive=None): |
| """ Helper function for unsorted segment ops. |
| |
| Gathers params for |
| positive segment ids and gathers 0 for inputs with negative segment id. |
| Also returns the clipped indices and a boolean mask with the same shape |
| as ids where a positive id is masked as true. With this, the latter two |
| can be passed as arguments to this function to reuse them. |
| """ |
| if zero_clipped_indices is None: |
| zero_clipped_indices = math_ops.maximum(ids, array_ops.zeros_like(ids)) |
| gathered = array_ops.gather(params, zero_clipped_indices) |
| if is_positive is None: |
| is_positive = math_ops.greater_equal(ids, 0) |
| # tf.where(condition, x, y) requires condition to have the same shape as x |
| # and y. |
| # todo(philjd): remove this if tf.where supports broadcasting (#9284) |
| for _ in range(gathered.shape.ndims - is_positive.shape.ndims): |
| is_positive = array_ops.expand_dims(is_positive, -1) |
| is_positive = ( |
| is_positive & array_ops.ones_like(gathered, dtype=dtypes.bool)) |
| # replace gathered params of negative indices with 0 |
| zero_slice = array_ops.zeros_like(gathered) |
| return (array_ops.where(is_positive, gathered, zero_slice), |
| zero_clipped_indices, is_positive) |
| |
| |
| def _UnsortedSegmentMinOrMaxGrad(op, grad): |
| """ Gradient for UnsortedSegmentMin and UnsortedSegmentMax. """ |
| # Get the number of selected (minimum or maximum) elements in each segment. |
| gathered_outputs, zero_clipped_indices, is_positive = \ |
| _GatherDropNegatives(op.outputs[0], op.inputs[1]) |
| is_selected = math_ops.equal(op.inputs[0], gathered_outputs) |
| is_selected = math_ops.logical_and(is_selected, is_positive) |
| num_selected = math_ops.unsorted_segment_sum( |
| math_ops.cast(is_selected, grad.dtype), op.inputs[1], op.inputs[2]) |
| # Compute the gradient for each segment. The gradient for the ith segment is |
| # divided evenly among the selected elements in that segment. |
| weighted_grads = math_ops.divide(grad, num_selected) |
| gathered_grads, _, _ = _GatherDropNegatives(weighted_grads, None, |
| zero_clipped_indices, is_positive) |
| zeros = array_ops.zeros_like(gathered_grads) |
| return array_ops.where(is_selected, gathered_grads, zeros), None, None |
| |
| |
| @ops.RegisterGradient("UnsortedSegmentSum") |
| def _UnsortedSegmentSumGrad(op, grad): |
| """Gradient for UnsortedSegmentSum.""" |
| return _GatherDropNegatives(grad, op.inputs[1])[0], None, None |
| |
| |
| @ops.RegisterGradient("UnsortedSegmentMax") |
| def _UnsortedSegmentMaxGrad(op, grad): |
| """ Gradient for UnsortedSegmentMax. """ |
| return _UnsortedSegmentMinOrMaxGrad(op, grad) |
| |
| |
| @ops.RegisterGradient("UnsortedSegmentMin") |
| def _UnsortedSegmentMinGrad(op, grad): |
| """ Gradient for UnsortedSegmentMin. """ |
| return _UnsortedSegmentMinOrMaxGrad(op, grad) |
| |
| |
| @ops.RegisterGradient("UnsortedSegmentProd") |
| def _UnsortedSegmentProdGrad(op, grad): |
| """ Gradient for UnsortedSegmentProd. |
| |
| The gradient can be expressed for each segment by dividing the segment's |
| product by each element of the segment input tensor, but this approach can't |
| deal with zeros in the input. |
| Unlike reduce_prod we can't use cumsum here as individual segments may have |
| a different number of elements. Therefore we consider three cases: |
| 1) A segment input contains no zeros and we can safely divide by the input |
| tensor. |
| 2) A segment contains exactly one zero. Then the gradient of each input of |
| the segment is zero except for the 0-input, there the gradient is |
| the product of the remaining segment entries. |
| 3) A segment contains at least two zeros. The gradient is zero for all |
| segment inputs. |
| """ |
| # Note that unsorted_segment_sum will filter out the negative indices, |
| # so we don't need to do a logical_and with is_positive here |
| is_zero = math_ops.equal(op.inputs[0], 0) |
| num_zeros = gen_math_ops.unsorted_segment_sum( |
| math_ops.cast(is_zero, dtype=dtypes.int32), op.inputs[1], op.inputs[2]) |
| # handle case 3 and set the gradient to 0 for segments with more than one |
| # 0 as input |
| grad = array_ops.where( |
| math_ops.greater(num_zeros, 1), array_ops.zeros_like(grad), grad) |
| # replace all zeros with ones and compute the unsorted_segment_prod |
| non_zero_data = array_ops.where(is_zero, array_ops.ones_like(op.inputs[0]), |
| op.inputs[0]) |
| non_zero_prod = gen_math_ops.unsorted_segment_prod(non_zero_data, |
| op.inputs[1], op.inputs[2]) |
| # clip the indices for gather to be positive |
| zero_clipped_indices = math_ops.maximum(op.inputs[1], |
| array_ops.zeros_like(op.inputs[1])) |
| gathered_prod = array_ops.gather(op.outputs[0], zero_clipped_indices) |
| gathered_non_zero_prod = array_ops.gather(non_zero_prod, zero_clipped_indices) |
| prod_divided_by_el = gathered_prod / op.inputs[0] # May contain nan/inf. |
| # Now fetch the individual results for segments containing 0 and those that |
| # don't. is_zero will also fetch results for entries with negative index |
| # but the following gather_drop_negatives sets the corresponding entry in |
| # grad to 0 for these |
| partial_derivative = array_ops.where(is_zero, gathered_non_zero_prod, |
| prod_divided_by_el) |
| gathered_grad = _GatherDropNegatives(grad, op.inputs[1], |
| zero_clipped_indices)[0] |
| return gathered_grad * partial_derivative, None, None |
| |
| |
| @ops.RegisterGradient("Abs") |
| def _AbsGrad(op, grad): |
| x = op.inputs[0] |
| return grad * math_ops.sign(x) |
| |
| |
| @ops.RegisterGradient("Neg") |
| def _NegGrad(_, grad): |
| """Returns -grad.""" |
| return -grad |
| |
| |
| @ops.RegisterGradient("Inv") |
| def _InvGrad(op, grad): |
| """Returns -grad * (1 / x^2).""" |
| y = op.outputs[0] # y = 1 / x |
| return gen_math_ops.reciprocal_grad(y, grad) |
| |
| |
| @ops.RegisterGradient("Reciprocal") |
| def _ReciprocalGrad(op, grad): |
| """Returns -grad * (1 / x^2).""" |
| y = op.outputs[0] # y = 1 / x |
| return gen_math_ops.reciprocal_grad(y, grad) |
| |
| |
| @ops.RegisterGradient("InvGrad") |
| def _InvGradGrad(op, grad): |
| b = op.inputs[1] |
| # op.output[0]: y = -b * conj(a)^2 |
| with ops.control_dependencies([grad]): |
| ca = math_ops.conj(op.inputs[0]) |
| cg = math_ops.conj(grad) |
| return cg * -2.0 * b * ca, gen_math_ops.reciprocal_grad(ca, grad) |
| |
| |
| @ops.RegisterGradient("ReciprocalGrad") |
| def _ReciprocalGradGrad(op, grad): |
| b = op.inputs[1] |
| # op.output[0]: y = -b * conj(a)^2 |
| with ops.control_dependencies([grad]): |
| ca = math_ops.conj(op.inputs[0]) |
| cg = math_ops.conj(grad) |
| return cg * -2.0 * b * ca, gen_math_ops.reciprocal_grad(ca, grad) |
| |
| |
| @ops.RegisterGradient("Square") |
| def _SquareGrad(op, grad): |
| x = op.inputs[0] |
| # Added control dependencies to prevent 2*x from being computed too early. |
| with ops.control_dependencies([grad]): |
| x = math_ops.conj(x) |
| y = constant_op.constant(2.0, dtype=x.dtype) |
| return math_ops.multiply(grad, math_ops.multiply(x, y)) |
| |
| |
| @ops.RegisterGradient("Sqrt") |
| def _SqrtGrad(op, grad): |
| y = op.outputs[0] # y = x^(1/2) |
| return gen_math_ops.sqrt_grad(y, grad) |
| |
| |
| @ops.RegisterGradient("SqrtGrad") |
| def _SqrtGradGrad(op, grad): |
| a = op.inputs[0] |
| y = op.outputs[0] # y = 0.5 * b / conj(a) |
| with ops.control_dependencies([grad]): |
| if compat.forward_compatible(2019, 9, 14): |
| ga = gen_math_ops.xdivy(grad, a) |
| return -gen_math_ops.mul_no_nan(y, math_ops.conj(ga)), 0.5 * ga |
| else: |
| ga = grad / a |
| return -math_ops.conj(ga) * y, 0.5 * ga |
| |
| |
| @ops.RegisterGradient("Rsqrt") |
| def _RsqrtGrad(op, grad): |
| """Returns -0.5 * grad * conj(y)^3.""" |
| y = op.outputs[0] # y = x^(-1/2) |
| return gen_math_ops.rsqrt_grad(y, grad) |
| |
| |
| @ops.RegisterGradient("RsqrtGrad") |
| def _RsqrtGradGrad(op, grad): |
| """Returns backprop gradient for f(a,b) = -0.5 * b * conj(a)^3.""" |
| a = op.inputs[0] # a = x^{-1/2} |
| b = op.inputs[1] # backprop gradient for a |
| with ops.control_dependencies([grad]): |
| ca = math_ops.conj(a) |
| cg = math_ops.conj(grad) |
| grad_a = -1.5 * cg * b * math_ops.square(ca) |
| grad_b = gen_math_ops.rsqrt_grad(ca, grad) |
| return grad_a, grad_b |
| |
| |
| @ops.RegisterGradient("Exp") |
| def _ExpGrad(op, grad): |
| """Returns grad * exp(x).""" |
| y = op.outputs[0] # y = e^x |
| with ops.control_dependencies([grad]): |
| y = math_ops.conj(y) |
| if compat.forward_compatible(2019, 9, 14): |
| return math_ops.mul_no_nan(y, grad) |
| else: |
| return grad * y |
| |
| |
| @ops.RegisterGradient("Expm1") |
| def _Expm1Grad(op, grad): |
| """Returns grad * exp(x).""" |
| x = op.inputs[0] |
| with ops.control_dependencies([grad]): |
| x = math_ops.conj(x) |
| y = math_ops.exp(x) |
| if compat.forward_compatible(2019, 9, 14): |
| return math_ops.mul_no_nan(y, grad) |
| else: |
| return grad * y |
| |
| |
| @ops.RegisterGradient("Log") |
| def _LogGrad(op, grad): |
| """Returns grad * (1/x).""" |
| x = op.inputs[0] |
| with ops.control_dependencies([grad]): |
| x = math_ops.conj(x) |
| if compat.forward_compatible(2019, 9, 14): |
| return gen_math_ops.xdivy(grad, x) |
| else: |
| return grad * math_ops.reciprocal(x) |
| |
| |
| @ops.RegisterGradient("Log1p") |
| def _Log1pGrad(op, grad): |
| """Returns grad * (1/(1 + x)).""" |
| x = op.inputs[0] |
| with ops.control_dependencies([grad]): |
| x = math_ops.conj(x) |
| if compat.forward_compatible(2019, 9, 14): |
| return gen_math_ops.xdivy(grad, 1 + x) |
| else: |
| return grad * math_ops.reciprocal(1 + x) |
| |
| |
| @ops.RegisterGradient("Xlogy") |
| def _XLogyGrad(op, grad): |
| """Returns gradient of xlogy(x, y) with respect to x and y.""" |
| x = op.inputs[0] |
| y = op.inputs[1] |
| sx = array_ops.shape(x) |
| sy = array_ops.shape(y) |
| rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) |
| with ops.control_dependencies([grad]): |
| not_zero_x = math_ops.cast( |
| math_ops.not_equal(x, math_ops.cast(0., dtype=x.dtype)), dtype=x.dtype) |
| partial_x = gen_math_ops.xlogy(not_zero_x, y) |
| partial_y = gen_math_ops.xdivy(x, y) |
| return (array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx), |
| array_ops.reshape(math_ops.reduce_sum(partial_y * grad, ry), sy)) |
| |
| |
| @ops.RegisterGradient("Xdivy") |
| def _XDivyGrad(op, grad): |
| """Returns gradient of xdivy(x, y) with respect to x and y.""" |
| x = op.inputs[0] |
| y = op.inputs[1] |
| sx = array_ops.shape(x) |
| sy = array_ops.shape(y) |
| rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) |
| with ops.control_dependencies([grad]): |
| not_zero_x = math_ops.cast( |
| math_ops.not_equal(x, math_ops.cast(0., dtype=x.dtype)), dtype=x.dtype) |
| partial_x = gen_math_ops.xdivy(not_zero_x, y) |
| partial_y = gen_math_ops.xdivy(math_ops.negative(x), y**2) |
| return (array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx), |
| array_ops.reshape(math_ops.reduce_sum(partial_y * grad, ry), sy)) |
| |
| |
| @ops.RegisterGradient("Sinh") |
| def _SinhGrad(op, grad): |
| """Returns grad * cosh(x).""" |
| x = op.inputs[0] |
| with ops.control_dependencies([grad]): |
| x = math_ops.conj(x) |
| return grad * math_ops.cosh(x) |
| |
| |
| @ops.RegisterGradient("Cosh") |
| def _CoshGrad(op, grad): |
| """Returns grad * sinh(x).""" |
| x = op.inputs[0] |
| with ops.control_dependencies([grad]): |
| x = math_ops.conj(x) |
| return grad * math_ops.sinh(x) |
| |
| |
| @ops.RegisterGradient("Tanh") |
| def _TanhGrad(op, grad): |
| """Returns grad * (1 - tanh(x) * tanh(x)).""" |
| y = op.outputs[0] # y = tanh(x) |
| with ops.control_dependencies([grad]): |
| y = math_ops.conj(y) |
| return gen_math_ops.tanh_grad(y, grad) |
| |
| |
| @ops.RegisterGradient("Asinh") |
| def _AsinhGrad(op, grad): |
| """Returns grad * 1/cosh(y).""" |
| y = op.outputs[0] |
| with ops.control_dependencies([grad]): |
| y = math_ops.conj(y) |
| return grad / math_ops.cosh(y) |
| |
| |
| @ops.RegisterGradient("Acosh") |
| def _AcoshGrad(op, grad): |
| """Returns grad * 1/sinh(y).""" |
| y = op.outputs[0] |
| with ops.control_dependencies([grad]): |
| y = math_ops.conj(y) |
| if compat.forward_compatible(2019, 9, 14): |
| return math_ops.xdivy(grad, math_ops.sinh(y)) |
| else: |
| return grad / math_ops.sinh(y) |
| |
| |
| @ops.RegisterGradient("Atanh") |
| def _AtanhGrad(op, grad): |
| """Returns grad * 1/ (1 - x^2).""" |
| x = op.inputs[0] |
| with ops.control_dependencies([grad]): |
| x = math_ops.conj(x) |
| x2 = math_ops.square(x) |
| one = constant_op.constant(1, dtype=grad.dtype) |
| inv = math_ops.reciprocal(math_ops.subtract(one, x2)) |
| return grad * inv |
| |
| |
| @ops.RegisterGradient("TanhGrad") |
| def _TanhGradGrad(op, grad): |
| with ops.control_dependencies([grad]): |
| a = math_ops.conj(op.inputs[0]) |
| b = math_ops.conj(op.inputs[1]) |
| return grad * -2.0 * b * a, gen_math_ops.tanh_grad(a, grad) |
| |
| |
| @ops.RegisterGradient("Erf") |
| def _ErfGrad(op, grad): |
| """Returns grad * 2/sqrt(pi) * exp(-x**2).""" |
| x = op.inputs[0] |
| two_over_root_pi = constant_op.constant(2 / np.sqrt(np.pi), dtype=grad.dtype) |
| with ops.control_dependencies([grad]): |
| x = math_ops.conj(x) |
| return grad * two_over_root_pi * math_ops.exp(-math_ops.square(x)) |
| |
| |
| @ops.RegisterGradient("Erfc") |
| def _ErfcGrad(op, grad): |
| """Returns -grad * 2/sqrt(pi) * exp(-x**2).""" |
| x = op.inputs[0] |
| minus_two_over_root_pi = constant_op.constant( |
| -2 / np.sqrt(np.pi), dtype=grad.dtype) |
| with ops.control_dependencies([grad]): |
| x = math_ops.conj(x) |
| return grad * minus_two_over_root_pi * math_ops.exp(-math_ops.square(x)) |
| |
| |
| @ops.RegisterGradient("Lgamma") |
| def _LgammaGrad(op, grad): |
| """Returns grad * digamma(x).""" |
| x = op.inputs[0] |
| with ops.control_dependencies([grad]): |
| x = math_ops.conj(x) |
| if compat.forward_compatible(2019, 9, 14): |
| return math_ops.mul_no_nan(math_ops.digamma(x), grad) |
| else: |
| return grad * math_ops.digamma(x) |
| |
| |
| @ops.RegisterGradient("Digamma") |
| def _DigammaGrad(op, grad): |
| """Compute gradient of the digamma function with respect to its argument.""" |
| x = op.inputs[0] |
| with ops.control_dependencies([grad]): |
| x = math_ops.conj(x) |
| partial_x = math_ops.polygamma(array_ops.constant(1, dtype=x.dtype), x) |
| if compat.forward_compatible(2019, 9, 14): |
| return math_ops.mul_no_nan(partial_x, grad) |
| else: |
| return grad * partial_x |
| |
| |
| @ops.RegisterGradient("BesselI0e") |
| def _BesselI0eGrad(op, grad): |
| """Compute gradient of bessel_i0e(x) with respect to its argument.""" |
| x = op.inputs[0] |
| y = op.outputs[0] |
| with ops.control_dependencies([grad]): |
| partial_x = (math_ops.bessel_i1e(x) - math_ops.sign(x) * y) |
| if compat.forward_compatible(2019, 9, 14): |
| return math_ops.mul_no_nan(partial_x, grad) |
| else: |
| return grad * partial_x |
| |
| |
| @ops.RegisterGradient("BesselI1e") |
| def _BesselI1eGrad(op, grad): |
| """Compute gradient of bessel_i1e(x) with respect to its argument.""" |
| x = op.inputs[0] |
| y = op.outputs[0] |
| with ops.control_dependencies([grad]): |
| # For x = 0, the correct gradient is 0.5. |
| # However, the main branch gives NaN because of the division by x, so |
| # we impute the gradient manually. |
| # An alternative solution is to express the gradient via bessel_i0e and |
| # bessel_i2e, but the latter is not yet implemented in Eigen. |
| eps = np.finfo(x.dtype.as_numpy_dtype).eps |
| zeros = array_ops.zeros_like(x) |
| x_is_not_tiny = math_ops.abs(x) > eps |
| safe_x = array_ops.where(x_is_not_tiny, x, eps + zeros) |
| dy_dx = math_ops.bessel_i0e(safe_x) - y * ( |
| math_ops.sign(safe_x) + math_ops.reciprocal(safe_x)) |
| dy_dx = array_ops.where(x_is_not_tiny, dy_dx, 0.5 + zeros) |
| if compat.forward_compatible(2019, 9, 14): |
| return math_ops.mul_no_nan(dy_dx, grad) |
| else: |
| return grad * dy_dx |
| |
| |
| @ops.RegisterGradient("Igamma") |
| def _IgammaGrad(op, grad): |
| """Returns gradient of igamma(a, x) with respect to a and x.""" |
| a = op.inputs[0] |
| x = op.inputs[1] |
| sa = array_ops.shape(a) |
| sx = array_ops.shape(x) |
| ra, rx = gen_array_ops.broadcast_gradient_args(sa, sx) |
| |
| with ops.control_dependencies([grad]): |
| partial_a = gen_math_ops.igamma_grad_a(a, x) |
| # Perform operations in log space before summing, because Gamma(a) |
| # and Gamma'(a) can grow large. |
| partial_x = math_ops.exp(-x + (a - 1) * math_ops.log(x) - |
| math_ops.lgamma(a)) |
| if compat.forward_compatible(2019, 9, 14): |
| return (array_ops.reshape( |
| math_ops.reduce_sum(math_ops.mul_no_nan(partial_a, grad), ra), sa), |
| array_ops.reshape( |
| math_ops.reduce_sum(math_ops.mul_no_nan(partial_x, grad), rx), |
| sx)) |
| else: |
| return (array_ops.reshape(math_ops.reduce_sum(partial_a * grad, ra), sa), |
| array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx)) |
| |
| |
| @ops.RegisterGradient("Igammac") |
| def _IgammacGrad(op, grad): |
| """Returns gradient of igammac(a, x) = 1 - igamma(a, x) w.r.t. a and x.""" |
| igamma_grad_a, igamma_grad_x = _IgammaGrad(op, grad) |
| return (-igamma_grad_a, -igamma_grad_x) |
| |
| |
| @ops.RegisterGradient("Betainc") |
| def _BetaincGrad(op, grad): |
| """Returns gradient of betainc(a, b, x) with respect to x.""" |
| # TODO(ebrevdo): Perhaps add the derivative w.r.t. a, b |
| a, b, x = op.inputs |
| |
| # two cases: x is a scalar and a/b are same-shaped tensors, or vice |
| # versa; so its sufficient to check against shape(a). |
| sa = array_ops.shape(a) |
| sx = array_ops.shape(x) |
| _, rx = gen_array_ops.broadcast_gradient_args(sa, sx) |
| |
| # Perform operations in log space before summing, because terms |
| # can grow large. |
| log_beta = ( |
| gen_math_ops.lgamma(a) + gen_math_ops.lgamma(b) - |
| gen_math_ops.lgamma(a + b)) |
| partial_x = math_ops.exp((b - 1) * math_ops.log(1 - x) + |
| (a - 1) * math_ops.log(x) - log_beta) |
| |
| # TODO(b/36815900): Mark None return values as NotImplemented |
| if compat.forward_compatible(2019, 9, 14): |
| return ( |
| None, # da |
| None, # db |
| array_ops.reshape( |
| math_ops.reduce_sum(math_ops.mul_no_nan(partial_x, grad), rx), sx)) |
| else: |
| return ( |
| None, # da |
| None, # db |
| array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx)) |
| |
| |
| @ops.RegisterGradient("Zeta") |
| def _ZetaGrad(op, grad): |
| """Returns gradient of zeta(x, q) with respect to x and q.""" |
| # TODO(tillahoffmann): Add derivative with respect to x |
| x = op.inputs[0] |
| q = op.inputs[1] |
| # Broadcast gradients |
| sx = array_ops.shape(x) |
| sq = array_ops.shape(q) |
| unused_rx, rq = gen_array_ops.broadcast_gradient_args(sx, sq) |
| # Evaluate gradient |
| with ops.control_dependencies([grad]): |
| x = math_ops.conj(x) |
| q = math_ops.conj(q) |
| partial_q = -x * math_ops.zeta(x + 1, q) |
| # TODO(b/36815900): Mark None return values as NotImplemented |
| if compat.forward_compatible(2019, 9, 14): |
| return (None, |
| array_ops.reshape( |
| math_ops.reduce_sum(math_ops.mul_no_nan(partial_q, grad), rq), |
| sq)) |
| else: |
| return (None, |
| array_ops.reshape(math_ops.reduce_sum(partial_q * grad, rq), sq)) |
| |
| |
| @ops.RegisterGradient("Polygamma") |
| def _PolygammaGrad(op, grad): |
| """Returns gradient of psi(n, x) with respect to n and x.""" |
| # TODO(tillahoffmann): Add derivative with respect to n |
| n = op.inputs[0] |
| x = op.inputs[1] |
| # Broadcast gradients |
| sn = array_ops.shape(n) |
| sx = array_ops.shape(x) |
| unused_rn, rx = gen_array_ops.broadcast_gradient_args(sn, sx) |
| # Evaluate gradient |
| with ops.control_dependencies([grad]): |
| n = math_ops.conj(n) |
| x = math_ops.conj(x) |
| partial_x = math_ops.polygamma(n + 1, x) |
| # TODO(b/36815900): Mark None return values as NotImplemented |
| if compat.forward_compatible(2019, 9, 14): |
| return (None, |
| array_ops.reshape( |
| math_ops.reduce_sum(math_ops.mul_no_nan(partial_x, grad), rx), |
| sx)) |
| else: |
| return (None, |
| array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx)) |
| |
| |
| @ops.RegisterGradient("Sigmoid") |
| def _SigmoidGrad(op, grad): |
| """Returns grad * sigmoid(x) * (1 - sigmoid(x)).""" |
| y = op.outputs[0] # y = sigmoid(x) |
| with ops.control_dependencies([grad]): |
| y = math_ops.conj(y) |
| return gen_math_ops.sigmoid_grad(y, grad) |
| |
| |
| @ops.RegisterGradient("SigmoidGrad") |
| def _SigmoidGradGrad(op, grad): |
| with ops.control_dependencies([grad]): |
| a = math_ops.conj(op.inputs[0]) |
| b = math_ops.conj(op.inputs[1]) |
| gb = grad * b |
| return gb - 2.0 * gb * a, gen_math_ops.sigmoid_grad(a, grad) |
| |
| |
| @ops.RegisterGradient("Sign") |
| def _SignGrad(op, _): |
| """Returns 0.""" |
| x = op.inputs[0] |
| return array_ops.zeros(array_ops.shape(x), dtype=x.dtype) |
| |
| |
| @ops.RegisterGradient("Sin") |
| def _SinGrad(op, grad): |
| """Returns grad * cos(x).""" |
| x = op.inputs[0] |
| with ops.control_dependencies([grad]): |
| x = math_ops.conj(x) |
| return grad * math_ops.cos(x) |
| |
| |
| @ops.RegisterGradient("Cos") |
| def _CosGrad(op, grad): |
| """Returns grad * -sin(x).""" |
| x = op.inputs[0] |
| with ops.control_dependencies([grad]): |
| x = math_ops.conj(x) |
| return -grad * math_ops.sin(x) |
| |
| |
| @ops.RegisterGradient("Tan") |
| def _TanGrad(op, grad): |
| """Returns grad * 1/sec^2(x).""" |
| x = op.inputs[0] |
| with ops.control_dependencies([grad]): |
| x = math_ops.conj(x) |
| secx = math_ops.reciprocal(math_ops.cos(x)) |
| secx2 = math_ops.square(secx) |
| if compat.forward_compatible(2019, 9, 14): |
| return math_ops.mul_no_nan(secx2, grad) |
| else: |
| return secx2 * grad |
| |
| |
| @ops.RegisterGradient("Asin") |
| def _AsinGrad(op, grad): |
| """Returns grad * 1/sqrt(1-x^2).""" |
| x = op.inputs[0] |
| with ops.control_dependencies([grad]): |
| x = math_ops.conj(x) |
| x2 = math_ops.square(x) |
| one = constant_op.constant(1, dtype=grad.dtype) |
| den = math_ops.sqrt(math_ops.subtract(one, x2)) |
| if compat.forward_compatible(2019, 9, 14): |
| return math_ops.xdivy(grad, den) |
| else: |
| inv = math_ops.reciprocal(den) |
| return grad * inv |
| |
| |
| @ops.RegisterGradient("Acos") |
| def _AcosGrad(op, grad): |
| """Returns grad * -1/sqrt(1-x^2).""" |
| x = op.inputs[0] |
| with ops.control_dependencies([grad]): |
| x = math_ops.conj(x) |
| x2 = math_ops.square(x) |
| one = constant_op.constant(1, dtype=grad.dtype) |
| den = math_ops.sqrt(math_ops.subtract(one, x2)) |
| if compat.forward_compatible(2019, 9, 14): |
| return -math_ops.xdivy(grad, den) |
| else: |
| inv = math_ops.reciprocal(den) |
| return -grad * inv |
| |
| |
| @ops.RegisterGradient("Atan") |
| def _AtanGrad(op, grad): |
| """Returns grad * 1/ (1 + x^2).""" |
| x = op.inputs[0] |
| with ops.control_dependencies([grad]): |
| x = math_ops.conj(x) |
| x2 = math_ops.square(x) |
| one = constant_op.constant(1, dtype=grad.dtype) |
| inv = math_ops.reciprocal(math_ops.add(one, x2)) |
| return grad * inv |
| |
| |
| @ops.RegisterGradient("Atan2") |
| def _Atan2Grad(op, grad): |
| """Returns grad * x / (x^2 + y^2), grad * -y / (x^2 + y^2).""" |
| y = op.inputs[0] |
| x = op.inputs[1] |
| with ops.control_dependencies([grad]): |
| if compat.forward_compatible(2019, 9, 14): |
| grad_inv = math_ops.xdivy(grad, (math_ops.square(x) + math_ops.square(y))) |
| else: |
| grad_inv = grad / (math_ops.square(x) + math_ops.square(y)) |
| return x * grad_inv, -y * grad_inv |
| |
| |
| @ops.RegisterGradient("AddN") |
| def _AddNGrad(op, grad): |
| """Copies the gradient to all inputs.""" |
| # Not broadcasting. |
| return [grad] * len(op.inputs) |
| |
| |
| def _ShapesFullySpecifiedAndEqual(x, y, grad): |
| # pylint: disable=protected-access |
| x_shape = x._shape_tuple() |
| y_shape = y._shape_tuple() |
| grad_shape = grad._shape_tuple() |
| # pylint: enable=protected-access |
| return (x_shape == y_shape and x_shape == grad_shape and |
| x_shape is not None and None not in x_shape) |
| |
| |
| @ops.RegisterGradient("Add") |
| @ops.RegisterGradient("AddV2") |
| def _AddGrad(op, grad): |
| """Gradient for Add.""" |
| y = op.inputs[1] |
| skip_input_indices = None |
| try: |
| skip_input_indices = op.skip_input_indices |
| if skip_input_indices is not None and 1 in skip_input_indices and _IsScalar( |
| y): |
| return grad, None |
| except AttributeError: |
| # No gradient skipping, so do the full gradient computation |
| pass |
| x = op.inputs[0] |
| if (isinstance(grad, ops.Tensor) and |
| _ShapesFullySpecifiedAndEqual(x, y, grad)): |
| return grad, grad |
| sx = array_ops.shape(x) |
| sy = array_ops.shape(y) |
| rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) |
| if skip_input_indices is not None and 0 in skip_input_indices: |
| gx = None |
| else: |
| gx = array_ops.reshape(math_ops.reduce_sum(grad, rx), sx) |
| if skip_input_indices is not None and 1 in skip_input_indices: |
| gy = None |
| else: |
| gy = array_ops.reshape(math_ops.reduce_sum(grad, ry), sy) |
| return (gx, gy) |
| |
| |
| |
| @ops.RegisterGradient("Sub") |
| def _SubGrad(op, grad): |
| """Gradient for Sub.""" |
| x = op.inputs[0] |
| y = op.inputs[1] |
| if (isinstance(grad, ops.Tensor) and |
| _ShapesFullySpecifiedAndEqual(x, y, grad)): |
| return grad, -grad |
| sx = array_ops.shape(x) |
| sy = array_ops.shape(y) |
| rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) |
| return (array_ops.reshape(math_ops.reduce_sum(grad, rx), sx), |
| array_ops.reshape(-math_ops.reduce_sum(grad, ry), sy)) |
| |
| |
| @ops.RegisterGradient("Mul") |
| def _MulGrad(op, grad): |
| """The gradient of scalar multiplication.""" |
| x = op.inputs[0] |
| y = op.inputs[1] |
| if (isinstance(grad, ops.Tensor) and |
| _ShapesFullySpecifiedAndEqual(x, y, grad) and |
| grad.dtype in (dtypes.int32, dtypes.float32)): |
| return gen_math_ops.mul(grad, y), gen_math_ops.mul(grad, x) |
| assert x.dtype.base_dtype == y.dtype.base_dtype, (x.dtype, " vs. ", y.dtype) |
| sx = array_ops.shape(x) |
| sy = array_ops.shape(y) |
| rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) |
| x = math_ops.conj(x) |
| y = math_ops.conj(y) |
| return (array_ops.reshape( |
| math_ops.reduce_sum(gen_math_ops.mul(grad, y), rx), sx), |
| array_ops.reshape( |
| math_ops.reduce_sum(gen_math_ops.mul(x, grad), ry), sy)) |
| |
| |
| @ops.RegisterGradient("MulNoNan") |
| def _MulNoNanGrad(op, grad): |
| """The gradient of scalar multiplication with NaN-suppression.""" |
| x = op.inputs[0] |
| y = op.inputs[1] |
| if (isinstance(grad, ops.Tensor) and |
| _ShapesFullySpecifiedAndEqual(x, y, grad)): |
| return gen_math_ops.mul_no_nan(grad, y), gen_math_ops.mul_no_nan(x, grad) |
| assert x.dtype.base_dtype == y.dtype.base_dtype, (x.dtype, " vs. ", y.dtype) |
| sx = array_ops.shape(x) |
| sy = array_ops.shape(y) |
| rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) |
| return (array_ops.reshape( |
| math_ops.reduce_sum(gen_math_ops.mul_no_nan(grad, y), rx), sx), |
| array_ops.reshape( |
| math_ops.reduce_sum(gen_math_ops.mul_no_nan(x, grad), ry), sy)) |
| |
| |
| @ops.RegisterGradient("Div") |
| def _DivGrad(op, grad): |
| """The gradient for the Div operator.""" |
| x = op.inputs[0] |
| y = op.inputs[1] |
| sx = array_ops.shape(x) |
| sy = array_ops.shape(y) |
| rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) |
| x = math_ops.conj(x) |
| y = math_ops.conj(y) |
| if compat.forward_compatible(2019, 9, 14): |
| return (array_ops.reshape( |
| math_ops.reduce_sum(math_ops.xdivy(grad, y), rx), sx), |
| array_ops.reshape( |
| math_ops.reduce_sum( |
| math_ops.mul_no_nan( |
| math_ops.divide(math_ops.divide(-x, y), y), grad), ry), |
| sy)) |
| else: |
| return (array_ops.reshape( |
| math_ops.reduce_sum(math_ops.divide(grad, y), rx), sx), |
| array_ops.reshape( |
| math_ops.reduce_sum( |
| grad * math_ops.divide(math_ops.divide(-x, y), y), ry), sy)) |
| |
| |
| @ops.RegisterGradient("FloorDiv") |
| def _FloorDivGrad(_, unused_grad): |
| """The gradient for the FloorDiv operator.""" |
| return None, None |
| |
| |
| @ops.RegisterGradient("FloorMod") |
| def _FloorModGrad(op, grad): |
| """Returns grad * (1, -floor(x/y)).""" |
| x = math_ops.conj(op.inputs[0]) |
| y = math_ops.conj(op.inputs[1]) |
| |
| sx = array_ops.shape(x) |
| sy = array_ops.shape(y) |
| rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) |
| floor_xy = math_ops.floor_div(x, y) |
| gx = array_ops.reshape(math_ops.reduce_sum(grad, rx), sx) |
| gy = array_ops.reshape( |
| math_ops.reduce_sum(grad * math_ops.negative(floor_xy), ry), sy) |
| return gx, gy |
| |
| |
| @ops.RegisterGradient("TruncateDiv") |
| def _TruncateDivGrad(_, unused_grad): |
| return None, None |
| |
| |
| @ops.RegisterGradient("RealDiv") |
| def _RealDivGrad(op, grad): |
| """RealDiv op gradient.""" |
| x = op.inputs[0] |
| y = op.inputs[1] |
| sx = array_ops.shape(x) |
| sy = array_ops.shape(y) |
| rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) |
| x = math_ops.conj(x) |
| y = math_ops.conj(y) |
| if compat.forward_compatible(2019, 9, 14): |
| return (array_ops.reshape( |
| math_ops.reduce_sum(math_ops.xdivy(grad, y), rx), sx), |
| array_ops.reshape( |
| math_ops.reduce_sum( |
| math_ops.mul_no_nan( |
| math_ops.realdiv(math_ops.realdiv(-x, y), y), grad), |
| ry), sy)) |
| else: |
| return (array_ops.reshape( |
| math_ops.reduce_sum(math_ops.realdiv(grad, y), rx), sx), |
| array_ops.reshape( |
| math_ops.reduce_sum( |
| grad * math_ops.realdiv(math_ops.realdiv(-x, y), y), ry), |
| sy)) |
| |
| |
| @ops.RegisterGradient("DivNoNan") |
| def _DivNoNanGrad(op, grad): |
| """DivNoNan op gradient.""" |
| x = op.inputs[0] |
| y = op.inputs[1] |
| sx = array_ops.shape(x) |
| sy = array_ops.shape(y) |
| rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) |
| x = math_ops.conj(x) |
| y = math_ops.conj(y) |
| if compat.forward_compatible(2019, 9, 14): |
| return (array_ops.reshape( |
| math_ops.reduce_sum(math_ops.div_no_nan(grad, y), rx), sx), |
| array_ops.reshape( |
| math_ops.reduce_sum( |
| math_ops.mul_no_nan( |
| math_ops.div_no_nan(math_ops.div_no_nan(-x, y), y), |
| grad), ry), sy)) |
| else: |
| return (array_ops.reshape( |
| math_ops.reduce_sum(math_ops.div_no_nan(grad, y), rx), sx), |
| array_ops.reshape( |
| math_ops.reduce_sum( |
| grad * math_ops.div_no_nan(math_ops.div_no_nan(-x, y), y), |
| ry), sy)) |
| |
| |
| @ops.RegisterGradient("Pow") |
| def _PowGrad(op, grad): |
| """Returns grad * (y*x^(y-1), z*log(x)).""" |
| x = op.inputs[0] |
| y = op.inputs[1] |
| z = op.outputs[0] |
| sx = array_ops.shape(x) |
| sy = array_ops.shape(y) |
| rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) |
| x = math_ops.conj(x) |
| y = math_ops.conj(y) |
| z = math_ops.conj(z) |
| |
| if compat.forward_compatible(2019, 9, 14): |
| gx = array_ops.reshape( |
| math_ops.reduce_sum( |
| gen_math_ops.mul_no_nan(y * math_ops.pow(x, y - 1), grad), rx), sx) |
| else: |
| gx = array_ops.reshape( |
| math_ops.reduce_sum(grad * y * math_ops.pow(x, y - 1), rx), sx) |
| # Avoid false singularity at x = 0 |
| if x.dtype.is_complex: |
| # real(x) < 0 is fine for the complex case |
| mask = math_ops.not_equal(x, 0) |
| else: |
| # There's no sensible real value to return if x < 0, so return 0 |
| mask = x > 0 |
| safe_x = array_ops.where(mask, x, array_ops.ones_like(x)) |
| log_x = array_ops.where(mask, math_ops.log(safe_x), array_ops.zeros_like(x)) |
| if compat.forward_compatible(2019, 9, 14): |
| gy = array_ops.reshape( |
| math_ops.reduce_sum(gen_math_ops.mul_no_nan(z * log_x, grad), ry), sy) |
| else: |
| gy = array_ops.reshape(math_ops.reduce_sum(grad * z * log_x, ry), sy) |
| return gx, gy |
| |
| |
| def _MaximumMinimumGradInputOnly(op, grad, selector_op): |
| x = op.inputs[0] |
| y = op.inputs[1] |
| zeros = array_ops.zeros_like(grad) |
| xmask = selector_op(x, y) |
| xgrad = array_ops.where(xmask, grad, zeros) |
| ygrad = None # Return None for ygrad since the config allows that. |
| return (xgrad, ygrad) |
| |
| |
| def _MaximumMinimumGrad(op, grad, selector_op): |
| """Factor out the code for the gradient of Maximum or Minimum.""" |
| y = op.inputs[1] |
| skip_input_indices = None |
| try: |
| skip_input_indices = op.skip_input_indices |
| if skip_input_indices is not None and 1 in skip_input_indices and _IsScalar( |
| y): |
| # When we want to get gradients for the first input only, and the second |
| # input tensor is a scalar, we can do a much simpler calculation |
| return _MaximumMinimumGradInputOnly(op, grad, selector_op) |
| except AttributeError: |
| # No gradient skipping, so do the full gradient computation |
| pass |
| x = op.inputs[0] |
| gdtype = grad.dtype |
| sx = array_ops.shape(x) |
| sy = array_ops.shape(y) |
| gradshape = array_ops.shape(grad) |
| zeros = array_ops.zeros(gradshape, gdtype) |
| xmask = selector_op(x, y) |
| rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) |
| if skip_input_indices is not None and 0 in skip_input_indices: |
| gx = None |
| else: |
| xgrad = array_ops.where(xmask, grad, zeros) |
| gx = array_ops.reshape(math_ops.reduce_sum(xgrad, rx), sx) |
| |
| if skip_input_indices is not None and 1 in skip_input_indices: |
| gy = None |
| else: |
| ygrad = array_ops.where(xmask, zeros, grad) |
| gy = array_ops.reshape(math_ops.reduce_sum(ygrad, ry), sy) |
| |
| return (gx, gy) |
| |
| |
| @ops.RegisterGradient("Maximum") |
| def _MaximumGrad(op, grad): |
| """Returns grad*(x > y, x <= y) with type of grad.""" |
| return _MaximumMinimumGrad(op, grad, math_ops.greater_equal) |
| |
| |
| @ops.RegisterGradient("Minimum") |
| def _MinimumGrad(op, grad): |
| """Returns grad*(x < y, x >= y) with type of grad.""" |
| return _MaximumMinimumGrad(op, grad, math_ops.less_equal) |
| |
| |
| @ops.RegisterGradient("SquaredDifference") |
| def _SquaredDifferenceGrad(op, grad): |
| """Returns the gradient for (x-y)^2.""" |
| x = op.inputs[0] |
| y = op.inputs[1] |
| sx = array_ops.shape(x) |
| sy = array_ops.shape(y) |
| rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) |
| with ops.control_dependencies([grad]): |
| # The parens ensure that if grad is IndexedSlices, it'll get multiplied by |
| # Tensor (not a number like 2.0) which causes it to convert to Tensor. |
| x_grad = math_ops.scalar_mul(2.0, grad) * (x - y) |
| return (array_ops.reshape(math_ops.reduce_sum(x_grad, rx), sx), |
| -array_ops.reshape(math_ops.reduce_sum(x_grad, ry), sy)) |
| |
| |
| # Logical operations have no gradients. |
| ops.NotDifferentiable("Less") |
| ops.NotDifferentiable("LessEqual") |
| ops.NotDifferentiable("Greater") |
| ops.NotDifferentiable("GreaterEqual") |
| ops.NotDifferentiable("Equal") |
| ops.NotDifferentiable("ApproximateEqual") |
| ops.NotDifferentiable("NotEqual") |
| ops.NotDifferentiable("LogicalAnd") |
| ops.NotDifferentiable("LogicalOr") |
| ops.NotDifferentiable("LogicalNot") |
| |
| |
| @ops.RegisterGradient("Select") |
| def _SelectGrad(op, grad): |
| c = op.inputs[0] |
| x = op.inputs[1] |
| zeros = array_ops.zeros_like(x) |
| return (None, array_ops.where(c, grad, zeros), array_ops.where( |
| c, zeros, grad)) |
| |
| |
| @ops.RegisterGradient("SelectV2") |
| def _SelectGradV2(op, grad): |
| c = op.inputs[0] |
| x = op.inputs[1] |
| y = op.inputs[2] |
| zeros = array_ops.zeros([], dtype=grad.dtype.base_dtype) |
| gx = array_ops.where_v2(c, grad, zeros) |
| x_shape = array_ops.shape(x) |
| output_shape = array_ops.shape(op.outputs[0]) |
| # Reduce away broadcasted leading dims. |
| reduce_x, _ = gen_array_ops.broadcast_gradient_args(x_shape, output_shape) |
| gx = math_ops.reduce_sum(gx, keepdims=True, axis=reduce_x) |
| gx = array_ops.reshape(gx, x_shape) |
| |
| gy = array_ops.where_v2(c, zeros, grad) |
| y_shape = array_ops.shape(y) |
| # Reduce away broadcasted leading dims. |
| reduce_y, _ = gen_array_ops.broadcast_gradient_args(y_shape, output_shape) |
| gy = math_ops.reduce_sum(gy, keepdims=True, axis=reduce_y) |
| gy = array_ops.reshape(gy, y_shape) |
| |
| return (None, gx, gy) |
| |
| |
| def _MatMulGradAgainstFirstOnly(op, grad): |
| """Gradient for MatMul, only for the first input.""" |
| t_a = op.get_attr("transpose_a") |
| t_b = op.get_attr("transpose_b") |
| b = math_ops.conj(op.inputs[1]) |
| if not t_a and not t_b: |
| grad_a = gen_math_ops.mat_mul(grad, b, transpose_b=True) |
| elif not t_a and t_b: |
| grad_a = gen_math_ops.mat_mul(grad, b) |
| elif t_a and not t_b: |
| grad_a = gen_math_ops.mat_mul(b, grad, transpose_b=True) |
| elif t_a and t_b: |
| grad_a = gen_math_ops.mat_mul(b, grad, transpose_a=True, transpose_b=True) |
| return grad_a, None |
| |
| |
| def _MatMulGradAgainstSecondOnly(op, grad): |
| """Gradient for MatMul, only for the second input.""" |
| t_a = op.get_attr("transpose_a") |
| t_b = op.get_attr("transpose_b") |
| a = math_ops.conj(op.inputs[0]) |
| if not t_a and not t_b: |
| grad_b = gen_math_ops.mat_mul(a, grad, transpose_a=True) |
| elif not t_a and t_b: |
| grad_b = gen_math_ops.mat_mul(grad, a, transpose_a=True) |
| elif t_a and not t_b: |
| grad_b = gen_math_ops.mat_mul(a, grad) |
| elif t_a and t_b: |
| grad_b = gen_math_ops.mat_mul(grad, a, transpose_a=True, transpose_b=True) |
| return None, grad_b |
| |
| |
| @ops.RegisterGradient("MatMul") |
| def _MatMulGrad(op, grad): |
| """Gradient for MatMul.""" |
| try: |
| skip_input_indices = op.skip_input_indices |
| if skip_input_indices is not None: |
| if 1 in skip_input_indices: |
| return _MatMulGradAgainstFirstOnly(op, grad) |
| elif 0 in skip_input_indices: |
| return _MatMulGradAgainstSecondOnly(op, grad) |
| except AttributeError: |
| # No gradient skipping, so do the full gradient computation |
| pass |
| |
| t_a = op.get_attr("transpose_a") |
| t_b = op.get_attr("transpose_b") |
| a = math_ops.conj(op.inputs[0]) |
| b = math_ops.conj(op.inputs[1]) |
| if not t_a and not t_b: |
| grad_a = gen_math_ops.mat_mul(grad, b, transpose_b=True) |
| grad_b = gen_math_ops.mat_mul(a, grad, transpose_a=True) |
| elif not t_a and t_b: |
| grad_a = gen_math_ops.mat_mul(grad, b) |
| grad_b = gen_math_ops.mat_mul(grad, a, transpose_a=True) |
| elif t_a and not t_b: |
| grad_a = gen_math_ops.mat_mul(b, grad, transpose_b=True) |
| grad_b = gen_math_ops.mat_mul(a, grad) |
| elif t_a and t_b: |
| grad_a = gen_math_ops.mat_mul(b, grad, transpose_a=True, transpose_b=True) |
| grad_b = gen_math_ops.mat_mul(grad, a, transpose_a=True, transpose_b=True) |
| return grad_a, grad_b |
| |
| |
| @ops.RegisterGradient("SparseMatMul") |
| def _SparseMatMulGrad(op, grad): |
| """Gradient for SparseMatMul.""" |
| |
| t_a = op.get_attr("transpose_a") |
| t_b = op.get_attr("transpose_b") |
| is_sparse = { |
| op.inputs[0]: op.get_attr("a_is_sparse"), |
| op.inputs[1]: op.get_attr("b_is_sparse"), |
| # Use heuristic to figure out if grad might be sparse |
| grad: not context.executing_eagerly() and (grad.op.type == "ReluGrad") |
| } |
| |
| def _SparseMatMul(t1, t2, out_dtype, transpose_a=False, transpose_b=False): |
| """Helper function to create SparseMatMul op.""" |
| |
| assert t1 in is_sparse and t2 in is_sparse |
| t1_sparse = is_sparse[t1] |
| t2_sparse = is_sparse[t2] |
| if transpose_b: |
| t2 = array_ops.transpose(t2) |
| transpose_b = False |
| prod = math_ops.matmul( |
| t1, |
| t2, |
| transpose_a=transpose_a, |
| transpose_b=transpose_b, |
| a_is_sparse=t1_sparse, |
| b_is_sparse=t2_sparse) |
| if prod.dtype != out_dtype: |
| prod = math_ops.cast(prod, out_dtype) |
| return prod |
| |
| dtype_a = op.inputs[0].dtype |
| dtype_b = op.inputs[1].dtype |
| if not t_a and not t_b: |
| return (_SparseMatMul(grad, op.inputs[1], dtype_a, transpose_b=True), |
| _SparseMatMul(op.inputs[0], grad, dtype_b, transpose_a=True)) |
| elif not t_a and t_b: |
| return (_SparseMatMul(grad, op.inputs[1], dtype_a), |
| _SparseMatMul(grad, op.inputs[0], dtype_b, transpose_a=True)) |
| elif t_a and not t_b: |
| return (_SparseMatMul(op.inputs[1], grad, dtype_a, transpose_b=True), |
| _SparseMatMul(op.inputs[0], grad, dtype_b)) |
| elif t_a and t_b: |
| return (_SparseMatMul( |
| op.inputs[1], grad, dtype_a, transpose_a=True, transpose_b=True), |
| _SparseMatMul( |
| grad, op.inputs[0], dtype_b, transpose_a=True, |
| transpose_b=True)) |
| |
| |
| @ops.RegisterGradient("Floor") |
| def _FloorGrad(_, unused_grad): |
| return [None] |
| |
| |
| @ops.RegisterGradient("Ceil") |
| def _CeilGrad(_, unused_grad): |
| return [None] |
| |
| |
| @ops.RegisterGradient("Round") |
| def _RoundGrad(_, unused_grad): |
| return [None] |
| |
| |
| @ops.RegisterGradient("Rint") |
| def _RintGrad(_, unused_grad): |
| # the gradient of Rint is zero |
| return [None] |
| |
| |
| @ops.RegisterGradient("BatchMatMul") |
| def _BatchMatMul(op, grad): |
| """Returns the gradient of x and y given the gradient of x * y.""" |
| x = op.inputs[0] |
| y = op.inputs[1] |
| adj_x = op.get_attr("adj_x") |
| adj_y = op.get_attr("adj_y") |
| |
| if not adj_x: |
| if not adj_y: |
| grad_x = math_ops.matmul(grad, y, adjoint_a=False, adjoint_b=True) |
| grad_y = math_ops.matmul(x, grad, adjoint_a=True, adjoint_b=False) |
| else: |
| grad_x = math_ops.matmul(grad, y, adjoint_a=False, adjoint_b=False) |
| grad_y = math_ops.matmul(grad, x, adjoint_a=True, adjoint_b=False) |
| else: |
| if not adj_y: |
| grad_x = math_ops.matmul(y, grad, adjoint_a=False, adjoint_b=True) |
| grad_y = math_ops.matmul(x, grad, adjoint_a=False, adjoint_b=False) |
| else: |
| grad_x = math_ops.matmul(y, grad, adjoint_a=True, adjoint_b=True) |
| grad_y = math_ops.matmul(grad, x, adjoint_a=True, adjoint_b=True) |
| |
| return grad_x, grad_y |
| |
| |
| @ops.RegisterGradient("BatchMatMulV2") |
| def _BatchMatMulV2(op, grad): |
| """Returns the gradient of x and y given the gradient of x * y.""" |
| x = op.inputs[0] |
| y = op.inputs[1] |
| adj_x = op.get_attr("adj_x") |
| adj_y = op.get_attr("adj_y") |
| |
| if not adj_x: |
| if not adj_y: |
| grad_x = math_ops.matmul(grad, y, adjoint_a=False, adjoint_b=True) |
| grad_y = math_ops.matmul(x, grad, adjoint_a=True, adjoint_b=False) |
| else: |
| grad_x = math_ops.matmul(grad, y, adjoint_a=False, adjoint_b=False) |
| grad_y = math_ops.matmul(grad, x, adjoint_a=True, adjoint_b=False) |
| else: |
| if not adj_y: |
| grad_x = math_ops.matmul(y, grad, adjoint_a=False, adjoint_b=True) |
| grad_y = math_ops.matmul(x, grad, adjoint_a=False, adjoint_b=False) |
| else: |
| grad_x = math_ops.matmul(y, grad, adjoint_a=True, adjoint_b=True) |
| grad_y = math_ops.matmul(grad, x, adjoint_a=True, adjoint_b=True) |
| |
| # Reduce along the broadcasted batch dimensions, if broadcasting is required. |
| shape_x_static = x.get_shape() |
| shape_y_static = y.get_shape() |
| if not (shape_x_static.is_fully_defined() and |
| shape_y_static.is_fully_defined() and |
| shape_x_static == shape_y_static): |
| sx = array_ops.shape(x) |
| sy = array_ops.shape(y) |
| rx, ry = gen_array_ops.broadcast_gradient_args(sx[:-2], sy[:-2]) |
| grad_x = array_ops.reshape(math_ops.reduce_sum(grad_x, rx), sx) |
| grad_y = array_ops.reshape(math_ops.reduce_sum(grad_y, ry), sy) |
| |
| return grad_x, grad_y |
| |
| |
| ops.NotDifferentiable("Range") |
| ops.NotDifferentiable("LinSpace") |
| |
| |
| @ops.RegisterGradient("Complex") |
| def _ComplexGrad(op, grad): |
| """Returns the real and imaginary components of 'grad', respectively.""" |
| x = op.inputs[0] |
| y = op.inputs[1] |
| sx = array_ops.shape(x) |
| sy = array_ops.shape(y) |
| rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) |
| return (array_ops.reshape(math_ops.reduce_sum(math_ops.real(grad), rx), sx), |
| array_ops.reshape(math_ops.reduce_sum(math_ops.imag(grad), ry), sy)) |
| |
| |
| @ops.RegisterGradient("Real") |
| def _RealGrad(_, grad): |
| """Returns 'grad' as the real part and set the imaginary part 0.""" |
| zero = constant_op.constant(0, dtype=grad.dtype) |
| return math_ops.complex(grad, zero) |
| |
| |
| @ops.RegisterGradient("Imag") |
| def _ImagGrad(_, grad): |
| """Returns 'grad' as the imaginary part and set the real part 0.""" |
| zero = constant_op.constant(0, dtype=grad.dtype) |
| return math_ops.complex(zero, grad) |
| |
| |
| @ops.RegisterGradient("Angle") |
| def _AngleGrad(op, grad): |
| """Returns -grad / (Im(x) + iRe(x))""" |
| x = op.inputs[0] |
| with ops.control_dependencies([grad]): |
| re = math_ops.real(x) |
| im = math_ops.imag(x) |
| z = math_ops.reciprocal(math_ops.complex(im, re)) |
| zero = constant_op.constant(0, dtype=grad.dtype) |
| complex_grad = math_ops.complex(grad, zero) |
| return -complex_grad * z |
| |
| |
| @ops.RegisterGradient("Conj") |
| def _ConjGrad(_, grad): |
| """Returns the complex conjugate of grad.""" |
| return math_ops.conj(grad) |
| |
| |
| @ops.RegisterGradient("ComplexAbs") |
| def _ComplexAbsGrad(op, grad): |
| """Returns the gradient of ComplexAbs.""" |
| return math_ops.div_no_nan( |
| math_ops.complex( |
| grad, array_ops.zeros_like(grad)) * op.inputs[0], |
| math_ops.complex( |
| op.outputs[0], array_ops.zeros_like(op.outputs[0]))) |
| |
| |
| @ops.RegisterGradient("Cast") |
| def _CastGrad(op, grad): |
| t = [ |
| dtypes.float16, dtypes.float32, dtypes.float64, dtypes.bfloat16, |
| dtypes.complex64, dtypes.complex128 |
| ] |
| src_type = op.inputs[0].dtype.base_dtype |
| dst_type = grad.dtype.base_dtype |
| if src_type in t and dst_type in t: |
| return math_ops.cast(grad, src_type) |
| else: |
| return None |
| |
| |
| @ops.RegisterGradient("Cross") |
| def _CrossGrad(op, grad): |
| u = op.inputs[0] |
| v = op.inputs[1] |
| return (math_ops.cross(v, grad), math_ops.cross(grad, u)) |
| |
| |
| @ops.RegisterGradient("Cumsum") |
| def _CumsumGrad(op, grad): |
| axis = op.inputs[1] |
| exclusive = op.get_attr("exclusive") |
| reverse = op.get_attr("reverse") |
| return [ |
| math_ops.cumsum(grad, axis, exclusive=exclusive, reverse=not reverse), |
| None |
| ] |
| |
| |
| @ops.RegisterGradient("Cumprod") |
| def _CumprodGrad(op, grad): |
| x = op.inputs[0] |
| axis = op.inputs[1] |
| exclusive = op.get_attr("exclusive") |
| reverse = op.get_attr("reverse") |
| |
| # TODO This fails when x contains 0 and should be fixed |
| prod = math_ops.cumprod(x, axis, exclusive=exclusive, reverse=reverse) |
| out = math_ops.cumsum( |
| prod * grad, axis, exclusive=exclusive, reverse=not reverse) |
| return [out / x, None] |
| |
| |
| @ops.RegisterGradient("CumulativeLogsumexp") |
| def _CumulativeLogsumexpGrad(op, grad): |
| x = op.inputs[0] |
| axis = op.inputs[1] |
| cumulative_logsumexp = op.outputs[0] |
| |
| exclusive = op.get_attr("exclusive") |
| reverse = op.get_attr("reverse") |
| |
| # Split the incoming gradient into positive and negative part |
| # in order to take logs. This is required for stable results. |
| log_grad_positive = array_ops.where_v2( |
| math_ops.greater(grad, 0), |
| math_ops.log(grad), |
| grad.dtype.min) |
| |
| log_grad_negative = array_ops.where_v2( |
| math_ops.less(grad, 0), |
| math_ops.log(-grad), |
| grad.dtype.min) |
| |
| output_pos = math_ops.exp( |
| math_ops.cumulative_logsumexp( |
| log_grad_positive - cumulative_logsumexp, |
| axis=axis, reverse=not reverse, exclusive=exclusive) + x) |
| |
| output_neg = math_ops.exp( |
| math_ops.cumulative_logsumexp( |
| log_grad_negative - cumulative_logsumexp, |
| axis=axis, reverse=not reverse, exclusive=exclusive) + x) |
| |
| return [output_pos - output_neg, None] |
| |
| |
| @ops.RegisterGradient("NextAfter") |
| def _NextAfterGrad(op, grad): |
| """Returns gradient of nextafter(x1, x2) with respect to x1 and x2.""" |
| x1 = op.inputs[0] |
| x2 = op.inputs[1] |
| s_x1 = array_ops.shape(x1) |
| s_x2 = array_ops.shape(x2) |
| r_x1, r_x2 = gen_array_ops.broadcast_gradient_args(s_x1, s_x2) |
| with ops.control_dependencies([grad]): |
| partial_x1 = array_ops.ones(s_x1, dtype=x1.dtype) |
| partial_x2 = array_ops.zeros(s_x2, dtype=x2.dtype) |
| return (array_ops.reshape( |
| math_ops.reduce_sum(partial_x1 * grad, r_x1), s_x1), |
| array_ops.reshape( |
| math_ops.reduce_sum(partial_x2 * grad, r_x2), s_x2)) |