Add optimized reduce mean.

PiperOrigin-RevId: 461584488
diff --git a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h
index 880ac5d..d76fa11 100644
--- a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h
+++ b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h
@@ -94,7 +94,6 @@
 using reference_ops::LessEqual;
 using reference_ops::LessEqualWithScaling;
 using reference_ops::LessWithScaling;
-using reference_ops::Mean;
 using reference_ops::ProcessBroadcastShapes;
 using reference_ops::RankOneSelect;
 using reference_ops::Relu0To1;  // NOLINT
diff --git a/tensorflow/lite/kernels/internal/optimized/reduce.h b/tensorflow/lite/kernels/internal/optimized/reduce.h
index 664272c..a937069 100644
--- a/tensorflow/lite/kernels/internal/optimized/reduce.h
+++ b/tensorflow/lite/kernels/internal/optimized/reduce.h
@@ -28,6 +28,7 @@
 #include "tensorflow/lite/kernels/internal/reference/reduce.h"
 #include "tensorflow/lite/kernels/internal/runtime_shape.h"
 #include "tensorflow/lite/kernels/internal/types.h"
+#include "tensorflow/lite/kernels/kernel_util.h"
 
 namespace tflite {
 namespace optimized_ops {
@@ -257,48 +258,6 @@
   }
 }
 
-template <typename T, typename U>
-inline bool MeanGeneral(const T* input_data, const int* input_dims,
-                        const int input_num_dims, T* output_data,
-                        const int* output_dims, const int output_num_dims,
-                        const int* axis, const int num_axis_dimensions,
-                        bool keep_dims, int* temp_index, int* resolved_axis,
-                        U* temp_sum) {
-  return reference_ops::Mean(input_data, input_dims, input_num_dims,
-                             output_data, output_dims, output_num_dims, axis,
-                             num_axis_dimensions, keep_dims, temp_index,
-                             resolved_axis, temp_sum);
-}
-
-template <>
-inline bool MeanGeneral<float, float>(
-    const float* input_data, const int* input_dims, const int input_num_dims,
-    float* output_data, const int* output_dims, const int output_num_dims,
-    const int* axis, const int num_axis_dimensions, bool keep_dims,
-    int* temp_index, int* resolved_axis, float* temp_sum) {
-  // Handle reduce_mean for the last dimensions.
-  if (num_axis_dimensions == 1 && axis[0] == (input_num_dims - 1)) {
-    ruy::profiler::ScopeLabel label("MeanLastDim/Float");
-    int output_size = 1;
-    for (int i = 0; i < input_num_dims - 1; ++i) {
-      output_size *= input_dims[i];
-    }
-    const int last_input_dim = input_dims[axis[0]];
-
-    // TODO(b/152563685): Consider use eigen to cover more general cases.
-    const MatrixMap<const float> in_mat(input_data, last_input_dim,
-                                        output_size);
-    VectorMap<float> out(output_data, output_size, 1);
-    out = (in_mat.array().colwise().sum()) / static_cast<float>(last_input_dim);
-    return true;
-  }
-
-  return reference_ops::Mean(input_data, input_dims, input_num_dims,
-                             output_data, output_dims, output_num_dims, axis,
-                             num_axis_dimensions, keep_dims, temp_index,
-                             resolved_axis, temp_sum);
-}
-
 template <typename T>
 struct SumOp {
   inline T operator()(const T& a) const { return a; }
@@ -352,10 +311,7 @@
 template <typename T>
 void ReduceIsCopy(const T* input_data, const int* input_dims,
                   const int input_num_dims, T* output_data) {
-  int num_elems = 1;
-  for (int i = 0; i < input_num_dims; ++i) {
-    num_elems *= input_dims[i];
-  }
+  int num_elems = NumElements(input_dims, input_num_dims);
   memcpy(output_data, input_data, num_elems * sizeof(T));
 }
 
@@ -481,10 +437,18 @@
     return false;
   }
 
