blob: 5c64a65ecbd8b24c86d4cd50cdc251aea9a2dcb4 [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_COMPILE_MLIR_UTIL_H_
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_COMPILE_MLIR_UTIL_H_
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/StringRef.h"
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "tensorflow/compiler/tf2xla/xla_argument.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
#include "tensorflow/stream_executor/lib/statusor.h"
namespace tensorflow {
// Lowers MLIR module to XLA HLO inside an XlaComputation. The input module
// should only contain operations in tf dialect. If the input module contains
// operation in the tf_executor dialect, for example, returns an error.
// Exception to this are tf_executor dialect ops that are optimized away through
// canonicalization.
//
// Operations in tf dialect are lowered to XLA HLO through the following steps:
// . Legalizes control flow operations.
// . Decomposes compound resource operations so that the only remaining
// operations on resource variables are resource reads/writes..
// . Replaces resource reads/writes with function inputs/outputs and
// eliminates the use of resource variables.
// . Legalizes the operations to XLA HLO operations.
// . Canonicalizes the XLA HLO operations.
//
// device_type: XLA JIT device to use for compilation such as "XLA_CPU_JIT",
// "XLA_GPU_JIT" or "XLA_TPU_JIT".
// use_tuple_args: when this is true, always create a tuple argument for the
// entry computation.
// return_tuple: when this is true, always create a tuple result for the
// entry computation.
// shape_representation_fn: when this is set, this shape representation function
// will be used to determine argument and result shapes. Otherwise the
// original shape will be used as is.
// custom_legalization_passes: passes to run before the default TF legalization
// passes for backend-specific ops.
Status ConvertMLIRToXlaComputation(
mlir::ModuleOp module_op, llvm::StringRef device_type,
xla::XlaComputation* xla_computation, bool use_tuple_args,
bool return_tuple,
const XlaHelpers::ShapeRepresentationFn shape_representation_fn = nullptr,
std::vector<std::unique_ptr<mlir::Pass>> custom_legalization_passes = {});
// Compiles a serialized MLIR module into XLA HLO, generates all accompanying
// metadata and stores them in CompilationResult.
Status CompileSerializedMlirToXlaHlo(
llvm::StringRef mlir_module_string, llvm::ArrayRef<TensorShape> arg_shapes,
llvm::StringRef device_type, bool use_tuple_args,
const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
XlaCompilationResult* compilation_result,
std::vector<std::unique_ptr<mlir::Pass>> custom_legalization_passes = {});
// Same as the above but takes input as TensorFlow Graph.
// TODO(lyandy): Allow populating of targets/control outputs.
Status CompileGraphToXlaHlo(
const Graph& graph, llvm::ArrayRef<const XlaArgument> args,
llvm::StringRef device_type, bool use_tuple_args,
const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info,
const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
XlaCompilationResult* compilation_result,
std::vector<std::unique_ptr<mlir::Pass>> custom_legalization_passes = {});
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_COMPILE_MLIR_UTIL_H_