Revert commit introducing performance regression
https://github.com/tensorflow/tensorflow/commit/cfec367771d9795e5b4f6d3cd6173fc5f6b57158
PiperOrigin-RevId: 361471567
Change-Id: I19aedc352ad71d69e5a082cad5053b5894668614
diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc
index f2079c8..657d90d 100644
--- a/tensorflow/compiler/jit/xla_compilation_cache.cc
+++ b/tensorflow/compiler/jit/xla_compilation_cache.cc
@@ -213,7 +213,9 @@
execution_count < kMinExecutionsPerCompile * compile_count;
}
-xla::StatusOr<std::unique_ptr<Graph>> CreateGraph(
+// Creates a simple graph using the specified op as the only op apart from the
+// arg and retval nodes.
+static xla::StatusOr<std::unique_ptr<Graph>> CreateGraph(
const NodeDef& node_def, absl::Span<const XlaCompiler::Argument> args,
absl::Span<const DataType> result_types) {
// TODO(b/74182462): We implement this by creating a new dummy Graph including
diff --git a/tensorflow/compiler/jit/xla_compilation_cache.h b/tensorflow/compiler/jit/xla_compilation_cache.h
index 83f38b0..cd58cf3 100644
--- a/tensorflow/compiler/jit/xla_compilation_cache.h
+++ b/tensorflow/compiler/jit/xla_compilation_cache.h
@@ -196,12 +196,6 @@
TF_DISALLOW_COPY_AND_ASSIGN(XlaCompilationCache);
};
-// Creates a single-node graph using the specified node_def as the only op apart
-// from the arg and retval nodes.
-xla::StatusOr<std::unique_ptr<Graph>> CreateGraph(
- const NodeDef& node_def, absl::Span<const XlaCompiler::Argument> args,
- absl::Span<const DataType> result_types);
-
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD
index 0500828..e877fce 100644
--- a/tensorflow/compiler/tf2xla/BUILD
+++ b/tensorflow/compiler/tf2xla/BUILD
@@ -1126,18 +1126,6 @@
)
cc_library(
- name = "mlir_xla_op_kernel",
- srcs = ["mlir_xla_op_kernel.cc"],
- hdrs = ["mlir_xla_op_kernel.h"],
- deps = [
- ":xla_compiler",
- "//tensorflow/compiler/jit:xla_compilation_cache",
- "//tensorflow/compiler/mlir:array_container_utils",
- "//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes",
- ],
-)
-
-cc_library(
name = "resource_util",
srcs = ["resource_util.cc"],
hdrs = ["resource_util.h"],
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index 5e55152..56eb5b3 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -150,7 +150,6 @@
"//tensorflow/compiler/jit:xla_activity_listener",
"//tensorflow/compiler/jit:xla_activity_proto_cc",
"//tensorflow/compiler/tf2xla:common",
- "//tensorflow/compiler/tf2xla:mlir_xla_op_kernel",
"//tensorflow/compiler/tf2xla:xla_compilation_device",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla:xla_context",
diff --git a/tensorflow/compiler/tf2xla/kernels/relu_op.cc b/tensorflow/compiler/tf2xla/kernels/relu_op.cc
index 36d4cca..3b53bac 100644
--- a/tensorflow/compiler/tf2xla/kernels/relu_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/relu_op.cc
@@ -17,7 +17,6 @@
#include "tensorflow/compiler/tf2xla/kernels/relu_op.h"
-#include "tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@@ -36,7 +35,15 @@
namespace tensorflow {
namespace {
-REGISTER_XLA_OP(Name("Relu"), MlirXlaOpKernel);
+class ReluOp : public XlaOpKernel {
+ public:
+ explicit ReluOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+ // Computes the max of the scalar input x and 0.
+ void Compile(XlaOpKernelContext* ctx) override {
+ ctx->SetOutput(0, xla::Relu(ctx->Input(0)));
+ }
+};
+REGISTER_XLA_OP(Name("Relu"), ReluOp);
class Relu6Op : public XlaOpKernel {
public:
diff --git a/tensorflow/compiler/tf2xla/mlir_xla_op_kernel.cc b/tensorflow/compiler/tf2xla/mlir_xla_op_kernel.cc
deleted file mode 100644
index d7947d0..0000000
--- a/tensorflow/compiler/tf2xla/mlir_xla_op_kernel.cc
+++ /dev/null
@@ -1,109 +0,0 @@
-/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h"
-
-#include "tensorflow/compiler/jit/xla_compilation_cache.h"
-#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
-#include "tensorflow/compiler/mlir/utils/array_container_utils.h"
-
-namespace tensorflow {
-
-namespace {
-
-Status ContextToXlaArgs(XlaOpKernelContext* ctx,
- std::vector<XlaCompiler::Argument>& xla_args) {
- int num_inputs = ctx->num_inputs();
- xla_args.reserve(num_inputs);
- for (int i = 0; i < num_inputs; ++i) {
- // TODO(b/180448676): If the input `XlaExpression` kind is `kConstant`, then
- // create a constant `XlaArgument`.
- // TODO(b/180448774): Handle kResource and kTensorList.
- XlaExpression::Kind ctx_kind_i = ctx->InputExpression(i).kind();
- if (ctx_kind_i != XlaExpression::Kind::kXlaOp &&
- ctx_kind_i != XlaExpression::Kind::kConstant)
- return tensorflow::errors::InvalidArgument(
- absl::StrCat("Input ", i, " to an MlirXlaOpKernel is invalid: ",
- ctx->InputExpression(i).HumanString()));
- XlaCompiler::Argument arg;
- arg.kind = XlaCompiler::Argument::kParameter;
- arg.type = ctx->input_type(i);
- arg.shape = ctx->InputXlaShape(i).ValueOrDie();
- arg.name = absl::StrCat("_arg", i);
- xla_args.push_back(arg);
- }
- return Status::OK();
-}
-
-} // namespace
-
-Status MlirXlaOpKernel::ConstructXlaOp(XlaOpKernelContext* ctx) {
- // Create input XlaArguments.
- std::vector<XlaCompiler::Argument> xla_args;
- TF_RETURN_IF_ERROR(ContextToXlaArgs(ctx, xla_args));
-
- // Create input XlaOps.
- llvm::SmallVector<xla::XlaOp, 4> xla_params(ctx->num_inputs());
- for (int i = 0, end = ctx->num_inputs(); i < end; ++i) {
- xla_params[i] = ctx->Input(i);
- }
-
- // Create outputs.
- std::vector<DataType> result_dtypes(ctx->num_outputs());
- for (int i = 0, end = result_dtypes.size(); i < end; ++i) {
- result_dtypes[i] = ctx->expected_output_dtype(i);
- }
-
- // When there are no data-flow outputs from the node, the node is used as a
- // control output by the graph to TensorflowDialect importer.
- std::vector<std::string> control_rets;
- if (result_dtypes.empty()) {
- control_rets.push_back(def().name());
- }
-
- // Get the context's device.
- auto device = dynamic_cast<Device*>(ctx->op_kernel_context()->device());
- if (!device) {
- return tensorflow::errors::InvalidArgument(
- "Expected the XlaOpKernelContext argument's device to have type "
- "Device.");
- }
-
- // Create a graph that wraps the kernel.
- TF_ASSIGN_OR_RETURN(auto graph, CreateGraph(def(), xla_args, result_dtypes));
-
- // Compile the graph to HLO.
- GraphDebugInfo debug_info;
- std::vector<xla::XlaOp> returns(1);
- TF_RETURN_IF_ERROR(BuildHloFromGraph(
- *graph, *ctx->builder(), xla_params, returns,
- mlir::SpanToArrayRef<XlaCompiler::Argument>(xla_args), control_rets,
- device->device_type(),
- *ctx->function_library()->GetFunctionLibraryDefinition(), debug_info,
- {}));
-
- // Set context outputs.
- for (int i = 0, end = returns.size(); i < end; ++i) {
- ctx->SetOutput(i, returns[i]);
- }
-
- return Status::OK();
-}
-
-void MlirXlaOpKernel::Compile(XlaOpKernelContext* ctx) {
- OP_REQUIRES_OK(ctx, ConstructXlaOp(ctx));
-}
-
-} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h b/tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h
deleted file mode 100644
index 278cc53..0000000
--- a/tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h
+++ /dev/null
@@ -1,36 +0,0 @@
-/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_COMPILER_TF2XLA_MLIR_XLA_OP_KERNEL_H_
-#define TENSORFLOW_COMPILER_TF2XLA_MLIR_XLA_OP_KERNEL_H_
-
-#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
-
-namespace tensorflow {
-
-// An XlaOpKernel that's implemented by lowering using MLIR TensorFlow to HLO
-// legalization.
-class MlirXlaOpKernel : public XlaOpKernel {
- public:
- explicit MlirXlaOpKernel(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
-
- private:
- void Compile(XlaOpKernelContext* ctx) override;
- Status ConstructXlaOp(XlaOpKernelContext* ctx);
-};
-
-} // namespace tensorflow
-
-#endif // TENSORFLOW_COMPILER_TF2XLA_MLIR_XLA_OP_KERNEL_H_