Dtype compliance: native_layer_norm

Reviewed By: SS-JIA

Differential Revision: D48371008

fbshipit-source-id: cbfb97ce1d29931fd27eec27a24a4eab8857dc50
diff --git a/kernels/portable/cpu/op_native_layer_norm.cpp b/kernels/portable/cpu/op_native_layer_norm.cpp
index 81b14a5..0a37b5e 100644
--- a/kernels/portable/cpu/op_native_layer_norm.cpp
+++ b/kernels/portable/cpu/op_native_layer_norm.cpp
@@ -6,6 +6,7 @@
  * LICENSE file in the root directory of this source tree.
  */
 
+#include <executorch/kernels/portable/cpu/util/normalization_ops_util.h>
 #include <executorch/kernels/portable/cpu/vec_ops.h>
 #include <executorch/runtime/kernel/kernel_includes.h>
 #include <cmath>
@@ -18,19 +19,22 @@
 using Tensor = exec_aten::Tensor;
 
 namespace {
+
 template <typename CTYPE>
 void layer_norm(
     const Tensor& input,
     const Tensor& weight,
     const Tensor& bias,
     CTYPE eps,
-    Tensor& output,
+    Tensor& out,
     Tensor& mean,
     Tensor& rstd) {
   const CTYPE* input_data = input.const_data_ptr<CTYPE>();
-  CTYPE* output_data = output.mutable_data_ptr<CTYPE>();
   const CTYPE* weight_data = weight.const_data_ptr<CTYPE>();
   const CTYPE* bias_data = bias.const_data_ptr<CTYPE>();
+  CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
+  CTYPE* mean_data = mean.mutable_data_ptr<CTYPE>();
+  CTYPE* rstd_data = rstd.mutable_data_ptr<CTYPE>();
 
   size_t dim = input.size(input.dim() - 1);
 
@@ -38,7 +42,7 @@
 
   for (int i = 0; i < leading_dim; ++i) {
     const CTYPE* x = input_data + i * dim;
-    CTYPE* y = output_data + i * dim;
+    CTYPE* y = out_data + i * dim;
 
     // compute E[X] and Var[x] = E[x^2] - E[x]^2
     CTYPE sum = reduce_add(x, dim);
@@ -51,13 +55,12 @@
     for (int j = 0; j < dim; ++j) {
       y[j] = (x[j] - mean_value) / std * weight_data[j] + bias_data[j];
     }
-  }
 
-  // Assign NAN to mean and rstd. They are not used in seen examples.
-  // Use NAN to make the error more obvious in case they are used.
-  mean.mutable_data_ptr<CTYPE>()[0] = NAN;
-  rstd.mutable_data_ptr<CTYPE>()[0] = NAN;
+    mean_data[i] = mean_value;
+    rstd_data[i] = 1.0 / std;
+  }
 }
+
 } // namespace
 
 // native_layer_norm.out(Tensor input, int[] normalized_shape, Tensor? weight,
