Add MlirHloBuilder op implementations
PiperOrigin-RevId: 307472994
Change-Id: Ifbca316f653f44469cebd3aa5a507e8ccabf5001
diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc
index 739f19e..cfa8c1b 100644
--- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc
+++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc
@@ -165,6 +165,90 @@
/*attributes=*/{});
}
+XlaOp MlirHloBuilder::CreateToken() {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return MakeXlaOp(builder_.create<mlir::xla_hlo::CreateTokenOp>(
+ loc_, mlir::xla_hlo::TokenType::get(builder_.getContext())));
+ });
+}
+
+StatusOr<XlaOp> MlirHloBuilder::InfeedWithTokenInternal(
+ const Shape& infeed_instruction_shape, XlaOp token, const string& config) {
+ TF_ASSIGN_OR_RETURN(mlir::Type result_type,
+ ConvertShapeToType<mlir::RankedTensorType>(
+ infeed_instruction_shape, builder_));
+ return MakeXlaOp(builder_.create<mlir::xla_hlo::InfeedOp>(
+ loc_, result_type, GetValue(token),
+ /*infeed_config=*/config));
+}
+
+StatusOr<XlaOp> MlirHloBuilder::OutfeedWithTokenInternal(
+ XlaOp operand, XlaOp token, const Shape& shape_with_layout,
+ const string& outfeed_config) {
+ auto token_type = mlir::xla_hlo::TokenType::get(builder_.getContext());
+ return MakeXlaOp(builder_.create<mlir::xla_hlo::OutfeedOp>(
+ loc_, token_type, GetValue(operand), GetValue(token), outfeed_config));
+}
+
+StatusOr<XlaOp> MlirHloBuilder::ConcatInDimInternal(
+ const Shape& shape, absl::Span<const XlaOp> operands, int64 dimension) {
+ TF_ASSIGN_OR_RETURN(
+ mlir::Type result_type,
+ ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
+ auto mlir_operands = GetValues(operands);
+ return MakeXlaOp(builder_.create<mlir::xla_hlo::ConcatenateOp>(
+ loc_, result_type, mlir_operands, builder_.getI64IntegerAttr(dimension)));
+}
+
+StatusOr<XlaOp> MlirHloBuilder::GetTupleElementInternal(const Shape& shape,
+ XlaOp tuple_data,
+ int64 index) {
+ TF_ASSIGN_OR_RETURN(
+ mlir::Type result_type,
+ ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
+ return MakeXlaOp(builder_.create<mlir::xla_hlo::GetTupleElementOp>(
+ loc_, result_type, GetValue(tuple_data),
+ builder_.getI32IntegerAttr(index)));
+}
+
+StatusOr<XlaOp> MlirHloBuilder::SliceInternal(
+ const Shape& shape, XlaOp operand, absl::Span<const int64> start_indices,
+ absl::Span<const int64> limit_indices, absl::Span<const int64> strides) {
+ return MakeXlaOp(builder_.create<mlir::xla_hlo::SliceOp>(
+ loc_, GetValue(operand), GetI64ElementsAttr(start_indices, &builder_),
+ GetI64ElementsAttr(limit_indices, &builder_),
+ GetI64ElementsAttr(strides, &builder_)));
+}
+
+StatusOr<XlaOp> MlirHloBuilder::PadInternal(
+ const Shape& shape, XlaOp operand, XlaOp padding_value,
+ const PaddingConfig& padding_config) {
+ TF_ASSIGN_OR_RETURN(
+ mlir::Type result_type,
+ ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
+ std::vector<int64> low;
+ std::vector<int64> high;
+ std::vector<int64> internal;
+ for (auto& dimension : padding_config.dimensions()) {
+ low.push_back(dimension.edge_padding_low());
+ high.push_back(dimension.edge_padding_high());
+ internal.push_back(dimension.interior_padding());
+ }
+ return MakeXlaOp(builder_.create<mlir::xla_hlo::PadOp>(
+ loc_, result_type, GetValue(operand), GetValue(padding_value),
+ GetI64ElementsAttr(low, &builder_), GetI64ElementsAttr(high, &builder_),
+ GetI64ElementsAttr(internal, &builder_)));
+}
+
+StatusOr<XlaOp> MlirHloBuilder::TupleInternal(
+ const Shape& shape, absl::Span<const XlaOp> elements) {
+ mlir::SmallVector<mlir::Value, 4> operands;
+ for (auto& element : elements) {
+ operands.push_back(GetValue(element));
+ }
+ return MakeXlaOp(builder_.create<mlir::xla_hlo::TupleOp>(loc_, operands));
+}
+
StatusOr<XlaOp> MlirHloBuilder::CreateOp(
const std::string& op_name, const Shape& shape,
llvm::ArrayRef<XlaOp> operands,
diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h
index 95dafbd..c0ef645 100644
--- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h
+++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h
@@ -54,6 +54,9 @@
// TODO(hinsu): Add a constructor to build a new MLIR function from scratch
// and override Build methods.
+ MlirHloBuilder(std::string name, mlir::OpBuilder builder, mlir::Location loc)
+ : XlaBuilder(name), builder_(builder), loc_(loc) {}
+
MlirHloBuilder(const MlirHloBuilder&) = delete;
MlirHloBuilder& operator=(const MlirHloBuilder&) = delete;
@@ -75,6 +78,17 @@
return mlir::Value::getFromOpaquePointer(ptr);
}
+ // Returns MLIR values corresponding to the given XLA ops.
+ //
+ // Requires that the ops were created by this builder.
+ std::vector<mlir::Value> GetValues(absl::Span<const XlaOp> ops) {
+ std::vector<mlir::Value> values;
+ for (auto xla_op : ops) {
+ values.push_back(GetValue(xla_op));
+ }
+ return values;
+ }
+
// Sets location for newly built ops, until reset.
void SetLocation(mlir::Location loc) { loc_ = loc; }
@@ -120,6 +134,34 @@
StatusOr<XlaOp> AddOpWithShape(HloOpcode opcode, const Shape& shape,
absl::Span<const XlaOp> operands) override;
+ XlaOp CreateToken() override;
+
+ StatusOr<XlaOp> InfeedWithTokenInternal(const Shape& infeed_instruction_shape,
+ XlaOp token,
+ const string& config) override;
+ StatusOr<XlaOp> OutfeedWithTokenInternal(
+ XlaOp operand, XlaOp token, const Shape& shape_with_layout,
+ const string& outfeed_config) override;
+
+ StatusOr<XlaOp> ConcatInDimInternal(const Shape& shape,
+ absl::Span<const XlaOp> operands,
+ int64 dimension) override;
+
+ StatusOr<XlaOp> GetTupleElementInternal(const Shape& shape, XlaOp tuple_data,
+ int64 index) override;
+
+ StatusOr<XlaOp> SliceInternal(const Shape& shape, XlaOp operand,
+ absl::Span<const int64> start_indices,
+ absl::Span<const int64> limit_indices,
+ absl::Span<const int64> strides) override;
+
+ StatusOr<XlaOp> PadInternal(const Shape& shape, XlaOp operand,
+ XlaOp padding_value,
+ const PaddingConfig& padding_config) override;
+
+ StatusOr<XlaOp> TupleInternal(const Shape& shape,
+ absl::Span<const XlaOp> elements) override;
+
// Creates HLO dialect op and returns the result as an XlaOp.
StatusOr<XlaOp> CreateOp(const std::string& op_name, const Shape& shape,
llvm::ArrayRef<XlaOp> operands,
diff --git a/tensorflow/compiler/mlir/xla/tests/BUILD b/tensorflow/compiler/mlir/xla/tests/BUILD
index 989b846..ad69383 100644
--- a/tensorflow/compiler/mlir/xla/tests/BUILD
+++ b/tensorflow/compiler/mlir/xla/tests/BUILD
@@ -1,4 +1,5 @@
load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
+load("//tensorflow:tensorflow.bzl", "tf_cc_test")
package(licenses = ["notice"])
@@ -18,3 +19,18 @@
"@llvm-project//llvm:FileCheck",
],
)
+
+tf_cc_test(
+ name = "mlir_hlo_builder_test",
+ srcs = ["mlir_hlo_builder_test.cc"],
+ deps = [
+ "//tensorflow/compiler/mlir/xla:hlo",
+ "//tensorflow/compiler/mlir/xla:mlir_hlo_builder",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "@llvm-project//llvm:support",
+ "@llvm-project//mlir:IR",
+ ],
+)
diff --git a/tensorflow/compiler/mlir/xla/tests/mlir_hlo_builder_test.cc b/tensorflow/compiler/mlir/xla/tests/mlir_hlo_builder_test.cc
new file mode 100644
index 0000000..54791e1
--- /dev/null
+++ b/tensorflow/compiler/mlir/xla/tests/mlir_hlo_builder_test.cc
@@ -0,0 +1,179 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h"
+
+#include <string>
+
+#include "llvm/Support/raw_ostream.h"
+#include "mlir/IR/Attributes.h" // from @llvm-project
+#include "mlir/IR/Builders.h" // from @llvm-project
+#include "mlir/IR/Dialect.h" // from @llvm-project
+#include "mlir/IR/Location.h" // from @llvm-project
+#include "mlir/IR/MLIRContext.h" // from @llvm-project
+#include "mlir/IR/Module.h" // from @llvm-project
+#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+
+namespace {
+
+static void ExpectHasSubstr(absl::string_view s, absl::string_view expected) {
+ EXPECT_TRUE(absl::StrContains(s, expected))
+ << s << " does not contain " << expected;
+}
+
+class XlaBuilderTest : public ::testing::Test {
+ protected:
+ XlaBuilderTest()
+ : name_(SetupTest()),
+ context_(),
+ module_(mlir::ModuleOp::create(mlir::UnknownLoc::get(&context_))),
+ builder_(&module_->getBodyRegion()),
+ xla_builder_(name_, builder_, module_->getLoc()) {}
+
+ string SetupTest() {
+ mlir::registerDialect<mlir::xla_hlo::XlaHloDialect>();
+ return ::testing::UnitTest::GetInstance()->current_test_info()->name();
+ }
+
+ // Retuns the MLIR op string representation of the given XlaOp.
+ string GetMlirOpString(XlaOp xla_op) {
+ string str;
+ llvm::raw_string_ostream ostream{str};
+ xla_builder_.GetValue(xla_op).print(ostream);
+ ostream.flush();
+ return str;
+ }
+
+ string name_;
+ mlir::MLIRContext context_;
+ mlir::OwningModuleRef module_;
+ mlir::OpBuilder builder_;
+ MlirHloBuilder xla_builder_;
+};
+
+TEST_F(XlaBuilderTest, CreateToken) {
+ auto token = CreateToken(&xla_builder_);
+ auto str = GetMlirOpString(token);
+
+ TF_ASSERT_OK(xla_builder_.GetCurrentStatus());
+
+ ExpectHasSubstr(GetMlirOpString(token),
+ R"("xla_hlo.create_token"() : () -> !xla_hlo.token)");
+}
+
+TEST_F(XlaBuilderTest, Infeed) {
+ auto token = CreateToken(&xla_builder_);
+ auto infeed = InfeedWithToken(token, ShapeUtil::MakeShape(F32, {4, 8}), "");
+
+ TF_ASSERT_OK(xla_builder_.GetCurrentStatus());
+ ExpectHasSubstr(
+ GetMlirOpString(infeed),
+ R"("xla_hlo.infeed"(%0) {infeed_config = ""} : (!xla_hlo.token) -> tuple<tensor<4x8xf32>, !xla_hlo.token>)");
+}
+
+TEST_F(XlaBuilderTest, Outfeed) {
+ auto outfeed_shape = ShapeUtil::MakeShape(F32, {4, 8});
+ auto data = ConstantLiteral(
+ &xla_builder_,
+ LiteralUtil::CreateFromDimensions(F32, outfeed_shape.dimensions()));
+ auto token = CreateToken(&xla_builder_);
+ auto outfeed = OutfeedWithToken(data, token, outfeed_shape, "");
+
+ TF_ASSERT_OK(xla_builder_.GetCurrentStatus());
+ ExpectHasSubstr(
+ GetMlirOpString(outfeed),
+ R"("xla_hlo.outfeed"(%0, %1) {outfeed_config = ""} : (tensor<4x8xf32>, !xla_hlo.token) -> !xla_hlo.token)");
+}
+
+TEST_F(XlaBuilderTest, ConcatInDim) {
+ auto data0 = ConstantLiteral(
+ &xla_builder_, LiteralUtil::CreateFromDimensions(F32, {2, 4, 5}));
+ auto data1 = ConstantLiteral(
+ &xla_builder_, LiteralUtil::CreateFromDimensions(F32, {2, 6, 5}));
+ auto concat = ConcatInDim(&xla_builder_, {data0, data1}, 1);
+
+ TF_ASSERT_OK(xla_builder_.GetCurrentStatus());
+ ExpectHasSubstr(
+ GetMlirOpString(concat),
+ R"("xla_hlo.concatenate"(%0, %1) {dimension = 1 : i64} : (tensor<2x4x5xf32>, tensor<2x6x5xf32>) -> tensor<2x10x5xf32>)");
+}
+
+TEST_F(XlaBuilderTest, Tuple) {
+ auto data0 = ConstantLiteral(&xla_builder_,
+ LiteralUtil::CreateFromDimensions(F32, {3, 7}));
+ auto data1 = ConstantLiteral(&xla_builder_,
+ LiteralUtil::CreateFromDimensions(F32, {}));
+ auto tuple = Tuple(&xla_builder_, {data0, data1});
+
+ TF_ASSERT_OK(xla_builder_.GetCurrentStatus());
+ ExpectHasSubstr(
+ GetMlirOpString(tuple),
+ R"("xla_hlo.tuple"(%0, %1) : (tensor<3x7xf32>, tensor<f32>) -> tuple<tensor<3x7xf32>, tensor<f32>>)");
+}
+
+TEST_F(XlaBuilderTest, GetTupleElement) {
+ auto data0 = ConstantLiteral(&xla_builder_,
+ LiteralUtil::CreateFromDimensions(F32, {3, 7}));
+ auto data1 = ConstantLiteral(&xla_builder_,
+ LiteralUtil::CreateFromDimensions(F32, {}));
+ auto tuple_data = Tuple(&xla_builder_, {data0, data1});
+ auto gte = GetTupleElement(tuple_data, 1);
+
+ TF_ASSERT_OK(xla_builder_.GetCurrentStatus());
+ ExpectHasSubstr(
+ GetMlirOpString(gte),
+ R"("xla_hlo.get_tuple_element"(%2) {index = 1 : i32} : (tuple<tensor<3x7xf32>, tensor<f32>>) -> tensor<f32>)");
+}
+
+TEST_F(XlaBuilderTest, Slice) {
+ auto data = ConstantLiteral(&xla_builder_,
+ LiteralUtil::CreateFromDimensions(F32, {3, 7}));
+ auto slice = Slice(data, {0, 1}, {2, 5}, {1, 1});
+
+ TF_ASSERT_OK(xla_builder_.GetCurrentStatus());
+ ExpectHasSubstr(
+ GetMlirOpString(slice),
+ R"("xla_hlo.slice"(%0) {limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<3x7xf32>) -> tensor<2x4xf32>)");
+}
+
+TEST_F(XlaBuilderTest, Pad) {
+ auto data = ConstantLiteral(&xla_builder_,
+ LiteralUtil::CreateFromDimensions(F32, {3, 7}));
+ auto zero = ConstantLiteral(&xla_builder_, LiteralUtil::Zero(F32));
+
+ PaddingConfig padding_config;
+ auto* dims0 = padding_config.add_dimensions();
+ dims0->set_edge_padding_low(1);
+ dims0->set_interior_padding(0);
+ dims0->set_edge_padding_high(2);
+ auto* dims1 = padding_config.add_dimensions();
+ dims1->set_edge_padding_low(3);
+ dims1->set_interior_padding(1);
+ dims1->set_edge_padding_high(0);
+ auto pad = Pad(data, zero, padding_config);
+
+ TF_ASSERT_OK(xla_builder_.GetCurrentStatus());
+ ExpectHasSubstr(
+ GetMlirOpString(pad),
+ R"("xla_hlo.pad"(%0, %1) {edge_padding_high = dense<[2, 0]> : tensor<2xi64>, edge_padding_low = dense<[1, 3]> : tensor<2xi64>, interior_padding = dense<[0, 1]> : tensor<2xi64>} : (tensor<3x7xf32>, tensor<f32>) -> tensor<6x16xf32>)");
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index 2a69023..ea93880 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -822,23 +822,29 @@
absl::Span<const int64> limit_indices,
absl::Span<const int64> strides) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
- HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferSliceShape(
*operand_shape, start_indices,
limit_indices, strides));
- *instr.mutable_shape() = shape.ToProto();
- for (int i = 0; i < start_indices.size(); i++) {
- auto* slice_config = instr.add_slice_dimensions();
- slice_config->set_start(start_indices[i]);
- slice_config->set_limit(limit_indices[i]);
- slice_config->set_stride(strides[i]);
- }
-
- return AddInstruction(std::move(instr), HloOpcode::kSlice, {operand});
+ return SliceInternal(shape, operand, start_indices, limit_indices, strides);
});
}
+StatusOr<XlaOp> XlaBuilder::SliceInternal(const Shape& shape, XlaOp operand,
+ absl::Span<const int64> start_indices,
+ absl::Span<const int64> limit_indices,
+ absl::Span<const int64> strides) {
+ HloInstructionProto instr;
+ *instr.mutable_shape() = shape.ToProto();
+ for (int i = 0; i < start_indices.size(); i++) {
+ auto* slice_config = instr.add_slice_dimensions();
+ slice_config->set_start(start_indices[i]);
+ slice_config->set_limit(limit_indices[i]);
+ slice_config->set_stride(strides[i]);
+ }
+ return AddInstruction(std::move(instr), HloOpcode::kSlice, {operand});
+}
+
XlaOp XlaBuilder::SliceInDim(XlaOp operand, int64 start_index,
int64 limit_index, int64 stride, int64 dimno) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
@@ -952,41 +958,49 @@
XlaOp XlaBuilder::ConcatInDim(absl::Span<const XlaOp> operands,
int64 dimension) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
- HloInstructionProto instr;
-
std::vector<const Shape*> operand_shape_ptrs;
TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands));
absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
[](const Shape& shape) { return &shape; });
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConcatOpShape(
operand_shape_ptrs, dimension));
- *instr.mutable_shape() = shape.ToProto();
-
- instr.add_dimensions(dimension);
-
- return AddInstruction(std::move(instr), HloOpcode::kConcatenate, operands);
+ return ConcatInDimInternal(shape, operands, dimension);
});
}
+StatusOr<XlaOp> XlaBuilder::ConcatInDimInternal(
+ const Shape& shape, absl::Span<const XlaOp> operands, int64 dimension) {
+ HloInstructionProto instr;
+ *instr.mutable_shape() = shape.ToProto();
+
+ instr.add_dimensions(dimension);
+
+ return AddInstruction(std::move(instr), HloOpcode::kConcatenate, operands);
+}
+
XlaOp XlaBuilder::Pad(XlaOp operand, XlaOp padding_value,
const PaddingConfig& padding_config) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
- HloInstructionProto instr;
-
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
TF_ASSIGN_OR_RETURN(const Shape* padding_value_shape,
GetShapePtr(padding_value));
TF_ASSIGN_OR_RETURN(
Shape shape, ShapeInference::InferPadShape(
*operand_shape, *padding_value_shape, padding_config));
- *instr.mutable_shape() = shape.ToProto();
- *instr.mutable_padding_config() = padding_config;
-
- return AddInstruction(std::move(instr), HloOpcode::kPad,
- {operand, padding_value});
+ return PadInternal(shape, operand, padding_value, padding_config);
});
}
+StatusOr<XlaOp> XlaBuilder::PadInternal(const Shape& shape, XlaOp operand,
+ XlaOp padding_value,
+ const PaddingConfig& padding_config) {
+ HloInstructionProto instr;
+ *instr.mutable_shape() = shape.ToProto();
+ *instr.mutable_padding_config() = padding_config;
+ return AddInstruction(std::move(instr), HloOpcode::kPad,
+ {operand, padding_value});
+}
+
XlaOp XlaBuilder::Reshape(XlaOp operand, absl::Span<const int64> dimensions,
absl::Span<const int64> new_sizes,
int64 inferred_dimension) {
@@ -1080,7 +1094,6 @@
XlaOp XlaBuilder::Tuple(absl::Span<const XlaOp> elements) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
- HloInstructionProto instr;
std::vector<const Shape*> operand_shape_ptrs;
TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(elements));
absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
@@ -1088,14 +1101,19 @@
TF_ASSIGN_OR_RETURN(const Shape shape,
ShapeInference::InferVariadicOpShape(
HloOpcode::kTuple, operand_shape_ptrs));
- *instr.mutable_shape() = shape.ToProto();
- return AddInstruction(std::move(instr), HloOpcode::kTuple, elements);
+ return TupleInternal(shape, elements);
});
}
+StatusOr<XlaOp> XlaBuilder::TupleInternal(const Shape& shape,
+ absl::Span<const XlaOp> elements) {
+ HloInstructionProto instr;
+ *instr.mutable_shape() = shape.ToProto();
+ return AddInstruction(std::move(instr), HloOpcode::kTuple, elements);
+}
+
XlaOp XlaBuilder::GetTupleElement(XlaOp tuple_data, int64 index) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
- HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape* tuple_shape, GetShapePtr(tuple_data));
if (!tuple_shape->IsTuple()) {
return InvalidArgument(
@@ -1107,16 +1125,22 @@
"GetTupleElement() index (%d) out of range for tuple shape %s", index,
ShapeUtil::HumanString(*tuple_shape));
}
- *instr.mutable_shape() =
- ShapeUtil::GetTupleElementShape(*tuple_shape, index).ToProto();
-
- instr.set_tuple_index(index);
-
- return AddInstruction(std::move(instr), HloOpcode::kGetTupleElement,
- {tuple_data});
+ return GetTupleElementInternal(
+ ShapeUtil::GetTupleElementShape(*tuple_shape, index), tuple_data,
+ index);
});
}
+StatusOr<XlaOp> XlaBuilder::GetTupleElementInternal(const Shape& shape,
+ XlaOp tuple_data,
+ int64 index) {
+ HloInstructionProto instr;
+ *instr.mutable_shape() = shape.ToProto();
+ instr.set_tuple_index(index);
+ return AddInstruction(std::move(instr), HloOpcode::kGetTupleElement,
+ {tuple_data});
+}
+
XlaOp XlaBuilder::Dot(XlaOp lhs, XlaOp rhs,
const PrecisionConfig* precision_config) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
@@ -1407,14 +1431,11 @@
XlaOp XlaBuilder::InfeedWithToken(XlaOp token, const Shape& shape,
const string& config) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
- HloInstructionProto instr;
if (!LayoutUtil::HasLayout(shape)) {
return InvalidArgument("Given shape to Infeed must have a layout");
}
const Shape infeed_instruction_shape =
ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()});
- *instr.mutable_shape() = infeed_instruction_shape.ToProto();
- instr.set_infeed_config(config);
if (shape.IsArray() && sharding() &&
sharding()->type() == OpSharding::OTHER) {
@@ -1427,11 +1448,18 @@
return InvalidArgument(
"Replicated sharding is not yet supported for infeeds");
}
-
- return AddInstruction(std::move(instr), HloOpcode::kInfeed, {token});
+ return InfeedWithTokenInternal(infeed_instruction_shape, token, config);
});
}
+StatusOr<XlaOp> XlaBuilder::InfeedWithTokenInternal(
+ const Shape& infeed_instruction_shape, XlaOp token, const string& config) {
+ HloInstructionProto instr;
+ *instr.mutable_shape() = infeed_instruction_shape.ToProto();
+ instr.set_infeed_config(config);
+ return AddInstruction(std::move(instr), HloOpcode::kInfeed, {token});
+}
+
void XlaBuilder::Outfeed(XlaOp operand, const Shape& shape_with_layout,
const string& outfeed_config) {
ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
@@ -1488,10 +1516,6 @@
const Shape& shape_with_layout,
const string& outfeed_config) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
- HloInstructionProto instr;
-
- *instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
-
// Check and set outfeed shape.
if (!LayoutUtil::HasLayout(shape_with_layout)) {
return InvalidArgument("Given shape to Outfeed must have a layout");
@@ -1503,15 +1527,22 @@
ShapeUtil::HumanStringWithLayout(shape_with_layout),
ShapeUtil::HumanStringWithLayout(*operand_shape));
}
- *instr.mutable_outfeed_shape() = shape_with_layout.ToProto();
-
- instr.set_outfeed_config(outfeed_config);
-
- return AddInstruction(std::move(instr), HloOpcode::kOutfeed,
- {operand, token});
+ return OutfeedWithTokenInternal(operand, token, shape_with_layout,
+ outfeed_config);
});
}
+StatusOr<XlaOp> XlaBuilder::OutfeedWithTokenInternal(
+ XlaOp operand, XlaOp token, const Shape& shape_with_layout,
+ const string& outfeed_config) {
+ HloInstructionProto instr;
+ *instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
+ *instr.mutable_outfeed_shape() = shape_with_layout.ToProto();
+ instr.set_outfeed_config(outfeed_config);
+ return AddInstruction(std::move(instr), HloOpcode::kOutfeed,
+ {operand, token});
+}
+
XlaOp XlaBuilder::CreateToken() {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h
index f320fee..4eba598 100644
--- a/tensorflow/compiler/xla/client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_builder.h
@@ -364,6 +364,10 @@
Status SetInstructionFrontendAttribute(XlaOp op, string attribute,
string value);
+ // Returns shapes for the operands.
+ StatusOr<std::vector<Shape>> GetOperandShapes(
+ absl::Span<const XlaOp> operands) const;
+
private:
// Build helper which takes the id of the root operation..
StatusOr<XlaComputation> Build(int64 root_id, bool remove_dynamic_dimensions);
@@ -391,6 +395,10 @@
XlaOp Pad(XlaOp operand, XlaOp padding_value,
const PaddingConfig& padding_config);
+ virtual StatusOr<XlaOp> PadInternal(const Shape& shape, XlaOp operand,
+ XlaOp padding_value,
+ const PaddingConfig& padding_config);
+
XlaOp Reshape(XlaOp operand, absl::Span<const int64> dimensions,
absl::Span<const int64> new_sizes,
int64 inferred_dimension = -1);
@@ -406,9 +414,12 @@
XlaOp Slice(XlaOp operand, absl::Span<const int64> start_indices,
absl::Span<const int64> limit_indices,
absl::Span<const int64> strides);
-
- XlaOp SliceInDim(XlaOp operand, int64 start_index, int64 limit_index,
- int64 stride, int64 dimno);
+ virtual StatusOr<XlaOp> SliceInternal(const Shape& shape, XlaOp operand,
+ absl::Span<const int64> start_indices,
+ absl::Span<const int64> limit_indices,
+ absl::Span<const int64> strides);
+ virtual XlaOp SliceInDim(XlaOp operand, int64 start_index, int64 limit_index,
+ int64 stride, int64 dimno);
ABSL_DEPRECATED("Use span-of-indices form instead")
XlaOp DynamicSlice(XlaOp operand, XlaOp start_indices,
@@ -422,14 +433,22 @@
absl::Span<const XlaOp> start_indices);
XlaOp ConcatInDim(absl::Span<const XlaOp> operands, int64 dimension);
+ virtual StatusOr<XlaOp> ConcatInDimInternal(const Shape& shape,
+ absl::Span<const XlaOp> operands,
+ int64 dimension);
void Trace(const string& tag, XlaOp operand);
XlaOp Select(XlaOp pred, XlaOp on_true, XlaOp on_false);
XlaOp Tuple(absl::Span<const XlaOp> elements);
+ virtual StatusOr<XlaOp> TupleInternal(const Shape& shape,
+ absl::Span<const XlaOp> elements);
XlaOp GetTupleElement(XlaOp tuple_data, int64 index);
+ virtual StatusOr<XlaOp> GetTupleElementInternal(const Shape& shape,
+ XlaOp tuple_data,
+ int64 index);
XlaOp Dot(XlaOp lhs, XlaOp rhs,
const PrecisionConfig* precision_config = nullptr);
@@ -476,15 +495,18 @@
absl::Span<const int64> fft_length);
XlaOp Infeed(const Shape& shape, const string& config = "");
- XlaOp InfeedWithToken(XlaOp token, const Shape& shape,
- const string& config = "");
+ XlaOp InfeedWithToken(XlaOp token, const Shape& shape, const string& config);
+ virtual StatusOr<XlaOp> InfeedWithTokenInternal(
+ const Shape& infeed_instruction_shape, XlaOp token, const string& config);
void Outfeed(XlaOp operand, const Shape& shape_with_layout,
const string& outfeed_config);
XlaOp OutfeedWithToken(XlaOp operand, XlaOp token,
const Shape& shape_with_layout,
const string& outfeed_config);
-
+ virtual StatusOr<XlaOp> OutfeedWithTokenInternal(
+ XlaOp operand, XlaOp token, const Shape& shape_with_layout,
+ const string& outfeed_config);
XlaOp Call(const XlaComputation& computation,
absl::Span<const XlaOp> operands);
@@ -624,7 +646,7 @@
XlaOp RecvFromHost(XlaOp token, const Shape& shape,
const ChannelHandle& handle);
- XlaOp CreateToken();
+ virtual XlaOp CreateToken();
XlaOp AfterAll(absl::Span<const XlaOp> tokens);
@@ -701,10 +723,6 @@
// Returns the (inferred) result for the program shape using the given root.
StatusOr<ProgramShape> GetProgramShape(int64 root_id) const;
- // Returns shapes for the operands.
- StatusOr<std::vector<Shape>> GetOperandShapes(
- absl::Span<const XlaOp> operands) const;
-
// A visitor which checks whether an operation is a compile-time constant,
// meaning that it doesn't depend on any parameters, or on any stateful
// operation such as `RngNormal` or `Infeed`. The visitor walks the