-  if (!Reduce<T, U, CastSumOp<T, U>, CastSumOp<T, U>>(
-          input_data, normalized_dims, normalized_num_dims, resolved_axis,
-          num_resolved_axis, temp_sum, CastSumOp<T, U>(), CastSumOp<T, U>())) {
-    return false;
+  if (num_resolved_axis == 0) {
+    int count = NumElements(input_dims, input_num_dims);
+    for (int i = 0; i < count; ++i) {
+      temp_sum[i] = U(input_data[i]);
+    }
+  } else {
+    if (!Reduce<T, U, CastSumOp<T, U>, CastSumOp<T, U>>(
+            input_data, normalized_dims, normalized_num_dims, resolved_axis,
+            num_resolved_axis, temp_sum, CastSumOp<T, U>(),
+            CastSumOp<T, U>())) {
+      return false;
+    }
   }
 
   // Calculate mean by dividing output_data by num of aggregated element.
@@ -692,6 +656,123 @@
   return true;
 }
 
+template <typename T>
+inline void Mean(const tflite::MeanParams& op_params,
+                 const RuntimeShape& input_shape, const T* input_data,
+                 const RuntimeShape& output_shape, T* output_data) {
+  return reference_ops::Mean(op_params, input_shape, input_data, output_shape,
+                             output_data);
+}
+
+// Computes the mean of elements across dimensions given in axis.
+// It does so in two stages, first calculates the sum of elements along the axis
+// then divides it by the number of element in axis.
+template <typename T, typename U>
+inline bool MeanGeneral(const T* input_data, const int* input_dims,
+                        const int input_num_dims, T* output_data,
+                        const int* output_dims, const int output_num_dims,
+                        const int* axis, const int num_axis_dimensions,
+                        bool keep_dims, int* normalized_dims,
+                        int* resolved_axis, U* temp_sum) {
+  ruy::profiler::ScopeLabel label("Mean");
+  // Resolve axis.
+  int num_resolved_axis = 0;
+  int normalized_num_dims = 0;
+  if (!reduce_utils::ResolveAxis(input_num_dims, axis, num_axis_dimensions,
+                                 resolved_axis, num_resolved_axis, input_dims,
+                                 normalized_dims, normalized_num_dims)) {
+    return false;
+  }
+  if (num_resolved_axis == 0) {
+    optimized_ops::ReduceIsCopy(input_data, input_dims, input_num_dims,
+                                output_data);
+    return true;
+  }
+  // Reset output data.
+  size_t num_outputs = 1;
+  for (int idx = 0; idx < output_num_dims; ++idx) {
+    size_t current = static_cast<size_t>(output_dims[idx]);
+    // Overflow prevention.
+    if (num_outputs > std::numeric_limits<size_t>::max() / current) {
+      return false;
+    }
+    num_outputs *= current;
+  }
+
+  if (!Reduce<T, U, CastSumOp<T, U>, CastSumOp<T, U>>(
+          input_data, normalized_dims, normalized_num_dims, resolved_axis,
+          num_resolved_axis, temp_sum, CastSumOp<T, U>(), CastSumOp<T, U>())) {
+    return false;
+  }
+
+  // Calculate mean by dividing output_data by num of aggregated element.
+  size_t num_elements_in_axis = 1;
+  for (int idx = 0; idx < num_resolved_axis; ++idx) {
+    size_t current = static_cast<size_t>(normalized_dims[resolved_axis[idx]]);
+    // Overflow prevention.
+    if (current > (std::numeric_limits<size_t>::max() / num_elements_in_axis)) {
+      return false;
+    }
+    num_elements_in_axis *= current;
+  }
+
+  if (num_elements_in_axis > 0) {
+    for (size_t idx = 0; idx < num_outputs; ++idx) {
+      output_data[idx] =
+          static_cast<T>(temp_sum[idx] / static_cast<U>(num_elements_in_axis));
+    }
+  }
+  return true;
+}
+
+template <typename T, typename U>
+inline bool Mean(const T* input_data, const int* input_dims,
+                 const int input_num_dims, T* output_data,
+                 const int* output_dims, const int output_num_dims,
+                 const int* axis, const int num_axis_dimensions, bool keep_dims,
+                 int* normalized_dims, int* resolved_axis, U* temp_sum) {
+  return MeanGeneral(input_data, input_dims, input_num_dims, output_data,
+                     output_dims, output_num_dims, axis, num_axis_dimensions,
+                     false, normalized_dims, resolved_axis, temp_sum);
+}
+
+// Use Eigen when Mean is calculated over the last dimension only of a float
+// tensor.
+template <>
+inline bool Mean<float, float>(const float* input_data, const int* input_dims,
+                               const int input_num_dims, float* output_data,
+                               const int* output_dims,
+                               const int output_num_dims, const int* axis,
+                               const int num_axis_dimensions, bool keep_dims,
+                               int* normalized_dims, int* resolved_axis,
+                               float* temp_sum) {
+  // Handle reduce_mean for the last dimensions.
+  int num_resolved_axis = 0;
+  int normalized_num_dims = 0;
+  if (!reduce_utils::ResolveAxis(input_num_dims, axis, num_axis_dimensions,
+                                 resolved_axis, num_resolved_axis, input_dims,
+                                 normalized_dims, normalized_num_dims)) {
+    return false;
+  }
+  if (normalized_num_dims > 1 && num_resolved_axis == 1 &&
+      resolved_axis[0] == (normalized_num_dims - 1)) {
+    ruy::profiler::ScopeLabel label("MeanLastDim/Float");
+    int output_size = normalized_dims[0];
+    const int last_input_dim = normalized_dims[1];
+
+    // TODO(b/152563685): Consider use eigen to cover more general cases.
+    const MatrixMap<const float> in_mat(input_data, last_input_dim,
+                                        output_size);
+    VectorMap<float> out(output_data, output_size, 1);
+    out = (in_mat.array().colwise().sum()) / static_cast<float>(last_input_dim);
+    return true;
+  }
+
+  return MeanGeneral(input_data, input_dims, input_num_dims, output_data,
+                     output_dims, output_num_dims, axis, num_axis_dimensions,
+                     false, normalized_dims, resolved_axis, temp_sum);
+}
+
 // Computes the generic value (i.e., sum/max/min/prod) of elements across
 // dimensions given in axis. It needs to pass in init_value and reducer.
 template <typename T>
