Register a MlirPassthroughOp with TensorFlow

This operation wraps an arbitrary MLIR computation expressed as a module with a
main() function. This operation does not have an associated kernel and is not
intended to be executed in a regular TensorFlow session. Instead it is intended
to be used for testing or for special case where a user intends to pass custom
MLIR computation through a TensorFlow graph with the intent of having custom
tooling processing it downstream (when targeting a different environment, like
TensorFlow lite for example).

PiperOrigin-RevId: 268071038
diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD
index 383d4c4..df24b3e 100644
--- a/tensorflow/compiler/mlir/tensorflow/BUILD
+++ b/tensorflow/compiler/mlir/tensorflow/BUILD
@@ -1,5 +1,5 @@
 load("@local_config_mlir//:tblgen.bzl", "gentbl")
-load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_native_cc_binary")
+load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_gen_op_wrapper_py", "tf_native_cc_binary")
 
 package(
     default_visibility = [":friends"],
@@ -138,6 +138,7 @@
     includes = ["include"],
     deps = [
         ":error_util",
+        ":mlir_passthrough_op",
         ":tensorflow_canonicalize_inc_gen",
         ":tensorflow_device_ops_inc_gen",
         ":tensorflow_executor_inc_gen",
@@ -708,3 +709,18 @@
         "//tensorflow/stream_executor/lib",
     ],
 )
+
+cc_library(
+    name = "mlir_passthrough_op",
+    srcs = ["ops/mlir_passthrough_op.cc"],
+    deps = [
+        "//tensorflow/core:framework",
+    ],
+    alwayslink = 1,
+)
+
+tf_gen_op_wrapper_py(
+    name = "gen_mlir_passthrough_op_py",
+    out = "gen_mlir_passthrough_op.py",
+    deps = [":mlir_passthrough_op"],
+)
diff --git a/tensorflow/compiler/mlir/tensorflow/ops/mlir_passthrough_op.cc b/tensorflow/compiler/mlir/tensorflow/ops/mlir_passthrough_op.cc
new file mode 100644
index 0000000..de9d910
--- /dev/null
+++ b/tensorflow/compiler/mlir/tensorflow/ops/mlir_passthrough_op.cc
@@ -0,0 +1,55 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/op.h"
+
+namespace tensorflow {
+
+REGISTER_OP("MlirPassthroughOp")
+    .Attr("mlir_module: string")
+    .Attr("Tinputs : list(type) >= 0")
+    .Input("inputs: Tinputs")
+    .Attr("Toutputs : list(type) >= 0")
+    .Output("outputs: Toutputs")
+    .Doc(R"doc(
+This operation wraps an arbitrary MLIR computation expressed as a module with a
+main() function. This operation does not have an associated kernel and is not
+intended to be executed in a regular TensorFlow session. Instead it is intended
+to be used for testing or for special case where a user intends to pass custom
+MLIR computation through a TensorFlow graph with the intent of having custom
+tooling processing it downstream (when targeting a different environment, like
+TensorFlow lite for example).
+Example usage:
+
+```
+import tensorflow as tf
+from tensorflow.compiler.mlir.tensorflow.gen_mlir_passthrough_op import mlir_passthrough_op
+
+mlir_module = '''
+func @main(%arg0 : tensor<10xf32>, %arg1 : tensor<10xf32>) -> tensor<10x10xf32> {
+   %add = "magic.op"(%arg0, %arg1) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10x10xf32>
+   return %ret : tensor<10x10xf32>
+}
+'''
+
+@tf.function
+def foo(x, y):
+  return = mlir_passthrough_op([x, y], mlir_module, Toutputs=[tf.float32])
+
+graph_def = foo.get_concrete_function(tf.TensorSpec([10], tf.float32), tf.TensorSpec([10], tf.float32)).graph.as_graph_def()
+```
+)doc");
+
+}  // namespace tensorflow
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/mlir_passthrough_op.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/mlir_passthrough_op.pbtxt
new file mode 100644
index 0000000..1df903d
--- /dev/null
+++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/mlir_passthrough_op.pbtxt
@@ -0,0 +1,101 @@
+# RUN: tf-mlir-translate -graphdef-to-mlir %s | FileCheck %s
+
+# CHECK:"tf.MlirPassthroughOp"
+# CHECK: mlir_module = "\0Afunc @main(%arg0 : tensor<10xf32>, %arg1 : tensor<10xf32>) -> tensor<10x10xf32> {\0A   %add = \22tf.Add\22(%arg0, %arg1) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>\0A   %ret = \22magic.op\22(%add, %add) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10x10xf32>\0A   return %ret : tensor<10x10xf32>\0A}\0A", name = "MlirPassthroughOp"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<*xf32>
+
+node {
+  name: "x"
+  op: "Placeholder"
+  attr {
+    key: "_user_specified_name"
+    value {
+      s: "x"
+    }
+  }
+  attr {
+    key: "dtype"
+    value {
+      type: DT_FLOAT
+    }
+  }
+  attr {
+    key: "shape"
+    value {
+      shape {
+        dim {
+          size: 10
+        }
+      }
+    }
+  }
+}
+node {
+  name: "y"
+  op: "Placeholder"
+  attr {
+    key: "_user_specified_name"
+    value {
+      s: "y"
+    }
+  }
+  attr {
+    key: "dtype"
+    value {
+      type: DT_FLOAT
+    }
+  }
+  attr {
+    key: "shape"
+    value {
+      shape {
+        dim {
+          size: 10
+        }
+      }
+    }
+  }
+}
+node {
+  name: "MlirPassthroughOp"
+  op: "MlirPassthroughOp"
+  input: "x"
+  input: "y"
+  attr {
+    key: "Tinputs"
+    value {
+      list {
+        type: DT_FLOAT
+        type: DT_FLOAT
+      }
+    }
+  }
+  attr {
+    key: "Toutputs"
+    value {
+      list {
+        type: DT_FLOAT
+      }
+    }
+  }
+  attr {
+    key: "mlir_module"
+    value {
+      s: "\nfunc @main(%arg0 : tensor<10xf32>, %arg1 : tensor<10xf32>) -> tensor<10x10xf32> {\n   %add = \"tf.Add\"(%arg0, %arg1) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>\n   %ret = \"magic.op\"(%add, %add) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10x10xf32>\n   return %ret : tensor<10x10xf32>\n}\n"
+    }
+  }
+}
+node {
+  name: "Identity"
+  op: "Identity"
+  input: "MlirPassthroughOp"
+  attr {
+    key: "T"
+    value {
+      type: DT_FLOAT
+    }
+  }
+}
+versions {
+  producer: 148
+}
+