@@ -75,54 +78,39 @@
     Tensor& out,
     Tensor& mean_out,
     Tensor& rstd_out) {
-  ET_CHECK_MSG(
-      normalized_shape.size() == 1,
-      "normalize_shape.size() must be 1 but saw %zd",
-      normalized_shape.size());
-  ET_CHECK_MSG(weight.has_value(), "Missing weight tensor");
-  ET_CHECK_MSG(
-      input.scalar_type() == out.scalar_type(),
-      "out and input must have the same type.");
-  ET_CHECK_MSG(
-      input.dim() == out.dim(),
-      "out and input must have the same number of dimensions");
-  ET_CHECK_MSG(
-      input.scalar_type() == mean_out.scalar_type(),
-      "mean_out and input must have the same type.");
-  ET_CHECK_MSG(
-      input.scalar_type() == rstd_out.scalar_type(),
-      "rstd_out and input must have the same type.");
+  std::tuple<Tensor&, Tensor&, Tensor&> ret_val(out, mean_out, rstd_out);
+
+  ET_KERNEL_CHECK(
+      ctx,
+      check_layer_norm_args(
+          input, normalized_shape, weight, bias, out, mean_out, rstd_out),
+      InvalidArgument,
+      ret_val);
 
   if (input.sizes() == out.sizes()) {
-    ET_CHECK_MSG(
+    ET_KERNEL_CHECK(
+        ctx,
         normalized_shape[0] == input.sizes()[input.dim() - 1],
-        "Normalized shape value must match the size of input.");
+        InvalidArgument,
+        ret_val);
   } else {
     // If we need to resize out to support dynamic input shapes, we can't count
     // on normalized_shape matching the shape of the input or output. But we
     // don't need to modify normalized_shape because it's not used in this
     // function besides some checks
-    torch::executor::Error err = resize_tensor(out, input.sizes());
-    ET_CHECK_MSG(
-        err == torch::executor::Error::Ok,
-        "Failed to resize out Tensor in native_layer_norm_out");
+    ET_KERNEL_CHECK(
+        ctx,
+        resize_tensor(out, input.sizes()) == Error::Ok,
+        InvalidArgument,
+        ret_val);
   }
 
-// helper for generating the cases for different data types
-#define LAYER_NORM(ctype, dtype)                                            \
-  case ScalarType::dtype:                                                   \
-    layer_norm<ctype>(                                                      \
-        input, weight.value(), bias.value(), eps, out, mean_out, rstd_out); \
-    break;
+  ET_SWITCH_FLOAT_TYPES(input.scalar_type(), ctx, __func__, CTYPE, [&]() {
+    layer_norm<CTYPE>(
+        input, weight.value(), bias.value(), eps, out, mean_out, rstd_out);
+  });
 
-  switch (input.scalar_type()) {
-    // TODO support bfloat16
-    ET_FORALL_FLOAT_TYPES(LAYER_NORM)
-    default:
-      ET_CHECK_MSG(false, "Unhandled dtype %hhd", input.scalar_type());
-  }
-#undef LAYER_NORM
-  return {out, mean_out, rstd_out};
+  return ret_val;
 }
 
 } // namespace native
diff --git a/kernels/portable/cpu/targets.bzl b/kernels/portable/cpu/targets.bzl
index f784252..4117c67 100644
--- a/kernels/portable/cpu/targets.bzl
+++ b/kernels/portable/cpu/targets.bzl
@@ -527,6 +527,7 @@
         name = "op_native_layer_norm",
         deps = [
             ":vec_ops",
+            "//executorch/kernels/portable/cpu/util:normalization_ops_util",
         ],
     ),
     op_target(
diff --git a/kernels/portable/cpu/util/normalization_ops_util.cpp b/kernels/portable/cpu/util/normalization_ops_util.cpp
index abe61e7..95c9860 100644
--- a/kernels/portable/cpu/util/normalization_ops_util.cpp
+++ b/kernels/portable/cpu/util/normalization_ops_util.cpp
@@ -55,5 +55,28 @@
   return true;
 }
 
+bool check_layer_norm_args(
+    const Tensor& input,
+    IntArrayRef normalized_shape,
+    const exec_aten::optional<Tensor>& weight,
+    const exec_aten::optional<Tensor>& bias,
+    Tensor& out,
+    Tensor& mean_out,
+    Tensor& rstd_out) {
+  ET_LOG_AND_RETURN_IF_FALSE(normalized_shape.size() == 1);
+  ET_LOG_AND_RETURN_IF_FALSE(weight.has_value());
+  if (weight.has_value()) {
+    ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(input, weight.value()));
+  }
+  if (bias.has_value()) {
+    ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(input, bias.value()));
+  }
+  ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(input, out));
+  ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(input, mean_out));
+  ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(input, rstd_out));
+  ET_LOG_AND_RETURN_IF_FALSE(input.dim() == out.dim());
+  return true;
+}
+
 } // namespace executor
 } // namespace torch
diff --git a/kernels/portable/cpu/util/normalization_ops_util.h b/kernels/portable/cpu/util/normalization_ops_util.h
index 0034ec4..b5853ab 100644
--- a/kernels/portable/cpu/util/normalization_ops_util.h
+++ b/kernels/portable/cpu/util/normalization_ops_util.h
@@ -23,5 +23,14 @@
     double eps,
     Tensor& out);
 
+bool check_layer_norm_args(
+    const Tensor& input,
+    IntArrayRef normalized_shape,
+    const exec_aten::optional<Tensor>& weight,
+    const exec_aten::optional<Tensor>& bias,
+    Tensor& out,
+    Tensor& mean_out,
+    Tensor& rstd_out);
+
 } // namespace executor
 } // namespace torch