Use new GPU kernel for [unsorted] segment reductions

- Optionally replaces the old atomics-based kernels with calls to
  SegmentReduceGPU (the same kernel already used for sparse segment
  reductions). This behavior is enabled by default, but the old kernels
  can be re-enabled by setting the environment variable
  TF_USE_ATOMIC_SEGMENT_REDUCTIONS=1. On Windows, the old kernels are
  always used due to a build issue with the new kernel.
- This improves performance, and guarantees that these ops are
  deterministic. In future it is hoped that the old kernels can be
  removed completely.
- Also adds a GPU kernel registration for SegmentMean, which didn't
  previously exist.
diff --git a/tensorflow/core/kernels/segment_reduction_ops.h b/tensorflow/core/kernels/segment_reduction_ops.h
index 8faf8ac..49035bf 100644
--- a/tensorflow/core/kernels/segment_reduction_ops.h
+++ b/tensorflow/core/kernels/segment_reduction_ops.h
@@ -16,13 +16,6 @@
 #ifndef TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_
 #define TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_
 
-// This file requires the following include because it uses GpuAtomicMax:
-// #include "tensorflow/core/util/gpu_kernel_helper.h"
-
-// Unfortunately we can't add the #include, since it breaks compilation for
-// non-GPU targets. This only breaks in clang, because it's more strict for
-// template code and GpuAtomicMax is used in template context.
-
 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/framework/tensor_shape.h"
@@ -32,6 +25,7 @@
 
 class OpKernelContext;
 
+bool UseAtomicSegmentReductions();
 bool DisableSegmentReductionOpDeterminismExceptions();
 
 // Type of SparseSegmentReduction operation to perform gradient of.
@@ -40,9 +34,51 @@
 namespace functor {
 
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+
+// Note that we define this ourselves to avoid a dependency on gpuprim.
+struct Sum {
+  template <typename T>
+  __host__ __device__ T operator()(const T& a, const T& b) const {
+    return a + b;
+  }
+};
+
+struct Prod {
+  template <typename T>
+  __host__ __device__ T operator()(const T& a, const T& b) const {
+    return a * b;
+  }
+};
+
+// Note that we don't use gpuprim::Min/Max because they use operator<, which is
+// not implemented for AlignedVector types.
+struct Min {
+  template <typename T>
+  __host__ __device__ T operator()(const T& a, const T& b) const {
+    return min(a, b);
+  }
+};
+
+struct Max {
+  template <typename T>
+  __host__ __device__ T operator()(const T& a, const T& b) const {
+    return max(a, b);
+  }
+};
+
+template <typename ReduceOp, typename T>
+struct ReduceOpIsAssociative {};
+template <typename T>
+struct ReduceOpIsAssociative<functor::Sum, T> : std::is_integral<T> {};
+template <typename T>
+struct ReduceOpIsAssociative<functor::Prod, T> : std::is_integral<T> {};
+template <typename T>
+struct ReduceOpIsAssociative<functor::Max, T> : std::true_type {};
+template <typename T>
+struct ReduceOpIsAssociative<functor::Min, T> : std::true_type {};
+
 typedef Eigen::GpuDevice GPUDevice;
-// Functor for SegmentSumGPUOp & SegmentProdGPUOp & SegmentMaxGPUOp
-//             & SegmentMinGPUOp.
+// Functor for SegmentReductionGPUOp.
 // output_rows: the number of output segments (unique segment ids in
 //                'segment_ids').
 // segment_ids_shape: shape of 'segment_ids' tensor.
@@ -52,18 +88,18 @@
 // data: input data tensor.
 // output: output reshaped to {output_rows, output.size/output_rows}
 template <typename T, typename Index, typename InitialValueF,
-          typename ReductionF, typename AtomicReductionF>
+          typename EmptySegmentValueF, typename ReductionF>
 struct SegmentReductionFunctor {
   void operator()(OpKernelContext* ctx, const GPUDevice& d,
                   const Index output_rows, const TensorShape& segment_ids_shape,
-                  typename TTypes<Index>::ConstFlat segment_ids,
+                  bool is_mean, typename TTypes<Index>::ConstFlat segment_ids,
                   const Index data_size, const T* data,
                   typename TTypes<T, 2>::Tensor output);
   static constexpr bool atomic_reduction_is_associative =
-      AtomicReductionF::is_associative;
+      ReduceOpIsAssociative<ReductionF, T>::value;
 };
 
-#endif
+#endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
 template <typename Device, typename T, typename Index, typename InitialValueF,
           typename ReductionF>
@@ -74,80 +110,6 @@
                   typename TTypes<T, 2>::Tensor output);
 };
 
