Updated the reduction code so that full reductions now return a tensor of rank 0.
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorReduction.h b/unsupported/Eigen/CXX11/src/Tensor/TensorReduction.h
index 1d22843..1d534f8 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorReduction.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorReduction.h
@@ -64,10 +64,10 @@
}
};
-template <> struct DimInitializer<Sizes<1> > {
+template <> struct DimInitializer<Sizes<0> > {
template <typename InputDims, typename Index, size_t Rank> EIGEN_DEVICE_FUNC
static void run(const InputDims& input_dims, const array<bool, Rank>&,
- Sizes<1>*, array<Index, Rank>* reduced_dims) {
+ Sizes<0>*, array<Index, Rank>* reduced_dims) {
const int NumInputDims = internal::array_size<InputDims>::value;
for (int i = 0; i < NumInputDims; ++i) {
(*reduced_dims)[i] = input_dims[i];
@@ -136,6 +136,12 @@
}
}
};
+template <typename Self, typename Op>
+struct GenericDimReducer<-1, Self, Op> {
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const Self&, typename Self::Index, Op&, typename Self::CoeffReturnType*) {
+ eigen_assert(false && "should never be called");
+ }
+};
template <typename Self, typename Op, bool Vectorizable = (Self::InputPacketAccess & Op::PacketAccess)>
struct InnerMostDimReducer {
@@ -192,6 +198,12 @@
}
}
};
+template <typename Self, typename Op>
+struct InnerMostDimPreserver<-1, Self, Op, true> {
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const Self&, typename Self::Index, Op&, typename Self::PacketReturnType*) {
+ eigen_assert(false && "should never be called");
+ }
+};
// Default full reducer
template <typename Self, typename Op, typename Device, bool Vectorizable = (Self::InputPacketAccess & Op::PacketAccess)>
@@ -550,8 +562,8 @@
typedef typename TensorEvaluator<ArgType, Device>::Dimensions InputDimensions;
static const int NumInputDims = internal::array_size<InputDimensions>::value;
static const int NumReducedDims = internal::array_size<Dims>::value;
- static const int NumOutputDims = (NumInputDims==NumReducedDims) ? 1 : NumInputDims - NumReducedDims;
- typedef typename internal::conditional<NumInputDims==NumReducedDims, Sizes<1>, DSizes<Index, NumOutputDims> >::type Dimensions;
+ static const int NumOutputDims = NumInputDims - NumReducedDims;
+ typedef typename internal::conditional<NumOutputDims==0, Sizes<0>, DSizes<Index, NumOutputDims> >::type Dimensions;
typedef typename XprType::Scalar Scalar;
typedef TensorEvaluator<const TensorReductionOp<Op, Dims, ArgType>, Device> Self;
static const bool InputPacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess;
@@ -565,7 +577,7 @@
static const bool ReducingInnerMostDims = internal::are_inner_most_dims<Dims, NumInputDims, Layout>::value;
static const bool PreservingInnerMostDims = internal::preserve_inner_most_dims<Dims, NumInputDims, Layout>::value;
- static const bool RunningFullReduction = (NumInputDims==NumReducedDims);
+ static const bool RunningFullReduction = (NumOutputDims==0);
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
: m_impl(op.expression(), device), m_reducer(op.reducer()), m_result(NULL), m_device(device)
@@ -589,51 +601,54 @@
internal::DimInitializer<Dimensions>::run(input_dims, reduced, &m_dimensions, &m_reducedDims);
// Precompute output strides.
- if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
- m_outputStrides[0] = 1;
- for (int i = 1; i < NumOutputDims; ++i) {
- m_outputStrides[i] = m_outputStrides[i - 1] * m_dimensions[i - 1];
- }
- } else {
- m_outputStrides[NumOutputDims - 1] = 1;
- for (int i = NumOutputDims - 2; i >= 0; --i) {
- m_outputStrides[i] = m_outputStrides[i + 1] * m_dimensions[i + 1];
- }
- }
-
- // Precompute input strides.
- array<Index, NumInputDims> input_strides;
- if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
- input_strides[0] = 1;
- for (int i = 1; i < NumInputDims; ++i) {
- input_strides[i] = input_strides[i-1] * input_dims[i-1];
- }
- } else {
- input_strides[NumInputDims - 1] = 1;
- for (int i = NumInputDims - 2; i >= 0; --i) {
- input_strides[i] = input_strides[i + 1] * input_dims[i + 1];
- }
- }
-
- int outputIndex = 0;
- int reduceIndex = 0;
- for (int i = 0; i < NumInputDims; ++i) {
- if (reduced[i]) {
- m_reducedStrides[reduceIndex] = input_strides[i];
- ++reduceIndex;
+ if (NumOutputDims > 0) {
+ if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
+ m_outputStrides[0] = 1;
+ for (int i = 1; i < NumOutputDims; ++i) {
+ m_outputStrides[i] = m_outputStrides[i - 1] * m_dimensions[i - 1];
+ }
} else {
- m_preservedStrides[outputIndex] = input_strides[i];
- ++outputIndex;
+ m_outputStrides[NumOutputDims - 1] = 1;
+ for (int i = NumOutputDims - 2; i >= 0; --i) {
+ m_outputStrides[i] = m_outputStrides[i + 1] * m_dimensions[i + 1];
+ }
+ }
+ }
+
+ // Precompute input strides.
+ if (NumInputDims > 0) {
+ array<Index, NumInputDims> input_strides;
+ if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
+ input_strides[0] = 1;
+ for (int i = 1; i < NumInputDims; ++i) {
+ input_strides[i] = input_strides[i-1] * input_dims[i-1];
+ }
+ } else {
+ input_strides[NumInputDims - 1] = 1;
+ for (int i = NumInputDims - 2; i >= 0; --i) {
+ input_strides[i] = input_strides[i + 1] * input_dims[i + 1];
+ }
+ }
+
+ int outputIndex = 0;
+ int reduceIndex = 0;
+ for (int i = 0; i < NumInputDims; ++i) {
+ if (reduced[i]) {
+ m_reducedStrides[reduceIndex] = input_strides[i];
+ ++reduceIndex;
+ } else {
+ m_preservedStrides[outputIndex] = input_strides[i];
+ ++outputIndex;
+ }
}
}
// Special case for full reductions
- if (NumInputDims == NumReducedDims) {
- eigen_assert(m_dimensions[0] == 1);
+ if (NumOutputDims == 0) {
m_preservedStrides[0] = internal::array_prod(input_dims);
}
}
-
+
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
typedef typename internal::remove_const<typename XprType::CoeffReturnType>::type CoeffReturnType;
@@ -674,9 +689,9 @@
return *m_result;
}
Op reducer(m_reducer);
- if (ReducingInnerMostDims) {
+ if (ReducingInnerMostDims || RunningFullReduction) {
const Index num_values_to_reduce =
- (static_cast<int>(Layout) == static_cast<int>(ColMajor)) ? m_preservedStrides[0] : m_preservedStrides[NumOutputDims - 1];
+ (static_cast<int>(Layout) == static_cast<int>(ColMajor)) ? m_preservedStrides[0] : m_preservedStrides[NumPreservedStrides - 1];
return internal::InnerMostDimReducer<Self, Op>::reduce(*this, firstInput(index),
num_values_to_reduce, reducer);
} else {
@@ -697,7 +712,7 @@
EIGEN_ALIGN_MAX typename internal::remove_const<CoeffReturnType>::type values[packetSize];
if (ReducingInnerMostDims) {
const Index num_values_to_reduce =
- (static_cast<int>(Layout) == static_cast<int>(ColMajor)) ? m_preservedStrides[0] : m_preservedStrides[NumOutputDims - 1];
+ (static_cast<int>(Layout) == static_cast<int>(ColMajor)) ? m_preservedStrides[0] : m_preservedStrides[NumPreservedStrides - 1];
const Index firstIndex = firstInput(index);
for (Index i = 0; i < packetSize; ++i) {
Op reducer(m_reducer);
@@ -748,7 +763,7 @@
if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
return index * m_preservedStrides[0];
} else {
- return index * m_preservedStrides[NumOutputDims - 1];
+ return index * m_preservedStrides[NumPreservedStrides - 1];
}
}
// TBD: optimize the case where we preserve the innermost dimensions.
@@ -774,10 +789,10 @@
index -= idx * m_outputStrides[i];
}
if (PreservingInnerMostDims) {
- eigen_assert(m_preservedStrides[NumOutputDims - 1] == 1);
+ eigen_assert(m_preservedStrides[NumPreservedStrides - 1] == 1);
startInput += index;
} else {
- startInput += index * m_preservedStrides[NumOutputDims - 1];
+ startInput += index * m_preservedStrides[NumPreservedStrides - 1];
}
}
return startInput;
@@ -789,7 +804,8 @@
array<Index, NumOutputDims> m_outputStrides;
// Subset of strides of the input tensor for the non-reduced dimensions.
// Indexed by output dimensions.
- array<Index, NumOutputDims> m_preservedStrides;
+ static const int NumPreservedStrides = max_n_1<NumOutputDims>::size;
+ array<Index, NumPreservedStrides> m_preservedStrides;
// Subset of strides of the input tensor for the reduced dimensions.
// Indexed by reduced dimensions.