`MlirXlaOpKernel` added, which can implement `XlaOpKernels` using MLIR legalization.
An `XlaOpKernel` is used in the old bridge to lower a TF op to HLO. The cl moves `ReluOp` from its custom implementation to `MlirXlaOpKernel`.
`xla_compilation_cache.h` now exports `CreateGraph` since it's used by `MlirXlaOpKernel`.
PiperOrigin-RevId: 366409160
Change-Id: Ie18eb5fc21b0c73c7f527d31ce15b33c44218dfd
diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc
index 112287b..d1bb09f 100644
--- a/tensorflow/compiler/jit/xla_compilation_cache.cc
+++ b/tensorflow/compiler/jit/xla_compilation_cache.cc
@@ -226,9 +226,7 @@
execution_count < kMinExecutionsPerCompile * compile_count;
}
-// Creates a simple graph using the specified op as the only op apart from the
-// arg and retval nodes.
-static xla::StatusOr<std::unique_ptr<Graph>> CreateGraph(
+xla::StatusOr<std::unique_ptr<Graph>> CreateGraph(
const NodeDef& node_def, absl::Span<const XlaCompiler::Argument> args,
absl::Span<const DataType> result_types) {
// TODO(b/74182462): We implement this by creating a new dummy Graph including
diff --git a/tensorflow/compiler/jit/xla_compilation_cache.h b/tensorflow/compiler/jit/xla_compilation_cache.h
index c84bc6d..0601433 100644
--- a/tensorflow/compiler/jit/xla_compilation_cache.h
+++ b/tensorflow/compiler/jit/xla_compilation_cache.h
@@ -246,6 +246,12 @@
TF_DISALLOW_COPY_AND_ASSIGN(XlaCompilationCache);
};
+// Creates a single-node graph using the specified node_def as the only op apart
+// from the arg and retval nodes.
+xla::StatusOr<std::unique_ptr<Graph>> CreateGraph(
+ const NodeDef& node_def, absl::Span<const XlaCompiler::Argument> args,
+ absl::Span<const DataType> result_types);
+
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD
index e00598d..edc89aa 100644
--- a/tensorflow/compiler/tf2xla/BUILD
+++ b/tensorflow/compiler/tf2xla/BUILD
@@ -1127,6 +1127,18 @@
)
cc_library(
+ name = "mlir_xla_op_kernel",
+ srcs = ["mlir_xla_op_kernel.cc"],
+ hdrs = ["mlir_xla_op_kernel.h"],
+ deps = [
+ ":xla_compiler",
+ "//tensorflow/compiler/jit:xla_compilation_cache",
+ "//tensorflow/compiler/mlir:array_container_utils",
+ "//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes",
+ ],
+)
+
+cc_library(
name = "resource_util",
srcs = ["resource_util.cc"],
hdrs = ["resource_util.h"],
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index 518704e..5179404 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -151,6 +151,7 @@
"//tensorflow/compiler/jit:xla_activity_listener",
"//tensorflow/compiler/jit:xla_activity_proto_cc",
"//tensorflow/compiler/tf2xla:common",
+ "//tensorflow/compiler/tf2xla:mlir_xla_op_kernel",
"//tensorflow/compiler/tf2xla:xla_compilation_device",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla:xla_context",
diff --git a/tensorflow/compiler/tf2xla/kernels/relu_op.cc b/tensorflow/compiler/tf2xla/kernels/relu_op.cc
index 3b53bac..36d4cca 100644
--- a/tensorflow/compiler/tf2xla/kernels/relu_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/relu_op.cc
@@ -17,6 +17,7 @@
#include "tensorflow/compiler/tf2xla/kernels/relu_op.h"
+#include "tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@@ -35,15 +36,7 @@
namespace tensorflow {
namespace {
-class ReluOp : public XlaOpKernel {
- public:
- explicit ReluOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
- // Computes the max of the scalar input x and 0.
- void Compile(XlaOpKernelContext* ctx) override {
- ctx->SetOutput(0, xla::Relu(ctx->Input(0)));
- }
-};
-REGISTER_XLA_OP(Name("Relu"), ReluOp);
+REGISTER_XLA_OP(Name("Relu"), MlirXlaOpKernel);
class Relu6Op : public XlaOpKernel {
public:
diff --git a/tensorflow/compiler/tf2xla/mlir_xla_op_kernel.cc b/tensorflow/compiler/tf2xla/mlir_xla_op_kernel.cc
new file mode 100644
index 0000000..d7947d0
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/mlir_xla_op_kernel.cc
@@ -0,0 +1,109 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h"
+
+#include "tensorflow/compiler/jit/xla_compilation_cache.h"
+#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
+#include "tensorflow/compiler/mlir/utils/array_container_utils.h"
+
+namespace tensorflow {
+
+namespace {
+
+Status ContextToXlaArgs(XlaOpKernelContext* ctx,
+ std::vector<XlaCompiler::Argument>& xla_args) {
+ int num_inputs = ctx->num_inputs();
+ xla_args.reserve(num_inputs);
+ for (int i = 0; i < num_inputs; ++i) {
+ // TODO(b/180448676): If the input `XlaExpression` kind is `kConstant`, then
+ // create a constant `XlaArgument`.
+ // TODO(b/180448774): Handle kResource and kTensorList.
+ XlaExpression::Kind ctx_kind_i = ctx->InputExpression(i).kind();
+ if (ctx_kind_i != XlaExpression::Kind::kXlaOp &&
+ ctx_kind_i != XlaExpression::Kind::kConstant)
+ return tensorflow::errors::InvalidArgument(
+ absl::StrCat("Input ", i, " to an MlirXlaOpKernel is invalid: ",
+ ctx->InputExpression(i).HumanString()));
+ XlaCompiler::Argument arg;
+ arg.kind = XlaCompiler::Argument::kParameter;
+ arg.type = ctx->input_type(i);
+ arg.shape = ctx->InputXlaShape(i).ValueOrDie();
+ arg.name = absl::StrCat("_arg", i);
+ xla_args.push_back(arg);
+ }
+ return Status::OK();
+}
+
+} // namespace
+
+Status MlirXlaOpKernel::ConstructXlaOp(XlaOpKernelContext* ctx) {
+ // Create input XlaArguments.
+ std::vector<XlaCompiler::Argument> xla_args;
+ TF_RETURN_IF_ERROR(ContextToXlaArgs(ctx, xla_args));
+
+ // Create input XlaOps.
+ llvm::SmallVector<xla::XlaOp, 4> xla_params(ctx->num_inputs());
+ for (int i = 0, end = ctx->num_inputs(); i < end; ++i) {
+ xla_params[i] = ctx->Input(i);
+ }
+
+ // Create outputs.
+ std::vector<DataType> result_dtypes(ctx->num_outputs());
+ for (int i = 0, end = result_dtypes.size(); i < end; ++i) {
+ result_dtypes[i] = ctx->expected_output_dtype(i);
+ }
+
+ // When there are no data-flow outputs from the node, the node is used as a
+ // control output by the graph to TensorflowDialect importer.
+ std::vector<std::string> control_rets;
+ if (result_dtypes.empty()) {
+ control_rets.push_back(def().name());
+ }
+
+ // Get the context's device.
+ auto device = dynamic_cast<Device*>(ctx->op_kernel_context()->device());
+ if (!device) {
+ return tensorflow::errors::InvalidArgument(
+ "Expected the XlaOpKernelContext argument's device to have type "
+ "Device.");
+ }
+
+ // Create a graph that wraps the kernel.
+ TF_ASSIGN_OR_RETURN(auto graph, CreateGraph(def(), xla_args, result_dtypes));
+
+ // Compile the graph to HLO.
+ GraphDebugInfo debug_info;
+ std::vector<xla::XlaOp> returns(1);
+ TF_RETURN_IF_ERROR(BuildHloFromGraph(
+ *graph, *ctx->builder(), xla_params, returns,
+ mlir::SpanToArrayRef<XlaCompiler::Argument>(xla_args), control_rets,
+ device->device_type(),
+ *ctx->function_library()->GetFunctionLibraryDefinition(), debug_info,
+ {}));
+
+ // Set context outputs.
+ for (int i = 0, end = returns.size(); i < end; ++i) {
+ ctx->SetOutput(i, returns[i]);
+ }
+
+ return Status::OK();
+}
+
+void MlirXlaOpKernel::Compile(XlaOpKernelContext* ctx) {
+ OP_REQUIRES_OK(ctx, ConstructXlaOp(ctx));
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h b/tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h
new file mode 100644
index 0000000..278cc53
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h
@@ -0,0 +1,36 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_TF2XLA_MLIR_XLA_OP_KERNEL_H_
+#define TENSORFLOW_COMPILER_TF2XLA_MLIR_XLA_OP_KERNEL_H_
+
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+
+namespace tensorflow {
+
+// An XlaOpKernel that's implemented by lowering using MLIR TensorFlow to HLO
+// legalization.
+class MlirXlaOpKernel : public XlaOpKernel {
+ public:
+ explicit MlirXlaOpKernel(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+
+ private:
+ void Compile(XlaOpKernelContext* ctx) override;
+ Status ConstructXlaOp(XlaOpKernelContext* ctx);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_TF2XLA_MLIR_XLA_OP_KERNEL_H_