`MlirXlaOpKernel` added, which can implement `XlaOpKernels` using MLIR legalization.

An `XlaOpKernel` is used in the old bridge to lower a TF op to HLO. The cl moves `ReluOp` from its custom implementation to `MlirXlaOpKernel`.

`xla_compilation_cache.h` now exports `CreateGraph` since it's used by `MlirXlaOpKernel`.

PiperOrigin-RevId: 366409160
Change-Id: Ie18eb5fc21b0c73c7f527d31ce15b33c44218dfd
diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc
index 112287b..d1bb09f 100644
--- a/tensorflow/compiler/jit/xla_compilation_cache.cc
+++ b/tensorflow/compiler/jit/xla_compilation_cache.cc
@@ -226,9 +226,7 @@
          execution_count < kMinExecutionsPerCompile * compile_count;
 }
 
-// Creates a simple graph using the specified op as the only op apart from the
-// arg and retval nodes.
-static xla::StatusOr<std::unique_ptr<Graph>> CreateGraph(
+xla::StatusOr<std::unique_ptr<Graph>> CreateGraph(
     const NodeDef& node_def, absl::Span<const XlaCompiler::Argument> args,
     absl::Span<const DataType> result_types) {
   // TODO(b/74182462): We implement this by creating a new dummy Graph including
diff --git a/tensorflow/compiler/jit/xla_compilation_cache.h b/tensorflow/compiler/jit/xla_compilation_cache.h
index c84bc6d..0601433 100644
--- a/tensorflow/compiler/jit/xla_compilation_cache.h
+++ b/tensorflow/compiler/jit/xla_compilation_cache.h
@@ -246,6 +246,12 @@
   TF_DISALLOW_COPY_AND_ASSIGN(XlaCompilationCache);
 };
 
+// Creates a single-node graph using the specified node_def as the only op apart
+// from the arg and retval nodes.
+xla::StatusOr<std::unique_ptr<Graph>> CreateGraph(
+    const NodeDef& node_def, absl::Span<const XlaCompiler::Argument> args,
+    absl::Span<const DataType> result_types);
+
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD
index e00598d..edc89aa 100644
--- a/tensorflow/compiler/tf2xla/BUILD
+++ b/tensorflow/compiler/tf2xla/BUILD
@@ -1127,6 +1127,18 @@
 )
 
 cc_library(
+    name = "mlir_xla_op_kernel",
+    srcs = ["mlir_xla_op_kernel.cc"],
+    hdrs = ["mlir_xla_op_kernel.h"],
+    deps = [
+        ":xla_compiler",
+        "//tensorflow/compiler/jit:xla_compilation_cache",
+        "//tensorflow/compiler/mlir:array_container_utils",
+        "//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes",
+    ],
+)
+
+cc_library(
     name = "resource_util",
     srcs = ["resource_util.cc"],
     hdrs = ["resource_util.h"],
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index 518704e..5179404 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -151,6 +151,7 @@
         "//tensorflow/compiler/jit:xla_activity_listener",
         "//tensorflow/compiler/jit:xla_activity_proto_cc",
         "//tensorflow/compiler/tf2xla:common",
+        "//tensorflow/compiler/tf2xla:mlir_xla_op_kernel",
         "//tensorflow/compiler/tf2xla:xla_compilation_device",
         "//tensorflow/compiler/tf2xla:xla_compiler",
         "//tensorflow/compiler/tf2xla:xla_context",
diff --git a/tensorflow/compiler/tf2xla/kernels/relu_op.cc b/tensorflow/compiler/tf2xla/kernels/relu_op.cc
index 3b53bac..36d4cca 100644
--- a/tensorflow/compiler/tf2xla/kernels/relu_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/relu_op.cc
@@ -17,6 +17,7 @@
 
 #include "tensorflow/compiler/tf2xla/kernels/relu_op.h"
 
+#include "tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h"
 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@@ -35,15 +36,7 @@
 namespace tensorflow {
 namespace {
 
-class ReluOp : public XlaOpKernel {
- public:
-  explicit ReluOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
-  // Computes the max of the scalar input x and 0.
-  void Compile(XlaOpKernelContext* ctx) override {
-    ctx->SetOutput(0, xla::Relu(ctx->Input(0)));
-  }
-};
-REGISTER_XLA_OP(Name("Relu"), ReluOp);
+REGISTER_XLA_OP(Name("Relu"), MlirXlaOpKernel);
 
 class Relu6Op : public XlaOpKernel {
  public:
diff --git a/tensorflow/compiler/tf2xla/mlir_xla_op_kernel.cc b/tensorflow/compiler/tf2xla/mlir_xla_op_kernel.cc
new file mode 100644
index 0000000..d7947d0
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/mlir_xla_op_kernel.cc
@@ -0,0 +1,109 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h"
+
+#include "tensorflow/compiler/jit/xla_compilation_cache.h"
+#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
+#include "tensorflow/compiler/mlir/utils/array_container_utils.h"
+
+namespace tensorflow {
+
+namespace {
+
+Status ContextToXlaArgs(XlaOpKernelContext* ctx,
+                        std::vector<XlaCompiler::Argument>& xla_args) {
+  int num_inputs = ctx->num_inputs();
+  xla_args.reserve(num_inputs);
+  for (int i = 0; i < num_inputs; ++i) {
+    // TODO(b/180448676): If the input `XlaExpression` kind is `kConstant`, then
+    // create a constant `XlaArgument`.
+    // TODO(b/180448774): Handle kResource and kTensorList.
+    XlaExpression::Kind ctx_kind_i = ctx->InputExpression(i).kind();
+    if (ctx_kind_i != XlaExpression::Kind::kXlaOp &&
+        ctx_kind_i != XlaExpression::Kind::kConstant)
+      return tensorflow::errors::InvalidArgument(
+          absl::StrCat("Input ", i, " to an MlirXlaOpKernel is invalid: ",
+                       ctx->InputExpression(i).HumanString()));
+    XlaCompiler::Argument arg;
+    arg.kind = XlaCompiler::Argument::kParameter;
+    arg.type = ctx->input_type(i);
+    arg.shape = ctx->InputXlaShape(i).ValueOrDie();
+    arg.name = absl::StrCat("_arg", i);
+    xla_args.push_back(arg);
+  }
+  return Status::OK();
+}
+
+}  // namespace
+
+Status MlirXlaOpKernel::ConstructXlaOp(XlaOpKernelContext* ctx) {
+  // Create input XlaArguments.
+  std::vector<XlaCompiler::Argument> xla_args;
+  TF_RETURN_IF_ERROR(ContextToXlaArgs(ctx, xla_args));
+
+  // Create input XlaOps.
+  llvm::SmallVector<xla::XlaOp, 4> xla_params(ctx->num_inputs());
+  for (int i = 0, end = ctx->num_inputs(); i < end; ++i) {
+    xla_params[i] = ctx->Input(i);
+  }
+
+  // Create outputs.
+  std::vector<DataType> result_dtypes(ctx->num_outputs());
+  for (int i = 0, end = result_dtypes.size(); i < end; ++i) {
+    result_dtypes[i] = ctx->expected_output_dtype(i);
+  }
+
+  // When there are no data-flow outputs from the node, the node is used as a
+  // control output by the graph to TensorflowDialect importer.
+  std::vector<std::string> control_rets;
+  if (result_dtypes.empty()) {
+    control_rets.push_back(def().name());
+  }
+
+  // Get the context's device.
+  auto device = dynamic_cast<Device*>(ctx->op_kernel_context()->device());
+  if (!device) {
+    return tensorflow::errors::InvalidArgument(
+        "Expected the XlaOpKernelContext argument's device to have type "
+        "Device.");
+  }
+
+  // Create a graph that wraps the kernel.
+  TF_ASSIGN_OR_RETURN(auto graph, CreateGraph(def(), xla_args, result_dtypes));
+
+  // Compile the graph to HLO.
+  GraphDebugInfo debug_info;
+  std::vector<xla::XlaOp> returns(1);
+  TF_RETURN_IF_ERROR(BuildHloFromGraph(
+      *graph, *ctx->builder(), xla_params, returns,
+      mlir::SpanToArrayRef<XlaCompiler::Argument>(xla_args), control_rets,
+      device->device_type(),
+      *ctx->function_library()->GetFunctionLibraryDefinition(), debug_info,
+      {}));
+
+  // Set context outputs.
+  for (int i = 0, end = returns.size(); i < end; ++i) {
+    ctx->SetOutput(i, returns[i]);
+  }
+
+  return Status::OK();
+}
+
+void MlirXlaOpKernel::Compile(XlaOpKernelContext* ctx) {
+  OP_REQUIRES_OK(ctx, ConstructXlaOp(ctx));
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h b/tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h
new file mode 100644
index 0000000..278cc53
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h
@@ -0,0 +1,36 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_TF2XLA_MLIR_XLA_OP_KERNEL_H_
+#define TENSORFLOW_COMPILER_TF2XLA_MLIR_XLA_OP_KERNEL_H_
+
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+
+namespace tensorflow {
+
+// An XlaOpKernel that's implemented by lowering using MLIR TensorFlow to HLO
+// legalization.
+class MlirXlaOpKernel : public XlaOpKernel {
+ public:
+  explicit MlirXlaOpKernel(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+
+ private:
+  void Compile(XlaOpKernelContext* ctx) override;
+  Status ConstructXlaOp(XlaOpKernelContext* ctx);
+};
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_COMPILER_TF2XLA_MLIR_XLA_OP_KERNEL_H_