diff --git a/tensorflow/lite/kernels/kernel_util.h b/tensorflow/lite/kernels/kernel_util.h
index ed3a566..0687442 100644
--- a/tensorflow/lite/kernels/kernel_util.h
+++ b/tensorflow/lite/kernels/kernel_util.h
@@ -177,6 +177,14 @@
   return NumElements(t->dims);
 }
 
+inline int64_t NumElements(const int* dims, int num_dims) {
+  int64_t count = 1;
+  for (int i = 0; i < num_dims; ++i) {
+    count *= dims[i];
+  }
+  return count;
+}
+
 // Determines whether tensor is constant.
 // TODO(b/138199592): Introduce new query which checks for constant OR
 // persistent-read-only, which would be useful for most tensor kernels that
diff --git a/tensorflow/lite/kernels/reduce.cc b/tensorflow/lite/kernels/reduce.cc
index ea0eebf..33ef06e 100644
--- a/tensorflow/lite/kernels/reduce.cc
+++ b/tensorflow/lite/kernels/reduce.cc
@@ -373,66 +373,76 @@
   }
 }
 
+template <typename T, typename U>
+TfLiteStatus Mean(TfLiteContext* context, const OpContext* op_context,
+                  int* temp_index, int* resolved_axis, U* temp_sum,
+                  KernelType kernel_type) {
+  int num_axis = static_cast<int>(NumElements(op_context->axis));
+  auto args = std::tuple(
+      GetTensorData<T>(op_context->input), &op_context->input->dims->data[0],
+      op_context->input->dims->size, GetTensorData<T>(op_context->output),
+      &op_context->output->dims->data[0], op_context->output->dims->size,
+      GetTensorData<int>(op_context->axis), num_axis,
+      op_context->params->keep_dims, temp_index, resolved_axis, temp_sum);
+  if (kernel_type == kReference) {
+    TF_LITE_ENSURE(context, std::apply(reference_ops::Mean<T, U>, args));
+  } else {
+    TF_LITE_ENSURE(context, std::apply(optimized_ops::Mean<T, U>, args));
+  }
+  return kTfLiteOk;
+}
+
+template <typename T>
+TfLiteStatus QuantizedMeanOrSum(TfLiteContext* context,
+                                const OpContext* op_context, int* temp_index,
+                                int* resolved_axis, int* temp_sum,
+                                KernelType kernel_type, bool compute_sum) {
+  int num_axis = static_cast<int>(NumElements(op_context->axis));
+  auto args = std::tuple(
+      GetTensorData<T>(op_context->input), op_context->input->params.zero_point,
+      op_context->input->params.scale, &op_context->input->dims->data[0],
+      op_context->input->dims->size, GetTensorData<T>(op_context->output),
+      op_context->output->params.zero_point, op_context->output->params.scale,
+      &op_context->output->dims->data[0], op_context->output->dims->size,
+      GetTensorData<int>(op_context->axis), num_axis,
+      op_context->params->keep_dims, temp_index, resolved_axis, temp_sum,
+      compute_sum);
+  if (kernel_type == kReference) {
+    TF_LITE_ENSURE(
+        context,
+        std::apply(reference_ops::QuantizedMeanOrSum<T, int32_t>, args));
+  } else {
+    TF_LITE_ENSURE(
+        context,
+        std::apply(optimized_ops::QuantizedMeanOrSum<T, int32_t>, args));
+  }
+  return kTfLiteOk;
+}
+
 template <typename integer_type>