-#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
-
-// Atomic reduction functors for the gpu.
-template <typename T>
-struct AtomicSumOpGpu {
-  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest,
-                                                        const T& value) {
-    GpuAtomicAdd(dest, value);
-  }
-  static constexpr bool is_associative = std::is_integral<T>::value;
-};
-
-template <typename T>
-struct AtomicProdOpGpu {
-  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest,
-                                                        const T& value) {
-    GpuAtomicMul(dest, value);
-  }
-  static constexpr bool is_associative = std::is_integral<T>::value;
-};
-
-template <typename T>
-struct AtomicMaxOpGpu {
-  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest,
-                                                        const T& value) {
-    GpuAtomicMax(dest, value);
-  }
-  static constexpr bool is_associative = true;
-};
-
-template <typename T>
-struct AtomicMinOpGpu {
-  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest,
-                                                        const T& value) {
-    GpuAtomicMin(dest, value);
-  }
-  static constexpr bool is_associative = true;
-};
-
-// Non-atomic reduction functors for the gpu.
-template <typename T>
-struct NonAtomicSumOpGpu {
-  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest,
-                                                        const T& value) {
-    *dest += value;
-  }
-};
-
-template <typename T>
-struct NonAtomicProdOpGpu {
-  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest,
-                                                        const T& value) {
-    *dest *= value;
-  }
-};
-
-template <typename T>
-struct NonAtomicMaxOpGpu {
-  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest,
-                                                        const T& value) {
-    *dest = max(*dest, value);
-  }
-};
-
-template <typename T>
-struct NonAtomicMinOpGpu {
-  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest,
-                                                        const T& value) {
-    *dest = min(*dest, value);
-  }
-};
-
-#endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
-
 // Initial value functors.
 template <typename T>
 struct Zero {
diff --git a/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.h b/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.h
index 19805df..2e75683 100644
--- a/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.h
+++ b/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.h
@@ -17,12 +17,7 @@
 
 #define EIGEN_USE_GPU
 
-// We need to include gpu_kernel_helper.h before segment_reduction_ops.h
-// See comment in segment_reduction_ops.h for more details.
-// clang-format off
-#include "tensorflow/core/util/gpu_kernel_helper.h"
-// clang-format on
-
+#include "tensorflow/core/framework/register_types.h"
 #include "tensorflow/core/kernels/gpu_prim.h"
 #include "tensorflow/core/kernels/gpu_prim_helpers.h"
 #include "tensorflow/core/kernels/segment_reduction_ops.h"
@@ -30,12 +25,48 @@
 #include "tensorflow/core/util/determinism.h"
 #include "tensorflow/core/util/env_var.h"
 #include "tensorflow/core/util/gpu_device_functions.h"
+#include "tensorflow/core/util/gpu_kernel_helper.h"
 #include "tensorflow/core/util/permutation_input_iterator.h"
 
 namespace tensorflow {
 
 using GPUDevice = Eigen::GpuDevice;
 
+// Non/Atomic reduction functors for the gpu.
+#define DEFINE_REDUCE_UPDATE_OP_GPU(name, func)                             \
+  struct name##OpGpu {                                                      \
+    template <typename T>                                                   \
+    EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest,          \
+                                                          const T& value) { \
+      func;                                                                 \
+    }                                                                       \
+  };
+DEFINE_REDUCE_UPDATE_OP_GPU(AtomicSum, GpuAtomicAdd(dest, value))
+DEFINE_REDUCE_UPDATE_OP_GPU(AtomicProd, GpuAtomicMul(dest, value))
+DEFINE_REDUCE_UPDATE_OP_GPU(AtomicMax, GpuAtomicMax(dest, value))
+DEFINE_REDUCE_UPDATE_OP_GPU(AtomicMin, GpuAtomicMin(dest, value))
+DEFINE_REDUCE_UPDATE_OP_GPU(NonAtomicSum, *dest += value)
+DEFINE_REDUCE_UPDATE_OP_GPU(NonAtomicProd, *dest *= value)
+DEFINE_REDUCE_UPDATE_OP_GPU(NonAtomicMax, *dest = max(*dest, value))
+DEFINE_REDUCE_UPDATE_OP_GPU(NonAtomicMin, *dest = min(*dest, value))
+#undef DEFINE_REDUCE_UPDATE_OP_GPU
+
+template <typename ReduceOp>
+struct ReduceUpdateOpFor {
+};
+
+#define DEFINE_REDUCE_UPDATE_OP_FOR(reduce_op, atomic, nonatomic) \
+  template <>                                                     \
+  struct ReduceUpdateOpFor<reduce_op> {                           \
+    using atomic_op = atomic;                                     \
+    using nonatomic_op = nonatomic;                               \
+  };
+DEFINE_REDUCE_UPDATE_OP_FOR(functor::Sum, AtomicSumOpGpu, NonAtomicSumOpGpu)
+DEFINE_REDUCE_UPDATE_OP_FOR(functor::Prod, AtomicProdOpGpu, NonAtomicProdOpGpu)
+DEFINE_REDUCE_UPDATE_OP_FOR(functor::Max, AtomicMaxOpGpu, NonAtomicMaxOpGpu)
+DEFINE_REDUCE_UPDATE_OP_FOR(functor::Min, AtomicMinOpGpu, NonAtomicMinOpGpu)
+#undef DEFINE_REDUCE_UPDATE_OP_FOR
+
 // SortedSegmentReductionFunctor kernel reduces input data just as
 // UnsortedSegmentReductionCustomKernel does except that input data
 // is partitioned along the outer reduction dimension. This is
@@ -108,6 +139,35 @@
   }
 }
 
+template <typename SegmentId, typename Index, typename T>
+__global__ void SegmentMeanNormalizeKernel(
+    SegmentId nsegments, Index ninner,
+    const Index* __restrict__ segment_offsets,  // [nsegments + 1]
+    T* __restrict__ output) {                   // [nsegments, ninner]
+  for(SegmentId seg : GpuGridRangeY(nsegments)) {
+    SegmentId segment_size = segment_offsets[seg + 1] - segment_offsets[seg];
+    segment_size = max(segment_size, Index(1));  // Avoid division by zero
+    T inv_norm = T(1) / static_cast<T>(segment_size);
+    for(Index i : GpuGridRangeX(ninner)) {
+      output[seg * ninner + i] *= inv_norm;
+    }
+  }
+}
+
+template <typename SegmentId, typename Index, typename T>
+Status LaunchSegmentMeanNormalizeKernel(
+    const GPUDevice& d, SegmentId nsegments, Index ninner,
+    const Index* __restrict__ segment_offsets,  // [nsegments + 1]
+    T* __restrict__ output) {                   // [nsegments, ninner]
+  Gpu2DLaunchConfig config = GetGpu2DLaunchConfig(
+      ninner, nsegments, d, SegmentMeanNormalizeKernel<SegmentId, Index, T>,
+      /*dynamic_shared_memory_size=*/0, /*block_size_limit=*/0);
+  return GpuLaunchKernel(SegmentMeanNormalizeKernel<SegmentId, Index, T>,
+                         config.block_count, config.thread_per_block, 0,
+                         d.stream(), nsegments, ninner, segment_offsets,
+                         output);
+}
+
 // UnsortedSegmentSumKernel processes 'input_total_size' elements.
 // Each element is mapped from input to output by a combination of its
 // 'segment_ids' mapping and 'inner_dim_size'.
@@ -623,37 +683,27 @@
 
 // Sum fp16 values using an fp32 accumulator to avoid numerical issues.
 template <>
