Dtype compliance: native_layer_norm
Reviewed By: SS-JIA
Differential Revision: D48371008
fbshipit-source-id: cbfb97ce1d29931fd27eec27a24a4eab8857dc50
diff --git a/kernels/portable/cpu/op_native_layer_norm.cpp b/kernels/portable/cpu/op_native_layer_norm.cpp
index 81b14a5..0a37b5e 100644
--- a/kernels/portable/cpu/op_native_layer_norm.cpp
+++ b/kernels/portable/cpu/op_native_layer_norm.cpp
@@ -6,6 +6,7 @@
* LICENSE file in the root directory of this source tree.
*/
+#include <executorch/kernels/portable/cpu/util/normalization_ops_util.h>
#include <executorch/kernels/portable/cpu/vec_ops.h>
#include <executorch/runtime/kernel/kernel_includes.h>
#include <cmath>
@@ -18,19 +19,22 @@
using Tensor = exec_aten::Tensor;
namespace {
+
template <typename CTYPE>
void layer_norm(
const Tensor& input,
const Tensor& weight,
const Tensor& bias,
CTYPE eps,
- Tensor& output,
+ Tensor& out,
Tensor& mean,
Tensor& rstd) {
const CTYPE* input_data = input.const_data_ptr<CTYPE>();
- CTYPE* output_data = output.mutable_data_ptr<CTYPE>();
const CTYPE* weight_data = weight.const_data_ptr<CTYPE>();
const CTYPE* bias_data = bias.const_data_ptr<CTYPE>();
+ CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
+ CTYPE* mean_data = mean.mutable_data_ptr<CTYPE>();
+ CTYPE* rstd_data = rstd.mutable_data_ptr<CTYPE>();
size_t dim = input.size(input.dim() - 1);
@@ -38,7 +42,7 @@
for (int i = 0; i < leading_dim; ++i) {
const CTYPE* x = input_data + i * dim;
- CTYPE* y = output_data + i * dim;
+ CTYPE* y = out_data + i * dim;
// compute E[X] and Var[x] = E[x^2] - E[x]^2
CTYPE sum = reduce_add(x, dim);
@@ -51,13 +55,12 @@
for (int j = 0; j < dim; ++j) {
y[j] = (x[j] - mean_value) / std * weight_data[j] + bias_data[j];
}
- }
- // Assign NAN to mean and rstd. They are not used in seen examples.
- // Use NAN to make the error more obvious in case they are used.
- mean.mutable_data_ptr<CTYPE>()[0] = NAN;
- rstd.mutable_data_ptr<CTYPE>()[0] = NAN;
+ mean_data[i] = mean_value;
+ rstd_data[i] = 1.0 / std;
+ }
}
+
} // namespace
// native_layer_norm.out(Tensor input, int[] normalized_shape, Tensor? weight,
@@ -75,54 +78,39 @@
Tensor& out,
Tensor& mean_out,
Tensor& rstd_out) {
- ET_CHECK_MSG(
- normalized_shape.size() == 1,
- "normalize_shape.size() must be 1 but saw %zd",
- normalized_shape.size());
- ET_CHECK_MSG(weight.has_value(), "Missing weight tensor");
- ET_CHECK_MSG(
- input.scalar_type() == out.scalar_type(),
- "out and input must have the same type.");
- ET_CHECK_MSG(
- input.dim() == out.dim(),
- "out and input must have the same number of dimensions");
- ET_CHECK_MSG(
- input.scalar_type() == mean_out.scalar_type(),
- "mean_out and input must have the same type.");
- ET_CHECK_MSG(
- input.scalar_type() == rstd_out.scalar_type(),
- "rstd_out and input must have the same type.");
+ std::tuple<Tensor&, Tensor&, Tensor&> ret_val(out, mean_out, rstd_out);
+
+ ET_KERNEL_CHECK(
+ ctx,
+ check_layer_norm_args(
+ input, normalized_shape, weight, bias, out, mean_out, rstd_out),
+ InvalidArgument,
+ ret_val);
if (input.sizes() == out.sizes()) {
- ET_CHECK_MSG(
+ ET_KERNEL_CHECK(
+ ctx,
normalized_shape[0] == input.sizes()[input.dim() - 1],
- "Normalized shape value must match the size of input.");
+ InvalidArgument,
+ ret_val);
} else {
// If we need to resize out to support dynamic input shapes, we can't count
// on normalized_shape matching the shape of the input or output. But we
// don't need to modify normalized_shape because it's not used in this
// function besides some checks
- torch::executor::Error err = resize_tensor(out, input.sizes());
- ET_CHECK_MSG(
- err == torch::executor::Error::Ok,
- "Failed to resize out Tensor in native_layer_norm_out");
+ ET_KERNEL_CHECK(
+ ctx,
+ resize_tensor(out, input.sizes()) == Error::Ok,
+ InvalidArgument,
+ ret_val);
}
-// helper for generating the cases for different data types
-#define LAYER_NORM(ctype, dtype) \
- case ScalarType::dtype: \
- layer_norm<ctype>( \
- input, weight.value(), bias.value(), eps, out, mean_out, rstd_out); \
- break;
+ ET_SWITCH_FLOAT_TYPES(input.scalar_type(), ctx, __func__, CTYPE, [&]() {
+ layer_norm<CTYPE>(
+ input, weight.value(), bias.value(), eps, out, mean_out, rstd_out);
+ });
- switch (input.scalar_type()) {
- // TODO support bfloat16
- ET_FORALL_FLOAT_TYPES(LAYER_NORM)
- default:
- ET_CHECK_MSG(false, "Unhandled dtype %hhd", input.scalar_type());
- }
-#undef LAYER_NORM
- return {out, mean_out, rstd_out};
+ return ret_val;
}
} // namespace native
diff --git a/kernels/portable/cpu/targets.bzl b/kernels/portable/cpu/targets.bzl
index f784252..4117c67 100644
--- a/kernels/portable/cpu/targets.bzl
+++ b/kernels/portable/cpu/targets.bzl
@@ -527,6 +527,7 @@
name = "op_native_layer_norm",
deps = [
":vec_ops",
+ "//executorch/kernels/portable/cpu/util:normalization_ops_util",
],
),
op_target(
diff --git a/kernels/portable/cpu/util/normalization_ops_util.cpp b/kernels/portable/cpu/util/normalization_ops_util.cpp
index abe61e7..95c9860 100644
--- a/kernels/portable/cpu/util/normalization_ops_util.cpp
+++ b/kernels/portable/cpu/util/normalization_ops_util.cpp
@@ -55,5 +55,28 @@
return true;
}
+bool check_layer_norm_args(
+ const Tensor& input,
+ IntArrayRef normalized_shape,
+ const exec_aten::optional<Tensor>& weight,
+ const exec_aten::optional<Tensor>& bias,
+ Tensor& out,
+ Tensor& mean_out,
+ Tensor& rstd_out) {
+ ET_LOG_AND_RETURN_IF_FALSE(normalized_shape.size() == 1);
+ ET_LOG_AND_RETURN_IF_FALSE(weight.has_value());
+ if (weight.has_value()) {
+ ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(input, weight.value()));
+ }
+ if (bias.has_value()) {
+ ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(input, bias.value()));
+ }
+ ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(input, out));
+ ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(input, mean_out));
+ ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(input, rstd_out));
+ ET_LOG_AND_RETURN_IF_FALSE(input.dim() == out.dim());
+ return true;
+}
+
} // namespace executor
} // namespace torch
diff --git a/kernels/portable/cpu/util/normalization_ops_util.h b/kernels/portable/cpu/util/normalization_ops_util.h
index 0034ec4..b5853ab 100644
--- a/kernels/portable/cpu/util/normalization_ops_util.h
+++ b/kernels/portable/cpu/util/normalization_ops_util.h
@@ -23,5 +23,14 @@
double eps,
Tensor& out);
+bool check_layer_norm_args(
+ const Tensor& input,
+ IntArrayRef normalized_shape,
+ const exec_aten::optional<Tensor>& weight,
+ const exec_aten::optional<Tensor>& bias,
+ Tensor& out,
+ Tensor& mean_out,
+ Tensor& rstd_out);
+
} // namespace executor
} // namespace torch