Use new GPU kernel for [unsorted] segment reductions
- Optionally replaces the old atomics-based kernels with calls to
SegmentReduceGPU (the same kernel already used for sparse segment
reductions). This behavior is enabled by default, but the old kernels
can be re-enabled by setting the environment variable
TF_USE_ATOMIC_SEGMENT_REDUCTIONS=1. On Windows, the old kernels are
always used due to a build issue with the new kernel.
- This improves performance, and guarantees that these ops are
deterministic. In future it is hoped that the old kernels can be
removed completely.
- Also adds a GPU kernel registration for SegmentMean, which didn't
previously exist.
diff --git a/tensorflow/core/kernels/segment_reduction_ops.h b/tensorflow/core/kernels/segment_reduction_ops.h
index 8faf8ac..49035bf 100644
--- a/tensorflow/core/kernels/segment_reduction_ops.h
+++ b/tensorflow/core/kernels/segment_reduction_ops.h
@@ -16,13 +16,6 @@
#ifndef TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_
#define TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_
-// This file requires the following include because it uses GpuAtomicMax:
-// #include "tensorflow/core/util/gpu_kernel_helper.h"
-
-// Unfortunately we can't add the #include, since it breaks compilation for
-// non-GPU targets. This only breaks in clang, because it's more strict for
-// template code and GpuAtomicMax is used in template context.
-
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
@@ -32,6 +25,7 @@
class OpKernelContext;
+bool UseAtomicSegmentReductions();
bool DisableSegmentReductionOpDeterminismExceptions();
// Type of SparseSegmentReduction operation to perform gradient of.
@@ -40,9 +34,51 @@
namespace functor {
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+
+// Note that we define this ourselves to avoid a dependency on gpuprim.
+struct Sum {
+ template <typename T>
+ __host__ __device__ T operator()(const T& a, const T& b) const {
+ return a + b;
+ }
+};
+
+struct Prod {
+ template <typename T>
+ __host__ __device__ T operator()(const T& a, const T& b) const {
+ return a * b;
+ }
+};
+
+// Note that we don't use gpuprim::Min/Max because they use operator<, which is
+// not implemented for AlignedVector types.
+struct Min {
+ template <typename T>
+ __host__ __device__ T operator()(const T& a, const T& b) const {
+ return min(a, b);
+ }
+};
+
+struct Max {
+ template <typename T>
+ __host__ __device__ T operator()(const T& a, const T& b) const {
+ return max(a, b);
+ }
+};
+
+template <typename ReduceOp, typename T>
+struct ReduceOpIsAssociative {};
+template <typename T>
+struct ReduceOpIsAssociative<functor::Sum, T> : std::is_integral<T> {};
+template <typename T>
+struct ReduceOpIsAssociative<functor::Prod, T> : std::is_integral<T> {};
+template <typename T>
+struct ReduceOpIsAssociative<functor::Max, T> : std::true_type {};
+template <typename T>
+struct ReduceOpIsAssociative<functor::Min, T> : std::true_type {};
+
typedef Eigen::GpuDevice GPUDevice;
-// Functor for SegmentSumGPUOp & SegmentProdGPUOp & SegmentMaxGPUOp
-// & SegmentMinGPUOp.
+// Functor for SegmentReductionGPUOp.
// output_rows: the number of output segments (unique segment ids in
// 'segment_ids').
// segment_ids_shape: shape of 'segment_ids' tensor.
@@ -52,18 +88,18 @@
// data: input data tensor.
// output: output reshaped to {output_rows, output.size/output_rows}
template <typename T, typename Index, typename InitialValueF,
- typename ReductionF, typename AtomicReductionF>
+ typename EmptySegmentValueF, typename ReductionF>
struct SegmentReductionFunctor {
void operator()(OpKernelContext* ctx, const GPUDevice& d,
const Index output_rows, const TensorShape& segment_ids_shape,
- typename TTypes<Index>::ConstFlat segment_ids,
+ bool is_mean, typename TTypes<Index>::ConstFlat segment_ids,
const Index data_size, const T* data,
typename TTypes<T, 2>::Tensor output);
static constexpr bool atomic_reduction_is_associative =
- AtomicReductionF::is_associative;
+ ReduceOpIsAssociative<ReductionF, T>::value;
};
-#endif
+#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
template <typename Device, typename T, typename Index, typename InitialValueF,
typename ReductionF>
@@ -74,80 +110,6 @@
typename TTypes<T, 2>::Tensor output);
};
-#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
-
-// Atomic reduction functors for the gpu.
-template <typename T>
-struct AtomicSumOpGpu {
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest,
- const T& value) {
- GpuAtomicAdd(dest, value);
- }
- static constexpr bool is_associative = std::is_integral<T>::value;
-};
-
-template <typename T>
-struct AtomicProdOpGpu {
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest,
- const T& value) {
- GpuAtomicMul(dest, value);
- }
- static constexpr bool is_associative = std::is_integral<T>::value;
-};
-
-template <typename T>
-struct AtomicMaxOpGpu {
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest,
- const T& value) {
- GpuAtomicMax(dest, value);
- }
- static constexpr bool is_associative = true;
-};
-
-template <typename T>
-struct AtomicMinOpGpu {
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest,
- const T& value) {
- GpuAtomicMin(dest, value);
- }
- static constexpr bool is_associative = true;
-};
-
-// Non-atomic reduction functors for the gpu.
-template <typename T>
-struct NonAtomicSumOpGpu {
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest,
- const T& value) {
- *dest += value;
- }
-};
-
-template <typename T>
-struct NonAtomicProdOpGpu {
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest,
- const T& value) {
- *dest *= value;
- }
-};
-
-template <typename T>
-struct NonAtomicMaxOpGpu {
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest,
- const T& value) {
- *dest = max(*dest, value);
- }
-};
-
-template <typename T>
-struct NonAtomicMinOpGpu {
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest,
- const T& value) {
- *dest = min(*dest, value);
- }
-};
-
-#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
-
// Initial value functors.
template <typename T>
struct Zero {
diff --git a/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.h b/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.h
index 19805df..2e75683 100644
--- a/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.h
+++ b/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.h
@@ -17,12 +17,7 @@
#define EIGEN_USE_GPU
-// We need to include gpu_kernel_helper.h before segment_reduction_ops.h
-// See comment in segment_reduction_ops.h for more details.
-// clang-format off
-#include "tensorflow/core/util/gpu_kernel_helper.h"
-// clang-format on
-
+#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/kernels/gpu_prim.h"
#include "tensorflow/core/kernels/gpu_prim_helpers.h"
#include "tensorflow/core/kernels/segment_reduction_ops.h"
@@ -30,12 +25,48 @@
#include "tensorflow/core/util/determinism.h"
#include "tensorflow/core/util/env_var.h"
#include "tensorflow/core/util/gpu_device_functions.h"
+#include "tensorflow/core/util/gpu_kernel_helper.h"
#include "tensorflow/core/util/permutation_input_iterator.h"
namespace tensorflow {
using GPUDevice = Eigen::GpuDevice;
+// Non/Atomic reduction functors for the gpu.
+#define DEFINE_REDUCE_UPDATE_OP_GPU(name, func) \
+ struct name##OpGpu { \
+ template <typename T> \
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest, \
+ const T& value) { \
+ func; \
+ } \
+ };
+DEFINE_REDUCE_UPDATE_OP_GPU(AtomicSum, GpuAtomicAdd(dest, value))
+DEFINE_REDUCE_UPDATE_OP_GPU(AtomicProd, GpuAtomicMul(dest, value))
+DEFINE_REDUCE_UPDATE_OP_GPU(AtomicMax, GpuAtomicMax(dest, value))
+DEFINE_REDUCE_UPDATE_OP_GPU(AtomicMin, GpuAtomicMin(dest, value))
+DEFINE_REDUCE_UPDATE_OP_GPU(NonAtomicSum, *dest += value)
+DEFINE_REDUCE_UPDATE_OP_GPU(NonAtomicProd, *dest *= value)
+DEFINE_REDUCE_UPDATE_OP_GPU(NonAtomicMax, *dest = max(*dest, value))
+DEFINE_REDUCE_UPDATE_OP_GPU(NonAtomicMin, *dest = min(*dest, value))
+#undef DEFINE_REDUCE_UPDATE_OP_GPU
+
+template <typename ReduceOp>
+struct ReduceUpdateOpFor {
+};
+
+#define DEFINE_REDUCE_UPDATE_OP_FOR(reduce_op, atomic, nonatomic) \
+ template <> \
+ struct ReduceUpdateOpFor<reduce_op> { \
+ using atomic_op = atomic; \
+ using nonatomic_op = nonatomic; \
+ };
+DEFINE_REDUCE_UPDATE_OP_FOR(functor::Sum, AtomicSumOpGpu, NonAtomicSumOpGpu)
+DEFINE_REDUCE_UPDATE_OP_FOR(functor::Prod, AtomicProdOpGpu, NonAtomicProdOpGpu)
+DEFINE_REDUCE_UPDATE_OP_FOR(functor::Max, AtomicMaxOpGpu, NonAtomicMaxOpGpu)
+DEFINE_REDUCE_UPDATE_OP_FOR(functor::Min, AtomicMinOpGpu, NonAtomicMinOpGpu)
+#undef DEFINE_REDUCE_UPDATE_OP_FOR
+
// SortedSegmentReductionFunctor kernel reduces input data just as
// UnsortedSegmentReductionCustomKernel does except that input data
// is partitioned along the outer reduction dimension. This is
@@ -108,6 +139,35 @@
}
}
+template <typename SegmentId, typename Index, typename T>
+__global__ void SegmentMeanNormalizeKernel(
+ SegmentId nsegments, Index ninner,
+ const Index* __restrict__ segment_offsets, // [nsegments + 1]
+ T* __restrict__ output) { // [nsegments, ninner]
+ for(SegmentId seg : GpuGridRangeY(nsegments)) {
+ SegmentId segment_size = segment_offsets[seg + 1] - segment_offsets[seg];
+ segment_size = max(segment_size, Index(1)); // Avoid division by zero
+ T inv_norm = T(1) / static_cast<T>(segment_size);
+ for(Index i : GpuGridRangeX(ninner)) {
+ output[seg * ninner + i] *= inv_norm;
+ }
+ }
+}
+
+template <typename SegmentId, typename Index, typename T>
+Status LaunchSegmentMeanNormalizeKernel(
+ const GPUDevice& d, SegmentId nsegments, Index ninner,
+ const Index* __restrict__ segment_offsets, // [nsegments + 1]
+ T* __restrict__ output) { // [nsegments, ninner]
+ Gpu2DLaunchConfig config = GetGpu2DLaunchConfig(
+ ninner, nsegments, d, SegmentMeanNormalizeKernel<SegmentId, Index, T>,
+ /*dynamic_shared_memory_size=*/0, /*block_size_limit=*/0);
+ return GpuLaunchKernel(SegmentMeanNormalizeKernel<SegmentId, Index, T>,
+ config.block_count, config.thread_per_block, 0,
+ d.stream(), nsegments, ninner, segment_offsets,
+ output);
+}
+
// UnsortedSegmentSumKernel processes 'input_total_size' elements.
// Each element is mapped from input to output by a combination of its
// 'segment_ids' mapping and 'inner_dim_size'.
@@ -623,37 +683,27 @@
// Sum fp16 values using an fp32 accumulator to avoid numerical issues.
template <>
-struct ReduceType<gpuprim::Sum, Eigen::half> {
+struct ReduceType<functor::Sum, Eigen::half> {
using type = float;
};
namespace functor {
template <typename T, typename Index, typename InitialValueF,
- typename ReductionF, typename AtomicReductionF>
+ typename EmptySegmentValueF, typename ReductionF>
void SegmentReductionFunctor<
- T, Index, InitialValueF, ReductionF,
- AtomicReductionF>::operator()(OpKernelContext* ctx, const GPUDevice& d,
- const Index output_rows,
- const TensorShape& segment_ids_shape,
- typename TTypes<Index>::ConstFlat segment_ids,
- const Index data_size, const T* data,
- typename TTypes<T, 2>::Tensor output) {
+ T, Index, InitialValueF, EmptySegmentValueF,
+ ReductionF>::operator()(OpKernelContext* ctx, const GPUDevice& d,
+ const Index output_rows,
+ const TensorShape& segment_ids_shape, bool is_mean,
+ typename TTypes<Index>::ConstFlat segment_ids,
+ const Index data_size, const T* data,
+ typename TTypes<T, 2>::Tensor output) {
if (output.size() == 0) {
return;
}
- // Set 'output' to initial value.
- GpuLaunchConfig config = GetGpuLaunchConfig(output.size(), d);
- const T InitialValue = InitialValueF()();
- TF_CHECK_OK(GpuLaunchKernel(SetToValue<T>, config.block_count,
- config.thread_per_block, 0, d.stream(),
- output.size(), output.data(), InitialValue));
- if (data_size == 0 || segment_ids_shape.num_elements() == 0) {
- return;
- }
-
- // Launch kernel to compute sorted segment reduction.
+ // Launch kernel(s) to compute sorted segment reduction.
// Notes:
// *) 'input_total_size' is the total number of elements to process.
// *) 'segment_ids.shape' is a prefix of data's shape.
@@ -661,30 +711,84 @@
const Index input_total_size = data_size;
const Index input_outer_dim_size = segment_ids.dimension(0);
const Index input_inner_dim_size = input_total_size / input_outer_dim_size;
+ const Index num_segments = output.size() / input_inner_dim_size;
- const int OuterDimTileSize = 8;
+ // TODO(benbarsdell): If there are no performance concerns with the new
+ // non-atomic kernels, remove this runtime check and only compile the old
+ // atomic kernels on Windows (as a workaround for the build failure issue).
+ if (UseAtomicSegmentReductions()) {
+ // Set 'output' to initial value.
+ GpuLaunchConfig config = GetGpuLaunchConfig(output.size(), d);
+ const T InitialValue = InitialValueF()();
+ TF_CHECK_OK(GpuLaunchKernel(SetToValue<T>, config.block_count,
+ config.thread_per_block, 0, d.stream(),
+ output.size(), output.data(), InitialValue));
+ if (data_size == 0 || segment_ids_shape.num_elements() == 0) {
+ return;
+ }
- const Index input_outer_dim_num_stripe =
- Eigen::divup(input_outer_dim_size, Index(OuterDimTileSize));
+ const int OuterDimTileSize = 8;
- const Index total_stripe_count =
- input_inner_dim_size * input_outer_dim_num_stripe;
+ const Index input_outer_dim_num_stripe =
+ Eigen::divup(input_outer_dim_size, Index(OuterDimTileSize));
- config = GetGpuLaunchConfig(total_stripe_count, d);
- TF_CHECK_OK(GpuLaunchKernel(
- SortedSegmentReductionCustomKernel<T, Index, OuterDimTileSize, ReductionF,
- AtomicReductionF>,
- config.block_count, config.thread_per_block, 0, d.stream(),
- input_outer_dim_size, input_inner_dim_size, output_rows,
- segment_ids.data(), data, output.data(), total_stripe_count,
- InitialValue));
+ const Index total_stripe_count =
+ input_inner_dim_size * input_outer_dim_num_stripe;
+
+ config = GetGpuLaunchConfig(total_stripe_count, d);
+ TF_CHECK_OK(GpuLaunchKernel(
+ SortedSegmentReductionCustomKernel<
+ T, Index, OuterDimTileSize,
+ typename ReduceUpdateOpFor<ReductionF>::nonatomic_op,
+ typename ReduceUpdateOpFor<ReductionF>::atomic_op>,
+ config.block_count, config.thread_per_block, 0, d.stream(),
+ input_outer_dim_size, input_inner_dim_size, output_rows,
+ segment_ids.data(), data, output.data(), total_stripe_count,
+ InitialValue));
+
+ if (is_mean) {
+ Tensor segment_offsets;
+ OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<Index>::value,
+ TensorShape({num_segments + 1}),
+ &segment_offsets));
+ Index* segment_offsets_ptr = segment_offsets.flat<Index>().data();
+ OP_REQUIRES_OK(ctx, LaunchSegmentOffsetsKernel(
+ d, input_outer_dim_size, num_segments,
+ segment_ids.data(), segment_offsets_ptr));
+
+ OP_REQUIRES_OK(ctx, LaunchSegmentMeanNormalizeKernel(
+ d, num_segments, input_inner_dim_size,
+ segment_offsets_ptr, output.data()));
+ }
+ } else {
+ // See comment in segment_reduction_ops_gpu_0.cu.cc regarding Windows CI
+ // build error.
+#if !defined(PLATFORM_WINDOWS)
+ using Treduce = typename ReduceType<ReductionF, T>::type;
+ OP_REQUIRES_OK(
+ ctx,
+ SegmentReduceGPU<Treduce>(
+ ctx, input_outer_dim_size, input_inner_dim_size, num_segments,
+ ReductionF(), InitialValueF()(), EmptySegmentValueF()(),
+ /*is_mean=*/is_mean, /*is_sqrtn=*/false, data, segment_ids.data(),
+ /*indices=*/static_cast<const Index*>(nullptr),
+ /*weights=*/static_cast<T*>(nullptr), output.data()));
+#else
+ // Note: Shouldn't reach here because UseAtomicSegmentReductions() always
+ // returns true on Windows.
+ OP_REQUIRES(
+ ctx, false,
+ errors::Unimplemented(
+ "Non-atomic segment reductions are not implemented on Windows."));
+#endif
+ }
}
template <typename T, typename Index, typename InitialValueF,
typename ReductionF>
struct UnsortedSegmentFunctor<GPUDevice, T, Index, InitialValueF, ReductionF> {
void operator()(OpKernelContext* ctx, const TensorShape& segment_ids_shape,
- typename TTypes<Index>::ConstFlat segment_ids,
+ typename TTypes<Index>::ConstFlat unsorted_segment_ids,
typename TTypes<T, 2>::ConstTensor data,
typename TTypes<T, 2>::Tensor output) {
if (output.size() == 0) {
@@ -692,7 +796,9 @@
}
bool determinism_requirement_met =
- ReductionF::is_associative || !OpDeterminismRequired() ||
+ !UseAtomicSegmentReductions() ||
+ ReduceOpIsAssociative<ReductionF, T>::value ||
+ !OpDeterminismRequired() ||
DisableSegmentReductionOpDeterminismExceptions();
OP_REQUIRES(
ctx, determinism_requirement_met,
@@ -700,31 +806,84 @@
"Deterministic GPU implementation of unsorted segment reduction op"
" not available."));
- // Set 'output' to initial value.
- GPUDevice d = ctx->template eigen_device<GPUDevice>();
- GpuLaunchConfig config = GetGpuLaunchConfig(output.size(), d);
- TF_CHECK_OK(GpuLaunchKernel(
- SetToValue<T>, config.block_count, config.thread_per_block, 0,
- d.stream(), output.size(), output.data(), InitialValueF()()));
- const int64_t data_size = data.size();
- if (data_size == 0 || segment_ids_shape.num_elements() == 0) {
- return;
- }
- // Launch kernel to compute unsorted segment reduction.
+ // Launch kernel(s) to compute unsorted segment reduction.
// Notes:
// *) 'data_size' is the total number of elements to process.
// *) 'segment_ids.shape' is a prefix of data's shape.
// *) 'input_outer_dim_size' is the total number of segments to process.
- const int64_t input_outer_dim_size = segment_ids.dimension(0);
- const int64_t input_inner_dim_size = data.dimension(1);
- const int64_t output_outer_dim_size = output.dimension(0);
- config = GetGpuLaunchConfig(data_size, d);
+ const Index input_outer_dim_size = unsorted_segment_ids.dimension(0);
+ const Index input_inner_dim_size = data.dimension(1);
+ const Index output_outer_dim_size = output.dimension(0);
+ const Index num_segments = output.size() / input_inner_dim_size;
- TF_CHECK_OK(GpuLaunchKernel(
- UnsortedSegmentCustomKernel<T, Index, ReductionF>, config.block_count,
- config.thread_per_block, 0, d.stream(), input_outer_dim_size,
- input_inner_dim_size, output_outer_dim_size, segment_ids.data(),
- data.data(), output.data()));
+ // TODO(benbarsdell): If there are no performance concerns with the new
+ // non-atomic kernels, remove this runtime check and only compile the old
+ // atomic kernels on Windows (as a workaround for the build failure issue).
+ if (UseAtomicSegmentReductions()) {
+ // Set 'output' to initial value.
+ GPUDevice d = ctx->template eigen_device<GPUDevice>();
+ GpuLaunchConfig config = GetGpuLaunchConfig(output.size(), d);
+ TF_CHECK_OK(GpuLaunchKernel(
+ SetToValue<T>, config.block_count, config.thread_per_block, 0,
+ d.stream(), output.size(), output.data(), InitialValueF()()));
+ const int64_t data_size = data.size();
+ if (data_size == 0 || segment_ids_shape.num_elements() == 0) {
+ return;
+ }
+ config = GetGpuLaunchConfig(data_size, d);
+ TF_CHECK_OK(GpuLaunchKernel(
+ UnsortedSegmentCustomKernel<
+ T, Index, typename ReduceUpdateOpFor<ReductionF>::atomic_op>,
+ config.block_count, config.thread_per_block, 0, d.stream(),
+ input_outer_dim_size, input_inner_dim_size, output_outer_dim_size,
+ unsorted_segment_ids.data(), data.data(), output.data()));
+ } else {
+ // See comment in segment_reduction_ops_gpu_0.cu.cc regarding Windows CI
+ // build error.
+#if !defined(PLATFORM_WINDOWS)
+ // Allocate temporary space and sort segment_ids, then call the sorted
+ // implem.
+ Tensor segment_ids;
+ OP_REQUIRES_OK(
+ ctx, ctx->allocate_temp(
+ DataTypeToEnum<Index>::value,
+ TensorShape({static_cast<int64_t>(input_outer_dim_size)}),
+ &segment_ids));
+ Index* segment_ids_ptr = segment_ids.flat<Index>().data();
+ Tensor sorted_indices;
+ OP_REQUIRES_OK(
+ ctx, ctx->allocate_temp(
+ DataTypeToEnum<Index>::value,
+ TensorShape({static_cast<int64_t>(input_outer_dim_size)}),
+ &sorted_indices));
+ Index* sorted_indices_ptr = sorted_indices.flat<Index>().data();
+ // Note: We must sort using all bits here because unsorted_segment_ids
+ // may contain negative values.
+ OP_REQUIRES_OK(
+ ctx, GpuRadixSort(ctx, input_outer_dim_size,
+ /*keys_in=*/unsorted_segment_ids.data(),
+ /*keys_out=*/segment_ids_ptr,
+ /*indices_in=*/static_cast<const Index*>(nullptr),
+ /*indices_out=*/sorted_indices_ptr));
+ using Treduce = typename ReduceType<ReductionF, T>::type;
+ OP_REQUIRES_OK(
+ ctx,
+ SegmentReduceGPU<Treduce>(
+ ctx, input_outer_dim_size, input_inner_dim_size, num_segments,
+ ReductionF(), /*initial_value=*/InitialValueF()(),
+ /*empty_segment_value=*/InitialValueF()(), /*is_mean=*/false,
+ /*is_sqrtn=*/false, /*input=*/data.data(),
+ /*segment_ids=*/segment_ids_ptr, /*indices=*/sorted_indices_ptr,
+ /*weights=*/static_cast<T*>(nullptr), output.data()));
+#else
+ // Note: Shouldn't reach here because UseAtomicSegmentReductions() always
+ // returns true on Windows.
+ OP_REQUIRES(
+ ctx, false,
+ errors::Unimplemented("Non-atomic unsorted segment reductions "
+ "are not implemented on Windows."));
+#endif
+ }
}
};
@@ -735,7 +894,7 @@
typename TTypes<Index>::ConstVec indices,
typename TTypes<SegmentId>::ConstVec segment_ids,
typename TTypes<T, 2>::Tensor output) {
- using ReduceOp = gpuprim::Sum;
+ using ReduceOp = functor::Sum;
using Treduce = typename ReduceType<ReduceOp, T>::type;
Index nouter = segment_ids.size();
Index ninner = input.dimension(1);
diff --git a/tensorflow/core/kernels/segment_reduction_ops_gpu_0.cu.cc b/tensorflow/core/kernels/segment_reduction_ops_gpu_0.cu.cc
index f77bda1..f2af012 100644
--- a/tensorflow/core/kernels/segment_reduction_ops_gpu_0.cu.cc
+++ b/tensorflow/core/kernels/segment_reduction_ops_gpu_0.cu.cc
@@ -20,6 +20,22 @@
namespace tensorflow {
+bool UseAtomicSegmentReductions() {
+ // See comment below regarding CI build error on Windows.
+#if !defined(PLATFORM_WINDOWS)
+ static bool cached_result = [] {
+ bool result = false;
+ TF_CHECK_OK(
+ tensorflow::ReadBoolFromEnvVar("TF_USE_ATOMIC_SEGMENT_REDUCTIONS",
+ /*default_val=*/false, &result));
+ return result;
+ }();
+ return cached_result;
+#else
+ return true;
+#endif
+}
+
bool DisableSegmentReductionOpDeterminismExceptions() {
static bool cached_disable = [] {
bool disable = false;
@@ -33,36 +49,32 @@
namespace functor {
-#define DEFINE_SORTED_GPU_SPECS_INDEX(T, Index) \
- template struct SegmentReductionFunctor<T, Index, functor::Zero<T>, \
- functor::NonAtomicSumOpGpu<T>, \
- functor::AtomicSumOpGpu<T>>; \
- template struct SegmentReductionFunctor<T, Index, functor::One<T>, \
- functor::NonAtomicProdOpGpu<T>, \
- functor::AtomicProdOpGpu<T>>; \
- template struct SegmentReductionFunctor<T, Index, functor::Highest<T>, \
- functor::NonAtomicMinOpGpu<T>, \
- functor::AtomicMinOpGpu<T>>; \
- template struct SegmentReductionFunctor<T, Index, functor::Lowest<T>, \
- functor::NonAtomicMaxOpGpu<T>, \
- functor::AtomicMaxOpGpu<T>>;
+#define DEFINE_SORTED_GPU_SPECS_INDEX(T, Index) \
+ template struct SegmentReductionFunctor<T, Index, functor::Zero<T>, \
+ functor::Zero<T>, functor::Sum>; \
+ template struct SegmentReductionFunctor<T, Index, functor::One<T>, \
+ functor::One<T>, functor::Prod>; \
+ template struct SegmentReductionFunctor<T, Index, functor::Highest<T>, \
+ functor::Zero<T>, functor::Min>; \
+ template struct SegmentReductionFunctor<T, Index, functor::Lowest<T>, \
+ functor::Zero<T>, functor::Max>;
#define DEFINE_SORTED_GPU_SPECS(T) DEFINE_SORTED_GPU_SPECS_INDEX(T, int32);
TF_CALL_GPU_NUMBER_TYPES(DEFINE_SORTED_GPU_SPECS);
#define DEFINE_REAL_UNSORTED_GPU_SPECS_INDEX(T, Index) \
- template struct UnsortedSegmentFunctor< \
- GPUDevice, T, Index, functor::Lowest<T>, functor::AtomicMaxOpGpu<T>>; \
- template struct UnsortedSegmentFunctor< \
- GPUDevice, T, Index, functor::Highest<T>, functor::AtomicMinOpGpu<T>>; \
+ template struct UnsortedSegmentFunctor<GPUDevice, T, Index, \
+ functor::Lowest<T>, functor::Max>; \
+ template struct UnsortedSegmentFunctor<GPUDevice, T, Index, \
+ functor::Highest<T>, functor::Min>; \
template struct UnsortedSegmentFunctor<GPUDevice, T, Index, functor::One<T>, \
- functor::AtomicProdOpGpu<T>>;
+ functor::Prod>;
// Sum is the only op that supports all input types currently.
-#define DEFINE_SUM_UNSORTED_GPU_SPECS_INDEX(T, Index) \
- template struct UnsortedSegmentFunctor< \
- GPUDevice, T, Index, functor::Zero<T>, functor::AtomicSumOpGpu<T>>;
+#define DEFINE_SUM_UNSORTED_GPU_SPECS_INDEX(T, Index) \
+ template struct UnsortedSegmentFunctor<GPUDevice, T, Index, \
+ functor::Zero<T>, functor::Sum>;
#define DEFINE_REAL_GPU_SPECS(T) DEFINE_REAL_UNSORTED_GPU_SPECS_INDEX(T, int32);
diff --git a/tensorflow/core/kernels/segment_reduction_ops_gpu_1.cu.cc b/tensorflow/core/kernels/segment_reduction_ops_gpu_1.cu.cc
index a109d74..7705b1f 100644
--- a/tensorflow/core/kernels/segment_reduction_ops_gpu_1.cu.cc
+++ b/tensorflow/core/kernels/segment_reduction_ops_gpu_1.cu.cc
@@ -21,36 +21,32 @@
namespace tensorflow {
namespace functor {
-#define DEFINE_SORTED_GPU_SPECS_INDEX(T, Index) \
- template struct SegmentReductionFunctor<T, Index, functor::Zero<T>, \
- functor::NonAtomicSumOpGpu<T>, \
- functor::AtomicSumOpGpu<T>>; \
- template struct SegmentReductionFunctor<T, Index, functor::One<T>, \
- functor::NonAtomicProdOpGpu<T>, \
- functor::AtomicProdOpGpu<T>>; \
- template struct SegmentReductionFunctor<T, Index, functor::Highest<T>, \
- functor::NonAtomicMinOpGpu<T>, \
- functor::AtomicMinOpGpu<T>>; \
- template struct SegmentReductionFunctor<T, Index, functor::Lowest<T>, \
- functor::NonAtomicMaxOpGpu<T>, \
- functor::AtomicMaxOpGpu<T>>;
+#define DEFINE_SORTED_GPU_SPECS_INDEX(T, Index) \
+ template struct SegmentReductionFunctor<T, Index, functor::Zero<T>, \
+ functor::Zero<T>, functor::Sum>; \
+ template struct SegmentReductionFunctor<T, Index, functor::One<T>, \
+ functor::One<T>, functor::Prod>; \
+ template struct SegmentReductionFunctor<T, Index, functor::Highest<T>, \
+ functor::Zero<T>, functor::Min>; \
+ template struct SegmentReductionFunctor<T, Index, functor::Lowest<T>, \
+ functor::Zero<T>, functor::Max>;
#define DEFINE_SORTED_GPU_SPECS(T) DEFINE_SORTED_GPU_SPECS_INDEX(T, int64_t);
TF_CALL_GPU_NUMBER_TYPES(DEFINE_SORTED_GPU_SPECS);
#define DEFINE_REAL_UNSORTED_GPU_SPECS_INDEX(T, Index) \
- template struct UnsortedSegmentFunctor< \
- GPUDevice, T, Index, functor::Lowest<T>, functor::AtomicMaxOpGpu<T>>; \
- template struct UnsortedSegmentFunctor< \
- GPUDevice, T, Index, functor::Highest<T>, functor::AtomicMinOpGpu<T>>; \
+ template struct UnsortedSegmentFunctor<GPUDevice, T, Index, \
+ functor::Lowest<T>, functor::Max>; \
+ template struct UnsortedSegmentFunctor<GPUDevice, T, Index, \
+ functor::Highest<T>, functor::Min>; \
template struct UnsortedSegmentFunctor<GPUDevice, T, Index, functor::One<T>, \
- functor::AtomicProdOpGpu<T>>;
+ functor::Prod>;
// Sum is the only op that supports all input types currently.
-#define DEFINE_SUM_UNSORTED_GPU_SPECS_INDEX(T, Index) \
- template struct UnsortedSegmentFunctor< \
- GPUDevice, T, Index, functor::Zero<T>, functor::AtomicSumOpGpu<T>>;
+#define DEFINE_SUM_UNSORTED_GPU_SPECS_INDEX(T, Index) \
+ template struct UnsortedSegmentFunctor<GPUDevice, T, Index, \
+ functor::Zero<T>, functor::Sum>;
#define DEFINE_REAL_GPU_SPECS(T) \
DEFINE_REAL_UNSORTED_GPU_SPECS_INDEX(T, int64_t);
diff --git a/tensorflow/core/kernels/segment_reduction_ops_gpu_2.cu.cc b/tensorflow/core/kernels/segment_reduction_ops_gpu_2.cu.cc
index 4ff00e9..45bd46d 100644
--- a/tensorflow/core/kernels/segment_reduction_ops_gpu_2.cu.cc
+++ b/tensorflow/core/kernels/segment_reduction_ops_gpu_2.cu.cc
@@ -22,17 +22,17 @@
namespace functor {
#define DEFINE_REAL_UNSORTED_GPU_SPECS_INDEX(T, Index) \
- template struct UnsortedSegmentFunctor< \
- GPUDevice, T, Index, functor::Lowest<T>, functor::AtomicMaxOpGpu<T>>; \
- template struct UnsortedSegmentFunctor< \
- GPUDevice, T, Index, functor::Highest<T>, functor::AtomicMinOpGpu<T>>; \
+ template struct UnsortedSegmentFunctor<GPUDevice, T, Index, \
+ functor::Lowest<T>, functor::Max>; \
+ template struct UnsortedSegmentFunctor<GPUDevice, T, Index, \
+ functor::Highest<T>, functor::Min>; \
template struct UnsortedSegmentFunctor<GPUDevice, T, Index, functor::One<T>, \
- functor::AtomicProdOpGpu<T>>;
+ functor::Prod>;
// Sum is the only op that supports all input types currently.
-#define DEFINE_SUM_UNSORTED_GPU_SPECS_INDEX(T, Index) \
- template struct UnsortedSegmentFunctor< \
- GPUDevice, T, Index, functor::Zero<T>, functor::AtomicSumOpGpu<T>>;
+#define DEFINE_SUM_UNSORTED_GPU_SPECS_INDEX(T, Index) \
+ template struct UnsortedSegmentFunctor<GPUDevice, T, Index, \
+ functor::Zero<T>, functor::Sum>;
#define DEFINE_REAL_GPU_SPECS(T) \
DEFINE_REAL_UNSORTED_GPU_SPECS_INDEX(T, int32); \
diff --git a/tensorflow/core/kernels/segment_reduction_ops_impl.h b/tensorflow/core/kernels/segment_reduction_ops_impl.h
index 6ad7bda..fb5e41e 100644
--- a/tensorflow/core/kernels/segment_reduction_ops_impl.h
+++ b/tensorflow/core/kernels/segment_reduction_ops_impl.h
@@ -222,7 +222,7 @@
// small. When to use the tiled version or the untiled version depends on
// many factors including data alignments, ratio of calculation to memory
// traffic and obviously, the problem sizes.
-template <class T, class Index, class SegmentReductionFunctor>
+template <class T, class Index, class SegmentReductionFunctor, bool IsMean>
class SegmentReductionGPUOp : public AsyncOpKernel {
public:
explicit SegmentReductionGPUOp(OpKernelConstruction* context)
@@ -300,6 +300,7 @@
// for the unsorted segment reduction ops) because the done callback
// (required for OP_REQUIRES_ASYNC) is not available inside the functor.
bool determinism_requirement_met =
+ !UseAtomicSegmentReductions() ||
SegmentReductionFunctor::atomic_reduction_is_associative ||
!OpDeterminismRequired() ||
DisableSegmentReductionOpDeterminismExceptions();
@@ -314,8 +315,8 @@
auto data_ptr = input.template flat<T>().data();
auto segment_flat = segment_ids.flat<Index>();
functor_(context, context->eigen_device<GPUDevice>(), output_rows,
- segment_ids.shape(), segment_flat, input.NumElements(), data_ptr,
- output_flat);
+ segment_ids.shape(), IsMean, segment_flat, input.NumElements(),
+ data_ptr, output_flat);
done();
};
diff --git a/tensorflow/core/kernels/segment_reduction_ops_impl_1.cc b/tensorflow/core/kernels/segment_reduction_ops_impl_1.cc
index b9bee23..0407be0 100644
--- a/tensorflow/core/kernels/segment_reduction_ops_impl_1.cc
+++ b/tensorflow/core/kernels/segment_reduction_ops_impl_1.cc
@@ -147,33 +147,37 @@
#undef REGISTER_COMPLEX_CPU_KERNELS_ALL
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
-#define REGISTER_GPU_KERNEL_SORTEDSEGMENT( \
- name, type, index_type, initial_value_functor, reduction_kernel_functor, \
- atomic_reduction_kernel_functor) \
- REGISTER_KERNEL_BUILDER( \
- Name(name) \
- .Device(DEVICE_GPU) \
- .TypeConstraint<type>("T") \
- .TypeConstraint<index_type>("Tindices"), \
- SegmentReductionGPUOp< \
- type, index_type, \
- functor::SegmentReductionFunctor< \
- type, index_type, initial_value_functor, \
- reduction_kernel_functor, atomic_reduction_kernel_functor> >)
+#define REGISTER_GPU_KERNEL_SORTEDSEGMENT( \
+ name, type, index_type, initial_value_functor, \
+ empty_segment_value_functor, reduction_kernel_functor, is_mean) \
+ REGISTER_KERNEL_BUILDER( \
+ Name(name) \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<index_type>("Tindices"), \
+ SegmentReductionGPUOp< \
+ type, index_type, \
+ functor::SegmentReductionFunctor< \
+ type, index_type, initial_value_functor, \
+ empty_segment_value_functor, reduction_kernel_functor>, \
+ is_mean>)
-#define REGISTER_GPU_SORTED_KERNELS(type, index_type) \
- REGISTER_GPU_KERNEL_SORTEDSEGMENT( \
- "SegmentSum", type, index_type, functor::Zero<type>, \
- functor::NonAtomicSumOpGpu<type>, functor::AtomicSumOpGpu<type>); \
- REGISTER_GPU_KERNEL_SORTEDSEGMENT( \
- "SegmentProd", type, index_type, functor::One<type>, \
- functor::NonAtomicProdOpGpu<type>, functor::AtomicProdOpGpu<type>); \
- REGISTER_GPU_KERNEL_SORTEDSEGMENT( \
- "SegmentMin", type, index_type, functor::Highest<type>, \
- functor::NonAtomicMinOpGpu<type>, functor::AtomicMinOpGpu<type>); \
- REGISTER_GPU_KERNEL_SORTEDSEGMENT( \
- "SegmentMax", type, index_type, functor::Lowest<type>, \
- functor::NonAtomicMaxOpGpu<type>, functor::AtomicMaxOpGpu<type>);
+#define REGISTER_GPU_SORTED_KERNELS(type, index_type) \
+ REGISTER_GPU_KERNEL_SORTEDSEGMENT("SegmentSum", type, index_type, \
+ functor::Zero<type>, functor::Zero<type>, \
+ functor::Sum, /*is_mean=*/false); \
+ REGISTER_GPU_KERNEL_SORTEDSEGMENT("SegmentMean", type, index_type, \
+ functor::Zero<type>, functor::Zero<type>, \
+ functor::Sum, /*is_mean=*/true); \
+ REGISTER_GPU_KERNEL_SORTEDSEGMENT("SegmentProd", type, index_type, \
+ functor::One<type>, functor::One<type>, \
+ functor::Prod, /*is_mean=*/false); \
+ REGISTER_GPU_KERNEL_SORTEDSEGMENT( \
+ "SegmentMin", type, index_type, functor::Highest<type>, \
+ functor::Zero<type>, functor::Min, /*is_mean=*/false); \
+ REGISTER_GPU_KERNEL_SORTEDSEGMENT( \
+ "SegmentMax", type, index_type, functor::Lowest<type>, \
+ functor::Zero<type>, functor::Max, /*is_mean=*/false);
#define REGISTER_GPU_SORTED_KERNELS_ALL(type) \
REGISTER_GPU_SORTED_KERNELS(type, int32)
diff --git a/tensorflow/core/kernels/segment_reduction_ops_impl_2.cc b/tensorflow/core/kernels/segment_reduction_ops_impl_2.cc
index 0761a64..cb6c0cf 100644
--- a/tensorflow/core/kernels/segment_reduction_ops_impl_2.cc
+++ b/tensorflow/core/kernels/segment_reduction_ops_impl_2.cc
@@ -63,33 +63,37 @@
#undef REGISTER_COMPLEX_CPU_KERNELS_ALL
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
-#define REGISTER_GPU_KERNEL_SORTEDSEGMENT( \
- name, type, index_type, initial_value_functor, reduction_kernel_functor, \
- atomic_reduction_kernel_functor) \
- REGISTER_KERNEL_BUILDER( \
- Name(name) \
- .Device(DEVICE_GPU) \
- .TypeConstraint<type>("T") \
- .TypeConstraint<index_type>("Tindices"), \
- SegmentReductionGPUOp< \
- type, index_type, \
- functor::SegmentReductionFunctor< \
- type, index_type, initial_value_functor, \
- reduction_kernel_functor, atomic_reduction_kernel_functor> >)
+#define REGISTER_GPU_KERNEL_SORTEDSEGMENT( \
+ name, type, index_type, initial_value_functor, \
+ empty_segment_value_functor, reduction_kernel_functor, is_mean) \
+ REGISTER_KERNEL_BUILDER( \
+ Name(name) \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<index_type>("Tindices"), \
+ SegmentReductionGPUOp< \
+ type, index_type, \
+ functor::SegmentReductionFunctor< \
+ type, index_type, initial_value_functor, \
+ empty_segment_value_functor, reduction_kernel_functor>, \
+ is_mean>)
-#define REGISTER_GPU_SORTED_KERNELS(type, index_type) \
- REGISTER_GPU_KERNEL_SORTEDSEGMENT( \
- "SegmentSum", type, index_type, functor::Zero<type>, \
- functor::NonAtomicSumOpGpu<type>, functor::AtomicSumOpGpu<type>); \
- REGISTER_GPU_KERNEL_SORTEDSEGMENT( \
- "SegmentProd", type, index_type, functor::One<type>, \
- functor::NonAtomicProdOpGpu<type>, functor::AtomicProdOpGpu<type>); \
- REGISTER_GPU_KERNEL_SORTEDSEGMENT( \
- "SegmentMin", type, index_type, functor::Highest<type>, \
- functor::NonAtomicMinOpGpu<type>, functor::AtomicMinOpGpu<type>); \
- REGISTER_GPU_KERNEL_SORTEDSEGMENT( \
- "SegmentMax", type, index_type, functor::Lowest<type>, \
- functor::NonAtomicMaxOpGpu<type>, functor::AtomicMaxOpGpu<type>);
+#define REGISTER_GPU_SORTED_KERNELS(type, index_type) \
+ REGISTER_GPU_KERNEL_SORTEDSEGMENT("SegmentSum", type, index_type, \
+ functor::Zero<type>, functor::Zero<type>, \
+ functor::Sum, /*is_mean=*/false); \
+ REGISTER_GPU_KERNEL_SORTEDSEGMENT("SegmentMean", type, index_type, \
+ functor::Zero<type>, functor::Zero<type>, \
+ functor::Sum, /*is_mean=*/true); \
+ REGISTER_GPU_KERNEL_SORTEDSEGMENT("SegmentProd", type, index_type, \
+ functor::One<type>, functor::One<type>, \
+ functor::Prod, /*is_mean=*/false); \
+ REGISTER_GPU_KERNEL_SORTEDSEGMENT( \
+ "SegmentMin", type, index_type, functor::Highest<type>, \
+ functor::Zero<type>, functor::Min, /*is_mean=*/false); \
+ REGISTER_GPU_KERNEL_SORTEDSEGMENT( \
+ "SegmentMax", type, index_type, functor::Lowest<type>, \
+ functor::Zero<type>, functor::Max, /*is_mean=*/false);
#define REGISTER_GPU_SORTED_KERNELS_ALL(type) \
REGISTER_GPU_SORTED_KERNELS(type, int64_t);
diff --git a/tensorflow/core/kernels/segment_reduction_ops_impl_3.cc b/tensorflow/core/kernels/segment_reduction_ops_impl_3.cc
index c809a5e..be5e94d 100644
--- a/tensorflow/core/kernels/segment_reduction_ops_impl_3.cc
+++ b/tensorflow/core/kernels/segment_reduction_ops_impl_3.cc
@@ -87,19 +87,15 @@
// sum is the only op that supports all input types currently
#define REGISTER_REAL_GPU_UNSORTED_KERNELS(type, index_type) \
REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMax", type, index_type, \
- functor::Lowest<type>, \
- functor::AtomicMaxOpGpu<type>); \
+ functor::Lowest<type>, functor::Max); \
REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMin", type, index_type, \
- functor::Highest<type>, \
- functor::AtomicMinOpGpu<type>); \
+ functor::Highest<type>, functor::Min); \
REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentProd", type, index_type, \
- functor::One<type>, \
- functor::AtomicProdOpGpu<type>);
+ functor::One<type>, functor::Prod);
#define REGISTER_SUM_GPU_UNSORTED_KERNELS(type, index_type) \
REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentSum", type, index_type, \
- functor::Zero<type>, \
- functor::AtomicSumOpGpu<type>);
+ functor::Zero<type>, functor::Sum);
#define REGISTER_REAL_GPU_UNSORTED_KERNELS_ALL(type) \
REGISTER_REAL_GPU_UNSORTED_KERNELS(type, int32)
diff --git a/tensorflow/core/kernels/segment_reduction_ops_impl_4.cc b/tensorflow/core/kernels/segment_reduction_ops_impl_4.cc
index fb09881..7913aef 100644
--- a/tensorflow/core/kernels/segment_reduction_ops_impl_4.cc
+++ b/tensorflow/core/kernels/segment_reduction_ops_impl_4.cc
@@ -87,19 +87,15 @@
// sum is the only op that supports all input types currently
#define REGISTER_REAL_GPU_UNSORTED_KERNELS(type, index_type) \
REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMax", type, index_type, \
- functor::Lowest<type>, \
- functor::AtomicMaxOpGpu<type>); \
+ functor::Lowest<type>, functor::Max); \
REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMin", type, index_type, \
- functor::Highest<type>, \
- functor::AtomicMinOpGpu<type>); \
+ functor::Highest<type>, functor::Min); \
REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentProd", type, index_type, \
- functor::One<type>, \
- functor::AtomicProdOpGpu<type>);
+ functor::One<type>, functor::Prod);
#define REGISTER_SUM_GPU_UNSORTED_KERNELS(type, index_type) \
REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentSum", type, index_type, \
- functor::Zero<type>, \
- functor::AtomicSumOpGpu<type>);
+ functor::Zero<type>, functor::Sum);
#define REGISTER_REAL_GPU_UNSORTED_KERNELS_ALL(type) \
REGISTER_REAL_GPU_UNSORTED_KERNELS(type, int64_t)
diff --git a/tensorflow/core/util/gpu_kernel_helper.h b/tensorflow/core/util/gpu_kernel_helper.h
index 8604d7c..a5ae09b 100644
--- a/tensorflow/core/util/gpu_kernel_helper.h
+++ b/tensorflow/core/util/gpu_kernel_helper.h
@@ -275,6 +275,19 @@
DEFINE_BINARY_OPERATOR(/)
#undef DEFINE_BINARY_OPERATOR
+#define DEFINE_BINARY_FUNCTION(func) \
+ friend __host__ __device__ AlignedVector func(const AlignedVector& lhs, \
+ const AlignedVector& rhs) { \
+ AlignedVector ret; \
+ UNROLL_ON_DEVICE for (int i = 0; i < kSize; ++i) { \
+ ret[i] = func(lhs[i], rhs[i]); \
+ } \
+ return ret; \
+ }
+ DEFINE_BINARY_FUNCTION(min)
+ DEFINE_BINARY_FUNCTION(max)
+#undef DEFINE_BINARY_FUNCTION
+
private:
value_type values_[N];
};
diff --git a/tensorflow/python/kernel_tests/segment_reduction_ops_deterministic_test.py b/tensorflow/python/kernel_tests/segment_reduction_ops_deterministic_test.py
index dceea7d..c2b1f05 100644
--- a/tensorflow/python/kernel_tests/segment_reduction_ops_deterministic_test.py
+++ b/tensorflow/python/kernel_tests/segment_reduction_ops_deterministic_test.py
@@ -18,6 +18,8 @@
from __future__ import division
from __future__ import print_function
+import os
+
from tensorflow.python.eager import backprop
from tensorflow.python.framework import config
from tensorflow.python.framework import constant_op
@@ -32,6 +34,10 @@
from tensorflow.python.platform import test
+def UsingAtomicSegmentReductions():
+ return bool(int(os.getenv("TF_USE_ATOMIC_SEGMENT_REDUCTIONS", "0")))
+
+
class SegmentReductionDeterminismExceptionsTest(test.TestCase):
"""Test d9m-unimplemented exceptions from the segment reduction ops.
@@ -63,7 +69,7 @@
for data_type in [dtypes.float16, dtypes.float32, dtypes.float64]:
with self.cached_session(force_gpu=True):
data, segment_ids, _ = self._input(data_type, segment_ids_type)
- if should_throw_for_float:
+ if UsingAtomicSegmentReductions() and should_throw_for_float:
with self.assertRaisesRegex(
errors_impl.UnimplementedError,
"Deterministic GPU implementation of sorted segment " +
@@ -99,7 +105,8 @@
continue
data, segment_ids, num_segments = self._input(
data_type, segment_ids_type)
- if (data_type != dtypes.int32) and should_throw_for_float:
+ if (UsingAtomicSegmentReductions() and (data_type != dtypes.int32)
+ and should_throw_for_float):
with self.assertRaisesRegex(errors_impl.UnimplementedError,
self._UNSORTED_ERROR_MESSAGE):
result = op(data, segment_ids, num_segments)
@@ -121,8 +128,12 @@
with self.cached_session(force_gpu=True):
data, segment_ids, num_segments = self._input(
data_type, segment_ids_type)
- with self.assertRaisesRegex(errors_impl.UnimplementedError,
- self._UNSORTED_ERROR_MESSAGE):
+ if UsingAtomicSegmentReductions():
+ with self.assertRaisesRegex(errors_impl.UnimplementedError,
+ self._UNSORTED_ERROR_MESSAGE):
+ result = op(data, segment_ids, num_segments)
+ self.evaluate(result)
+ else:
result = op(data, segment_ids, num_segments)
self.evaluate(result)
@@ -138,9 +149,13 @@
values, indices, _ = self._input(data_type, segment_ids_type)
sparse_value = indexed_slices.IndexedSlices(
values, indices, dense_shape=values.shape)
- with self.assertRaisesRegex(errors_impl.UnimplementedError,
- self._UNSORTED_ERROR_MESSAGE):
- # convert_to_tensor with IndexedSlices uses unsorted_segment_sum
+ if UsingAtomicSegmentReductions():
+ with self.assertRaisesRegex(errors_impl.UnimplementedError,
+ self._UNSORTED_ERROR_MESSAGE):
+ # convert_to_tensor with IndexedSlices uses unsorted_segment_sum
+ result = ops.convert_to_tensor(sparse_value)
+ self.evaluate(result)
+ else:
result = ops.convert_to_tensor(sparse_value)
self.evaluate(result)
@@ -158,9 +173,12 @@
tape.watch(params)
op_output = array_ops.gather(params, indices)
gradient = tape.gradient(op_output, params)
- with self.assertRaisesRegex(errors_impl.UnimplementedError,
- self._UNSORTED_ERROR_MESSAGE):
- # convert_to_tensor on IndexedSlices
+ if UsingAtomicSegmentReductions():
+ with self.assertRaisesRegex(errors_impl.UnimplementedError,
+ self._UNSORTED_ERROR_MESSAGE):
+ # convert_to_tensor on IndexedSlices
+ self.evaluate(params.assign(gradient))
+ else:
self.evaluate(params.assign(gradient))