-struct ReduceType<gpuprim::Sum, Eigen::half> {
+struct ReduceType<functor::Sum, Eigen::half> {
   using type = float;
 };
 
 namespace functor {
 
 template <typename T, typename Index, typename InitialValueF,
-          typename ReductionF, typename AtomicReductionF>
+          typename EmptySegmentValueF, typename ReductionF>
 void SegmentReductionFunctor<
-    T, Index, InitialValueF, ReductionF,
-    AtomicReductionF>::operator()(OpKernelContext* ctx, const GPUDevice& d,
-                                  const Index output_rows,
-                                  const TensorShape& segment_ids_shape,
-                                  typename TTypes<Index>::ConstFlat segment_ids,
-                                  const Index data_size, const T* data,
-                                  typename TTypes<T, 2>::Tensor output) {
+    T, Index, InitialValueF, EmptySegmentValueF,
+    ReductionF>::operator()(OpKernelContext* ctx, const GPUDevice& d,
+                            const Index output_rows,
+                            const TensorShape& segment_ids_shape, bool is_mean,
+                            typename TTypes<Index>::ConstFlat segment_ids,
+                            const Index data_size, const T* data,
+                            typename TTypes<T, 2>::Tensor output) {
   if (output.size() == 0) {
     return;
   }
 
-  // Set 'output' to initial value.
-  GpuLaunchConfig config = GetGpuLaunchConfig(output.size(), d);
-  const T InitialValue = InitialValueF()();
-  TF_CHECK_OK(GpuLaunchKernel(SetToValue<T>, config.block_count,
-                              config.thread_per_block, 0, d.stream(),
-                              output.size(), output.data(), InitialValue));
-  if (data_size == 0 || segment_ids_shape.num_elements() == 0) {
-    return;
-  }
-
-  // Launch kernel to compute sorted segment reduction.
+  // Launch kernel(s) to compute sorted segment reduction.
   // Notes:
   // *) 'input_total_size' is the total number of elements to process.
   // *) 'segment_ids.shape' is a prefix of data's shape.
@@ -661,30 +711,84 @@
   const Index input_total_size = data_size;
   const Index input_outer_dim_size = segment_ids.dimension(0);
   const Index input_inner_dim_size = input_total_size / input_outer_dim_size;
+  const Index num_segments = output.size() / input_inner_dim_size;
 
-  const int OuterDimTileSize = 8;
+  // TODO(benbarsdell): If there are no performance concerns with the new
+  // non-atomic kernels, remove this runtime check and only compile the old
+  // atomic kernels on Windows (as a workaround for the build failure issue).
+  if (UseAtomicSegmentReductions()) {
+    // Set 'output' to initial value.
+    GpuLaunchConfig config = GetGpuLaunchConfig(output.size(), d);
+    const T InitialValue = InitialValueF()();
+    TF_CHECK_OK(GpuLaunchKernel(SetToValue<T>, config.block_count,
+                                config.thread_per_block, 0, d.stream(),
+                                output.size(), output.data(), InitialValue));
+    if (data_size == 0 || segment_ids_shape.num_elements() == 0) {
+      return;
+    }
 
-  const Index input_outer_dim_num_stripe =
-      Eigen::divup(input_outer_dim_size, Index(OuterDimTileSize));
+    const int OuterDimTileSize = 8;
 
-  const Index total_stripe_count =
-      input_inner_dim_size * input_outer_dim_num_stripe;
+    const Index input_outer_dim_num_stripe =
+        Eigen::divup(input_outer_dim_size, Index(OuterDimTileSize));
 
-  config = GetGpuLaunchConfig(total_stripe_count, d);
-  TF_CHECK_OK(GpuLaunchKernel(
-      SortedSegmentReductionCustomKernel<T, Index, OuterDimTileSize, ReductionF,
-                                         AtomicReductionF>,
-      config.block_count, config.thread_per_block, 0, d.stream(),
-      input_outer_dim_size, input_inner_dim_size, output_rows,
-      segment_ids.data(), data, output.data(), total_stripe_count,
-      InitialValue));
+    const Index total_stripe_count =
+        input_inner_dim_size * input_outer_dim_num_stripe;
+
+    config = GetGpuLaunchConfig(total_stripe_count, d);
+    TF_CHECK_OK(GpuLaunchKernel(
+        SortedSegmentReductionCustomKernel<
+            T, Index, OuterDimTileSize,
+            typename ReduceUpdateOpFor<ReductionF>::nonatomic_op,
+            typename ReduceUpdateOpFor<ReductionF>::atomic_op>,
+        config.block_count, config.thread_per_block, 0, d.stream(),
+        input_outer_dim_size, input_inner_dim_size, output_rows,
+        segment_ids.data(), data, output.data(), total_stripe_count,
+        InitialValue));
+
+    if (is_mean) {
+      Tensor segment_offsets;
+      OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<Index>::value,
+                                             TensorShape({num_segments + 1}),
+                                             &segment_offsets));
+      Index* segment_offsets_ptr = segment_offsets.flat<Index>().data();
+      OP_REQUIRES_OK(ctx, LaunchSegmentOffsetsKernel(
+                              d, input_outer_dim_size, num_segments,
+                              segment_ids.data(), segment_offsets_ptr));
+
+      OP_REQUIRES_OK(ctx, LaunchSegmentMeanNormalizeKernel(
+                              d, num_segments, input_inner_dim_size,
+                              segment_offsets_ptr, output.data()));
+    }
+  } else {
+    // See comment in segment_reduction_ops_gpu_0.cu.cc regarding Windows CI
+    // build error.
+#if !defined(PLATFORM_WINDOWS)
+    using Treduce = typename ReduceType<ReductionF, T>::type;
+    OP_REQUIRES_OK(
+        ctx,
+        SegmentReduceGPU<Treduce>(
+            ctx, input_outer_dim_size, input_inner_dim_size, num_segments,
+            ReductionF(), InitialValueF()(), EmptySegmentValueF()(),
+            /*is_mean=*/is_mean, /*is_sqrtn=*/false, data, segment_ids.data(),
+            /*indices=*/static_cast<const Index*>(nullptr),
+            /*weights=*/static_cast<T*>(nullptr), output.data()));
+#else
+    // Note: Shouldn't reach here because UseAtomicSegmentReductions() always
+    // returns true on Windows.
+    OP_REQUIRES(
+        ctx, false,
+        errors::Unimplemented(
+            "Non-atomic segment reductions are not implemented on Windows."));
+#endif
+  }
 }
 
 template <typename T, typename Index, typename InitialValueF,
           typename ReductionF>
 struct UnsortedSegmentFunctor<GPUDevice, T, Index, InitialValueF, ReductionF> {
   void operator()(OpKernelContext* ctx, const TensorShape& segment_ids_shape,
-                  typename TTypes<Index>::ConstFlat segment_ids,
+                  typename TTypes<Index>::ConstFlat unsorted_segment_ids,
                   typename TTypes<T, 2>::ConstTensor data,
                   typename TTypes<T, 2>::Tensor output) {
     if (output.size() == 0) {
@@ -692,7 +796,9 @@
     }
 
     bool determinism_requirement_met =
-        ReductionF::is_associative || !OpDeterminismRequired() ||
+        !UseAtomicSegmentReductions() ||
+        ReduceOpIsAssociative<ReductionF, T>::value ||
+        !OpDeterminismRequired() ||
         DisableSegmentReductionOpDeterminismExceptions();
     OP_REQUIRES(
         ctx, determinism_requirement_met,
@@ -700,31 +806,84 @@
             "Deterministic GPU implementation of unsorted segment reduction op"
             " not available."));
 
-    // Set 'output' to initial value.
-    GPUDevice d = ctx->template eigen_device<GPUDevice>();
-    GpuLaunchConfig config = GetGpuLaunchConfig(output.size(), d);
-    TF_CHECK_OK(GpuLaunchKernel(
-        SetToValue<T>, config.block_count, config.thread_per_block, 0,
-        d.stream(), output.size(), output.data(), InitialValueF()()));
-    const int64_t data_size = data.size();
-    if (data_size == 0 || segment_ids_shape.num_elements() == 0) {
-      return;
-    }
-    // Launch kernel to compute unsorted segment reduction.
+    // Launch kernel(s) to compute unsorted segment reduction.
     // Notes:
     // *) 'data_size' is the total number of elements to process.
     // *) 'segment_ids.shape' is a prefix of data's shape.
     // *) 'input_outer_dim_size' is the total number of segments to process.
-    const int64_t input_outer_dim_size = segment_ids.dimension(0);
-    const int64_t input_inner_dim_size = data.dimension(1);
-    const int64_t output_outer_dim_size = output.dimension(0);
-    config = GetGpuLaunchConfig(data_size, d);
+    const Index input_outer_dim_size = unsorted_segment_ids.dimension(0);
+    const Index input_inner_dim_size = data.dimension(1);
+    const Index output_outer_dim_size = output.dimension(0);
+    const Index num_segments = output.size() / input_inner_dim_size;
 
-    TF_CHECK_OK(GpuLaunchKernel(
-        UnsortedSegmentCustomKernel<T, Index, ReductionF>, config.block_count,
-        config.thread_per_block, 0, d.stream(), input_outer_dim_size,
-        input_inner_dim_size, output_outer_dim_size, segment_ids.data(),
-        data.data(), output.data()));
+    // TODO(benbarsdell): If there are no performance concerns with the new
+    // non-atomic kernels, remove this runtime check and only compile the old
+    // atomic kernels on Windows (as a workaround for the build failure issue).
+    if (UseAtomicSegmentReductions()) {
+      // Set 'output' to initial value.
+      GPUDevice d = ctx->template eigen_device<GPUDevice>();
+      GpuLaunchConfig config = GetGpuLaunchConfig(output.size(), d);
+      TF_CHECK_OK(GpuLaunchKernel(
+          SetToValue<T>, config.block_count, config.thread_per_block, 0,
+          d.stream(), output.size(), output.data(), InitialValueF()()));
+      const int64_t data_size = data.size();
+      if (data_size == 0 || segment_ids_shape.num_elements() == 0) {
+        return;
+      }
+      config = GetGpuLaunchConfig(data_size, d);
+      TF_CHECK_OK(GpuLaunchKernel(
+          UnsortedSegmentCustomKernel<
+              T, Index, typename ReduceUpdateOpFor<ReductionF>::atomic_op>,
+          config.block_count, config.thread_per_block, 0, d.stream(),
+          input_outer_dim_size, input_inner_dim_size, output_outer_dim_size,
+          unsorted_segment_ids.data(), data.data(), output.data()));
+    } else {
+      // See comment in segment_reduction_ops_gpu_0.cu.cc regarding Windows CI
+      // build error.
+#if !defined(PLATFORM_WINDOWS)
+      // Allocate temporary space and sort segment_ids, then call the sorted
+      // implem.
+      Tensor segment_ids;
+      OP_REQUIRES_OK(
+          ctx, ctx->allocate_temp(
+                   DataTypeToEnum<Index>::value,
+                   TensorShape({static_cast<int64_t>(input_outer_dim_size)}),
+                   &segment_ids));
+      Index* segment_ids_ptr = segment_ids.flat<Index>().data();
+      Tensor sorted_indices;
+      OP_REQUIRES_OK(
+          ctx, ctx->allocate_temp(
+                   DataTypeToEnum<Index>::value,
+                   TensorShape({static_cast<int64_t>(input_outer_dim_size)}),
+                   &sorted_indices));
+      Index* sorted_indices_ptr = sorted_indices.flat<Index>().data();
+      // Note: We must sort using all bits here because unsorted_segment_ids
+      // may contain negative values.
+      OP_REQUIRES_OK(
+          ctx, GpuRadixSort(ctx, input_outer_dim_size,
+                            /*keys_in=*/unsorted_segment_ids.data(),
+                            /*keys_out=*/segment_ids_ptr,
+                            /*indices_in=*/static_cast<const Index*>(nullptr),
+                            /*indices_out=*/sorted_indices_ptr));
+      using Treduce = typename ReduceType<ReductionF, T>::type;
+      OP_REQUIRES_OK(
+          ctx,
+          SegmentReduceGPU<Treduce>(
+              ctx, input_outer_dim_size, input_inner_dim_size, num_segments,
+              ReductionF(), /*initial_value=*/InitialValueF()(),
+              /*empty_segment_value=*/InitialValueF()(), /*is_mean=*/false,
+              /*is_sqrtn=*/false, /*input=*/data.data(),
+              /*segment_ids=*/segment_ids_ptr, /*indices=*/sorted_indices_ptr,
+              /*weights=*/static_cast<T*>(nullptr), output.data()));
+#else
+      // Note: Shouldn't reach here because UseAtomicSegmentReductions() always
+      // returns true on Windows.
+      OP_REQUIRES(
+          ctx, false,
+          errors::Unimplemented("Non-atomic unsorted segment reductions "
+                                "are not implemented on Windows."));
+#endif
+    }
   }
 };
 
