Added compatibility guard to new sparse segment sum gradient code.

PiperOrigin-RevId: 375106664
Change-Id: I67fe3d5515ba86a4b3b003f881fface7a507ff9a
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 2a87d0f..12ab4fe 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -1928,7 +1928,10 @@
         ":math_ops",
         ":math_ops_gen",
         ":pywrap_tf_session",
+        ":special_math_ops",
+        "//tensorflow/python/compat",
         "//tensorflow/python/eager:context",
+        "//tensorflow/python/framework:constant_op",
         "//tensorflow/python/framework:for_generated_wrappers",
         "//tensorflow/python/framework:tensor_util",
         "//third_party/py/numpy",
diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py
index ed42179..a2c437e 100644
--- a/tensorflow/python/ops/math_grad.py
+++ b/tensorflow/python/ops/math_grad.py
@@ -20,6 +20,7 @@
 import numpy as np
 
 from tensorflow.python.client import pywrap_tf_session as c_api
+from tensorflow.python.compat import compat
 from tensorflow.python.eager import context
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
@@ -342,16 +343,25 @@
 def _SparseSegmentSumGrad(op, grad):
   """Gradient for SparseSegmentSum."""
   dim0 = array_ops.shape(op.inputs[0])[0]
-  return (math_ops.sparse_segment_sum_grad(grad, op.inputs[1], op.inputs[2],
-                                           dim0), None, None)
+  if compat.forward_compatible(2021, 6, 10):
+    return (math_ops.sparse_segment_sum_grad(grad, op.inputs[1], op.inputs[2],
+                                             dim0), None, None)
+  else:
+    return (math_ops.unsorted_segment_sum(
+        array_ops.gather(grad, op.inputs[2]), op.inputs[1], dim0), None, None)
 
 
 @ops.RegisterGradient("SparseSegmentSumWithNumSegments")
 def _SparseSegmentSumWithNumSegmentsGrad(op, grad):
   """Gradient for SparseSegmentSumWithNumSegments."""
   dim0 = array_ops.shape(op.inputs[0])[0]
-  return (math_ops.sparse_segment_sum_grad(grad, op.inputs[1], op.inputs[2],
-                                           dim0), None, None, None)
+  if compat.forward_compatible(2021, 6, 10):
+    return (math_ops.sparse_segment_sum_grad(grad, op.inputs[1], op.inputs[2],
+                                             dim0), None, None, None)
+  else:
+    return (math_ops.unsorted_segment_sum(
+        array_ops.gather(grad, op.inputs[2]), op.inputs[1],
+        dim0), None, None, None)
 
 
 @ops.RegisterGradient("SparseSegmentMean")