-TfLiteStatus EvalMeanReferenceOps(TfLiteContext* context,
-                                  const OpContext& op_context, int num_axis,
-                                  OpData* data, TfLiteTensor* temp_index,
-                                  TfLiteTensor* resolved_axis,
-                                  TfLiteTensor* temp_sum) {
+TfLiteStatus EvalIntegerMean(TfLiteContext* context,
+                             const OpContext& op_context, int num_axis,
+                             OpData* data, TfLiteTensor* temp_index,
+                             TfLiteTensor* resolved_axis,
+                             TfLiteTensor* temp_sum,
+                             TfLiteTensor* normalized_dims,
+                             KernelType kernel_type) {
   tflite::MeanParams op_params;
   op_params.axis_count = num_axis;
   ResolveAxis(GetTensorData<int>(op_context.axis), num_axis, &op_params);
   const TfLiteTensor* input = op_context.input;
 
-  // TODO(b/139102329): Handle all the cases in the combined reference
-  // method.
-  if (op_context.params->keep_dims && NumDimensions(input) == 4 &&
-      op_params.axis_count == 2 &&
-      ((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
-       (op_params.axis[0] == 2 && op_params.axis[1] == 1))) {
-    if (std::is_same<integer_type, uint8_t>::value) {
-      reference_ops::Mean(op_params, GetTensorShape(op_context.input),
-                          GetTensorData<uint8_t>(op_context.input),
-                          op_context.input->params.zero_point,
-                          op_context.input->params.scale,
-                          GetTensorShape(op_context.output),
-                          GetTensorData<uint8_t>(op_context.output),
-                          op_context.output->params.zero_point,
-                          op_context.output->params.scale);
-    } else {
-      reference_integer_ops::Mean(
-          op_params, data->multiplier, data->shift, GetTensorShape(input),
-          GetTensorData<integer_type>(input),
-          op_context.input->params.zero_point,
-          GetTensorShape(op_context.output),
-          GetTensorData<integer_type>(op_context.output),
-          op_context.output->params.zero_point);
-    }
-  } else if (input->params.zero_point == op_context.output->params.zero_point &&
-             input->params.scale == op_context.output->params.scale) {
-    TF_LITE_ENSURE(
-        context,
-        reference_ops::Mean(
-            GetTensorData<integer_type>(input), input->dims->data,
-            input->dims->size, GetTensorData<integer_type>(op_context.output),
-            op_context.output->dims->data, op_context.output->dims->size,
-            GetTensorData<int>(op_context.axis), num_axis,
-            op_context.params->keep_dims, GetTensorData<int>(temp_index),
-            GetTensorData<int>(resolved_axis), GetTensorData<int>(temp_sum)));
+  if (input->params.zero_point == op_context.output->params.zero_point &&
+      input->params.scale == op_context.output->params.scale) {
+    Mean<integer_type, int>(context, &op_context,
+                            GetTensorData<int>(temp_index),
+                            GetTensorData<int>(resolved_axis),
+                            GetTensorData<int>(temp_sum), kernel_type);
   } else {
-    TF_LITE_ENSURE(
-        context,
-        reference_ops::QuantizedMeanOrSum<>(
-            GetTensorData<integer_type>(input), input->params.zero_point,
-            input->params.scale, input->dims->data, input->dims->size,
-            GetTensorData<integer_type>(op_context.output),
-            op_context.output->params.zero_point,
-            op_context.output->params.scale, op_context.output->dims->data,
-            op_context.output->dims->size, GetTensorData<int>(op_context.axis),
-            num_axis, op_context.params->keep_dims,
-            GetTensorData<int>(temp_index), GetTensorData<int>(resolved_axis),
-            GetTensorData<int>(temp_sum),
-            /*compute_sum=*/false));
+    QuantizedMeanOrSum<integer_type>(
+        context, &op_context, GetTensorData<int>(temp_index),
+        GetTensorData<int>(resolved_axis), GetTensorData<int32_t>(temp_sum),
+        kernel_type, /*compute_sum=*/false);
   }
   return kTfLiteOk;
 }
@@ -496,6 +506,13 @@
     TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
     TF_LITE_ENSURE_OK(context, ResizeTempAccum(context, &op_context, temp_sum));
   }
+  TfLiteTensor* normalized_dims;
+  TF_LITE_ENSURE_OK(
+      context, GetTemporarySafe(context, node, /*index=*/3, &normalized_dims));
+  if (IsDynamicTensor(normalized_dims)) {
+    TF_LITE_ENSURE_OK(context,
+                      ResizeTempDims(context, &op_context, normalized_dims));
+  }
 
   // Return early when input is empty.
   const TfLiteTensor* input = op_context.input;
@@ -551,80 +568,40 @@
     }
   }
 