@@ -735,7 +894,7 @@
     typename TTypes<Index>::ConstVec indices,
     typename TTypes<SegmentId>::ConstVec segment_ids,
     typename TTypes<T, 2>::Tensor output) {
-  using ReduceOp = gpuprim::Sum;
+  using ReduceOp = functor::Sum;
   using Treduce = typename ReduceType<ReduceOp, T>::type;
   Index nouter = segment_ids.size();
   Index ninner = input.dimension(1);
diff --git a/tensorflow/core/kernels/segment_reduction_ops_gpu_0.cu.cc b/tensorflow/core/kernels/segment_reduction_ops_gpu_0.cu.cc
index f77bda1..f2af012 100644
--- a/tensorflow/core/kernels/segment_reduction_ops_gpu_0.cu.cc
+++ b/tensorflow/core/kernels/segment_reduction_ops_gpu_0.cu.cc
@@ -20,6 +20,22 @@
 
 namespace tensorflow {
 
+bool UseAtomicSegmentReductions() {
+  // See comment below regarding CI build error on Windows.
+#if !defined(PLATFORM_WINDOWS)
+  static bool cached_result = [] {
+    bool result = false;
+    TF_CHECK_OK(
+        tensorflow::ReadBoolFromEnvVar("TF_USE_ATOMIC_SEGMENT_REDUCTIONS",
+                                       /*default_val=*/false, &result));
+    return result;
+  }();
+  return cached_result;
+#else
+  return true;
+#endif
+}
+
 bool DisableSegmentReductionOpDeterminismExceptions() {
   static bool cached_disable = [] {
     bool disable = false;
@@ -33,36 +49,32 @@
 
 namespace functor {
 
-#define DEFINE_SORTED_GPU_SPECS_INDEX(T, Index)                           \
-  template struct SegmentReductionFunctor<T, Index, functor::Zero<T>,     \
-                                          functor::NonAtomicSumOpGpu<T>,  \
-                                          functor::AtomicSumOpGpu<T>>;    \
-  template struct SegmentReductionFunctor<T, Index, functor::One<T>,      \
-                                          functor::NonAtomicProdOpGpu<T>, \
-                                          functor::AtomicProdOpGpu<T>>;   \
-  template struct SegmentReductionFunctor<T, Index, functor::Highest<T>,  \
-                                          functor::NonAtomicMinOpGpu<T>,  \
-                                          functor::AtomicMinOpGpu<T>>;    \
-  template struct SegmentReductionFunctor<T, Index, functor::Lowest<T>,   \
-                                          functor::NonAtomicMaxOpGpu<T>,  \
-                                          functor::AtomicMaxOpGpu<T>>;
+#define DEFINE_SORTED_GPU_SPECS_INDEX(T, Index)                            \
+  template struct SegmentReductionFunctor<T, Index, functor::Zero<T>,      \
+                                          functor::Zero<T>, functor::Sum>; \
+  template struct SegmentReductionFunctor<T, Index, functor::One<T>,       \
+                                          functor::One<T>, functor::Prod>; \
+  template struct SegmentReductionFunctor<T, Index, functor::Highest<T>,   \
+                                          functor::Zero<T>, functor::Min>; \
+  template struct SegmentReductionFunctor<T, Index, functor::Lowest<T>,    \
+                                          functor::Zero<T>, functor::Max>;
 
 #define DEFINE_SORTED_GPU_SPECS(T) DEFINE_SORTED_GPU_SPECS_INDEX(T, int32);
 
 TF_CALL_GPU_NUMBER_TYPES(DEFINE_SORTED_GPU_SPECS);
 
 #define DEFINE_REAL_UNSORTED_GPU_SPECS_INDEX(T, Index)                         \
-  template struct UnsortedSegmentFunctor<                                      \
-      GPUDevice, T, Index, functor::Lowest<T>, functor::AtomicMaxOpGpu<T>>;    \
-  template struct UnsortedSegmentFunctor<                                      \
-      GPUDevice, T, Index, functor::Highest<T>, functor::AtomicMinOpGpu<T>>;   \
+  template struct UnsortedSegmentFunctor<GPUDevice, T, Index,                  \
+                                         functor::Lowest<T>, functor::Max>;    \
+  template struct UnsortedSegmentFunctor<GPUDevice, T, Index,                  \
+                                         functor::Highest<T>, functor::Min>;   \
   template struct UnsortedSegmentFunctor<GPUDevice, T, Index, functor::One<T>, \
-                                         functor::AtomicProdOpGpu<T>>;
+                                         functor::Prod>;
 
 // Sum is the only op that supports all input types currently.
-#define DEFINE_SUM_UNSORTED_GPU_SPECS_INDEX(T, Index) \
-  template struct UnsortedSegmentFunctor<             \
-      GPUDevice, T, Index, functor::Zero<T>, functor::AtomicSumOpGpu<T>>;
+#define DEFINE_SUM_UNSORTED_GPU_SPECS_INDEX(T, Index)         \
+  template struct UnsortedSegmentFunctor<GPUDevice, T, Index, \
+                                         functor::Zero<T>, functor::Sum>;
 
 #define DEFINE_REAL_GPU_SPECS(T) DEFINE_REAL_UNSORTED_GPU_SPECS_INDEX(T, int32);
 
diff --git a/tensorflow/core/kernels/segment_reduction_ops_gpu_1.cu.cc b/tensorflow/core/kernels/segment_reduction_ops_gpu_1.cu.cc
index a109d74..7705b1f 100644
--- a/tensorflow/core/kernels/segment_reduction_ops_gpu_1.cu.cc
+++ b/tensorflow/core/kernels/segment_reduction_ops_gpu_1.cu.cc
@@ -21,36 +21,32 @@
 namespace tensorflow {
 namespace functor {
 
-#define DEFINE_SORTED_GPU_SPECS_INDEX(T, Index)                           \
-  template struct SegmentReductionFunctor<T, Index, functor::Zero<T>,     \
-                                          functor::NonAtomicSumOpGpu<T>,  \
-                                          functor::AtomicSumOpGpu<T>>;    \
-  template struct SegmentReductionFunctor<T, Index, functor::One<T>,      \
-                                          functor::NonAtomicProdOpGpu<T>, \
-                                          functor::AtomicProdOpGpu<T>>;   \
-  template struct SegmentReductionFunctor<T, Index, functor::Highest<T>,  \
-                                          functor::NonAtomicMinOpGpu<T>,  \
-                                          functor::AtomicMinOpGpu<T>>;    \
-  template struct SegmentReductionFunctor<T, Index, functor::Lowest<T>,   \
-                                          functor::NonAtomicMaxOpGpu<T>,  \
-                                          functor::AtomicMaxOpGpu<T>>;
+#define DEFINE_SORTED_GPU_SPECS_INDEX(T, Index)                            \
+  template struct SegmentReductionFunctor<T, Index, functor::Zero<T>,      \
+                                          functor::Zero<T>, functor::Sum>; \
+  template struct SegmentReductionFunctor<T, Index, functor::One<T>,       \
+                                          functor::One<T>, functor::Prod>; \
+  template struct SegmentReductionFunctor<T, Index, functor::Highest<T>,   \
+                                          functor::Zero<T>, functor::Min>; \
+  template struct SegmentReductionFunctor<T, Index, functor::Lowest<T>,    \
+                                          functor::Zero<T>, functor::Max>;
 
 #define DEFINE_SORTED_GPU_SPECS(T) DEFINE_SORTED_GPU_SPECS_INDEX(T, int64_t);
 
 TF_CALL_GPU_NUMBER_TYPES(DEFINE_SORTED_GPU_SPECS);
 
 #define DEFINE_REAL_UNSORTED_GPU_SPECS_INDEX(T, Index)                         \
-  template struct UnsortedSegmentFunctor<                                      \
-      GPUDevice, T, Index, functor::Lowest<T>, functor::AtomicMaxOpGpu<T>>;    \
-  template struct UnsortedSegmentFunctor<                                      \
-      GPUDevice, T, Index, functor::Highest<T>, functor::AtomicMinOpGpu<T>>;   \
+  template struct UnsortedSegmentFunctor<GPUDevice, T, Index,                  \
+                                         functor::Lowest<T>, functor::Max>;    \
+  template struct UnsortedSegmentFunctor<GPUDevice, T, Index,                  \
+                                         functor::Highest<T>, functor::Min>;   \
   template struct UnsortedSegmentFunctor<GPUDevice, T, Index, functor::One<T>, \
-                                         functor::AtomicProdOpGpu<T>>;
+                                         functor::Prod>;
 
 // Sum is the only op that supports all input types currently.
-#define DEFINE_SUM_UNSORTED_GPU_SPECS_INDEX(T, Index) \
-  template struct UnsortedSegmentFunctor<             \
-      GPUDevice, T, Index, functor::Zero<T>, functor::AtomicSumOpGpu<T>>;
+#define DEFINE_SUM_UNSORTED_GPU_SPECS_INDEX(T, Index)         \
+  template struct UnsortedSegmentFunctor<GPUDevice, T, Index, \
+                                         functor::Zero<T>, functor::Sum>;
 
 #define DEFINE_REAL_GPU_SPECS(T) \
   DEFINE_REAL_UNSORTED_GPU_SPECS_INDEX(T, int64_t);
diff --git a/tensorflow/core/kernels/segment_reduction_ops_gpu_2.cu.cc b/tensorflow/core/kernels/segment_reduction_ops_gpu_2.cu.cc
index 4ff00e9..45bd46d 100644
--- a/tensorflow/core/kernels/segment_reduction_ops_gpu_2.cu.cc
+++ b/tensorflow/core/kernels/segment_reduction_ops_gpu_2.cu.cc
@@ -22,17 +22,17 @@
 namespace functor {
 
 #define DEFINE_REAL_UNSORTED_GPU_SPECS_INDEX(T, Index)                         \
-  template struct UnsortedSegmentFunctor<                                      \
-      GPUDevice, T, Index, functor::Lowest<T>, functor::AtomicMaxOpGpu<T>>;    \
-  template struct UnsortedSegmentFunctor<                                      \
-      GPUDevice, T, Index, functor::Highest<T>, functor::AtomicMinOpGpu<T>>;   \
+  template struct UnsortedSegmentFunctor<GPUDevice, T, Index,                  \
+                                         functor::Lowest<T>, functor::Max>;    \
+  template struct UnsortedSegmentFunctor<GPUDevice, T, Index,                  \
+                                         functor::Highest<T>, functor::Min>;   \
   template struct UnsortedSegmentFunctor<GPUDevice, T, Index, functor::One<T>, \
-                                         functor::AtomicProdOpGpu<T>>;
+                                         functor::Prod>;
 
 // Sum is the only op that supports all input types currently.
-#define DEFINE_SUM_UNSORTED_GPU_SPECS_INDEX(T, Index) \
-  template struct UnsortedSegmentFunctor<             \
-      GPUDevice, T, Index, functor::Zero<T>, functor::AtomicSumOpGpu<T>>;
+#define DEFINE_SUM_UNSORTED_GPU_SPECS_INDEX(T, Index)         \
+  template struct UnsortedSegmentFunctor<GPUDevice, T, Index, \
+                                         functor::Zero<T>, functor::Sum>;
 
 #define DEFINE_REAL_GPU_SPECS(T)                  \
   DEFINE_REAL_UNSORTED_GPU_SPECS_INDEX(T, int32); \
diff --git a/tensorflow/core/kernels/segment_reduction_ops_impl.h b/tensorflow/core/kernels/segment_reduction_ops_impl.h
index 6ad7bda..fb5e41e 100644
--- a/tensorflow/core/kernels/segment_reduction_ops_impl.h
+++ b/tensorflow/core/kernels/segment_reduction_ops_impl.h
@@ -222,7 +222,7 @@
 //     small. When to use the tiled version or the untiled version depends on
 //     many factors including data alignments, ratio of calculation to memory
 //     traffic and obviously, the problem sizes.
-template <class T, class Index, class SegmentReductionFunctor>
+template <class T, class Index, class SegmentReductionFunctor, bool IsMean>
 class SegmentReductionGPUOp : public AsyncOpKernel {
  public:
   explicit SegmentReductionGPUOp(OpKernelConstruction* context)
@@ -300,6 +300,7 @@
       // for the unsorted segment reduction ops) because the done callback
       // (required for OP_REQUIRES_ASYNC) is not available inside the functor.
       bool determinism_requirement_met =
+          !UseAtomicSegmentReductions() ||
           SegmentReductionFunctor::atomic_reduction_is_associative ||
           !OpDeterminismRequired() ||
           DisableSegmentReductionOpDeterminismExceptions();
@@ -314,8 +315,8 @@
       auto data_ptr = input.template flat<T>().data();
       auto segment_flat = segment_ids.flat<Index>();
       functor_(context, context->eigen_device<GPUDevice>(), output_rows,
-               segment_ids.shape(), segment_flat, input.NumElements(), data_ptr,
-               output_flat);
+               segment_ids.shape(), IsMean, segment_flat, input.NumElements(),
+               data_ptr, output_flat);
 
       done();
     };
diff --git a/tensorflow/core/kernels/segment_reduction_ops_impl_1.cc b/tensorflow/core/kernels/segment_reduction_ops_impl_1.cc
index b9bee23..0407be0 100644
--- a/tensorflow/core/kernels/segment_reduction_ops_impl_1.cc
+++ b/tensorflow/core/kernels/segment_reduction_ops_impl_1.cc
@@ -147,33 +147,37 @@
 #undef REGISTER_COMPLEX_CPU_KERNELS_ALL
 
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
-#define REGISTER_GPU_KERNEL_SORTEDSEGMENT(                                   \
-    name, type, index_type, initial_value_functor, reduction_kernel_functor, \
-    atomic_reduction_kernel_functor)                                         \
-  REGISTER_KERNEL_BUILDER(                                                   \
-      Name(name)                                                             \
-          .Device(DEVICE_GPU)                                                \
-          .TypeConstraint<type>("T")                                         \
-          .TypeConstraint<index_type>("Tindices"),                           \
-      SegmentReductionGPUOp<                                                 \
-          type, index_type,                                                  \
-          functor::SegmentReductionFunctor<                                  \
-              type, index_type, initial_value_functor,                       \
-              reduction_kernel_functor, atomic_reduction_kernel_functor> >)
+#define REGISTER_GPU_KERNEL_SORTEDSEGMENT(                            \
+    name, type, index_type, initial_value_functor,                    \
+    empty_segment_value_functor, reduction_kernel_functor, is_mean)   \
+  REGISTER_KERNEL_BUILDER(                                            \
+      Name(name)                                                      \
+          .Device(DEVICE_GPU)                                         \
+          .TypeConstraint<type>("T")                                  \
+          .TypeConstraint<index_type>("Tindices"),                    \
+      SegmentReductionGPUOp<                                          \
+          type, index_type,                                           \
+          functor::SegmentReductionFunctor<                           \
+              type, index_type, initial_value_functor,                \
+              empty_segment_value_functor, reduction_kernel_functor>, \
+          is_mean>)
 
-#define REGISTER_GPU_SORTED_KERNELS(type, index_type)                     \
-  REGISTER_GPU_KERNEL_SORTEDSEGMENT(                                      \
-      "SegmentSum", type, index_type, functor::Zero<type>,                \
-      functor::NonAtomicSumOpGpu<type>, functor::AtomicSumOpGpu<type>);   \
-  REGISTER_GPU_KERNEL_SORTEDSEGMENT(                                      \
-      "SegmentProd", type, index_type, functor::One<type>,                \
-      functor::NonAtomicProdOpGpu<type>, functor::AtomicProdOpGpu<type>); \
-  REGISTER_GPU_KERNEL_SORTEDSEGMENT(                                      \
-      "SegmentMin", type, index_type, functor::Highest<type>,             \
-      functor::NonAtomicMinOpGpu<type>, functor::AtomicMinOpGpu<type>);   \
-  REGISTER_GPU_KERNEL_SORTEDSEGMENT(                                      \
-      "SegmentMax", type, index_type, functor::Lowest<type>,              \
-      functor::NonAtomicMaxOpGpu<type>, functor::AtomicMaxOpGpu<type>);
+#define REGISTER_GPU_SORTED_KERNELS(type, index_type)                         \
+  REGISTER_GPU_KERNEL_SORTEDSEGMENT("SegmentSum", type, index_type,           \
+                                    functor::Zero<type>, functor::Zero<type>, \
+                                    functor::Sum, /*is_mean=*/false);         \
+  REGISTER_GPU_KERNEL_SORTEDSEGMENT("SegmentMean", type, index_type,          \
+                                    functor::Zero<type>, functor::Zero<type>, \
+                                    functor::Sum, /*is_mean=*/true);          \
+  REGISTER_GPU_KERNEL_SORTEDSEGMENT("SegmentProd", type, index_type,          \
+                                    functor::One<type>, functor::One<type>,   \
+                                    functor::Prod, /*is_mean=*/false);        \
+  REGISTER_GPU_KERNEL_SORTEDSEGMENT(                                          \
+      "SegmentMin", type, index_type, functor::Highest<type>,                 \
+      functor::Zero<type>, functor::Min, /*is_mean=*/false);                  \
+  REGISTER_GPU_KERNEL_SORTEDSEGMENT(                                          \
+      "SegmentMax", type, index_type, functor::Lowest<type>,                  \
+      functor::Zero<type>, functor::Max, /*is_mean=*/false);
 
 #define REGISTER_GPU_SORTED_KERNELS_ALL(type) \
   REGISTER_GPU_SORTED_KERNELS(type, int32)
diff --git a/tensorflow/core/kernels/segment_reduction_ops_impl_2.cc b/tensorflow/core/kernels/segment_reduction_ops_impl_2.cc
index 0761a64..cb6c0cf 100644
--- a/tensorflow/core/kernels/segment_reduction_ops_impl_2.cc
+++ b/tensorflow/core/kernels/segment_reduction_ops_impl_2.cc
@@ -63,33 +63,37 @@
 #undef REGISTER_COMPLEX_CPU_KERNELS_ALL
 
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
-#define REGISTER_GPU_KERNEL_SORTEDSEGMENT(                                   \
-    name, type, index_type, initial_value_functor, reduction_kernel_functor, \
-    atomic_reduction_kernel_functor)                                         \
-  REGISTER_KERNEL_BUILDER(                                                   \
-      Name(name)                                                             \
-          .Device(DEVICE_GPU)                                                \
-          .TypeConstraint<type>("T")                                         \
-          .TypeConstraint<index_type>("Tindices"),                           \
-      SegmentReductionGPUOp<                                                 \
-          type, index_type,                                                  \
-          functor::SegmentReductionFunctor<                                  \
-              type, index_type, initial_value_functor,                       \
-              reduction_kernel_functor, atomic_reduction_kernel_functor> >)
+#define REGISTER_GPU_KERNEL_SORTEDSEGMENT(                            \
+    name, type, index_type, initial_value_functor,                    \
+    empty_segment_value_functor, reduction_kernel_functor, is_mean)   \
+  REGISTER_KERNEL_BUILDER(                                            \
+      Name(name)                                                      \
+          .Device(DEVICE_GPU)                                         \
+          .TypeConstraint<type>("T")                                  \
+          .TypeConstraint<index_type>("Tindices"),                    \
+      SegmentReductionGPUOp<                                          \
+          type, index_type,                                           \
+          functor::SegmentReductionFunctor<                           \
+              type, index_type, initial_value_functor,                \
+              empty_segment_value_functor, reduction_kernel_functor>, \
+          is_mean>)
 
-#define REGISTER_GPU_SORTED_KERNELS(type, index_type)                     \
-  REGISTER_GPU_KERNEL_SORTEDSEGMENT(                                      \
-      "SegmentSum", type, index_type, functor::Zero<type>,                \
-      functor::NonAtomicSumOpGpu<type>, functor::AtomicSumOpGpu<type>);   \
-  REGISTER_GPU_KERNEL_SORTEDSEGMENT(                                      \
-      "SegmentProd", type, index_type, functor::One<type>,                \
-      functor::NonAtomicProdOpGpu<type>, functor::AtomicProdOpGpu<type>); \
-  REGISTER_GPU_KERNEL_SORTEDSEGMENT(                                      \
-      "SegmentMin", type, index_type, functor::Highest<type>,             \
-      functor::NonAtomicMinOpGpu<type>, functor::AtomicMinOpGpu<type>);   \
-  REGISTER_GPU_KERNEL_SORTEDSEGMENT(                                      \
-      "SegmentMax", type, index_type, functor::Lowest<type>,              \
-      functor::NonAtomicMaxOpGpu<type>, functor::AtomicMaxOpGpu<type>);
+#define REGISTER_GPU_SORTED_KERNELS(type, index_type)                         \
+  REGISTER_GPU_KERNEL_SORTEDSEGMENT("SegmentSum", type, index_type,           \
+                                    functor::Zero<type>, functor::Zero<type>, \
+                                    functor::Sum, /*is_mean=*/false);         \
+  REGISTER_GPU_KERNEL_SORTEDSEGMENT("SegmentMean", type, index_type,          \
+                                    functor::Zero<type>, functor::Zero<type>, \
+                                    functor::Sum, /*is_mean=*/true);          \
+  REGISTER_GPU_KERNEL_SORTEDSEGMENT("SegmentProd", type, index_type,          \
+                                    functor::One<type>, functor::One<type>,   \
+                                    functor::Prod, /*is_mean=*/false);        \
+  REGISTER_GPU_KERNEL_SORTEDSEGMENT(                                          \
+      "SegmentMin", type, index_type, functor::Highest<type>,                 \
+      functor::Zero<type>, functor::Min, /*is_mean=*/false);                  \
+  REGISTER_GPU_KERNEL_SORTEDSEGMENT(                                          \
+      "SegmentMax", type, index_type, functor::Lowest<type>,                  \
+      functor::Zero<type>, functor::Max, /*is_mean=*/false);
 
 #define REGISTER_GPU_SORTED_KERNELS_ALL(type) \
   REGISTER_GPU_SORTED_KERNELS(type, int64_t);
diff --git a/tensorflow/core/kernels/segment_reduction_ops_impl_3.cc b/tensorflow/core/kernels/segment_reduction_ops_impl_3.cc
index c809a5e..be5e94d 100644
--- a/tensorflow/core/kernels/segment_reduction_ops_impl_3.cc
+++ b/tensorflow/core/kernels/segment_reduction_ops_impl_3.cc
@@ -87,19 +87,15 @@
 // sum is the only op that supports all input types currently
 #define REGISTER_REAL_GPU_UNSORTED_KERNELS(type, index_type)                   \
   REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMax", type, index_type,  \
-                                      functor::Lowest<type>,                   \
-                                      functor::AtomicMaxOpGpu<type>);          \
+                                      functor::Lowest<type>, functor::Max);    \
   REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMin", type, index_type,  \
-                                      functor::Highest<type>,                  \
-                                      functor::AtomicMinOpGpu<type>);          \
+                                      functor::Highest<type>, functor::Min);   \
   REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentProd", type, index_type, \
-                                      functor::One<type>,                      \
-                                      functor::AtomicProdOpGpu<type>);
+                                      functor::One<type>, functor::Prod);
 
 #define REGISTER_SUM_GPU_UNSORTED_KERNELS(type, index_type)                   \
   REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentSum", type, index_type, \
-                                      functor::Zero<type>,                    \
-                                      functor::AtomicSumOpGpu<type>);
+                                      functor::Zero<type>, functor::Sum);
 
 #define REGISTER_REAL_GPU_UNSORTED_KERNELS_ALL(type) \
   REGISTER_REAL_GPU_UNSORTED_KERNELS(type, int32)
