blob: 85345621677c4b0a31a9b81336ae075a354a8dce [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_MLIR_HLO_BUILDER_H_
#define TENSORFLOW_COMPILER_MLIR_XLA_IR_MLIR_HLO_BUILDER_H_
#include <memory>
#include "absl/container/flat_hash_map.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/StringRef.h"
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/stream_executor/lib/statusor.h"
namespace xla {
// Provides a way to construct xla_hlo dialect ops in MLIR using XlaBuilder
// interface.
//
// Requires that all XlaOp arguments are either returned by any of the builder
// method or constructed using MakeXlaOp method in this builder.
//
// TODO(hinsu): Support more ops and utility functions to set special attributes
// like OpMetadata and Sharding.
class MlirHloBuilder : public XlaBuilder {
public:
// Constructs builder for the given function. New operations are added to the
// beginning of the function, if it is non empty and has a block.
explicit MlirHloBuilder(mlir::FuncOp func)
: XlaBuilder(func.getName().str()),
builder_(&func.getBody()),
loc_(builder_.getUnknownLoc()) {}
// TODO(hinsu): Add a constructor to build a new MLIR function from scratch
// and override Build methods.
MlirHloBuilder(const MlirHloBuilder&) = delete;
MlirHloBuilder& operator=(const MlirHloBuilder&) = delete;
~MlirHloBuilder() override;
// Wraps the given MLIR value under an XlaOp instance. Note that all HLO
// operations returns exactly one result therefore each op has an XlaOp
// wrapping result of the op.
//
// Returns an error if the HLO dialect doesn't support type of the given
// value.
StatusOr<XlaOp> MakeXlaOp(mlir::Value val);
// Returns value corresponding to the given op.
//
// Requires that the op was created by this builder.
mlir::Value GetValue(XlaOp op) {
void* ptr = reinterpret_cast<void*>(op.handle());
return mlir::Value::getFromOpaquePointer(ptr);
}
// Sets location for newly built ops, until reset.
void SetLocation(mlir::Location loc) { loc_ = loc; }
// Update insertion point so that newly built ops are inserted before the
// given op in order, until reset.
void setInsertionPoint(mlir::Operation* op) {
builder_.setInsertionPoint(op);
}
// Returns the shape of the given op.
StatusOr<const Shape*> GetShapePtr(XlaOp op) const override;
private:
XlaOp ConstantLiteral(const LiteralSlice& literal) override;
StatusOr<XlaOp> ReshapeInternal(const Shape& shape, XlaOp operand,
int64 inferred_dimension) override;
StatusOr<XlaOp> InDimBroadcast(
const Shape& shape, XlaOp operand,
absl::Span<const int64> broadcast_dimensions) override;
StatusOr<XlaOp> Compare(const Shape& shape, XlaOp lhs, XlaOp rhs,
ComparisonDirection direction) override;
XlaOp BinaryOpNoBroadcast(HloOpcode binop, const Shape& shape, XlaOp lhs,
XlaOp rhs) override;
StatusOr<XlaOp> AddOpWithShape(HloOpcode opcode, const Shape& shape,
absl::Span<const XlaOp> operands) override;
// Creates HLO dialect op and returns the result as an XlaOp.
StatusOr<XlaOp> CreateOp(const std::string& op_name, const Shape& shape,
llvm::ArrayRef<XlaOp> operands,
llvm::ArrayRef<mlir::NamedAttribute> attributes);
mlir::OpBuilder builder_;
mlir::Location loc_;
absl::flat_hash_map<int64, std::unique_ptr<Shape>> handle_to_shape_;
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_MLIR_XLA_IR_MLIR_HLO_BUILDER_H_