-  // From here, it uses the reference implementations.
-  // TODO(b/139102329): Clean up the function signatures to merge the variations
-  // and handle the specialized cases in the combined reference implementations
-  // per each op.
   switch (op_context.input->type) {
-    case kTfLiteFloat32: {
-      tflite::MeanParams op_params;
-      op_params.axis_count = num_axis;
-      ResolveAxis(GetTensorData<int>(op_context.axis), num_axis, &op_params);
-      const TfLiteTensor* input = op_context.input;
-      // TODO(b/139102329): Handle the below special case in the combined
-      // reference method.
-      // Defer to specialized implementation for 4D Mean across axes 1 & 2.
-      if (op_context.params->keep_dims && NumDimensions(input) == 4 &&
-          op_params.axis_count == 2 &&
-          ((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
-           (op_params.axis[0] == 2 && op_params.axis[1] == 1))) {
-        reference_ops::Mean(op_params, input_shape, GetTensorData<float>(input),
-                            GetTensorShape(op_context.output),
-                            GetTensorData<float>(op_context.output));
-      } else {
-        TF_LITE_ENSURE(
-            context,
-            optimized_ops::MeanGeneral(
-                GetTensorData<float>(op_context.input),
-                op_context.input->dims->data, op_context.input->dims->size,
-                GetTensorData<float>(op_context.output),
-                op_context.output->dims->data, op_context.output->dims->size,
-                GetTensorData<int>(op_context.axis), num_axis,
-                op_context.params->keep_dims, GetTensorData<int>(temp_index),
-                GetTensorData<int>(resolved_axis),
-                GetTensorData<float>(temp_sum)));
-      }
-    } break;
+    case kTfLiteFloat32:
+      Mean<float, float>(context, &op_context, GetTensorData<int>(temp_index),
+                         GetTensorData<int>(resolved_axis),
+                         GetTensorData<float>(temp_sum), kernel_type);
+      break;
     case kTfLiteInt32:
-      TF_LITE_ENSURE(
-          context,
-          reference_ops::Mean(
-              GetTensorData<int>(op_context.input),
-              op_context.input->dims->data, op_context.input->dims->size,
-              GetTensorData<int>(op_context.output),
-              op_context.output->dims->data, op_context.output->dims->size,
-              GetTensorData<int>(op_context.axis), num_axis,
-              op_context.params->keep_dims, GetTensorData<int>(temp_index),
-              GetTensorData<int>(resolved_axis),
-              GetTensorData<int64_t>(temp_sum)));
+      Mean<int, int64_t>(context, &op_context, GetTensorData<int>(temp_index),
+                         GetTensorData<int>(resolved_axis),
+                         GetTensorData<int64_t>(temp_sum), kernel_type);
       break;
     case kTfLiteInt64:
-      TF_LITE_ENSURE(
-          context,
-          reference_ops::Mean(
-              GetTensorData<int64_t>(op_context.input),
-              op_context.input->dims->data, op_context.input->dims->size,
-              GetTensorData<int64_t>(op_context.output),
-              op_context.output->dims->data, op_context.output->dims->size,
-              GetTensorData<int>(op_context.axis), num_axis,
-              op_context.params->keep_dims, GetTensorData<int>(temp_index),
-              GetTensorData<int>(resolved_axis),
-              GetTensorData<int64_t>(temp_sum)));
+      Mean<int64_t, int64_t>(context, &op_context,
+                             GetTensorData<int>(temp_index),
+                             GetTensorData<int>(resolved_axis),
+                             GetTensorData<int64_t>(temp_sum), kernel_type);
       break;
     case kTfLiteInt8: {
-      TF_LITE_ENSURE_OK(context, EvalMeanReferenceOps<int8_t>(
-                                     context, op_context, num_axis, data,
-                                     temp_index, resolved_axis, temp_sum));
+      TF_LITE_ENSURE_OK(
+          context, EvalIntegerMean<int8_t>(context, op_context, num_axis, data,
+                                           temp_index, resolved_axis, temp_sum,
+                                           normalized_dims, kernel_type));
     } break;
     case kTfLiteInt16: {
-      TF_LITE_ENSURE_OK(context, EvalMeanReferenceOps<int16_t>(
-                                     context, op_context, num_axis, data,
-                                     temp_index, resolved_axis, temp_sum));
+      TF_LITE_ENSURE_OK(
+          context, EvalIntegerMean<int16_t>(context, op_context, num_axis, data,
+                                            temp_index, resolved_axis, temp_sum,
+                                            normalized_dims, kernel_type));
     } break;
     case kTfLiteUInt8: {
-      TF_LITE_ENSURE_OK(context, EvalMeanReferenceOps<uint8_t>(
-                                     context, op_context, num_axis, data,
-                                     temp_index, resolved_axis, temp_sum));
+      TF_LITE_ENSURE_OK(
+          context, EvalIntegerMean<uint8_t>(context, op_context, num_axis, data,
+                                            temp_index, resolved_axis, temp_sum,
+                                            normalized_dims, kernel_type));
     } break;
     default:
       return kTfLiteError;
@@ -681,10 +658,7 @@
   eval_data.input_data = input_data;
   eval_data.output = init_value;
 
-  int num_elems = 1;
-  for (int i = 0; i < input_num_dims; ++i) {
-    num_elems *= input_dims[i];
-  }
+  int num_elems = NumElements(input_dims, input_num_dims);
 
   // Fetch backend context and number of threads.
   CpuBackendContext* cpu_backend_context =
@@ -873,33 +847,6 @@
   }
 }
 