diff --git a/tensorflow/core/kernels/segment_reduction_ops_impl_4.cc b/tensorflow/core/kernels/segment_reduction_ops_impl_4.cc
index fb09881..7913aef 100644
--- a/tensorflow/core/kernels/segment_reduction_ops_impl_4.cc
+++ b/tensorflow/core/kernels/segment_reduction_ops_impl_4.cc
@@ -87,19 +87,15 @@
 // sum is the only op that supports all input types currently
 #define REGISTER_REAL_GPU_UNSORTED_KERNELS(type, index_type)                   \
   REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMax", type, index_type,  \
-                                      functor::Lowest<type>,                   \
-                                      functor::AtomicMaxOpGpu<type>);          \
+                                      functor::Lowest<type>, functor::Max);    \
   REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMin", type, index_type,  \
-                                      functor::Highest<type>,                  \
-                                      functor::AtomicMinOpGpu<type>);          \
+                                      functor::Highest<type>, functor::Min);   \
   REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentProd", type, index_type, \
-                                      functor::One<type>,                      \
-                                      functor::AtomicProdOpGpu<type>);
+                                      functor::One<type>, functor::Prod);
 
 #define REGISTER_SUM_GPU_UNSORTED_KERNELS(type, index_type)                   \
   REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentSum", type, index_type, \
-                                      functor::Zero<type>,                    \
-                                      functor::AtomicSumOpGpu<type>);
+                                      functor::Zero<type>, functor::Sum);
 
 #define REGISTER_REAL_GPU_UNSORTED_KERNELS_ALL(type) \
   REGISTER_REAL_GPU_UNSORTED_KERNELS(type, int64_t)
