blob: f74b931c6c5353fee23e06c494dce1845945b6ee [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_FUNCTION_IMPORTER_H_
#define TENSORFLOW_COMPILER_MLIR_XLA_FUNCTION_IMPORTER_H_
#include <unordered_map>
#include "absl/types/optional.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" // from @llvm-project
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
#include "tensorflow/compiler/xla/comparison_util.h"
#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
class HloModule;
class HloComputation;
class HloInstruction;
class Shape;
// HLO bounded dynamic shapes can be converted to either MLIR dynamic shapes
// (which lose the bound information) or casted to static shape using the
// bounds.
enum class DynamicShapeHandlingMode { kDynamic, kConvertToStatic };
// Helper class for importing HloComputations.
class HloFunctionImporter {
public:
// Imports the given computation as a function in the given module. This also
// imports any computations referred by instructions in this computation.
static Status ImportAsFunc(
const xla::HloComputation& computation, mlir::ModuleOp module,
std::unordered_map<const xla::HloComputation*, mlir::func::FuncOp>*
function_map,
mlir::Builder* builder, bool is_main);
// Imports the given hlo computation to the specified region. If
// 'flatten_region_arg_tuple' is true, then flatten the tuple-typed region
// argument(s) and return value(s).
static Status ImportAsRegion(const xla::HloComputation& computation,
mlir::Region* region, mlir::Builder* builder,
bool flatten_region_arg_tuple = false);
// Imports the given computation to the given place specified by `builder`.
// `arguments` contains values for all parameters.
static StatusOr<mlir::Value> ImportInstructions(
const xla::HloComputation& computation,
const llvm::SmallVectorImpl<mlir::Value>& arguments,
mlir::OpBuilder* builder);
static StatusOr<mlir::Operation*> ImportInstruction(
const xla::HloInstruction* instr,
const llvm::SmallVectorImpl<mlir::Value>& operands,
mlir::OpBuilder* builder,
DynamicShapeHandlingMode mode = DynamicShapeHandlingMode::kDynamic);
static void SetLayoutForMlir(mlir::Operation* op, const Shape& shape,
llvm::StringRef attr_name);
// TODO(b/179166199): move this to attribute_importer.h.
// Converts XLA instruction source target pairs to MLIR attribute.
static mlir::NamedAttribute ConvertSourceTargetPairs(
const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
mlir::Builder* builder);
// TODO(b/179166199): move this to attribute_importer.h.
// Converts replica groups to attribute
static mlir::NamedAttribute ConvertReplicaGroups(
absl::Span<const ReplicaGroup> replica_groups, mlir::Builder* builder);
// For mlir::IfOp or mlir::CaseOp, replace the uses of their region's block
// arguments with 'implicit_operands'. Here | implicit_operands | == sum of
// the number of arguments in all the regions in IfOp or CaseOp.
void ReplaceBlockArgumentsWithImplicitOperands(
mlir::Operation* op, llvm::ArrayRef<mlir::Value> implicit_operands);
// Create a TupleOp using the results of 'op' if 'type' is a mlir::TupleType.
// Otherwise, return 'op'.
mlir::Operation* CreateTupleFromOpResults(mlir::OpBuilder* func_builder,
mlir::Location loc,
mlir::Operation* op,
mlir::Type type);
// FlattenTupleType flattens the types in (nested) tuple-type 'type' and
// stores them in 'types'.
static void FlattenTupleType(
mlir::Type type, llvm::SmallVectorImpl<mlir::Type>& flattened_types);
// FlattenTupleValue flattens the values in (nested) tuple-typed 'value' and
// stores them in 'flattened_values'.
static void FlattenTupleValue(
mlir::OpBuilder* func_builder, mlir::Location loc, mlir::Value value,
llvm::SmallVectorImpl<mlir::Value>& flattened_values);
// CreateTupleValue creates a root TupleOp of (nested) tuple-type 'type' using
// the non-tuple-typed values in 'flatten_values'.
//
// e.g., Given 'flatten_values': [V1, V2, V3] &'type': tuple<T1,tuple<T1,T2>>,
// The function returns %t2 such that:
// %t1 = mhlo.tuple(V2,V3) : (T2,T3) -> tuple<T2,T3>
// %t2 = mhlo.tuple(V1,%t1): (T1,tuple<T2,T3>) -> tuple<T1,tuple<T1,T2>>
//
// Note: 1. FlattenTupleValue and CreateTupleValue is a pair of functions to
// resp. flatten and create tuples in the exact same order.
// 2. `flatten_values`, initially storing the flattened values, will be
// mutated to a 0-length array by the end of function invocation.
static mlir::Value CreateTupleValue(
mlir::OpBuilder* func_builder, mlir::Location loc,
llvm::MutableArrayRef<mlir::Value>& flatten_values, mlir::Type type);
private:
HloFunctionImporter(mlir::ModuleOp module,
std::unordered_map<const xla::HloComputation*,
mlir::func::FuncOp>* function_map,
mlir::Builder* builder)
: context_(module.getContext()),
module_(module),
builder_(builder),
function_map_(function_map) {
context_->loadDialect<mlir::arith::ArithmeticDialect>();
context_->loadDialect<mlir::func::FuncDialect>();
context_->loadDialect<mlir::mhlo::MhloDialect>();
}
// Imports the given computation as a new function, if it hasn't been already
// imported.
StatusOr<mlir::func::FuncOp> ImportAsFunc(
const xla::HloComputation& computation, bool is_main);
// Imports the given computation in the specified region.
tensorflow::Status ImportAsRegion(const HloComputation& computation,
mlir::Region* region,
bool flatten_region_arg_tuple = false);
// Imports instructions from the given computation in the specified block.
// Assumes that the block already has correct arguments populated.
tensorflow::Status ImportInstructions(const HloComputation& computation,
mlir::Block* block,
bool flatten_region_arg_tuple);
StatusOr<mlir::Value> ImportInstructionsImpl(
const xla::HloComputation& computation,
const llvm::SmallVectorImpl<mlir::Value>& arguments,
mlir::OpBuilder* builder);
// Imports an instruction.
StatusOr<mlir::Operation*> ImportInstructionWithLayout(
const xla::HloInstruction* instruction,
const llvm::SmallVectorImpl<mlir::Value>& operands,
mlir::OpBuilder* func_builder,
DynamicShapeHandlingMode mode = DynamicShapeHandlingMode::kDynamic);
StatusOr<mlir::Operation*> ImportInstructionImpl(
const HloInstruction* instruction,
const llvm::SmallVectorImpl<mlir::Value>& operands,
mlir::OpBuilder* func_builder,
DynamicShapeHandlingMode mode = DynamicShapeHandlingMode::kDynamic);
// Gets the MLIR operand values from an HLO Instruction.
StatusOr<llvm::SmallVector<mlir::Value, 4>> GetOperands(
const xla::HloInstruction* instruction);
// Converts xla Tensor type to the corresponding MLIR type.
StatusOr<mlir::RankedTensorType> ConvertTensorType(const xla::Shape& shape);
// Converts an XLA shape/layout to the corresponding MLIR layout, in
// flattened_attr, while flattening the tuple layout.
Status ConvertShapeToMlirLayout(
const xla::Shape& shape,
llvm::SmallVectorImpl<mlir::Attribute>& flattened_attr);
// Returns the output type of an HloInstruction.
StatusOr<mlir::Type> GetReturnType(const xla::HloInstruction* instruction);
// Takes a list of HloInstructions and generates the list of types used for
// input, bypassing tuples to subsets.
Status GetMlirTypes(const std::vector<xla::HloInstruction*>& instructions,
llvm::SmallVectorImpl<mlir::Type>* types);
// Returns the Mlir Value for the corresponding HloInstruction.
StatusOr<mlir::Value> GetMlirValue(const xla::HloInstruction* instruction);
// Converts an XLA ComparisonDirection to the corresponding MLIR attribute.
mlir::NamedAttribute ConvertComparisonDirection(
ComparisonDirection direction);
// Converts an XLA Comparison::Type to the corresponding MLIR attribute.
mlir::NamedAttribute ConvertComparisonType(Comparison::Type type);
// Converts the dimensions of an HLO instruction into an MLIR attribute.
mlir::DenseIntElementsAttr ConvertDimensions(
absl::Span<const int64_t> op_dimensions);
// Converts Array ref to an DenseIntElementsAttr.
mlir::DenseIntElementsAttr Convert(llvm::ArrayRef<int64_t> elements);
// Converts Array ref of bools to a DenseIntElementsAttr of I1 type.
mlir::DenseIntElementsAttr Convert(llvm::ArrayRef<bool> elements);
// Converts Array ref to padding attribute. Input is a flattened list of
// padding low and padding high for each of the spatial dimensions.
mlir::NamedAttribute ConvertPadding(llvm::ArrayRef<int64_t> padding);
// Converts channel id to attribute
mlir::NamedAttribute ConvertChannelHandle(std::optional<int64_t> channel_id);
// Converts channel handle to attribute
mlir::NamedAttribute ConvertChannelHandle(const xla::ChannelHandle& channel);
mlir::MLIRContext* context_;
mlir::ModuleOp module_;
mlir::Builder* builder_;
// Mapping from HloComputation to the created MLIR function.
std::unordered_map<const xla::HloComputation*, mlir::func::FuncOp>*
function_map_;
// Mapping from HloInstructions to the associative MLIR values.
std::unordered_map<const xla::HloInstruction*, mlir::Value>
instruction_value_map_;
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_MLIR_XLA_FUNCTION_IMPORTER_H_