| # Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Gradients for operators defined in linalg_ops.py. |
| |
| Useful reference for derivative formulas is (Mike Giles, 2008). |
| |
| Ionescu et al. (2015) provide a detailed derivation of formulas for |
| backpropagating through spectral layers (SVD and Eig). |
| |
| References: |
| An extended collection of matrix derivative results for |
| forward and reverse mode automatic differentiation: |
| [Mike Giles, 2008] |
| (https://ora.ox.ac.uk/objects/uuid:8d0c0a29-c92b-4153-a1d2-38b276e93124) |
| ([pdf](http://eprints.maths.ox.ac.uk/1079/1/NA-08-01.pdf)) |
| Matrix Backpropagation for Deep Networks with Structured Layers |
| [Ionescu et al., 2015] |
| (https://www.cv-foundation.org/openaccess/content_iccv_2015/html/Ionescu_Matrix_Backpropagation_for_ICCV_2015_paper.html) |
| ([pdf](https://www.cv-foundation.org/openaccess/content_iccv_2015/papers/Ionescu_Matrix_Backpropagation_for_ICCV_2015_paper.pdf)) |
| Training Deep Networks with Structured Layers by Matrix Backpropagation: |
| [Ionescu et al., 2015](https://arxiv.org/abs/1509.07838) |
| ([pdf](https://arxiv.org/pdf/1509.07838.pdf)) |
| """ |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| from tensorflow.python.framework import dtypes |
| from tensorflow.python.framework import ops |
| from tensorflow.python.ops import array_ops |
| from tensorflow.python.ops import control_flow_ops |
| from tensorflow.python.ops import gen_linalg_ops |
| from tensorflow.python.ops import linalg_ops |
| from tensorflow.python.ops import math_ops |
| from tensorflow.python.ops.linalg import linalg_impl as _linalg |
| |
| |
| @ops.RegisterGradient("MatrixInverse") |
| def _MatrixInverseGrad(op, grad): |
| """Gradient for MatrixInverse.""" |
| ainv = op.outputs[0] |
| return -math_ops.matmul( |
| ainv, math_ops.matmul(grad, ainv, adjoint_b=True), adjoint_a=True) |
| |
| |
| @ops.RegisterGradient("Einsum") |
| def _EinsumGrad(op, grad): |
| """Gradient for Einsum.""" |
| ellipsis = "..." |
| |
| def _GetAxisFromLabel(subscripts, label): |
| """Returns the axis (possibly negative) corresponding to a label. |
| |
| Returns the axis index of the axis label if it is before an ellipsis (or if |
| the ellipsis is not present), and the negative index if it occurs after the |
| ellipsis. E.g. index of `b` in `ab...cd`, is `1`, but that of `c` is `-2`. |
| |
| For multiple occurrences, returns the leftmost one. If not found, returns |
| None. |
| |
| Args: |
| subscripts: A string denoting the einsum subscript (e.g. `ab...cd`) |
| label: The single character axis label. |
| """ |
| splits = subscripts.split(ellipsis) |
| index = splits[0].find(label) |
| if index != -1: |
| return index |
| if len(splits) < 2: |
| return None |
| index = splits[1].find(label) |
| if index != -1: |
| return index - len(splits[1]) |
| return None |
| |
| def _GetBcastSubshape(subscripts): |
| """Returns a tuple denoting the slice mapping to ellipsis. |
| |
| For a given subscript, returns a tuple (start, end) denoting the start |
| axis index and the (negative) end axis index respectively. For any input |
| Tensor `x` described by the subscript, `x[start:end]` would be the slice |
| represented by the ellipsis. E.g. For `ab...cd` returns `[1, -2]`. |
| |
| If ellipsis is not present in `subscripts`, returns `(0, 0)`. |
| |
| Args: |
| subscripts: A string denoting the einsum subscript. |
| """ |
| start = subscripts.find(ellipsis) |
| if start == -1: |
| return 0, 0 |
| remaining = len(subscripts) - (start + len(ellipsis)) |
| end = -remaining if remaining > 0 else None |
| return start, end |
| |
| def _GetReducedSubscripts(reduced_label_set, input_shape, subscripts): |
| """Returns reduced subscripts and their corresponding dimensions and axes. |
| |
| Given a set of axis labels, returns their concatenated subscript, their |
| corresponding dimensions from input_shape, and their corresponding axes. |
| Note that the concatenated subscript `reduced_subs` may have axis labels |
| from `reduced_label_set` in any order. For example, for the reduced label |
| set `{b, d}`, subscripts `aabbcd` and input shape `[2,2,5,5,3,4]`, returns |
| subscripts `bd`, dimensions `[5,4]` and axes `[2,5]`. |
| |
| Args: |
| reduced_label_set: Set of axis labels which appear in `subscripts`. |
| input_shape: A `Tensor` representing the shape of the einsum operand |
| corresponding to `subscripts`. |
| subscripts: A string denoting the einsum subscript. |
| |
| Returns: |
| reduced_subs: Subscripts formed by a concatenation of labels in |
| `reduced_label_set`. |
| reduced_dims: Dimensions from `input_shape` corresponding to each label |
| in `reduced_subs`. |
| reduced_axes: Axes described by `subscripts` corresponding to each label |
| in `reduced_subs`. If there are multiple occurrences in `subscripts`, |
| we consider only the leftmost one. |
| |
| """ |
| # Concatenate the sequence of reduced axis labels. |
| reduced_subs = "".join(list(reduced_label_set)) |
| # Get the axis (may be positive, negative or zero) for each of the reduced |
| # labels. If the same label appears multiple times, get the left-most axis. |
| reduced_axes = [_GetAxisFromLabel(subscripts, s) for s in reduced_subs] |
| # Get the corresponding dimensions for each reduced axis. |
| reduced_dims = array_ops.stack([input_shape[ax] for ax in reduced_axes]) |
| return reduced_subs, reduced_dims, reduced_axes |
| |
| def _GetGradReduced(output_grad, output_subs, input_subs, input_shape, |
| reduced_label_set): |
| """Returns the gradient wrt input for a unary einsum with reductions. |
| |
| Args: |
| output_grad: The gradient wrt the output of a unary einsum operation. |
| output_subs: The output subscript. (E.g. `ac` for equation `abc->ac`). |
| input_subs: The input subscript. (E.g. `abc` for equation `abc->ac`). |
| input_shape: A `Tensor` representing the shape of the input operand. |
| reduced_label_set: The set of axis labels appearing in `input_subs` but |
| not in `output_subs`. |
| """ |
| # Let's say the einsum operation was "aabbcd->ca", where axis labels 'b' and |
| # 'd' are reduced with input_shape [2,2,5,5,3,4]. Then obtain the reduced |
| # subscripts "bd", corresponding dimensions [5,4] and axes [2,5]. |
| reduced_subs, reduced_dims, reduced_axes = _GetReducedSubscripts( |
| reduced_label_set, input_shape, input_subs) |
| # Whether either the input or the output subscripts have a repeated label. |
| # This is true for "aabbcd->ca" or "abd->cca" but false for "abcd->ca". |
| has_repeated_labels = ( |
| len(set(input_subs)) + len(set(output_subs)) < |
| len(input_subs) + len(output_subs)) |
| # Compute the input subscripts without the reduced axis labels, e.g. "aac" |
| # for the equation "aabbcd->ca". |
| input_subs_without_reduced_labels = "".join( |
| [s for s in input_subs if s not in reduced_label_set]) |
| |
| # The gradient wrt the input for the equation "abc->ac" (or, equivalently |
| # reduce_sum(..., axis=1)) is just the gradient of the output tiled N times |
| # along axis 1, where label 'b' represents a dimension of size N. |
| # |
| # If we're not dealing with repeated labels, and the non-reduced labels |
| # doesn't need to be transposed, then just tiling is enough and there is no |
| # need to call another einsum. For example, tiling is sufficient for |
| # "abcd->ac". But for equations like "aabbcd->ac" (generalized traces) or |
| # "abc->ca" (transpose), we'd need another einsum operation after tiling. |
| if (not has_repeated_labels and |
| input_subs_without_reduced_labels == output_subs): |
| # Obtain the shape of the output, as if keepdims=True on reduce sum. E.g. |
| # for the equation "abcd->ac" with input shape [2,5,3,4], we get the |
| # reduced shape [2,1,3,1]. |
| reduced_shape = math_ops.reduced_shape( |
| input_shape, ops.convert_to_tensor(reduced_axes)) |
| # Reshaping the gradient (wrt "ac") to [2,1,3,1] and broadcasting it to |
| # the shape [2,5,3,4] results in the gradient wrt "abcd". |
| return array_ops.broadcast_to( |
| array_ops.reshape(output_grad, reduced_shape), input_shape) |
| |
| # If we *do* have traces or transpose operations, then prepend the extra |
| # reduced dimensions to the front. E.g. Given the equation "aabbcd->ca" we'd |
| # first obtain the VJP for "bdca->ca", and then the VJP for "aabbcd->bdca". |
| # |
| # Obtain the input shape with reduced dimensions prepended, viz. [5,4,3,2]. |
| # This is the shape of the intermediate "bdca". |
| grad_shape_with_reduced_labels = array_ops.concat( |
| [reduced_dims, array_ops.shape(output_grad)], axis=0) |
| # Obtain the output shape of the reduction-only equation "bdca->ca" as if |
| # keepdims=True; viz. [1,1,3,2]. Since we prepended the reduced labels, we |
| # just have to prepend that many 1s to the output shape. |
| reduced_shape = ( |
| array_ops.concat([ |
| array_ops.ones(len(reduced_label_set), dtype=dtypes.int32), |
| array_ops.shape(output_grad) |
| ], |
| axis=0)) |
| # Compute the VJP for the intermediate (viz. "bdca->ca") for which |
| # broadcasting is sufficient. |
| broadcasted_grad = array_ops.broadcast_to( |
| array_ops.reshape(output_grad, reduced_shape), |
| grad_shape_with_reduced_labels) |
| # Compute the VJP for the final step (viz. "aabbcd->bdca"). We can use |
| # einsum with the input and output subscripts reversed (viz. "bdca->aabbcd") |
| # since the output axis labels now appear in the input subscripts. |
| return gen_linalg_ops.einsum([broadcasted_grad], |
| "{}->{}".format(reduced_subs + output_subs, |
| input_subs)) |
| |
| def _GetGradWrt(output_grad, other_operand, input_shape, input_subs, |
| other_subs, output_subs): |
| """Returns the gradient wrt an input operand for a binary einsum. |
| |
| This function does not handle (un)broadcasting. This must be done separately |
| on the returned gradient. |
| |
| Args: |
| output_grad: The gradient wrt the output of a binary einsum operation. |
| other_operand: The complementary `Tensor` operand i.e. which is not the |
| input operand. |
| input_shape: A `Tensor` representing the shape of input operand. |
| input_subs: The subscripts of the input operand. |
| other_subs: The subscripts of the complementary operand. |
| output_subs: The output subscripts. |
| """ |
| # Claim: For the einsum operation z = einsum("{eq_x},{eq_y}->{eq_z}", x, y), |
| # where the equation involves only Tensor contractions, generalized traces |
| # and transposes, the input gradients are given by the vector-jacobian |
| # products (VJPs): |
| # |
| # grad_wrt_x = einsum("{eq_y},{eq_z}->{eq_x}", y, grad_wrt_z) |
| # grad_wrt_y = einsum("{eq_x},{eq_z}->{eq_y}", x, grad_wrt_z} |
| # |
| # where grad_wrt_x and grad_wrt_y are the gradients with respect to inputs |
| # x and y and grad_wrt_z is the given gradient with respect to output z. |
| # |
| # Proof: For unary einsum equations involving only transpose ("ij->ji") and |
| # traces ("ii->i"), the linear mapping's Jacobian at input x is given |
| # by the function itself. We can verify that the linear map given by the |
| # VJP are einsums with the equations "ji->ij" and "i->ii" respectively, |
| # where the latter represents 'un-tracing', or filling the diagonal with |
| # the input axis and non-diagonal entries are zeros. |
| # Furthermore, recall that matrix multiplication, which is |
| # represented by the equation "ab,bc->ac", has its VJPs given by the |
| # einsum equations "ac,bc->ab" and "ab,ac->bc" (see, for example |
| # https://math.stackexchange.com/a/2755680). Combined with transposes and |
| # traces we can rewrite Tensor contractions as regular matrix |
| # multiplication. Since each of these operations have their VJPs described |
| # by einsums of the required pattern, the result follows. |
| # |
| # Accordingly, einsum operations except for those with reductions, e.g. |
| # "abc,cd->ad" have their VJPs defined by: |
| # "{output_subs},{other_subs}->{input_subs}". |
| # |
| # But if there is a reduction, this would lead to the equation "ad,cd->abc" |
| # which is invalid because the reduced axis label 'b' is present in the |
| # output but not in any of the inputs. Therefore, we compute the VJP in two |
| # steps: first we obtain VJP for "ac,cd->ad" and then we compute the VJP of |
| # "abc->ac" or, equivalently, reduce_sum(..., axis=1). |
| # |
| # Compute the set of input axis labels which doesn't appear in either the |
| # output subscripts or the other operand's subscript. E.g. the set {'b'} for |
| # the equation "abc,cd->ad". |
| reduced_label_set = set(input_subs).difference( |
| set(output_subs + other_subs + ".")) |
| # Obtain the input subscripts with the reduced axis labels removed. E.g. |
| # "ac" in the above example. |
| left_subs = "".join(s for s in input_subs if s not in reduced_label_set) |
| |
| # Compute the gradient wrt the input, without accounting for the operation |
| # "abc->ac". So, now we have the VJP of the operation "ac,cd->ad". |
| grad_reduced = gen_linalg_ops.einsum([output_grad, other_operand], |
| "{},{}->{}".format( |
| output_subs, other_subs, |
| left_subs)) |
| # If the reduced_label_set is empty, then we already have the gradient |
| # wrt the input. |
| if not reduced_label_set: |
| return grad_reduced |
| # Otherwise, we currently have the gradient wrt the output of the reduction |
| # operation "abc->ac". Invoke the subroutine for the gradient for unary |
| # einsum with reductions. |
| return _GetGradReduced(grad_reduced, left_subs, input_subs, input_shape, |
| reduced_label_set) |
| |
| equation = op.get_attr("equation") |
| if isinstance(equation, bytes): |
| equation = equation.decode() |
| input_subs, output_subs = equation.split("->") |
| |
| if len(op.inputs) == 1: |
| # For the unary einsum z = einsum("{eq_x}->{eq_z}", x), the gradient wrt the |
| # input (VJP) is given by the reversed equation: |
| # grad_wrt_x = einsum("{eq_z}->{eq_x}", grad_wrt_z) |
| # (See the justification in _GetGradWrt). This is valid unless there are |
| # reduced axis labels; i.e. axis labels appearing in the input but not in |
| # the output subscripts. |
| input_shape = array_ops.shape(op.inputs[0]) |
| # Find the axis labels which appear only in the input. |
| reduced_label_set = set(input_subs).difference(set(output_subs + ellipsis)) |
| if not reduced_label_set: |
| # Return the einsum given by the reversed equation, since we don't have |
| # reduced axes. |
| return gen_linalg_ops.einsum([grad], |
| "{}->{}".format(output_subs, input_subs)) |
| # We do have reduced axes, so we invoke the subroutine for reduced unary |
| # einsums. |
| return _GetGradReduced(grad, output_subs, input_subs, input_shape, |
| reduced_label_set) |
| |
| x_subs, y_subs = input_subs.split(",") |
| # Add ellipsis for broadcasted dimensions if any operand does not have it. |
| # This is because the equation "...ij,jk->ik" may be valid if the 0th input's |
| # batch shape is empty, but the VJP equation "jk,ik->...ij" is not valid |
| # because only the output subscripts contain ellipsis. |
| if ellipsis in output_subs: |
| if ellipsis not in x_subs: |
| x_subs += ellipsis |
| if ellipsis not in y_subs: |
| y_subs += ellipsis |
| |
| # Obtain the gradients wrt the inputs x and y, without taking into account |
| # the unbroadcasting. |
| x, y = op.inputs[0], op.inputs[1] |
| x_shape = array_ops.shape(x) |
| y_shape = array_ops.shape(y) |
| grad_x = _GetGradWrt(grad, y, x_shape, x_subs, y_subs, output_subs) |
| grad_y = _GetGradWrt(grad, x, y_shape, y_subs, x_subs, output_subs) |
| |
| if ellipsis not in output_subs: |
| # If no ellipsis in the output; then no need to unbroadcast. |
| return grad_x, grad_y |
| |
| # Below we handle the case that broadcasting between x and y was necessary, |
| # with x and y having possibly different batch shapes. |
| |
| # Obtain the range of axes which map to ellipsis. E.g. for subscripts 'ab...c' |
| # and shape of rank 10; the range [3:-1] denotes the broadcasted axes. |
| bx_start, bx_end = _GetBcastSubshape(x_subs) |
| by_start, by_end = _GetBcastSubshape(y_subs) |
| # If the static batch shapes are equal, we don't need to unbroadcast. |
| x_shape_static = x.get_shape() |
| y_shape_static = y.get_shape() |
| if (x_shape_static.is_fully_defined() and |
| y_shape_static.is_fully_defined() and |
| x_shape_static[bx_start:bx_end] == y_shape_static[by_start:by_end]): |
| return grad_x, grad_y |
| |
| # Sum the gradient across the broadcasted axes. |
| rx, ry = array_ops.broadcast_gradient_args(x_shape[bx_start:bx_end], |
| y_shape[by_start:by_end]) |
| grad_x = array_ops.reshape( |
| math_ops.reduce_sum(grad_x, bx_start + rx), x_shape) |
| grad_y = array_ops.reshape( |
| math_ops.reduce_sum(grad_y, by_start + ry), y_shape) |
| return grad_x, grad_y |
| |
| |
| @ops.RegisterGradient("MatrixDeterminant") |
| def _MatrixDeterminantGrad(op, grad): |
| """Gradient for MatrixDeterminant.""" |
| a = op.inputs[0] |
| c = op.outputs[0] |
| a_adj_inv = linalg_ops.matrix_inverse(a, adjoint=True) |
| multipliers = array_ops.reshape(grad * c, |
| array_ops.concat([array_ops.shape(c), [1, 1]], |
| 0)) |
| return multipliers * a_adj_inv |
| |
| |
| @ops.RegisterGradient("MatrixSquareRoot") |
| def _MatrixSquareRootGrad(op, grad): |
| """Gradient for MatrixSquareRoot.""" |
| |
| # Let A be an m x m square matrix (or batch of matrices) |
| # Let R = sqrtm(A) |
| # By definition, A = RR |
| # Take the differential: dA = d(RR) = RdR + dRR |
| # Solve the resulting Sylvester equation for dR |
| |
| # Used to find Kronecker products within the Sylvester equation |
| def _KroneckerProduct(b1, b2): |
| """Computes the Kronecker product of two batches of square matrices.""" |
| b1_shape = array_ops.shape(b1) |
| b2_shape = array_ops.shape(b2) |
| b1_order = b1_shape[-1] |
| b2_order = b2_shape[-1] |
| |
| shape_slice_size = [math_ops.subtract(array_ops.size(b1_shape), 2)] |
| shape_slice = array_ops.slice(b1_shape, [0], |
| shape_slice_size) # Same for both batches |
| b1_reshape_shape = array_ops.concat( |
| [shape_slice, [b1_order], [1], [b1_order], [1]], 0) |
| b2_reshape_shape = array_ops.concat( |
| [shape_slice, [1], [b2_order], [1], [b2_order]], 0) |
| |
| b1_reshape = array_ops.reshape(b1, b1_reshape_shape) |
| b2_reshape = array_ops.reshape(b2, b2_reshape_shape) |
| |
| order_prod = b1_order * b2_order |
| kprod_shape = array_ops.concat([shape_slice, [order_prod], [order_prod]], 0) |
| return array_ops.reshape(b1_reshape * b2_reshape, kprod_shape) |
| |
| sqrtm = op.outputs[0] # R |
| shape = array_ops.shape(sqrtm) |
| order = shape[-1] # m |
| matrix_count = math_ops.reduce_prod(shape[0:-2]) |
| |
| # Get batch of m x m identity matrices |
| eye = linalg_ops.eye(order, dtype=sqrtm.dtype) # m x m identity matrix |
| eye_flat = array_ops.reshape(eye, [-1]) |
| eye_tiled = array_ops.tile(eye_flat, [matrix_count]) |
| eye_batch = array_ops.reshape(eye_tiled, shape) |
| |
| # The transpose of R is taken in the k1 term instead of k2 in |
| # order to prevent redundant transposition of R (i.e. (R')' = R) |
| sqrtm_transpose = array_ops.matrix_transpose(sqrtm) |
| k1 = _KroneckerProduct(eye_batch, sqrtm_transpose) |
| k2 = _KroneckerProduct(sqrtm, eye_batch) |
| ksum = math_ops.add(k1, k2) |
| |
| # Vectorize dA |
| shape_slice_size = [math_ops.subtract(array_ops.size(shape), 2)] |
| shape_slice = array_ops.slice(shape, [0], shape_slice_size) |
| shape_vec_da = array_ops.concat([shape_slice, [order * order], [1]], 0) |
| vec_da = array_ops.reshape(array_ops.matrix_transpose(grad), shape_vec_da) |
| |
| # Solve for vec(dR) |
| vec_dsqrtm = linalg_ops.matrix_solve(ksum, vec_da) |
| |
| # Solve for dR by inverse vectorizing vec(dR) |
| dsqrtm_transpose = array_ops.reshape(vec_dsqrtm, shape) |
| return array_ops.matrix_transpose(dsqrtm_transpose) |
| |
| |
| @ops.RegisterGradient("LogMatrixDeterminant") |
| def _LogMatrixDeterminantGrad(op, _, grad_b): |
| """Gradient for LogMatrixDeterminant.""" |
| a = op.inputs[0] |
| c = op.outputs[1] |
| a_adj_inv = linalg_ops.matrix_inverse(a, adjoint=True) |
| multipliers = array_ops.reshape( |
| grad_b, array_ops.concat([array_ops.shape(c), [1, 1]], 0)) |
| return multipliers * a_adj_inv |
| |
| |
| @ops.RegisterGradient("Cholesky") |
| def _CholeskyGrad(op, grad): |
| """Gradient for Cholesky.""" |
| |
| # Gradient is l^{-H} @ ((l^{H} @ grad) * (tril(ones)-1/2*eye)) @ l^{-1} |
| l = op.outputs[0] |
| num_rows = array_ops.shape(l)[-1] |
| batch_shape = array_ops.shape(l)[:-2] |
| l_inverse = linalg_ops.matrix_triangular_solve(l, |
| linalg_ops.eye( |
| num_rows, |
| batch_shape=batch_shape, |
| dtype=l.dtype)) |
| |
| middle = math_ops.matmul(l, grad, adjoint_a=True) |
| middle = array_ops.matrix_set_diag(middle, |
| 0.5 * array_ops.matrix_diag_part(middle)) |
| middle = array_ops.matrix_band_part(middle, -1, 0) |
| |
| grad_a = math_ops.matmul( |
| math_ops.matmul(l_inverse, middle, adjoint_a=True), l_inverse) |
| |
| grad_a += _linalg.adjoint(grad_a) |
| return grad_a * 0.5 |
| |
| |
| @ops.RegisterGradient("Qr") |
| def _QrGrad(op, dq, dr): |
| """Gradient for Qr.""" |
| q, r = op.outputs |
| if q.dtype.is_complex: |
| raise NotImplementedError("QrGrad not implemented for dtype: %s" % q.dtype) |
| if (r.shape.ndims is None or r.shape.as_list()[-2] is None or |
| r.shape.as_list()[-1] is None): |
| raise NotImplementedError("QrGrad not implemented with dynamic shapes.") |
| if r.shape.dims[-2].value != r.shape.dims[-1].value: |
| raise NotImplementedError("QrGrad not implemented when ncols > nrows " |
| "or full_matrices is true and ncols != nrows.") |
| |
| qdq = math_ops.matmul(q, dq, adjoint_a=True) |
| qdq_ = qdq - _linalg.adjoint(qdq) |
| rdr = math_ops.matmul(r, dr, adjoint_b=True) |
| rdr_ = rdr - _linalg.adjoint(rdr) |
| tril = array_ops.matrix_band_part(qdq_ + rdr_, -1, 0) |
| |
| def _TriangularSolve(x, r): |
| """Equiv to matmul(x, adjoint(matrix_inverse(r))) if r is upper-tri.""" |
| return _linalg.adjoint( |
| linalg_ops.matrix_triangular_solve( |
| r, _linalg.adjoint(x), lower=False, adjoint=False)) |
| |
| grad_a = math_ops.matmul(q, dr + _TriangularSolve(tril, r)) |
| grad_b = _TriangularSolve(dq - math_ops.matmul(q, qdq), r) |
| return grad_a + grad_b |
| |
| |
| @ops.RegisterGradient("MatrixSolve") |
| def _MatrixSolveGrad(op, grad): |
| """Gradient for MatrixSolve.""" |
| a = op.inputs[0] |
| adjoint_a = op.get_attr("adjoint") |
| c = op.outputs[0] |
| grad_b = linalg_ops.matrix_solve(a, grad, adjoint=not adjoint_a) |
| if adjoint_a: |
| grad_a = -math_ops.matmul(c, grad_b, adjoint_b=True) |
| else: |
| grad_a = -math_ops.matmul(grad_b, c, adjoint_b=True) |
| return (grad_a, grad_b) |
| |
| |
| @ops.RegisterGradient("MatrixSolveLs") |
| def _MatrixSolveLsGrad(op, grad): |
| """Gradients for MatrixSolveLs.""" |
| |
| # TODO(rmlarsen): The implementation could be more efficient: |
| # a) Output the Cholesky factorization from forward op instead of |
| # recomputing it here. |
| # b) Implement a symmetric rank-k update op instead of computing |
| # x*z + transpose(x*z). This pattern occurs other places in TensorFlow. |
| |
| def _Overdetermined(op, grad): |
| """Gradients for the overdetermined case of MatrixSolveLs. |
| |
| This is the backprop for the solution to the normal equations of the first |
| kind: |
| X = F(A, B) = (A^T * A + lambda * I)^{-1} * A^T * B |
| which solve the least squares problem |
| min ||A * X - B||_F^2 + lambda ||X||_F^2. |
| """ |
| a = op.inputs[0] |
| b = op.inputs[1] |
| x = op.outputs[0] |
| l2_regularizer = math_ops.cast(op.inputs[2], a.dtype.base_dtype) |
| # pylint: disable=protected-access |
| chol = linalg_ops._RegularizedGramianCholesky( |
| a, l2_regularizer=l2_regularizer, first_kind=True) |
| # pylint: enable=protected-access |
| # Temporary z = (A^T * A + lambda * I)^{-1} * grad. |
| z = linalg_ops.cholesky_solve(chol, grad) |
| xzt = math_ops.matmul(x, z, adjoint_b=True) |
| zx_sym = xzt + array_ops.matrix_transpose(xzt) |
| grad_a = -math_ops.matmul(a, zx_sym) + math_ops.matmul(b, z, adjoint_b=True) |
| grad_b = math_ops.matmul(a, z) |
| return (grad_a, grad_b, None) |
| |
| def _Underdetermined(op, grad): |
| """Gradients for the underdetermined case of MatrixSolveLs. |
| |
| This is the backprop for the solution to the normal equations of the second |
| kind: |
| X = F(A, B) = A * (A*A^T + lambda*I)^{-1} * B |
| that (for lambda=0) solve the least squares problem |
| min ||X||_F subject to A*X = B. |
| """ |
| a = op.inputs[0] |
| b = op.inputs[1] |
| l2_regularizer = math_ops.cast(op.inputs[2], a.dtype.base_dtype) |
| # pylint: disable=protected-access |
| chol = linalg_ops._RegularizedGramianCholesky( |
| a, l2_regularizer=l2_regularizer, first_kind=False) |
| # pylint: enable=protected-access |
| grad_b = linalg_ops.cholesky_solve(chol, math_ops.matmul(a, grad)) |
| # Temporary tmp = (A * A^T + lambda * I)^{-1} * B. |
| tmp = linalg_ops.cholesky_solve(chol, b) |
| a1 = math_ops.matmul(tmp, a, adjoint_a=True) |
| a1 = -math_ops.matmul(grad_b, a1) |
| a2 = grad - math_ops.matmul(a, grad_b, adjoint_a=True) |
| a2 = math_ops.matmul(tmp, a2, adjoint_b=True) |
| grad_a = a1 + a2 |
| return (grad_a, grad_b, None) |
| |
| fast = op.get_attr("fast") |
| if fast is False: |
| raise ValueError("Gradient not defined for fast=False") |
| matrix_shape = op.inputs[0].get_shape()[-2:] |
| if matrix_shape.is_fully_defined(): |
| if matrix_shape[-2] >= matrix_shape[-1]: |
| return _Overdetermined(op, grad) |
| else: |
| return _Underdetermined(op, grad) |
| else: |
| # We have to defer determining the shape to runtime and use |
| # conditional execution of the appropriate graph. |
| matrix_shape = array_ops.shape(op.inputs[0])[-2:] |
| return control_flow_ops.cond(matrix_shape[-2] >= matrix_shape[-1], |
| lambda: _Overdetermined(op, grad), |
| lambda: _Underdetermined(op, grad)) |
| |
| |
| @ops.RegisterGradient("MatrixTriangularSolve") |
| def _MatrixTriangularSolveGrad(op, grad): |
| """Gradient for MatrixTriangularSolve.""" |
| a = op.inputs[0] |
| b = op.inputs[1] |
| adjoint_a = op.get_attr("adjoint") |
| lower_a = op.get_attr("lower") |
| c = op.outputs[0] |
| grad_b = linalg_ops.matrix_triangular_solve( |
| a, grad, lower=lower_a, adjoint=not adjoint_a) |
| if adjoint_a: |
| grad_a = -math_ops.matmul(c, grad_b, adjoint_b=True) |
| else: |
| grad_a = -math_ops.matmul(grad_b, c, adjoint_b=True) |
| if lower_a: |
| grad_a = array_ops.matrix_band_part(grad_a, -1, 0) |
| else: |
| grad_a = array_ops.matrix_band_part(grad_a, 0, -1) |
| # If the static batch shapes are equal, we don't need to unbroadcast. |
| if (a.shape.is_fully_defined() and b.shape.is_fully_defined() and |
| a.shape[:-2] == b.shape[:-2]): |
| return grad_a, grad_b |
| a_shape = array_ops.shape(a) |
| b_shape = array_ops.shape(b) |
| ra, rb = array_ops.broadcast_gradient_args(a_shape[:-2], b_shape[:-2]) |
| grad_a = array_ops.reshape(math_ops.reduce_sum(grad_a, axis=ra), a_shape) |
| grad_b = array_ops.reshape(math_ops.reduce_sum(grad_b, axis=rb), b_shape) |
| return grad_a, grad_b |
| |
| |
| # To avoid nan in cases with degenerate eigenvalues or |
| # degenerate/zero singular values in calculations of |
| # f and s_inv_mat, we introduce a Lorentz broadening. |
| def _SafeReciprocal(x, epsilon=1E-20): |
| return x * math_ops.reciprocal(x * x + epsilon) |
| |
| @ops.RegisterGradient("Eig") |
| def _EigGrad(op, grad_e, grad_v): |
| """Gradient for Eig. |
| Based on eq. 4.77 from paper by |
| Christoph Boeddeker et al. |
| https://arxiv.org/abs/1701.00392 |
| See also |
| "Computation of eigenvalue and eigenvector derivatives |
| for a general complex-valued eigensystem" by Nico van der Aa. |
| As for now only distinct eigenvalue case is considered. |
| """ |
| e = op.outputs[0] |
| compute_v = op.get_attr("compute_v") |
| # a = op.inputs[0], which satisfies |
| # a[...,:,:] * v[...,:,i] = e[...,i] * v[...,i] |
| with ops.control_dependencies([grad_e, grad_v]): |
| if compute_v: |
| v = op.outputs[1] |
| vt = _linalg.adjoint(v) |
| # Construct the matrix f(i,j) = (i != j ? 1 / (e_i - e_j) : 0). |
| # Notice that because of the term involving f, the gradient becomes |
| # infinite (or NaN in practice) when eigenvalues are not unique. |
| # Mathematically this should not be surprising, since for (k-fold) |
| # degenerate eigenvalues, the corresponding eigenvectors are only defined |
| # up to arbitrary rotation in a (k-dimensional) subspace. |
| f = array_ops.matrix_set_diag( |
| _SafeReciprocal( |
| array_ops.expand_dims(e, -2) - array_ops.expand_dims(e, -1)), |
| array_ops.zeros_like(e)) |
| f = math_ops.conj(f) |
| vgv = math_ops.matmul(vt, grad_v) |
| mid = array_ops.matrix_diag(grad_e) |
| diag_grad_part = array_ops.matrix_diag(array_ops.matrix_diag_part( |
| math_ops.cast(math_ops.real(vgv), vgv.dtype))) |
| mid += f * (vgv - math_ops.matmul(math_ops.matmul(vt, v), diag_grad_part)) |
| # vt is formally invertible as long as the original matrix is |
| # diagonalizable. However, in practice, vt may |
| # be ill-conditioned when matrix original matrix is close to |
| # non-diagonalizable one |
| grad_a = linalg_ops.solve(vt, math_ops.matmul(mid, vt)) |
| else: |
| _, v = linalg_ops.eig(op.inputs[0]) |
| vt = _linalg.adjoint(v) |
| # vt is formally invertible as long as the original matrix is |
| # diagonalizable. However, in practice, vt may |
| # be ill-conditioned when matrix original matrix is close to |
| # non-diagonalizable one |
| grad_a = linalg_ops.solve(vt, |
| math_ops.matmul( |
| array_ops.matrix_diag(grad_e), |
| vt)) |
| return math_ops.cast(grad_a, op.inputs[0].dtype) |
| |
| @ops.RegisterGradient("SelfAdjointEigV2") |
| def _SelfAdjointEigV2Grad(op, grad_e, grad_v): |
| """Gradient for SelfAdjointEigV2.""" |
| e = op.outputs[0] |
| compute_v = op.get_attr("compute_v") |
| # a = op.inputs[0], which satisfies |
| # a[...,:,:] * v[...,:,i] = e[...,i] * v[...,i] |
| with ops.control_dependencies([grad_e, grad_v]): |
| if compute_v: |
| v = op.outputs[1] |
| # Construct the matrix f(i,j) = (i != j ? 1 / (e_i - e_j) : 0). |
| # Notice that because of the term involving f, the gradient becomes |
| # infinite (or NaN in practice) when eigenvalues are not unique. |
| # Mathematically this should not be surprising, since for (k-fold) |
| # degenerate eigenvalues, the corresponding eigenvectors are only defined |
| # up to arbitrary rotation in a (k-dimensional) subspace. |
| f = array_ops.matrix_set_diag( |
| _SafeReciprocal( |
| array_ops.expand_dims(e, -2) - array_ops.expand_dims(e, -1)), |
| array_ops.zeros_like(e)) |
| grad_a = math_ops.matmul( |
| v, |
| math_ops.matmul( |
| array_ops.matrix_diag(grad_e) + |
| f * math_ops.matmul(v, grad_v, adjoint_a=True), |
| v, |
| adjoint_b=True)) |
| else: |
| _, v = linalg_ops.self_adjoint_eig(op.inputs[0]) |
| grad_a = math_ops.matmul(v, |
| math_ops.matmul( |
| array_ops.matrix_diag(grad_e), |
| v, |
| adjoint_b=True)) |
| # The forward op only depends on the lower triangular part of a, so here we |
| # symmetrize and take the lower triangle |
| grad_a = array_ops.matrix_band_part(grad_a + _linalg.adjoint(grad_a), -1, 0) |
| grad_a = array_ops.matrix_set_diag(grad_a, |
| 0.5 * array_ops.matrix_diag_part(grad_a)) |
| return grad_a |
| |
| |
| @ops.RegisterGradient("Svd") |
| def _SvdGrad(op, grad_s, grad_u, grad_v): |
| """Gradient for the singular value decomposition.""" |
| |
| # The derivation for the compute_uv=False case, and most of |
| # the derivation for the full_matrices=True case, are in |
| # Giles' paper (see reference at top of file). A derivation for |
| # the full_matrices=False case is available at |
| # https://j-towns.github.io/papers/svd-derivative.pdf |
| # The derivation for complex valued SVD can be found in |
| # https://re-ra.xyz/misc/complexsvd.pdf or |
| # https://giggleliu.github.io/2019/04/02/einsumbp.html |
| a = op.inputs[0] |
| a_shape = a.get_shape().with_rank_at_least(2) |
| grad_s = math_ops.cast(grad_s, a.dtype) |
| grad_s_mat = array_ops.matrix_diag(grad_s) |
| |
| if not op.get_attr("compute_uv"): |
| s, u, v = linalg_ops.svd(a, compute_uv=True) |
| grad_a = math_ops.matmul(u, math_ops.matmul(grad_s_mat, v, adjoint_b=True)) |
| grad_a.set_shape(a_shape) |
| return grad_a |
| |
| full_matrices = op.get_attr("full_matrices") |
| |
| grad_u_shape = grad_u.get_shape().with_rank_at_least(2) |
| grad_v_shape = grad_v.get_shape().with_rank_at_least(2) |
| m = a_shape.dims[-2].merge_with(grad_u_shape[-2]) |
| n = a_shape.dims[-1].merge_with(grad_v_shape[-2]) |
| batch_shape = a_shape[:-2].merge_with(grad_u_shape[:-2]).merge_with( |
| grad_v_shape[:-2]) |
| a_shape = batch_shape.concatenate([m, n]) |
| |
| m = a_shape.dims[-2].value |
| n = a_shape.dims[-1].value |
| # TODO(rmlarsen): Make this work with placeholders. |
| if m is None or n is None: |
| raise NotImplementedError( |
| "SVD gradient has not been implemented for input with unknown " |
| "inner matrix shape.") |
| |
| s = op.outputs[0] |
| u = op.outputs[1] |
| v = op.outputs[2] |
| s = math_ops.cast(s, a.dtype) |
| |
| use_adjoint = False |
| if m > n: |
| # Compute the gradient for A^H = V * S^T * U^H, and (implicitly) take the |
| # Hermitian transpose of the gradient at the end. |
| use_adjoint = True |
| m, n = n, m |
| u, v = v, u |
| grad_u, grad_v = grad_v, grad_u |
| |
| with ops.control_dependencies([grad_s, grad_u, grad_v]): |
| if full_matrices and abs(m - n) > 1: |
| raise NotImplementedError( |
| "svd gradient is not implemented for abs(m - n) > 1 " |
| "when full_matrices is True") |
| s_mat = array_ops.matrix_diag(s) |
| s2 = math_ops.square(s) |
| |
| # NOTICE: Because of the term involving f, the gradient becomes |
| # infinite (or NaN in practice) when singular values are not unique. |
| # Mathematically this should not be surprising, since for (k-fold) |
| # degenerate singular values, the corresponding singular vectors are |
| # only defined up a (k-dimensional) subspace. In practice, this can |
| # lead to numerical instability when singular values are close but not |
| # exactly equal. |
| |
| s_shape = array_ops.shape(s) |
| f = array_ops.matrix_set_diag( |
| _SafeReciprocal( |
| array_ops.expand_dims(s2, -2) - array_ops.expand_dims(s2, -1)), |
| array_ops.zeros_like(s)) |
| s_inv_mat = array_ops.matrix_diag(_SafeReciprocal(s)) |
| |
| v1 = v[..., :, :m] |
| grad_v1 = grad_v[..., :, :m] |
| |
| u_gu = math_ops.matmul(u, grad_u, adjoint_a=True) |
| v_gv = math_ops.matmul(v1, grad_v1, adjoint_a=True) |
| |
| f_u = f * u_gu |
| f_v = f * v_gv |
| |
| term1_nouv = ( |
| grad_s_mat + math_ops.matmul(f_u + _linalg.adjoint(f_u), s_mat) + |
| math_ops.matmul(s_mat, f_v + _linalg.adjoint(f_v))) |
| |
| term1 = math_ops.matmul(u, math_ops.matmul(term1_nouv, v1, adjoint_b=True)) |
| |
| if m == n: |
| grad_a_before_transpose = term1 |
| else: |
| gv1t = array_ops.matrix_transpose(grad_v1, conjugate=True) |
| gv1t_v1 = math_ops.matmul(gv1t, v1) |
| term2_nous = gv1t - math_ops.matmul(gv1t_v1, v1, adjoint_b=True) |
| |
| if full_matrices: |
| v2 = v[..., :, m:n] |
| grad_v2 = grad_v[..., :, m:n] |
| |
| v1t_gv2 = math_ops.matmul(v1, grad_v2, adjoint_a=True) |
| term2_nous -= math_ops.matmul(v1t_gv2, v2, adjoint_b=True) |
| |
| u_s_inv = math_ops.matmul(u, s_inv_mat) |
| term2 = math_ops.matmul(u_s_inv, term2_nous) |
| |
| grad_a_before_transpose = term1 + term2 |
| |
| if a.dtype.is_complex: |
| eye = _linalg.eye(s_shape[-1], batch_shape=s_shape[:-1], dtype=a.dtype) |
| l = eye * v_gv |
| term3_nouv = math_ops.matmul(s_inv_mat, _linalg.adjoint(l) - l) |
| term3 = 1 / 2. * math_ops.matmul( |
| u, math_ops.matmul(term3_nouv, v1, adjoint_b=True)) |
| |
| grad_a_before_transpose += term3 |
| |
| if use_adjoint: |
| grad_a = array_ops.matrix_transpose( |
| grad_a_before_transpose, conjugate=True) |
| else: |
| grad_a = grad_a_before_transpose |
| |
| grad_a.set_shape(a_shape) |
| return grad_a |
| |
| |
| def _LeftShift(x): |
| """Shifts next-to-last dimension to the left, adding zero on the right.""" |
| rank = array_ops.rank(x) |
| zeros = array_ops.zeros((rank - 2, 2), dtype=dtypes.int32) |
| pad = array_ops.concat([zeros, array_ops.constant([[0, 1], [0, 0]])], axis=0) |
| return array_ops.pad(x[..., 1:, :], pad) |
| |
| |
| def _RightShift(x): |
| """Shifts next-to-last dimension to the right, adding zero on the left.""" |
| rank = array_ops.rank(x) |
| zeros = array_ops.zeros((rank - 2, 2), dtype=dtypes.int32) |
| pad = array_ops.concat([zeros, array_ops.constant([[1, 0], [0, 0]])], axis=0) |
| return array_ops.pad(x[..., :-1, :], pad) |
| |
| |
| @ops.RegisterGradient("TridiagonalMatMul") |
| def _TridiagonalMatMulGrad(op, grad): |
| """Gradient for TridiagonalMatMul.""" |
| superdiag_conj = array_ops.matrix_transpose(op.inputs[0], conjugate=True) |
| maindiag_conj = array_ops.matrix_transpose(op.inputs[1], conjugate=True) |
| subdiag_conj = array_ops.matrix_transpose(op.inputs[2], conjugate=True) |
| rhs_conj = math_ops.conj(op.inputs[3]) |
| |
| superdiag_grad = math_ops.reduce_sum(_LeftShift(rhs_conj) * grad, axis=-1) |
| maindiag_grad = math_ops.reduce_sum(rhs_conj * grad, axis=-1) |
| subdiag_grad = math_ops.reduce_sum(_RightShift(rhs_conj) * grad, axis=-1) |
| rhs_grad = _RightShift(superdiag_conj * grad) + \ |
| maindiag_conj * grad + _LeftShift(subdiag_conj * grad) |
| |
| superdiag_grad = array_ops.expand_dims(superdiag_grad, -2) |
| maindiag_grad = array_ops.expand_dims(maindiag_grad, -2) |
| subdiag_grad = array_ops.expand_dims(subdiag_grad, -2) |
| |
| return superdiag_grad, maindiag_grad, subdiag_grad, rhs_grad |
| |
| |
| @ops.RegisterGradient("TridiagonalSolve") |
| def _TridiagonalSolveGrad(op, grad): |
| """Gradient for TridiagonalSolveGrad.""" |
| diags = op.inputs[0] |
| x = op.outputs[0] |
| partial_pivoting = op.get_attr("partial_pivoting") |
| |
| # Transposing the matrix within tridiagonal_solve kernel by interchanging |
| # superdiagonal and subdiagonal wouldn't work on GPU due to mismatch with |
| # paddings required by cusparse*gtsv routines. |
| # So constructing the transposed matrix in Python. |
| diags_transposed = _TransposeTridiagonalMatrix(diags) |
| |
| grad_rhs = linalg_ops.tridiagonal_solve(diags_transposed, grad, |
| partial_pivoting=partial_pivoting) |
| grad_diags = -_MatmulExtractingThreeDiagonals(grad_rhs, x) |
| return grad_diags, grad_rhs |
| |
| |
| def _TransposeTridiagonalMatrix(diags): |
| """Transposes a tridiagonal matrix. |
| |
| Args: |
| diags: the diagonals of the input matrix in the compact form (see |
| linalg_ops.tridiagonal_solve). |
| |
| Returns: |
| Diagonals of the transposed matrix in the compact form. |
| """ |
| |
| diag = diags[..., 1, :] |
| |
| if diags.shape.is_fully_defined(): |
| # For fully defined tensor we can concat with a tensor of zeros, which is |
| # faster than using array_ops.pad(). |
| zeros = array_ops.zeros(list(diags.shape[:-2]) + [1], dtype=diags.dtype) |
| superdiag = array_ops.concat((diags[..., 2, 1:], zeros), axis=-1) |
| subdiag = array_ops.concat((zeros, diags[..., 0, :-1]), axis=-1) |
| else: |
| rank = array_ops.rank(diags) |
| zeros = array_ops.zeros((rank - 2, 2), dtype=dtypes.int32) |
| superdiag_pad = array_ops.concat((zeros, array_ops.constant([[0, 1]])), |
| axis=0) |
| superdiag = array_ops.pad(diags[..., 2, 1:], superdiag_pad) |
| subdiag_pad = array_ops.concat((zeros, array_ops.constant([[1, 0]])), |
| axis=0) |
| subdiag = array_ops.pad(diags[..., 0, :-1], subdiag_pad) |
| return array_ops.stack([superdiag, diag, subdiag], axis=-2) |
| |
| |
| def _MatmulExtractingThreeDiagonals(x, y_tr): |
| """Multiplies matrices and extracts three diagonals from the product. |
| |
| With sizes M x K and K x M, this function takes O(MK) time and O(M) space, |
| while using math_ops.matmul, and then extracting the diagonals would take |
| O(M^2 K) time and O(M^2) space. |
| |
| Args: |
| x: first matrix |
| y_tr: second matrix transposed |
| |
| Returns: |
| Diagonals of the product in compact format (see |
| linalg_ops.tridiagonal_solve) |
| |
| """ |
| diag = math_ops.reduce_sum(x * y_tr, axis=-1) |
| |
| if y_tr.shape.is_fully_defined(): |
| zeros = array_ops.zeros( |
| list(x.shape[:-2]) + [1, x.shape[-1]], dtype=x.dtype) |
| superdiag = math_ops.reduce_sum( |
| x * array_ops.concat((y_tr[..., 1:, :], zeros), axis=-2), axis=-1) |
| subdiag = math_ops.reduce_sum( |
| x * array_ops.concat((zeros, y_tr[..., :-1, :]), axis=-2), axis=-1) |
| else: |
| rank = array_ops.rank(y_tr) |
| zeros = array_ops.zeros((rank - 2, 2), dtype=dtypes.int32) |
| superdiag_pad = array_ops.concat( |
| (zeros, array_ops.constant([[0, 1], [0, 0]])), axis=0) |
| superdiag = math_ops.reduce_sum( |
| x * array_ops.pad(y_tr[..., 1:, :], superdiag_pad), axis=-1) |
| subdiag_pad = array_ops.concat( |
| (zeros, array_ops.constant([[1, 0], [0, 0]])), axis=0) |
| subdiag = math_ops.reduce_sum( |
| x * array_ops.pad(y_tr[..., :-1, :], subdiag_pad), axis=-1) |
| return array_ops.stack([superdiag, diag, subdiag], axis=-2) |