blob: de9d910108ea442ba402b6b5fd200c39005309ce [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/op.h"
namespace tensorflow {
REGISTER_OP("MlirPassthroughOp")
.Attr("mlir_module: string")
.Attr("Tinputs : list(type) >= 0")
.Input("inputs: Tinputs")
.Attr("Toutputs : list(type) >= 0")
.Output("outputs: Toutputs")
.Doc(R"doc(
This operation wraps an arbitrary MLIR computation expressed as a module with a
main() function. This operation does not have an associated kernel and is not
intended to be executed in a regular TensorFlow session. Instead it is intended
to be used for testing or for special case where a user intends to pass custom
MLIR computation through a TensorFlow graph with the intent of having custom
tooling processing it downstream (when targeting a different environment, like
TensorFlow lite for example).
Example usage:
```
import tensorflow as tf
from tensorflow.compiler.mlir.tensorflow.gen_mlir_passthrough_op import mlir_passthrough_op
mlir_module = '''
func @main(%arg0 : tensor<10xf32>, %arg1 : tensor<10xf32>) -> tensor<10x10xf32> {
%add = "magic.op"(%arg0, %arg1) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10x10xf32>
return %ret : tensor<10x10xf32>
}
'''
@tf.function
def foo(x, y):
return = mlir_passthrough_op([x, y], mlir_module, Toutputs=[tf.float32])
graph_def = foo.get_concrete_function(tf.TensorSpec([10], tf.float32), tf.TensorSpec([10], tf.float32)).graph.as_graph_def()
```
)doc");
} // namespace tensorflow