diff --git a/tensorflow/core/util/gpu_kernel_helper.h b/tensorflow/core/util/gpu_kernel_helper.h
index 8604d7c..a5ae09b 100644
--- a/tensorflow/core/util/gpu_kernel_helper.h
+++ b/tensorflow/core/util/gpu_kernel_helper.h
@@ -275,6 +275,19 @@
   DEFINE_BINARY_OPERATOR(/)
 #undef DEFINE_BINARY_OPERATOR
 
+#define DEFINE_BINARY_FUNCTION(func)                                        \
+  friend __host__ __device__ AlignedVector func(const AlignedVector& lhs,   \
+                                                const AlignedVector& rhs) { \
+    AlignedVector ret;                                                      \
+    UNROLL_ON_DEVICE for (int i = 0; i < kSize; ++i) {                      \
+      ret[i] = func(lhs[i], rhs[i]);                                        \
+    }                                                                       \
+    return ret;                                                             \
+  }
+  DEFINE_BINARY_FUNCTION(min)
+  DEFINE_BINARY_FUNCTION(max)
+#undef DEFINE_BINARY_FUNCTION
+
  private:
   value_type values_[N];
 };
diff --git a/tensorflow/python/kernel_tests/segment_reduction_ops_deterministic_test.py b/tensorflow/python/kernel_tests/segment_reduction_ops_deterministic_test.py
index dceea7d..c2b1f05 100644
--- a/tensorflow/python/kernel_tests/segment_reduction_ops_deterministic_test.py
+++ b/tensorflow/python/kernel_tests/segment_reduction_ops_deterministic_test.py
@@ -18,6 +18,8 @@
 from __future__ import division
 from __future__ import print_function
 
