Format according to code styles and test for complex SVD backprop
diff --git a/tensorflow/python/kernel_tests/svd_op_test.py b/tensorflow/python/kernel_tests/svd_op_test.py
index 278ec9d..bbcab12 100644
--- a/tensorflow/python/kernel_tests/svd_op_test.py
+++ b/tensorflow/python/kernel_tests/svd_op_test.py
@@ -406,7 +406,7 @@
             _AddTest(SvdGradOpTest, "SvdGrad", name,
                      _GetSvdGradOpTest(dtype, shape, compute_uv, full_matrices))
             # The results are too inacurate for float32.
-            if dtype == np.float64:
+            if dtype in (np.float64, np.complex128):
               _AddTest(
                   SvdGradGradOpTest, "SvdGradGrad", name,
                   _GetSvdGradGradOpTest(dtype, shape, compute_uv,
diff --git a/tensorflow/python/ops/linalg_grad.py b/tensorflow/python/ops/linalg_grad.py
index ea8174e..1d67b8a 100644
--- a/tensorflow/python/ops/linalg_grad.py
+++ b/tensorflow/python/ops/linalg_grad.py
@@ -352,7 +352,7 @@
   # Giles' paper (see reference at top of file).  A derivation for
   # the full_matrices=False case is available at
   # https://j-towns.github.io/papers/svd-derivative.pdf
-  # The derivation for complex valued SVD can be found in 
+  # The derivation for complex valued SVD can be found in
   # https://re-ra.xyz/misc/complexsvd.pdf or
   # https://giggleliu.github.io/2019/04/02/einsumbp.html
   a = op.inputs[0]
@@ -413,18 +413,18 @@
     # only defined up a (k-dimensional) subspace. In practice, this can
     # lead to numerical instability when singular values are close but not
     # exactly equal.
-    # To avoid nan in cases with degenrate sigular values or zero sigular values 
+    # To avoid nan in cases with degenrate sigular values or zero sigular values
     # in calculating f and s_inv_mat, we introduce a Lorentz brodening.
-    
-    def safe_reciprocal(x, epsilon=1E-20):
-        return x * math_ops.reciprocal(x * x + epsilon)
-    
+
+    def _SafeReciprocal(x, epsilon=1E-20):
+      return x * math_ops.reciprocal(x * x + epsilon)
+
     s_shape = array_ops.shape(s)
     f = array_ops.matrix_set_diag(
-        safe_reciprocal(
+        _SafeReciprocal(
             array_ops.expand_dims(s2, -2) - array_ops.expand_dims(s2, -1)
         ), array_ops.zeros_like(s))
-    s_inv_mat = array_ops.matrix_diag(safe_reciprocal(s))
+    s_inv_mat = array_ops.matrix_diag(_SafeReciprocal(s))
 
     v1 = v[..., :, :m]
     grad_v1 = grad_v[..., :, :m]
@@ -459,17 +459,19 @@
       term2 = math_ops.matmul(u_s_inv, term2_nous)
 
       grad_a_before_transpose = term1 + term2
-    
+
     if a.dtype.is_complex:
       eye = _linalg.eye(s_shape[-1], batch_shape=s_shape[:-1], dtype=a.dtype)
       l = eye * v_gv
       term3_nouv = math_ops.matmul(s_inv_mat, _linalg.adjoint(l)-l)
-      term3 = 1/2. * math_ops.matmul(u, math_ops.matmul(term3_nouv, v1, adjoint_b=True))
-        
+      term3 = 1/2. * math_ops.matmul(
+          u, math_ops.matmul(term3_nouv, v1, adjoint_b=True))
+
       grad_a_before_transpose += term3
 
     if use_adjoint:
-      grad_a = array_ops.matrix_transpose(grad_a_before_transpose, conjugate=True)
+      grad_a = array_ops.matrix_transpose(
+          grad_a_before_transpose, conjugate=True)
     else:
       grad_a = grad_a_before_transpose