blob: 96cd7d771248a021ee5ba22f7f059f4977106064 [file] [log] [blame]
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifdef INTEL_MKL
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "dnnl.hpp"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/util/mkl_util.h"
#include "tensorflow/core/util/tensor_format.h"
using CPUDevice = Eigen::ThreadPoolDevice;
using dnnl::layer_normalization_forward;
using dnnl::normalization_flags;
using dnnl::prop_kind;
namespace tensorflow {
template <typename Device, typename T>
class MklLayerNormOp : public OpKernel {
public:
explicit MklLayerNormOp(OpKernelConstruction* context) : OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon_));
}
void Compute(OpKernelContext* ctx) override {
try {
const Tensor& src_tensor = MklGetInput(ctx, kSrcIndex);
const Tensor& scale_tensor = MklGetInput(ctx, kScaleIndex);
const Tensor& shift_tensor = MklGetInput(ctx, kShiftIndex);
OP_REQUIRES(ctx, src_tensor.dims() == 2 || src_tensor.dims() == 3,
errors::InvalidArgument("input must be 2D or 3D",
src_tensor.shape().DebugString()));
OP_REQUIRES(ctx, scale_tensor.dims() == 1,
errors::InvalidArgument("scale must be 1D tensor",
scale_tensor.shape().DebugString()));
OP_REQUIRES(ctx, shift_tensor.dims() == 1,
errors::InvalidArgument("offset must be 1D tensor",
shift_tensor.shape().DebugString()));
size_t num_elements_scale = scale_tensor.dim_size(0);
size_t num_elements_shift = shift_tensor.dim_size(0);
OP_REQUIRES(
ctx, num_elements_scale == num_elements_shift,
errors::InvalidArgument("Number of elements in scale and shift",
"tensors are not same."));
auto cpu_engine = engine(engine::kind::cpu, 0);
auto engine_stream = stream(cpu_engine);
memory::dims src_dims = TFShapeToMklDnnDims(src_tensor.shape());
auto src_md =
memory::desc(src_dims, MklDnnType<T>(),
(src_dims.size() == 3) ? memory::format_tag::tnc
: memory::format_tag::nc);
void* src_buf =
static_cast<void*>(const_cast<T*>(src_tensor.flat<T>().data()));
auto src_mem = memory(src_md, cpu_engine, src_buf);
// oneDNN requires scale-shift as a combined array in float32 type.
memory::dims scale_shift_dims = {
2, static_cast<dnnl_dim_t>(num_elements_scale)};
auto scale_shift_md =
memory::desc(static_cast<memory::dims>(scale_shift_dims),
MklDnnType<float>(), memory::format_tag::nc);
Tensor scale_shift_tensor;
int64_t tensor_shape = scale_shift_md.get_size() / sizeof(float);
OP_REQUIRES_OK(
ctx, ctx->allocate_temp(DataTypeToEnum<float>::v(), {tensor_shape},
&scale_shift_tensor));
void* scale_shift_buf =
static_cast<void*>(scale_shift_tensor.flat<float>().data());
auto scale_shift_mem =
memory(scale_shift_md, cpu_engine, scale_shift_buf);
// Copy of reorder scale and shift tensor data into scale_shift_tensor.
void* scale_buf_src =
static_cast<void*>(const_cast<T*>(scale_tensor.flat<T>().data()));
auto scale_mem_src = memory({{static_cast<ptrdiff_t>(num_elements_scale)},
MklDnnType<T>(),
memory::format_tag::x},
cpu_engine, scale_buf_src);
void* scale_buf_dst = scale_shift_buf;
auto scale_mem_dst = memory({{static_cast<ptrdiff_t>(num_elements_scale)},
MklDnnType<float>(),
memory::format_tag::x},
cpu_engine, scale_buf_dst);
auto scale_reorder_prim = reorder(scale_mem_src, scale_mem_dst);
std::unordered_map<int, memory> scale_reorder_args;
scale_reorder_args.insert({DNNL_ARG_FROM, scale_mem_src});
scale_reorder_args.insert({DNNL_ARG_TO, scale_mem_dst});
scale_reorder_prim.execute(engine_stream, scale_reorder_args);
void* shift_buf_src =
static_cast<void*>(const_cast<T*>(shift_tensor.flat<T>().data()));
auto shift_mem_src = memory({{static_cast<ptrdiff_t>(num_elements_shift)},
MklDnnType<T>(),
memory::format_tag::x},
cpu_engine, shift_buf_src);
void* shift_buf_dst = static_cast<char*>(scale_shift_buf) +
sizeof(float) * num_elements_scale;
auto shift_mem_dst = memory({{static_cast<ptrdiff_t>(num_elements_shift)},
MklDnnType<float>(),
memory::format_tag::x},
cpu_engine, shift_buf_dst);
auto shift_reorder_prim = reorder(shift_mem_src, shift_mem_dst);
std::unordered_map<int, memory> shift_reorder_args;
shift_reorder_args.insert({DNNL_ARG_FROM, shift_mem_src});
shift_reorder_args.insert({DNNL_ARG_TO, shift_mem_dst});
shift_reorder_prim.execute(engine_stream, shift_reorder_args);
// Create layer_normalization primitive
auto lnorm_desc = layer_normalization_forward::desc(
prop_kind::forward_inference, src_md, epsilon_,
normalization_flags::use_scale_shift);
auto lnorm_pd =
layer_normalization_forward::primitive_desc(lnorm_desc, cpu_engine);
auto lnorm_prim = layer_normalization_forward(lnorm_pd);
// mean and variance memory
auto mean_mem = memory(lnorm_pd.mean_desc(), cpu_engine);
auto variance_mem = memory(lnorm_pd.variance_desc(), cpu_engine);
// dst memory
Tensor* output_tensor = nullptr;
OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
{0}, 0, src_tensor.shape(), &output_tensor));
void* dst_buf =
static_cast<void*>(const_cast<T*>(output_tensor->flat<T>().data()));
auto dst_mem = memory(src_md, cpu_engine, dst_buf);
std::unordered_map<int, memory> lnorm_args;
lnorm_args.insert({DNNL_ARG_SRC, src_mem});
lnorm_args.insert({DNNL_ARG_MEAN, mean_mem});
lnorm_args.insert({DNNL_ARG_VARIANCE, variance_mem});
lnorm_args.insert({DNNL_ARG_SCALE_SHIFT, scale_shift_mem});
lnorm_args.insert({DNNL_ARG_DST, dst_mem});
lnorm_prim.execute(engine_stream, lnorm_args);
} catch (dnnl::error& e) {
string error_msg = "Status: " + std::to_string(e.status) +
", message: " + string(e.message) + ", in file " +
string(__FILE__) + ":" + std::to_string(__LINE__);
OP_REQUIRES_OK(
ctx, errors::Aborted("Operation received an exception:", error_msg));
}
}
private:
float epsilon_;
const int kSrcIndex = 0;
const int kScaleIndex = 1;
const int kShiftIndex = 2;
};
REGISTER_KERNEL_BUILDER(
Name("_MklLayerNorm").Device(DEVICE_CPU).TypeConstraint<float>("T"),
MklLayerNormOp<CPUDevice, float>);
REGISTER_KERNEL_BUILDER(
Name("_MklLayerNorm").Device(DEVICE_CPU).TypeConstraint<bfloat16>("T"),
MklLayerNormOp<CPUDevice, bfloat16>);
} // namespace tensorflow
#endif // INTEL_MKL