blob: 99fc64f40ba53b81061b0627622b9663be886b39 [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_FUNCTION_IMPORTER_H_
#define TENSORFLOW_COMPILER_MLIR_XLA_FUNCTION_IMPORTER_H_
#include <unordered_map>
#include "absl/types/optional.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
#include "tensorflow/compiler/xla/comparison_util.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
class HloModule;
class HloComputation;
class HloInstruction;
class Shape;
// Helper class for importing HloComputations.
class HloFunctionImporter {
public:
// Imports the given computation as a function in the given module. This also
// imports any computations referred by instructions in this computation.
static Status ImportAsFunc(const xla::HloComputation& computation,
mlir::ModuleOp module,
std::unordered_map<const xla::HloComputation*,
mlir::FuncOp>* function_map,
mlir::Builder* builder);
// Imports the given hlo computation to the specified region.
static Status ImportAsRegion(const xla::HloComputation& computation,
mlir::Region* region, mlir::Builder* builder);
// Imports the given computation to the given place specified by `builder`.
// `arguments` contains values for all parameters.
static StatusOr<mlir::Value> ImportInstructions(
const xla::HloComputation& computation,
const llvm::SmallVectorImpl<mlir::Value>& arguments,
mlir::OpBuilder* builder);
static void SetLayoutForMlir(mlir::Operation* op, const Shape& shape);
// Converts replica groups to attribute
//
// TODO(timshen): move this to attribute_importer.h.
static mlir::NamedAttribute ConvertReplicaGroups(
const std::vector<ReplicaGroup>& replica_groups, mlir::Builder builder);
private:
HloFunctionImporter(mlir::ModuleOp module,
std::unordered_map<const xla::HloComputation*,
mlir::FuncOp>* function_map,
mlir::Builder* builder)
: context_(module.getContext()),
module_(module),
builder_(builder),
function_map_(function_map) {
context_->loadDialect<mlir::StandardOpsDialect>();
context_->loadDialect<mlir::mhlo::MhloDialect>();
}
// Imports the given computation as a new function, if it hasn't been already
// imported.
StatusOr<mlir::FuncOp> ImportAsFunc(const xla::HloComputation& computation);
// Imports the given computation in the specified region.
tensorflow::Status ImportAsRegion(const HloComputation& computation,
mlir::Region* region);
// Imports instructions from the given computation in the specified block.
// Assumes that the block already has correct arguments populated.
tensorflow::Status ImportInstructions(const HloComputation& computation,
mlir::Block* block);
StatusOr<mlir::Value> ImportInstructionsImpl(
const xla::HloComputation& computation,
const llvm::SmallVectorImpl<mlir::Value>& arguments,
mlir::OpBuilder* builder);
// Imports an instruction.
StatusOr<mlir::Operation*> ImportInstruction(xla::HloInstruction* instruction,
mlir::OpBuilder* func_builder);
StatusOr<mlir::Operation*> ImportInstructionImpl(
HloInstruction* instruction, mlir::OpBuilder* func_builder);
// Gets the MLIR operand values from an HLO Instruction.
StatusOr<llvm::SmallVector<mlir::Value, 4>> GetOperands(
xla::HloInstruction* instruction);
// Converts xla Tensor type to the corresponding MLIR type.
StatusOr<mlir::RankedTensorType> ConvertTensorType(const xla::Shape& shape);
// Returns the output type of an HloInstruction.
StatusOr<mlir::Type> GetReturnType(xla::HloInstruction* instruction);
// Takes a list of HloInstructions and generates the list of types used for
// input, bypassing tuples to subsets.
Status GetMlirTypes(const std::vector<xla::HloInstruction*>& instructions,
llvm::SmallVectorImpl<mlir::Type>* types);
// Returns the Mlir Value for the corresponding HloInstruction.
StatusOr<mlir::Value> GetMlirValue(xla::HloInstruction* instruction);
// Converts an XLA ComparisonDirection to the corresponding MLIR attribute.
mlir::NamedAttribute ConvertComparisonDirection(
ComparisonDirection direction);
// Converts an XLA Comparison::Type to the corresponding MLIR attribute.
mlir::NamedAttribute ConvertComparisonType(Comparison::Type type);
// Converts the dimensions of an HLO instruction into an MLIR attribute.
mlir::DenseIntElementsAttr ConvertDimensions(
llvm::ArrayRef<tensorflow::int64> op_dimensions);
// Converts Array ref to an DenseIntElementsAttr.
mlir::DenseIntElementsAttr Convert(llvm::ArrayRef<int64_t> elements);
// Converts Array ref to padding attribute. Input is a flattened list of
// padding low and padding high for each of the spatial dimensions.
mlir::NamedAttribute ConvertPadding(llvm::ArrayRef<int64_t> padding);
// Converts channel id to attribute
mlir::NamedAttribute ConvertChannelHandle(
absl::optional<tensorflow::int64> channel_id);
// Converts channel handle to attribute
mlir::NamedAttribute ConvertChannelHandle(const xla::ChannelHandle& channel);
// Converts XLA instruction source target pairs to MLIR attribute.
mlir::NamedAttribute ConvertSourceTargetPairs(
const std::vector<std::pair<tensorflow::int64, tensorflow::int64>>&
source_target_pairs);
mlir::MLIRContext* context_;
mlir::ModuleOp module_;
mlir::Builder* builder_;
// Mapping from HloComputation to the created MLIR function.
std::unordered_map<const xla::HloComputation*, mlir::FuncOp>* function_map_;
// Mapping from HloInstructions to the associative MLIR values.
std::unordered_map<xla::HloInstruction*, mlir::Value> instruction_value_map_;
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_MLIR_XLA_FUNCTION_IMPORTER_H_