Revert commit introducing performance regression

https://github.com/tensorflow/tensorflow/commit/cfec367771d9795e5b4f6d3cd6173fc5f6b57158

PiperOrigin-RevId: 361471567
Change-Id: I19aedc352ad71d69e5a082cad5053b5894668614
diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc
index f2079c8..657d90d 100644
--- a/tensorflow/compiler/jit/xla_compilation_cache.cc
+++ b/tensorflow/compiler/jit/xla_compilation_cache.cc
@@ -213,7 +213,9 @@
          execution_count < kMinExecutionsPerCompile * compile_count;
 }
 
-xla::StatusOr<std::unique_ptr<Graph>> CreateGraph(
+// Creates a simple graph using the specified op as the only op apart from the
+// arg and retval nodes.
+static xla::StatusOr<std::unique_ptr<Graph>> CreateGraph(
     const NodeDef& node_def, absl::Span<const XlaCompiler::Argument> args,
     absl::Span<const DataType> result_types) {
   // TODO(b/74182462): We implement this by creating a new dummy Graph including
diff --git a/tensorflow/compiler/jit/xla_compilation_cache.h b/tensorflow/compiler/jit/xla_compilation_cache.h
index 83f38b0..cd58cf3 100644
--- a/tensorflow/compiler/jit/xla_compilation_cache.h
+++ b/tensorflow/compiler/jit/xla_compilation_cache.h
@@ -196,12 +196,6 @@
   TF_DISALLOW_COPY_AND_ASSIGN(XlaCompilationCache);
 };
 
-// Creates a single-node graph using the specified node_def as the only op apart
-// from the arg and retval nodes.
-xla::StatusOr<std::unique_ptr<Graph>> CreateGraph(
-    const NodeDef& node_def, absl::Span<const XlaCompiler::Argument> args,
-    absl::Span<const DataType> result_types);
-
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD
index 0500828..e877fce 100644
--- a/tensorflow/compiler/tf2xla/BUILD
+++ b/tensorflow/compiler/tf2xla/BUILD
@@ -1126,18 +1126,6 @@
 )
 
 cc_library(
-    name = "mlir_xla_op_kernel",
-    srcs = ["mlir_xla_op_kernel.cc"],
-    hdrs = ["mlir_xla_op_kernel.h"],
-    deps = [
-        ":xla_compiler",
-        "//tensorflow/compiler/jit:xla_compilation_cache",
-        "//tensorflow/compiler/mlir:array_container_utils",
-        "//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes",
-    ],
-)
-
-cc_library(
     name = "resource_util",
     srcs = ["resource_util.cc"],
     hdrs = ["resource_util.h"],
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index 5e55152..56eb5b3 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -150,7 +150,6 @@
         "//tensorflow/compiler/jit:xla_activity_listener",
         "//tensorflow/compiler/jit:xla_activity_proto_cc",
         "//tensorflow/compiler/tf2xla:common",
-        "//tensorflow/compiler/tf2xla:mlir_xla_op_kernel",
         "//tensorflow/compiler/tf2xla:xla_compilation_device",
         "//tensorflow/compiler/tf2xla:xla_compiler",
         "//tensorflow/compiler/tf2xla:xla_context",
diff --git a/tensorflow/compiler/tf2xla/kernels/relu_op.cc b/tensorflow/compiler/tf2xla/kernels/relu_op.cc
index 36d4cca..3b53bac 100644
--- a/tensorflow/compiler/tf2xla/kernels/relu_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/relu_op.cc
@@ -17,7 +17,6 @@
 
 #include "tensorflow/compiler/tf2xla/kernels/relu_op.h"
 
-#include "tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h"
 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@@ -36,7 +35,15 @@
 namespace tensorflow {
 namespace {
 
-REGISTER_XLA_OP(Name("Relu"), MlirXlaOpKernel);
+class ReluOp : public XlaOpKernel {
+ public:
+  explicit ReluOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+  // Computes the max of the scalar input x and 0.
+  void Compile(XlaOpKernelContext* ctx) override {
+    ctx->SetOutput(0, xla::Relu(ctx->Input(0)));
+  }
+};
+REGISTER_XLA_OP(Name("Relu"), ReluOp);
 
 class Relu6Op : public XlaOpKernel {
  public:
diff --git a/tensorflow/compiler/tf2xla/mlir_xla_op_kernel.cc b/tensorflow/compiler/tf2xla/mlir_xla_op_kernel.cc
deleted file mode 100644
index d7947d0..0000000
--- a/tensorflow/compiler/tf2xla/mlir_xla_op_kernel.cc
+++ /dev/null
@@ -1,109 +0,0 @@
-/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h"
-
-#include "tensorflow/compiler/jit/xla_compilation_cache.h"
-#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
-#include "tensorflow/compiler/mlir/utils/array_container_utils.h"
-
-namespace tensorflow {
-
-namespace {
-
-Status ContextToXlaArgs(XlaOpKernelContext* ctx,
-                        std::vector<XlaCompiler::Argument>& xla_args) {
-  int num_inputs = ctx->num_inputs();
-  xla_args.reserve(num_inputs);
-  for (int i = 0; i < num_inputs; ++i) {
-    // TODO(b/180448676): If the input `XlaExpression` kind is `kConstant`, then
-    // create a constant `XlaArgument`.
-    // TODO(b/180448774): Handle kResource and kTensorList.
-    XlaExpression::Kind ctx_kind_i = ctx->InputExpression(i).kind();
-    if (ctx_kind_i != XlaExpression::Kind::kXlaOp &&
-        ctx_kind_i != XlaExpression::Kind::kConstant)
-      return tensorflow::errors::InvalidArgument(
-          absl::StrCat("Input ", i, " to an MlirXlaOpKernel is invalid: ",
-                       ctx->InputExpression(i).HumanString()));
-    XlaCompiler::Argument arg;
-    arg.kind = XlaCompiler::Argument::kParameter;
-    arg.type = ctx->input_type(i);
-    arg.shape = ctx->InputXlaShape(i).ValueOrDie();
-    arg.name = absl::StrCat("_arg", i);
-    xla_args.push_back(arg);
-  }
-  return Status::OK();
-}
-
-}  // namespace
-
-Status MlirXlaOpKernel::ConstructXlaOp(XlaOpKernelContext* ctx) {
-  // Create input XlaArguments.
-  std::vector<XlaCompiler::Argument> xla_args;
-  TF_RETURN_IF_ERROR(ContextToXlaArgs(ctx, xla_args));
-
-  // Create input XlaOps.
-  llvm::SmallVector<xla::XlaOp, 4> xla_params(ctx->num_inputs());
-  for (int i = 0, end = ctx->num_inputs(); i < end; ++i) {
-    xla_params[i] = ctx->Input(i);
-  }
-
-  // Create outputs.
-  std::vector<DataType> result_dtypes(ctx->num_outputs());
-  for (int i = 0, end = result_dtypes.size(); i < end; ++i) {
-    result_dtypes[i] = ctx->expected_output_dtype(i);
-  }
-
-  // When there are no data-flow outputs from the node, the node is used as a
-  // control output by the graph to TensorflowDialect importer.
-  std::vector<std::string> control_rets;
-  if (result_dtypes.empty()) {
-    control_rets.push_back(def().name());
-  }
-
-  // Get the context's device.
-  auto device = dynamic_cast<Device*>(ctx->op_kernel_context()->device());
-  if (!device) {
-    return tensorflow::errors::InvalidArgument(
-        "Expected the XlaOpKernelContext argument's device to have type "
-        "Device.");
-  }
-
-  // Create a graph that wraps the kernel.
-  TF_ASSIGN_OR_RETURN(auto graph, CreateGraph(def(), xla_args, result_dtypes));
-
-  // Compile the graph to HLO.
-  GraphDebugInfo debug_info;
-  std::vector<xla::XlaOp> returns(1);
-  TF_RETURN_IF_ERROR(BuildHloFromGraph(
-      *graph, *ctx->builder(), xla_params, returns,
-      mlir::SpanToArrayRef<XlaCompiler::Argument>(xla_args), control_rets,
-      device->device_type(),
-      *ctx->function_library()->GetFunctionLibraryDefinition(), debug_info,
-      {}));
-
-  // Set context outputs.
-  for (int i = 0, end = returns.size(); i < end; ++i) {
-    ctx->SetOutput(i, returns[i]);
-  }
-
-  return Status::OK();
-}
-
-void MlirXlaOpKernel::Compile(XlaOpKernelContext* ctx) {
-  OP_REQUIRES_OK(ctx, ConstructXlaOp(ctx));
-}
-
-}  // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h b/tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h
deleted file mode 100644
index 278cc53..0000000
--- a/tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h
+++ /dev/null
@@ -1,36 +0,0 @@
-/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_COMPILER_TF2XLA_MLIR_XLA_OP_KERNEL_H_
-#define TENSORFLOW_COMPILER_TF2XLA_MLIR_XLA_OP_KERNEL_H_
-
-#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
-
-namespace tensorflow {
-
-// An XlaOpKernel that's implemented by lowering using MLIR TensorFlow to HLO
-// legalization.
-class MlirXlaOpKernel : public XlaOpKernel {
- public:
-  explicit MlirXlaOpKernel(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
-
- private:
-  void Compile(XlaOpKernelContext* ctx) override;
-  Status ConstructXlaOp(XlaOpKernelContext* ctx);
-};
-
-}  // namespace tensorflow
-
-#endif  // TENSORFLOW_COMPILER_TF2XLA_MLIR_XLA_OP_KERNEL_H_