+import os
+
 from tensorflow.python.eager import backprop
 from tensorflow.python.framework import config
 from tensorflow.python.framework import constant_op
@@ -32,6 +34,10 @@
 from tensorflow.python.platform import test
 
 
+def UsingAtomicSegmentReductions():
+  return bool(int(os.getenv("TF_USE_ATOMIC_SEGMENT_REDUCTIONS", "0")))
+
+
 class SegmentReductionDeterminismExceptionsTest(test.TestCase):
   """Test d9m-unimplemented exceptions from the segment reduction ops.
 
@@ -63,7 +69,7 @@
         for data_type in [dtypes.float16, dtypes.float32, dtypes.float64]:
           with self.cached_session(force_gpu=True):
             data, segment_ids, _ = self._input(data_type, segment_ids_type)
-            if should_throw_for_float:
+            if UsingAtomicSegmentReductions() and should_throw_for_float:
               with self.assertRaisesRegex(
                   errors_impl.UnimplementedError,
                   "Deterministic GPU implementation of sorted segment " +
@@ -99,7 +105,8 @@
               continue
             data, segment_ids, num_segments = self._input(
                 data_type, segment_ids_type)
-            if (data_type != dtypes.int32) and should_throw_for_float:
+            if (UsingAtomicSegmentReductions() and (data_type != dtypes.int32)
+                and should_throw_for_float):
               with self.assertRaisesRegex(errors_impl.UnimplementedError,
                                           self._UNSORTED_ERROR_MESSAGE):
                 result = op(data, segment_ids, num_segments)
@@ -121,8 +128,12 @@
           with self.cached_session(force_gpu=True):
             data, segment_ids, num_segments = self._input(
                 data_type, segment_ids_type)
-            with self.assertRaisesRegex(errors_impl.UnimplementedError,
-                                        self._UNSORTED_ERROR_MESSAGE):
+            if UsingAtomicSegmentReductions():
+              with self.assertRaisesRegex(errors_impl.UnimplementedError,
+                                          self._UNSORTED_ERROR_MESSAGE):
+                result = op(data, segment_ids, num_segments)
+                self.evaluate(result)
+            else:
               result = op(data, segment_ids, num_segments)
               self.evaluate(result)
 
@@ -138,9 +149,13 @@
           values, indices, _ = self._input(data_type, segment_ids_type)
           sparse_value = indexed_slices.IndexedSlices(
               values, indices, dense_shape=values.shape)
-          with self.assertRaisesRegex(errors_impl.UnimplementedError,
-                                      self._UNSORTED_ERROR_MESSAGE):
-            # convert_to_tensor with IndexedSlices uses unsorted_segment_sum
+          if UsingAtomicSegmentReductions():
+            with self.assertRaisesRegex(errors_impl.UnimplementedError,
+                                        self._UNSORTED_ERROR_MESSAGE):
+              # convert_to_tensor with IndexedSlices uses unsorted_segment_sum
+              result = ops.convert_to_tensor(sparse_value)
+              self.evaluate(result)
+          else:
             result = ops.convert_to_tensor(sparse_value)
             self.evaluate(result)
 
@@ -158,9 +173,12 @@
             tape.watch(params)
             op_output = array_ops.gather(params, indices)
           gradient = tape.gradient(op_output, params)
-          with self.assertRaisesRegex(errors_impl.UnimplementedError,
-                                      self._UNSORTED_ERROR_MESSAGE):
-            # convert_to_tensor on IndexedSlices
+          if UsingAtomicSegmentReductions():
+            with self.assertRaisesRegex(errors_impl.UnimplementedError,
+                                        self._UNSORTED_ERROR_MESSAGE):
+              # convert_to_tensor on IndexedSlices
+              self.evaluate(params.assign(gradient))
+          else:
             self.evaluate(params.assign(gradient))