Add optimized reduce mean.
PiperOrigin-RevId: 461584488
diff --git a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h
index 880ac5d..d76fa11 100644
--- a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h
+++ b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h
@@ -94,7 +94,6 @@
using reference_ops::LessEqual;
using reference_ops::LessEqualWithScaling;
using reference_ops::LessWithScaling;
-using reference_ops::Mean;
using reference_ops::ProcessBroadcastShapes;
using reference_ops::RankOneSelect;
using reference_ops::Relu0To1; // NOLINT
diff --git a/tensorflow/lite/kernels/internal/optimized/reduce.h b/tensorflow/lite/kernels/internal/optimized/reduce.h
index 664272c..a937069 100644
--- a/tensorflow/lite/kernels/internal/optimized/reduce.h
+++ b/tensorflow/lite/kernels/internal/optimized/reduce.h
@@ -28,6 +28,7 @@
#include "tensorflow/lite/kernels/internal/reference/reduce.h"
#include "tensorflow/lite/kernels/internal/runtime_shape.h"
#include "tensorflow/lite/kernels/internal/types.h"
+#include "tensorflow/lite/kernels/kernel_util.h"
namespace tflite {
namespace optimized_ops {
@@ -257,48 +258,6 @@
}
}
-template <typename T, typename U>
-inline bool MeanGeneral(const T* input_data, const int* input_dims,
- const int input_num_dims, T* output_data,
- const int* output_dims, const int output_num_dims,
- const int* axis, const int num_axis_dimensions,
- bool keep_dims, int* temp_index, int* resolved_axis,
- U* temp_sum) {
- return reference_ops::Mean(input_data, input_dims, input_num_dims,
- output_data, output_dims, output_num_dims, axis,
- num_axis_dimensions, keep_dims, temp_index,
- resolved_axis, temp_sum);
-}
-
-template <>
-inline bool MeanGeneral<float, float>(
- const float* input_data, const int* input_dims, const int input_num_dims,
- float* output_data, const int* output_dims, const int output_num_dims,
- const int* axis, const int num_axis_dimensions, bool keep_dims,
- int* temp_index, int* resolved_axis, float* temp_sum) {
- // Handle reduce_mean for the last dimensions.
- if (num_axis_dimensions == 1 && axis[0] == (input_num_dims - 1)) {
- ruy::profiler::ScopeLabel label("MeanLastDim/Float");
- int output_size = 1;
- for (int i = 0; i < input_num_dims - 1; ++i) {
- output_size *= input_dims[i];
- }
- const int last_input_dim = input_dims[axis[0]];
-
- // TODO(b/152563685): Consider use eigen to cover more general cases.
- const MatrixMap<const float> in_mat(input_data, last_input_dim,
- output_size);
- VectorMap<float> out(output_data, output_size, 1);
- out = (in_mat.array().colwise().sum()) / static_cast<float>(last_input_dim);
- return true;
- }
-
- return reference_ops::Mean(input_data, input_dims, input_num_dims,
- output_data, output_dims, output_num_dims, axis,
- num_axis_dimensions, keep_dims, temp_index,
- resolved_axis, temp_sum);
-}
-
template <typename T>
struct SumOp {
inline T operator()(const T& a) const { return a; }
@@ -352,10 +311,7 @@
template <typename T>
void ReduceIsCopy(const T* input_data, const int* input_dims,
const int input_num_dims, T* output_data) {
- int num_elems = 1;
- for (int i = 0; i < input_num_dims; ++i) {
- num_elems *= input_dims[i];
- }
+ int num_elems = NumElements(input_dims, input_num_dims);
memcpy(output_data, input_data, num_elems * sizeof(T));
}
@@ -481,10 +437,18 @@
return false;
}
- if (!Reduce<T, U, CastSumOp<T, U>, CastSumOp<T, U>>(
- input_data, normalized_dims, normalized_num_dims, resolved_axis,
- num_resolved_axis, temp_sum, CastSumOp<T, U>(), CastSumOp<T, U>())) {
- return false;
+ if (num_resolved_axis == 0) {
+ int count = NumElements(input_dims, input_num_dims);
+ for (int i = 0; i < count; ++i) {
+ temp_sum[i] = U(input_data[i]);
+ }
+ } else {
+ if (!Reduce<T, U, CastSumOp<T, U>, CastSumOp<T, U>>(
+ input_data, normalized_dims, normalized_num_dims, resolved_axis,
+ num_resolved_axis, temp_sum, CastSumOp<T, U>(),
+ CastSumOp<T, U>())) {
+ return false;
+ }
}
// Calculate mean by dividing output_data by num of aggregated element.
@@ -692,6 +656,123 @@
return true;
}
+template <typename T>
+inline void Mean(const tflite::MeanParams& op_params,
+ const RuntimeShape& input_shape, const T* input_data,
+ const RuntimeShape& output_shape, T* output_data) {
+ return reference_ops::Mean(op_params, input_shape, input_data, output_shape,
+ output_data);
+}
+
+// Computes the mean of elements across dimensions given in axis.
+// It does so in two stages, first calculates the sum of elements along the axis
+// then divides it by the number of element in axis.
+template <typename T, typename U>
+inline bool MeanGeneral(const T* input_data, const int* input_dims,
+ const int input_num_dims, T* output_data,
+ const int* output_dims, const int output_num_dims,
+ const int* axis, const int num_axis_dimensions,
+ bool keep_dims, int* normalized_dims,
+ int* resolved_axis, U* temp_sum) {
+ ruy::profiler::ScopeLabel label("Mean");
+ // Resolve axis.
+ int num_resolved_axis = 0;
+ int normalized_num_dims = 0;
+ if (!reduce_utils::ResolveAxis(input_num_dims, axis, num_axis_dimensions,
+ resolved_axis, num_resolved_axis, input_dims,
+ normalized_dims, normalized_num_dims)) {
+ return false;
+ }
+ if (num_resolved_axis == 0) {
+ optimized_ops::ReduceIsCopy(input_data, input_dims, input_num_dims,
+ output_data);
+ return true;
+ }
+ // Reset output data.
+ size_t num_outputs = 1;
+ for (int idx = 0; idx < output_num_dims; ++idx) {
+ size_t current = static_cast<size_t>(output_dims[idx]);
+ // Overflow prevention.
+ if (num_outputs > std::numeric_limits<size_t>::max() / current) {
+ return false;
+ }
+ num_outputs *= current;
+ }
+
+ if (!Reduce<T, U, CastSumOp<T, U>, CastSumOp<T, U>>(
+ input_data, normalized_dims, normalized_num_dims, resolved_axis,
+ num_resolved_axis, temp_sum, CastSumOp<T, U>(), CastSumOp<T, U>())) {
+ return false;
+ }
+
+ // Calculate mean by dividing output_data by num of aggregated element.
+ size_t num_elements_in_axis = 1;
+ for (int idx = 0; idx < num_resolved_axis; ++idx) {
+ size_t current = static_cast<size_t>(normalized_dims[resolved_axis[idx]]);
+ // Overflow prevention.
+ if (current > (std::numeric_limits<size_t>::max() / num_elements_in_axis)) {
+ return false;
+ }
+ num_elements_in_axis *= current;
+ }
+
+ if (num_elements_in_axis > 0) {
+ for (size_t idx = 0; idx < num_outputs; ++idx) {
+ output_data[idx] =
+ static_cast<T>(temp_sum[idx] / static_cast<U>(num_elements_in_axis));
+ }
+ }
+ return true;
+}
+
+template <typename T, typename U>
+inline bool Mean(const T* input_data, const int* input_dims,
+ const int input_num_dims, T* output_data,
+ const int* output_dims, const int output_num_dims,
+ const int* axis, const int num_axis_dimensions, bool keep_dims,
+ int* normalized_dims, int* resolved_axis, U* temp_sum) {
+ return MeanGeneral(input_data, input_dims, input_num_dims, output_data,
+ output_dims, output_num_dims, axis, num_axis_dimensions,
+ false, normalized_dims, resolved_axis, temp_sum);
+}
+
+// Use Eigen when Mean is calculated over the last dimension only of a float
+// tensor.
+template <>
+inline bool Mean<float, float>(const float* input_data, const int* input_dims,
+ const int input_num_dims, float* output_data,
+ const int* output_dims,
+ const int output_num_dims, const int* axis,
+ const int num_axis_dimensions, bool keep_dims,
+ int* normalized_dims, int* resolved_axis,
+ float* temp_sum) {
+ // Handle reduce_mean for the last dimensions.
+ int num_resolved_axis = 0;
+ int normalized_num_dims = 0;
+ if (!reduce_utils::ResolveAxis(input_num_dims, axis, num_axis_dimensions,
+ resolved_axis, num_resolved_axis, input_dims,
+ normalized_dims, normalized_num_dims)) {
+ return false;
+ }
+ if (normalized_num_dims > 1 && num_resolved_axis == 1 &&
+ resolved_axis[0] == (normalized_num_dims - 1)) {
+ ruy::profiler::ScopeLabel label("MeanLastDim/Float");
+ int output_size = normalized_dims[0];
+ const int last_input_dim = normalized_dims[1];
+
+ // TODO(b/152563685): Consider use eigen to cover more general cases.
+ const MatrixMap<const float> in_mat(input_data, last_input_dim,
+ output_size);
+ VectorMap<float> out(output_data, output_size, 1);
+ out = (in_mat.array().colwise().sum()) / static_cast<float>(last_input_dim);
+ return true;
+ }
+
+ return MeanGeneral(input_data, input_dims, input_num_dims, output_data,
+ output_dims, output_num_dims, axis, num_axis_dimensions,
+ false, normalized_dims, resolved_axis, temp_sum);
+}
+
// Computes the generic value (i.e., sum/max/min/prod) of elements across
// dimensions given in axis. It needs to pass in init_value and reducer.
template <typename T>
diff --git a/tensorflow/lite/kernels/kernel_util.h b/tensorflow/lite/kernels/kernel_util.h
index ed3a566..0687442 100644
--- a/tensorflow/lite/kernels/kernel_util.h
+++ b/tensorflow/lite/kernels/kernel_util.h
@@ -177,6 +177,14 @@
return NumElements(t->dims);
}
+inline int64_t NumElements(const int* dims, int num_dims) {
+ int64_t count = 1;
+ for (int i = 0; i < num_dims; ++i) {
+ count *= dims[i];
+ }
+ return count;
+}
+
// Determines whether tensor is constant.
// TODO(b/138199592): Introduce new query which checks for constant OR
// persistent-read-only, which would be useful for most tensor kernels that
diff --git a/tensorflow/lite/kernels/reduce.cc b/tensorflow/lite/kernels/reduce.cc
index ea0eebf..33ef06e 100644
--- a/tensorflow/lite/kernels/reduce.cc
+++ b/tensorflow/lite/kernels/reduce.cc
@@ -373,66 +373,76 @@
}
}
+template <typename T, typename U>
+TfLiteStatus Mean(TfLiteContext* context, const OpContext* op_context,
+ int* temp_index, int* resolved_axis, U* temp_sum,
+ KernelType kernel_type) {
+ int num_axis = static_cast<int>(NumElements(op_context->axis));
+ auto args = std::tuple(
+ GetTensorData<T>(op_context->input), &op_context->input->dims->data[0],
+ op_context->input->dims->size, GetTensorData<T>(op_context->output),
+ &op_context->output->dims->data[0], op_context->output->dims->size,
+ GetTensorData<int>(op_context->axis), num_axis,
+ op_context->params->keep_dims, temp_index, resolved_axis, temp_sum);
+ if (kernel_type == kReference) {
+ TF_LITE_ENSURE(context, std::apply(reference_ops::Mean<T, U>, args));
+ } else {
+ TF_LITE_ENSURE(context, std::apply(optimized_ops::Mean<T, U>, args));
+ }
+ return kTfLiteOk;
+}
+
+template <typename T>
+TfLiteStatus QuantizedMeanOrSum(TfLiteContext* context,
+ const OpContext* op_context, int* temp_index,
+ int* resolved_axis, int* temp_sum,
+ KernelType kernel_type, bool compute_sum) {
+ int num_axis = static_cast<int>(NumElements(op_context->axis));
+ auto args = std::tuple(
+ GetTensorData<T>(op_context->input), op_context->input->params.zero_point,
+ op_context->input->params.scale, &op_context->input->dims->data[0],
+ op_context->input->dims->size, GetTensorData<T>(op_context->output),
+ op_context->output->params.zero_point, op_context->output->params.scale,
+ &op_context->output->dims->data[0], op_context->output->dims->size,
+ GetTensorData<int>(op_context->axis), num_axis,
+ op_context->params->keep_dims, temp_index, resolved_axis, temp_sum,
+ compute_sum);
+ if (kernel_type == kReference) {
+ TF_LITE_ENSURE(
+ context,
+ std::apply(reference_ops::QuantizedMeanOrSum<T, int32_t>, args));
+ } else {
+ TF_LITE_ENSURE(
+ context,
+ std::apply(optimized_ops::QuantizedMeanOrSum<T, int32_t>, args));
+ }
+ return kTfLiteOk;
+}
+
template <typename integer_type>
-TfLiteStatus EvalMeanReferenceOps(TfLiteContext* context,
- const OpContext& op_context, int num_axis,
- OpData* data, TfLiteTensor* temp_index,
- TfLiteTensor* resolved_axis,
- TfLiteTensor* temp_sum) {
+TfLiteStatus EvalIntegerMean(TfLiteContext* context,
+ const OpContext& op_context, int num_axis,
+ OpData* data, TfLiteTensor* temp_index,
+ TfLiteTensor* resolved_axis,
+ TfLiteTensor* temp_sum,
+ TfLiteTensor* normalized_dims,
+ KernelType kernel_type) {
tflite::MeanParams op_params;
op_params.axis_count = num_axis;
ResolveAxis(GetTensorData<int>(op_context.axis), num_axis, &op_params);
const TfLiteTensor* input = op_context.input;
- // TODO(b/139102329): Handle all the cases in the combined reference
- // method.
- if (op_context.params->keep_dims && NumDimensions(input) == 4 &&
- op_params.axis_count == 2 &&
- ((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
- (op_params.axis[0] == 2 && op_params.axis[1] == 1))) {
- if (std::is_same<integer_type, uint8_t>::value) {
- reference_ops::Mean(op_params, GetTensorShape(op_context.input),
- GetTensorData<uint8_t>(op_context.input),
- op_context.input->params.zero_point,
- op_context.input->params.scale,
- GetTensorShape(op_context.output),
- GetTensorData<uint8_t>(op_context.output),
- op_context.output->params.zero_point,
- op_context.output->params.scale);
- } else {
- reference_integer_ops::Mean(
- op_params, data->multiplier, data->shift, GetTensorShape(input),
- GetTensorData<integer_type>(input),
- op_context.input->params.zero_point,
- GetTensorShape(op_context.output),
- GetTensorData<integer_type>(op_context.output),
- op_context.output->params.zero_point);
- }
- } else if (input->params.zero_point == op_context.output->params.zero_point &&
- input->params.scale == op_context.output->params.scale) {
- TF_LITE_ENSURE(
- context,
- reference_ops::Mean(
- GetTensorData<integer_type>(input), input->dims->data,
- input->dims->size, GetTensorData<integer_type>(op_context.output),
- op_context.output->dims->data, op_context.output->dims->size,
- GetTensorData<int>(op_context.axis), num_axis,
- op_context.params->keep_dims, GetTensorData<int>(temp_index),
- GetTensorData<int>(resolved_axis), GetTensorData<int>(temp_sum)));
+ if (input->params.zero_point == op_context.output->params.zero_point &&
+ input->params.scale == op_context.output->params.scale) {
+ Mean<integer_type, int>(context, &op_context,
+ GetTensorData<int>(temp_index),
+ GetTensorData<int>(resolved_axis),
+ GetTensorData<int>(temp_sum), kernel_type);
} else {
- TF_LITE_ENSURE(
- context,
- reference_ops::QuantizedMeanOrSum<>(
- GetTensorData<integer_type>(input), input->params.zero_point,
- input->params.scale, input->dims->data, input->dims->size,
- GetTensorData<integer_type>(op_context.output),
- op_context.output->params.zero_point,
- op_context.output->params.scale, op_context.output->dims->data,
- op_context.output->dims->size, GetTensorData<int>(op_context.axis),
- num_axis, op_context.params->keep_dims,
- GetTensorData<int>(temp_index), GetTensorData<int>(resolved_axis),
- GetTensorData<int>(temp_sum),
- /*compute_sum=*/false));
+ QuantizedMeanOrSum<integer_type>(
+ context, &op_context, GetTensorData<int>(temp_index),
+ GetTensorData<int>(resolved_axis), GetTensorData<int32_t>(temp_sum),
+ kernel_type, /*compute_sum=*/false);
}
return kTfLiteOk;
}
@@ -496,6 +506,13 @@
TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
TF_LITE_ENSURE_OK(context, ResizeTempAccum(context, &op_context, temp_sum));
}
+ TfLiteTensor* normalized_dims;
+ TF_LITE_ENSURE_OK(
+ context, GetTemporarySafe(context, node, /*index=*/3, &normalized_dims));
+ if (IsDynamicTensor(normalized_dims)) {
+ TF_LITE_ENSURE_OK(context,
+ ResizeTempDims(context, &op_context, normalized_dims));
+ }
// Return early when input is empty.
const TfLiteTensor* input = op_context.input;
@@ -551,80 +568,40 @@
}
}
- // From here, it uses the reference implementations.
- // TODO(b/139102329): Clean up the function signatures to merge the variations
- // and handle the specialized cases in the combined reference implementations
- // per each op.
switch (op_context.input->type) {
- case kTfLiteFloat32: {
- tflite::MeanParams op_params;
- op_params.axis_count = num_axis;
- ResolveAxis(GetTensorData<int>(op_context.axis), num_axis, &op_params);
- const TfLiteTensor* input = op_context.input;
- // TODO(b/139102329): Handle the below special case in the combined
- // reference method.
- // Defer to specialized implementation for 4D Mean across axes 1 & 2.
- if (op_context.params->keep_dims && NumDimensions(input) == 4 &&
- op_params.axis_count == 2 &&
- ((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
- (op_params.axis[0] == 2 && op_params.axis[1] == 1))) {
- reference_ops::Mean(op_params, input_shape, GetTensorData<float>(input),
- GetTensorShape(op_context.output),
- GetTensorData<float>(op_context.output));
- } else {
- TF_LITE_ENSURE(
- context,
- optimized_ops::MeanGeneral(
- GetTensorData<float>(op_context.input),
- op_context.input->dims->data, op_context.input->dims->size,
- GetTensorData<float>(op_context.output),
- op_context.output->dims->data, op_context.output->dims->size,
- GetTensorData<int>(op_context.axis), num_axis,
- op_context.params->keep_dims, GetTensorData<int>(temp_index),
- GetTensorData<int>(resolved_axis),
- GetTensorData<float>(temp_sum)));
- }
- } break;
+ case kTfLiteFloat32:
+ Mean<float, float>(context, &op_context, GetTensorData<int>(temp_index),
+ GetTensorData<int>(resolved_axis),
+ GetTensorData<float>(temp_sum), kernel_type);
+ break;
case kTfLiteInt32:
- TF_LITE_ENSURE(
- context,
- reference_ops::Mean(
- GetTensorData<int>(op_context.input),
- op_context.input->dims->data, op_context.input->dims->size,
- GetTensorData<int>(op_context.output),
- op_context.output->dims->data, op_context.output->dims->size,
- GetTensorData<int>(op_context.axis), num_axis,
- op_context.params->keep_dims, GetTensorData<int>(temp_index),
- GetTensorData<int>(resolved_axis),
- GetTensorData<int64_t>(temp_sum)));
+ Mean<int, int64_t>(context, &op_context, GetTensorData<int>(temp_index),
+ GetTensorData<int>(resolved_axis),
+ GetTensorData<int64_t>(temp_sum), kernel_type);
break;
case kTfLiteInt64:
- TF_LITE_ENSURE(
- context,
- reference_ops::Mean(
- GetTensorData<int64_t>(op_context.input),
- op_context.input->dims->data, op_context.input->dims->size,
- GetTensorData<int64_t>(op_context.output),
- op_context.output->dims->data, op_context.output->dims->size,
- GetTensorData<int>(op_context.axis), num_axis,
- op_context.params->keep_dims, GetTensorData<int>(temp_index),
- GetTensorData<int>(resolved_axis),
- GetTensorData<int64_t>(temp_sum)));
+ Mean<int64_t, int64_t>(context, &op_context,
+ GetTensorData<int>(temp_index),
+ GetTensorData<int>(resolved_axis),
+ GetTensorData<int64_t>(temp_sum), kernel_type);
break;
case kTfLiteInt8: {
- TF_LITE_ENSURE_OK(context, EvalMeanReferenceOps<int8_t>(
- context, op_context, num_axis, data,
- temp_index, resolved_axis, temp_sum));
+ TF_LITE_ENSURE_OK(
+ context, EvalIntegerMean<int8_t>(context, op_context, num_axis, data,
+ temp_index, resolved_axis, temp_sum,
+ normalized_dims, kernel_type));
} break;
case kTfLiteInt16: {
- TF_LITE_ENSURE_OK(context, EvalMeanReferenceOps<int16_t>(
- context, op_context, num_axis, data,
- temp_index, resolved_axis, temp_sum));
+ TF_LITE_ENSURE_OK(
+ context, EvalIntegerMean<int16_t>(context, op_context, num_axis, data,
+ temp_index, resolved_axis, temp_sum,
+ normalized_dims, kernel_type));
} break;
case kTfLiteUInt8: {
- TF_LITE_ENSURE_OK(context, EvalMeanReferenceOps<uint8_t>(
- context, op_context, num_axis, data,
- temp_index, resolved_axis, temp_sum));
+ TF_LITE_ENSURE_OK(
+ context, EvalIntegerMean<uint8_t>(context, op_context, num_axis, data,
+ temp_index, resolved_axis, temp_sum,
+ normalized_dims, kernel_type));
} break;
default:
return kTfLiteError;
@@ -681,10 +658,7 @@
eval_data.input_data = input_data;
eval_data.output = init_value;
- int num_elems = 1;
- for (int i = 0; i < input_num_dims; ++i) {
- num_elems *= input_dims[i];
- }
+ int num_elems = NumElements(input_dims, input_num_dims);
// Fetch backend context and number of threads.
CpuBackendContext* cpu_backend_context =
@@ -873,33 +847,6 @@
}
}
-template <typename T>
-TfLiteStatus QuantizedMeanOrSum(TfLiteContext* context, OpContext* op_context,
- int* temp_index, int* resolved_axis,
- int* temp_sum, KernelType kernel_type,
- bool compute_sum) {
- int num_axis = static_cast<int>(NumElements(op_context->axis));
- auto args = std::tuple(
- GetTensorData<T>(op_context->input), op_context->input->params.zero_point,
- op_context->input->params.scale, &op_context->input->dims->data[0],
- op_context->input->dims->size, GetTensorData<T>(op_context->output),
- op_context->output->params.zero_point, op_context->output->params.scale,
- &op_context->output->dims->data[0], op_context->output->dims->size,
- GetTensorData<int>(op_context->axis), num_axis,
- op_context->params->keep_dims, temp_index, resolved_axis, temp_sum,
- compute_sum);
- if (kernel_type == kReference) {
- TF_LITE_ENSURE(
- context,
- std::apply(reference_ops::QuantizedMeanOrSum<T, int32_t>, args));
- } else {
- TF_LITE_ENSURE(
- context,
- std::apply(optimized_ops::QuantizedMeanOrSum<T, int32_t>, args));
- }
- return kTfLiteOk;
-}
-
template <KernelType kernel_type>
TfLiteStatus EvalSum(TfLiteContext* context, TfLiteNode* node) {
OpContext op_context(context, node);
@@ -1153,13 +1100,7 @@
return &r;
}
-TfLiteRegistration* Register_MEAN() {
-#ifdef USE_NEON
- return Register_MEAN_OPT();
-#else
- return Register_MEAN_REF();
-#endif
-}
+TfLiteRegistration* Register_MEAN() { return Register_MEAN_OPT(); }
TfLiteRegistration* Register_SUM() { return Register_SUM_OPT(); }
TfLiteRegistration* Register_REDUCE_PROD() {
diff --git a/tensorflow/lite/kernels/reduce_test.cc b/tensorflow/lite/kernels/reduce_test.cc
index fd03651..e9f5fca 100644
--- a/tensorflow/lite/kernels/reduce_test.cc
+++ b/tensorflow/lite/kernels/reduce_test.cc
@@ -309,6 +309,18 @@
EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({3.})));
}
+TEST(ConstFloatMeanOpTest, UseOptimzedFloatMean) {
+ std::vector<float> data = {0.1, 0.2, 0.3, 0.4, 0.1, 0.2,
+ 0.3, 0.4, 0.1, 0.2, 0.3, 0.4};
+ MeanOpConstModel m({TensorType_FLOAT32, {2, 3, 2}}, {TensorType_FLOAT32, {2}},
+ {2}, {1, 2}, false);
+ m.SetInput(data);
+ ASSERT_EQ(m.Invoke(), kTfLiteOk);
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
+ EXPECT_THAT(m.GetOutput<float>(),
+ ElementsAreArray(ArrayFloatNear({0.216667, 0.283333})));
+}
+
TEST(DynamicFloatMeanOpTest, NotKeepDims) {
std::vector<float> data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,