Added compatibility guard to new sparse segment sum gradient code.
PiperOrigin-RevId: 375106664
Change-Id: I67fe3d5515ba86a4b3b003f881fface7a507ff9a
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 2a87d0f..12ab4fe 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -1928,7 +1928,10 @@
":math_ops",
":math_ops_gen",
":pywrap_tf_session",
+ ":special_math_ops",
+ "//tensorflow/python/compat",
"//tensorflow/python/eager:context",
+ "//tensorflow/python/framework:constant_op",
"//tensorflow/python/framework:for_generated_wrappers",
"//tensorflow/python/framework:tensor_util",
"//third_party/py/numpy",
diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py
index ed42179..a2c437e 100644
--- a/tensorflow/python/ops/math_grad.py
+++ b/tensorflow/python/ops/math_grad.py
@@ -20,6 +20,7 @@
import numpy as np
from tensorflow.python.client import pywrap_tf_session as c_api
+from tensorflow.python.compat import compat
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -342,16 +343,25 @@
def _SparseSegmentSumGrad(op, grad):
"""Gradient for SparseSegmentSum."""
dim0 = array_ops.shape(op.inputs[0])[0]
- return (math_ops.sparse_segment_sum_grad(grad, op.inputs[1], op.inputs[2],
- dim0), None, None)
+ if compat.forward_compatible(2021, 6, 10):
+ return (math_ops.sparse_segment_sum_grad(grad, op.inputs[1], op.inputs[2],
+ dim0), None, None)
+ else:
+ return (math_ops.unsorted_segment_sum(
+ array_ops.gather(grad, op.inputs[2]), op.inputs[1], dim0), None, None)
@ops.RegisterGradient("SparseSegmentSumWithNumSegments")
def _SparseSegmentSumWithNumSegmentsGrad(op, grad):
"""Gradient for SparseSegmentSumWithNumSegments."""
dim0 = array_ops.shape(op.inputs[0])[0]
- return (math_ops.sparse_segment_sum_grad(grad, op.inputs[1], op.inputs[2],
- dim0), None, None, None)
+ if compat.forward_compatible(2021, 6, 10):
+ return (math_ops.sparse_segment_sum_grad(grad, op.inputs[1], op.inputs[2],
+ dim0), None, None, None)
+ else:
+ return (math_ops.unsorted_segment_sum(
+ array_ops.gather(grad, op.inputs[2]), op.inputs[1],
+ dim0), None, None, None)
@ops.RegisterGradient("SparseSegmentMean")