-template <typename T>
-TfLiteStatus QuantizedMeanOrSum(TfLiteContext* context, OpContext* op_context,
-                                int* temp_index, int* resolved_axis,
-                                int* temp_sum, KernelType kernel_type,
-                                bool compute_sum) {
-  int num_axis = static_cast<int>(NumElements(op_context->axis));
-  auto args = std::tuple(
-      GetTensorData<T>(op_context->input), op_context->input->params.zero_point,
-      op_context->input->params.scale, &op_context->input->dims->data[0],
-      op_context->input->dims->size, GetTensorData<T>(op_context->output),
-      op_context->output->params.zero_point, op_context->output->params.scale,
-      &op_context->output->dims->data[0], op_context->output->dims->size,
-      GetTensorData<int>(op_context->axis), num_axis,
-      op_context->params->keep_dims, temp_index, resolved_axis, temp_sum,
-      compute_sum);
-  if (kernel_type == kReference) {
-    TF_LITE_ENSURE(
-        context,
-        std::apply(reference_ops::QuantizedMeanOrSum<T, int32_t>, args));
-  } else {
-    TF_LITE_ENSURE(
-        context,
-        std::apply(optimized_ops::QuantizedMeanOrSum<T, int32_t>, args));
-  }
-  return kTfLiteOk;
-}
-
 template <KernelType kernel_type>
 TfLiteStatus EvalSum(TfLiteContext* context, TfLiteNode* node) {
   OpContext op_context(context, node);
@@ -1153,13 +1100,7 @@
   return &r;
 }
 
