Replace matmul with linalg.solve
diff --git a/tensorflow/python/ops/linalg_grad.py b/tensorflow/python/ops/linalg_grad.py
index 25f889c..6dbb7a3 100644
--- a/tensorflow/python/ops/linalg_grad.py
+++ b/tensorflow/python/ops/linalg_grad.py
@@ -658,7 +658,6 @@
if compute_v:
v = op.outputs[1]
vt = _linalg.adjoint(v)
- w = linalg_ops.matrix_inverse(vt)
# Construct the matrix f(i,j) = (i != j ? 1 / (e_i - e_j) : 0).
# Notice that because of the term involving f, the gradient becomes
# infinite (or NaN in practice) when eigenvalues are not unique.
@@ -675,12 +674,19 @@
diag_grad_part = array_ops.matrix_diag(array_ops.matrix_diag_part(
math_ops.cast(math_ops.real(vgv), vgv.dtype)))
mid += f * (vgv - math_ops.matmul(math_ops.matmul(vt, v), diag_grad_part))
- grad_a = math_ops.matmul(w, math_ops.matmul(mid, vt))
+ # vt is formally invertible as long as the original matrix is
+ # diagonalizable. However, in practice, vt may
+ # be ill-conditioned when matrix original matrix is close to
+ # non-diagonalizable one
+ grad_a = linalg_ops.solve(vt, math_ops.matmul(mid, vt))
else:
_, v = linalg_ops.eig(op.inputs[0])
vt = _linalg.adjoint(v)
- w = linalg_ops.matrix_inverse(vt)
- grad_a = math_ops.matmul(w,
+ # vt is formally invertible as long as the original matrix is
+ # diagonalizable. However, in practice, vt may
+ # be ill-conditioned when matrix original matrix is close to
+ # non-diagonalizable one
+ grad_a = linalg_ops.solve(vt,
math_ops.matmul(
array_ops.matrix_diag(grad_e),
vt))