Register a MlirPassthroughOp with TensorFlow
This operation wraps an arbitrary MLIR computation expressed as a module with a
main() function. This operation does not have an associated kernel and is not
intended to be executed in a regular TensorFlow session. Instead it is intended
to be used for testing or for special case where a user intends to pass custom
MLIR computation through a TensorFlow graph with the intent of having custom
tooling processing it downstream (when targeting a different environment, like
TensorFlow lite for example).
PiperOrigin-RevId: 268071038
diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD
index 383d4c4..df24b3e 100644
--- a/tensorflow/compiler/mlir/tensorflow/BUILD
+++ b/tensorflow/compiler/mlir/tensorflow/BUILD
@@ -1,5 +1,5 @@
load("@local_config_mlir//:tblgen.bzl", "gentbl")
-load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_native_cc_binary")
+load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_gen_op_wrapper_py", "tf_native_cc_binary")
package(
default_visibility = [":friends"],
@@ -138,6 +138,7 @@
includes = ["include"],
deps = [
":error_util",
+ ":mlir_passthrough_op",
":tensorflow_canonicalize_inc_gen",
":tensorflow_device_ops_inc_gen",
":tensorflow_executor_inc_gen",
@@ -708,3 +709,18 @@
"//tensorflow/stream_executor/lib",
],
)
+
+cc_library(
+ name = "mlir_passthrough_op",
+ srcs = ["ops/mlir_passthrough_op.cc"],
+ deps = [
+ "//tensorflow/core:framework",
+ ],
+ alwayslink = 1,
+)
+
+tf_gen_op_wrapper_py(
+ name = "gen_mlir_passthrough_op_py",
+ out = "gen_mlir_passthrough_op.py",
+ deps = [":mlir_passthrough_op"],
+)
diff --git a/tensorflow/compiler/mlir/tensorflow/ops/mlir_passthrough_op.cc b/tensorflow/compiler/mlir/tensorflow/ops/mlir_passthrough_op.cc
new file mode 100644
index 0000000..de9d910
--- /dev/null
+++ b/tensorflow/compiler/mlir/tensorflow/ops/mlir_passthrough_op.cc
@@ -0,0 +1,55 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/op.h"
+
+namespace tensorflow {
+
+REGISTER_OP("MlirPassthroughOp")
+ .Attr("mlir_module: string")
+ .Attr("Tinputs : list(type) >= 0")
+ .Input("inputs: Tinputs")
+ .Attr("Toutputs : list(type) >= 0")
+ .Output("outputs: Toutputs")
+ .Doc(R"doc(
+This operation wraps an arbitrary MLIR computation expressed as a module with a
+main() function. This operation does not have an associated kernel and is not
+intended to be executed in a regular TensorFlow session. Instead it is intended
+to be used for testing or for special case where a user intends to pass custom
+MLIR computation through a TensorFlow graph with the intent of having custom
+tooling processing it downstream (when targeting a different environment, like
+TensorFlow lite for example).
+Example usage:
+
+```
+import tensorflow as tf
+from tensorflow.compiler.mlir.tensorflow.gen_mlir_passthrough_op import mlir_passthrough_op
+
+mlir_module = '''
+func @main(%arg0 : tensor<10xf32>, %arg1 : tensor<10xf32>) -> tensor<10x10xf32> {
+ %add = "magic.op"(%arg0, %arg1) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10x10xf32>
+ return %ret : tensor<10x10xf32>
+}
+'''
+
+@tf.function
+def foo(x, y):
+ return = mlir_passthrough_op([x, y], mlir_module, Toutputs=[tf.float32])
+
+graph_def = foo.get_concrete_function(tf.TensorSpec([10], tf.float32), tf.TensorSpec([10], tf.float32)).graph.as_graph_def()
+```
+)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/mlir_passthrough_op.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/mlir_passthrough_op.pbtxt
new file mode 100644
index 0000000..1df903d
--- /dev/null
+++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/mlir_passthrough_op.pbtxt
@@ -0,0 +1,101 @@
+# RUN: tf-mlir-translate -graphdef-to-mlir %s | FileCheck %s
+
+# CHECK:"tf.MlirPassthroughOp"
+# CHECK: mlir_module = "\0Afunc @main(%arg0 : tensor<10xf32>, %arg1 : tensor<10xf32>) -> tensor<10x10xf32> {\0A %add = \22tf.Add\22(%arg0, %arg1) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>\0A %ret = \22magic.op\22(%add, %add) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10x10xf32>\0A return %ret : tensor<10x10xf32>\0A}\0A", name = "MlirPassthroughOp"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<*xf32>
+
+node {
+ name: "x"
+ op: "Placeholder"
+ attr {
+ key: "_user_specified_name"
+ value {
+ s: "x"
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+}
+node {
+ name: "y"
+ op: "Placeholder"
+ attr {
+ key: "_user_specified_name"
+ value {
+ s: "y"
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+}
+node {
+ name: "MlirPassthroughOp"
+ op: "MlirPassthroughOp"
+ input: "x"
+ input: "y"
+ attr {
+ key: "Tinputs"
+ value {
+ list {
+ type: DT_FLOAT
+ type: DT_FLOAT
+ }
+ }
+ }
+ attr {
+ key: "Toutputs"
+ value {
+ list {
+ type: DT_FLOAT
+ }
+ }
+ }
+ attr {
+ key: "mlir_module"
+ value {
+ s: "\nfunc @main(%arg0 : tensor<10xf32>, %arg1 : tensor<10xf32>) -> tensor<10x10xf32> {\n %add = \"tf.Add\"(%arg0, %arg1) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>\n %ret = \"magic.op\"(%add, %add) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10x10xf32>\n return %ret : tensor<10x10xf32>\n}\n"
+ }
+ }
+}
+node {
+ name: "Identity"
+ op: "Identity"
+ input: "MlirPassthroughOp"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+versions {
+ producer: 148
+}
+