-TfLiteRegistration* Register_MEAN() {
-#ifdef USE_NEON
-  return Register_MEAN_OPT();
-#else
-  return Register_MEAN_REF();
-#endif
-}
+TfLiteRegistration* Register_MEAN() { return Register_MEAN_OPT(); }
 
 TfLiteRegistration* Register_SUM() { return Register_SUM_OPT(); }
 TfLiteRegistration* Register_REDUCE_PROD() {
diff --git a/tensorflow/lite/kernels/reduce_test.cc b/tensorflow/lite/kernels/reduce_test.cc
index fd03651..e9f5fca 100644
--- a/tensorflow/lite/kernels/reduce_test.cc
+++ b/tensorflow/lite/kernels/reduce_test.cc
@@ -309,6 +309,18 @@
   EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({3.})));
 }
 
+TEST(ConstFloatMeanOpTest, UseOptimzedFloatMean) {
+  std::vector<float> data = {0.1, 0.2, 0.3, 0.4, 0.1, 0.2,
+                             0.3, 0.4, 0.1, 0.2, 0.3, 0.4};
+  MeanOpConstModel m({TensorType_FLOAT32, {2, 3, 2}}, {TensorType_FLOAT32, {2}},
+                     {2}, {1, 2}, false);
+  m.SetInput(data);
+  ASSERT_EQ(m.Invoke(), kTfLiteOk);
+  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
+  EXPECT_THAT(m.GetOutput<float>(),
+              ElementsAreArray(ArrayFloatNear({0.216667, 0.283333})));
+}
+
 TEST(DynamicFloatMeanOpTest, NotKeepDims) {
   std::vector<float> data = {1.0,  2.0,  3.0,  4.0,  5.0,  6.0,  7.0,  8.0,
                              9.0,  10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,