Add MlirHloBuilder op implementations

PiperOrigin-RevId: 307472994
Change-Id: Ifbca316f653f44469cebd3aa5a507e8ccabf5001
diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc
index 739f19e..cfa8c1b 100644
--- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc
+++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc
@@ -165,6 +165,90 @@
                   /*attributes=*/{});
 }
 
+XlaOp MlirHloBuilder::CreateToken() {
+  return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+    return MakeXlaOp(builder_.create<mlir::xla_hlo::CreateTokenOp>(
+        loc_, mlir::xla_hlo::TokenType::get(builder_.getContext())));
+  });
+}
+
+StatusOr<XlaOp> MlirHloBuilder::InfeedWithTokenInternal(
+    const Shape& infeed_instruction_shape, XlaOp token, const string& config) {
+  TF_ASSIGN_OR_RETURN(mlir::Type result_type,
+                      ConvertShapeToType<mlir::RankedTensorType>(
+                          infeed_instruction_shape, builder_));
+  return MakeXlaOp(builder_.create<mlir::xla_hlo::InfeedOp>(
+      loc_, result_type, GetValue(token),
+      /*infeed_config=*/config));
+}
+
+StatusOr<XlaOp> MlirHloBuilder::OutfeedWithTokenInternal(
+    XlaOp operand, XlaOp token, const Shape& shape_with_layout,
+    const string& outfeed_config) {
+  auto token_type = mlir::xla_hlo::TokenType::get(builder_.getContext());
+  return MakeXlaOp(builder_.create<mlir::xla_hlo::OutfeedOp>(
+      loc_, token_type, GetValue(operand), GetValue(token), outfeed_config));
+}
+
+StatusOr<XlaOp> MlirHloBuilder::ConcatInDimInternal(
+    const Shape& shape, absl::Span<const XlaOp> operands, int64 dimension) {
+  TF_ASSIGN_OR_RETURN(
+      mlir::Type result_type,
+      ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
+  auto mlir_operands = GetValues(operands);
+  return MakeXlaOp(builder_.create<mlir::xla_hlo::ConcatenateOp>(
+      loc_, result_type, mlir_operands, builder_.getI64IntegerAttr(dimension)));
+}
+
+StatusOr<XlaOp> MlirHloBuilder::GetTupleElementInternal(const Shape& shape,
+                                                        XlaOp tuple_data,
+                                                        int64 index) {
+  TF_ASSIGN_OR_RETURN(
+      mlir::Type result_type,
+      ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
+  return MakeXlaOp(builder_.create<mlir::xla_hlo::GetTupleElementOp>(
+      loc_, result_type, GetValue(tuple_data),
+      builder_.getI32IntegerAttr(index)));
+}
+
+StatusOr<XlaOp> MlirHloBuilder::SliceInternal(
+    const Shape& shape, XlaOp operand, absl::Span<const int64> start_indices,
+    absl::Span<const int64> limit_indices, absl::Span<const int64> strides) {
+  return MakeXlaOp(builder_.create<mlir::xla_hlo::SliceOp>(
+      loc_, GetValue(operand), GetI64ElementsAttr(start_indices, &builder_),
+      GetI64ElementsAttr(limit_indices, &builder_),
+      GetI64ElementsAttr(strides, &builder_)));
+}
+
+StatusOr<XlaOp> MlirHloBuilder::PadInternal(
+    const Shape& shape, XlaOp operand, XlaOp padding_value,
+    const PaddingConfig& padding_config) {
+  TF_ASSIGN_OR_RETURN(
+      mlir::Type result_type,
+      ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
+  std::vector<int64> low;
+  std::vector<int64> high;
+  std::vector<int64> internal;
+  for (auto& dimension : padding_config.dimensions()) {
+    low.push_back(dimension.edge_padding_low());
+    high.push_back(dimension.edge_padding_high());
+    internal.push_back(dimension.interior_padding());
+  }
+  return MakeXlaOp(builder_.create<mlir::xla_hlo::PadOp>(
+      loc_, result_type, GetValue(operand), GetValue(padding_value),
+      GetI64ElementsAttr(low, &builder_), GetI64ElementsAttr(high, &builder_),
+      GetI64ElementsAttr(internal, &builder_)));
+}
+
+StatusOr<XlaOp> MlirHloBuilder::TupleInternal(
+    const Shape& shape, absl::Span<const XlaOp> elements) {
+  mlir::SmallVector<mlir::Value, 4> operands;
+  for (auto& element : elements) {
+    operands.push_back(GetValue(element));
+  }
+  return MakeXlaOp(builder_.create<mlir::xla_hlo::TupleOp>(loc_, operands));
+}
+
 StatusOr<XlaOp> MlirHloBuilder::CreateOp(
     const std::string& op_name, const Shape& shape,
     llvm::ArrayRef<XlaOp> operands,
diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h
index 95dafbd..c0ef645 100644
--- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h
+++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h
@@ -54,6 +54,9 @@
   // TODO(hinsu): Add a constructor to build a new MLIR function from scratch
   // and override Build methods.
 
+  MlirHloBuilder(std::string name, mlir::OpBuilder builder, mlir::Location loc)
+      : XlaBuilder(name), builder_(builder), loc_(loc) {}
+
   MlirHloBuilder(const MlirHloBuilder&) = delete;
   MlirHloBuilder& operator=(const MlirHloBuilder&) = delete;
 
@@ -75,6 +78,17 @@
     return mlir::Value::getFromOpaquePointer(ptr);
   }
 
+  // Returns MLIR values corresponding to the given XLA ops.
+  //
+  // Requires that the ops were created by this builder.
+  std::vector<mlir::Value> GetValues(absl::Span<const XlaOp> ops) {
+    std::vector<mlir::Value> values;
+    for (auto xla_op : ops) {
+      values.push_back(GetValue(xla_op));
+    }
+    return values;
+  }
+
   // Sets location for newly built ops, until reset.
   void SetLocation(mlir::Location loc) { loc_ = loc; }
 
@@ -120,6 +134,34 @@
   StatusOr<XlaOp> AddOpWithShape(HloOpcode opcode, const Shape& shape,
                                  absl::Span<const XlaOp> operands) override;
 
+  XlaOp CreateToken() override;
+
+  StatusOr<XlaOp> InfeedWithTokenInternal(const Shape& infeed_instruction_shape,
+                                          XlaOp token,
+                                          const string& config) override;
+  StatusOr<XlaOp> OutfeedWithTokenInternal(
+      XlaOp operand, XlaOp token, const Shape& shape_with_layout,
+      const string& outfeed_config) override;
+
+  StatusOr<XlaOp> ConcatInDimInternal(const Shape& shape,
+                                      absl::Span<const XlaOp> operands,
+                                      int64 dimension) override;
+
+  StatusOr<XlaOp> GetTupleElementInternal(const Shape& shape, XlaOp tuple_data,
+                                          int64 index) override;
+
+  StatusOr<XlaOp> SliceInternal(const Shape& shape, XlaOp operand,
+                                absl::Span<const int64> start_indices,
+                                absl::Span<const int64> limit_indices,
+                                absl::Span<const int64> strides) override;
+
+  StatusOr<XlaOp> PadInternal(const Shape& shape, XlaOp operand,
+                              XlaOp padding_value,
+                              const PaddingConfig& padding_config) override;
+
+  StatusOr<XlaOp> TupleInternal(const Shape& shape,
+                                absl::Span<const XlaOp> elements) override;
+
   // Creates HLO dialect op and returns the result as an XlaOp.
   StatusOr<XlaOp> CreateOp(const std::string& op_name, const Shape& shape,
                            llvm::ArrayRef<XlaOp> operands,
diff --git a/tensorflow/compiler/mlir/xla/tests/BUILD b/tensorflow/compiler/mlir/xla/tests/BUILD
index 989b846..ad69383 100644
--- a/tensorflow/compiler/mlir/xla/tests/BUILD
+++ b/tensorflow/compiler/mlir/xla/tests/BUILD
@@ -1,4 +1,5 @@
 load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
+load("//tensorflow:tensorflow.bzl", "tf_cc_test")
 
 package(licenses = ["notice"])
 
@@ -18,3 +19,18 @@
         "@llvm-project//llvm:FileCheck",
     ],
 )
+
+tf_cc_test(
+    name = "mlir_hlo_builder_test",
+    srcs = ["mlir_hlo_builder_test.cc"],
+    deps = [
+        "//tensorflow/compiler/mlir/xla:hlo",
+        "//tensorflow/compiler/mlir/xla:mlir_hlo_builder",
+        "//tensorflow/compiler/xla:shape_util",
+        "//tensorflow/core:test",
+        "//tensorflow/core:test_main",
+        "//tensorflow/core:testlib",
+        "@llvm-project//llvm:support",
+        "@llvm-project//mlir:IR",
+    ],
+)
diff --git a/tensorflow/compiler/mlir/xla/tests/mlir_hlo_builder_test.cc b/tensorflow/compiler/mlir/xla/tests/mlir_hlo_builder_test.cc
new file mode 100644
index 0000000..54791e1
--- /dev/null
+++ b/tensorflow/compiler/mlir/xla/tests/mlir_hlo_builder_test.cc
@@ -0,0 +1,179 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h"
+
+#include <string>
+
+#include "llvm/Support/raw_ostream.h"
+#include "mlir/IR/Attributes.h"  // from @llvm-project
+#include "mlir/IR/Builders.h"  // from @llvm-project
+#include "mlir/IR/Dialect.h"  // from @llvm-project
+#include "mlir/IR/Location.h"  // from @llvm-project
+#include "mlir/IR/MLIRContext.h"  // from @llvm-project
+#include "mlir/IR/Module.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+
+namespace {
+
+static void ExpectHasSubstr(absl::string_view s, absl::string_view expected) {
+  EXPECT_TRUE(absl::StrContains(s, expected))
+      << s << " does not contain " << expected;
+}
+
+class XlaBuilderTest : public ::testing::Test {
+ protected:
+  XlaBuilderTest()
+      : name_(SetupTest()),
+        context_(),
+        module_(mlir::ModuleOp::create(mlir::UnknownLoc::get(&context_))),
+        builder_(&module_->getBodyRegion()),
+        xla_builder_(name_, builder_, module_->getLoc()) {}
+
+  string SetupTest() {
+    mlir::registerDialect<mlir::xla_hlo::XlaHloDialect>();
+    return ::testing::UnitTest::GetInstance()->current_test_info()->name();
+  }
+
+  // Retuns the MLIR op string representation of the given XlaOp.
+  string GetMlirOpString(XlaOp xla_op) {
+    string str;
+    llvm::raw_string_ostream ostream{str};
+    xla_builder_.GetValue(xla_op).print(ostream);
+    ostream.flush();
+    return str;
+  }
+
+  string name_;
+  mlir::MLIRContext context_;
+  mlir::OwningModuleRef module_;
+  mlir::OpBuilder builder_;
+  MlirHloBuilder xla_builder_;
+};
+
+TEST_F(XlaBuilderTest, CreateToken) {
+  auto token = CreateToken(&xla_builder_);
+  auto str = GetMlirOpString(token);
+
+  TF_ASSERT_OK(xla_builder_.GetCurrentStatus());
+
+  ExpectHasSubstr(GetMlirOpString(token),
+                  R"("xla_hlo.create_token"() : () -> !xla_hlo.token)");
+}
+
+TEST_F(XlaBuilderTest, Infeed) {
+  auto token = CreateToken(&xla_builder_);
+  auto infeed = InfeedWithToken(token, ShapeUtil::MakeShape(F32, {4, 8}), "");
+
+  TF_ASSERT_OK(xla_builder_.GetCurrentStatus());
+  ExpectHasSubstr(
+      GetMlirOpString(infeed),
+      R"("xla_hlo.infeed"(%0) {infeed_config = ""} : (!xla_hlo.token) -> tuple<tensor<4x8xf32>, !xla_hlo.token>)");
+}
+
+TEST_F(XlaBuilderTest, Outfeed) {
+  auto outfeed_shape = ShapeUtil::MakeShape(F32, {4, 8});
+  auto data = ConstantLiteral(
+      &xla_builder_,
+      LiteralUtil::CreateFromDimensions(F32, outfeed_shape.dimensions()));
+  auto token = CreateToken(&xla_builder_);
+  auto outfeed = OutfeedWithToken(data, token, outfeed_shape, "");
+
+  TF_ASSERT_OK(xla_builder_.GetCurrentStatus());
+  ExpectHasSubstr(
+      GetMlirOpString(outfeed),
+      R"("xla_hlo.outfeed"(%0, %1) {outfeed_config = ""} : (tensor<4x8xf32>, !xla_hlo.token) -> !xla_hlo.token)");
+}
+
+TEST_F(XlaBuilderTest, ConcatInDim) {
+  auto data0 = ConstantLiteral(
+      &xla_builder_, LiteralUtil::CreateFromDimensions(F32, {2, 4, 5}));
+  auto data1 = ConstantLiteral(
+      &xla_builder_, LiteralUtil::CreateFromDimensions(F32, {2, 6, 5}));
+  auto concat = ConcatInDim(&xla_builder_, {data0, data1}, 1);
+
+  TF_ASSERT_OK(xla_builder_.GetCurrentStatus());
+  ExpectHasSubstr(
+      GetMlirOpString(concat),
+      R"("xla_hlo.concatenate"(%0, %1) {dimension = 1 : i64} : (tensor<2x4x5xf32>, tensor<2x6x5xf32>) -> tensor<2x10x5xf32>)");
+}
+
+TEST_F(XlaBuilderTest, Tuple) {
+  auto data0 = ConstantLiteral(&xla_builder_,
+                               LiteralUtil::CreateFromDimensions(F32, {3, 7}));
+  auto data1 = ConstantLiteral(&xla_builder_,
+                               LiteralUtil::CreateFromDimensions(F32, {}));
+  auto tuple = Tuple(&xla_builder_, {data0, data1});
+
+  TF_ASSERT_OK(xla_builder_.GetCurrentStatus());
+  ExpectHasSubstr(
+      GetMlirOpString(tuple),
+      R"("xla_hlo.tuple"(%0, %1) : (tensor<3x7xf32>, tensor<f32>) -> tuple<tensor<3x7xf32>, tensor<f32>>)");
+}
+
+TEST_F(XlaBuilderTest, GetTupleElement) {
+  auto data0 = ConstantLiteral(&xla_builder_,
+                               LiteralUtil::CreateFromDimensions(F32, {3, 7}));
+  auto data1 = ConstantLiteral(&xla_builder_,
+                               LiteralUtil::CreateFromDimensions(F32, {}));
+  auto tuple_data = Tuple(&xla_builder_, {data0, data1});
+  auto gte = GetTupleElement(tuple_data, 1);
+
+  TF_ASSERT_OK(xla_builder_.GetCurrentStatus());
+  ExpectHasSubstr(
+      GetMlirOpString(gte),
+      R"("xla_hlo.get_tuple_element"(%2) {index = 1 : i32} : (tuple<tensor<3x7xf32>, tensor<f32>>) -> tensor<f32>)");
+}
+
+TEST_F(XlaBuilderTest, Slice) {
+  auto data = ConstantLiteral(&xla_builder_,
+                              LiteralUtil::CreateFromDimensions(F32, {3, 7}));
+  auto slice = Slice(data, {0, 1}, {2, 5}, {1, 1});
+
+  TF_ASSERT_OK(xla_builder_.GetCurrentStatus());
+  ExpectHasSubstr(
+      GetMlirOpString(slice),
+      R"("xla_hlo.slice"(%0) {limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<3x7xf32>) -> tensor<2x4xf32>)");
+}
+
+TEST_F(XlaBuilderTest, Pad) {
+  auto data = ConstantLiteral(&xla_builder_,
+                              LiteralUtil::CreateFromDimensions(F32, {3, 7}));
+  auto zero = ConstantLiteral(&xla_builder_, LiteralUtil::Zero(F32));
+
+  PaddingConfig padding_config;
+  auto* dims0 = padding_config.add_dimensions();
+  dims0->set_edge_padding_low(1);
+  dims0->set_interior_padding(0);
+  dims0->set_edge_padding_high(2);
+  auto* dims1 = padding_config.add_dimensions();
+  dims1->set_edge_padding_low(3);
+  dims1->set_interior_padding(1);
+  dims1->set_edge_padding_high(0);
+  auto pad = Pad(data, zero, padding_config);
+
+  TF_ASSERT_OK(xla_builder_.GetCurrentStatus());
+  ExpectHasSubstr(
+      GetMlirOpString(pad),
+      R"("xla_hlo.pad"(%0, %1) {edge_padding_high = dense<[2, 0]> : tensor<2xi64>, edge_padding_low = dense<[1, 3]> : tensor<2xi64>, interior_padding = dense<[0, 1]> : tensor<2xi64>} : (tensor<3x7xf32>, tensor<f32>) -> tensor<6x16xf32>)");
+}
+
+}  // namespace
+}  // namespace xla
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index 2a69023..ea93880 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -822,23 +822,29 @@
                         absl::Span<const int64> limit_indices,
                         absl::Span<const int64> strides) {
   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
-    HloInstructionProto instr;
     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferSliceShape(
                                          *operand_shape, start_indices,
                                          limit_indices, strides));
-    *instr.mutable_shape() = shape.ToProto();
-    for (int i = 0; i < start_indices.size(); i++) {
-      auto* slice_config = instr.add_slice_dimensions();
-      slice_config->set_start(start_indices[i]);
-      slice_config->set_limit(limit_indices[i]);
-      slice_config->set_stride(strides[i]);
-    }
-
-    return AddInstruction(std::move(instr), HloOpcode::kSlice, {operand});
+    return SliceInternal(shape, operand, start_indices, limit_indices, strides);
   });
 }
 
+StatusOr<XlaOp> XlaBuilder::SliceInternal(const Shape& shape, XlaOp operand,
+                                          absl::Span<const int64> start_indices,
+                                          absl::Span<const int64> limit_indices,
+                                          absl::Span<const int64> strides) {
+  HloInstructionProto instr;
+  *instr.mutable_shape() = shape.ToProto();
+  for (int i = 0; i < start_indices.size(); i++) {
+    auto* slice_config = instr.add_slice_dimensions();
+    slice_config->set_start(start_indices[i]);
+    slice_config->set_limit(limit_indices[i]);
+    slice_config->set_stride(strides[i]);
+  }
+  return AddInstruction(std::move(instr), HloOpcode::kSlice, {operand});
+}
+
 XlaOp XlaBuilder::SliceInDim(XlaOp operand, int64 start_index,
                              int64 limit_index, int64 stride, int64 dimno) {
   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
@@ -952,41 +958,49 @@
 XlaOp XlaBuilder::ConcatInDim(absl::Span<const XlaOp> operands,
                               int64 dimension) {
   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
-    HloInstructionProto instr;
-
     std::vector<const Shape*> operand_shape_ptrs;
     TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands));
     absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
                       [](const Shape& shape) { return &shape; });
     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConcatOpShape(
                                          operand_shape_ptrs, dimension));
-    *instr.mutable_shape() = shape.ToProto();
-
-    instr.add_dimensions(dimension);
-
-    return AddInstruction(std::move(instr), HloOpcode::kConcatenate, operands);
+    return ConcatInDimInternal(shape, operands, dimension);
   });
 }
 
+StatusOr<XlaOp> XlaBuilder::ConcatInDimInternal(
+    const Shape& shape, absl::Span<const XlaOp> operands, int64 dimension) {
+  HloInstructionProto instr;
+  *instr.mutable_shape() = shape.ToProto();
+
+  instr.add_dimensions(dimension);
+
+  return AddInstruction(std::move(instr), HloOpcode::kConcatenate, operands);
+}
+
 XlaOp XlaBuilder::Pad(XlaOp operand, XlaOp padding_value,
                       const PaddingConfig& padding_config) {
   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
-    HloInstructionProto instr;
-
     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
     TF_ASSIGN_OR_RETURN(const Shape* padding_value_shape,
                         GetShapePtr(padding_value));
     TF_ASSIGN_OR_RETURN(
         Shape shape, ShapeInference::InferPadShape(
                          *operand_shape, *padding_value_shape, padding_config));
-    *instr.mutable_shape() = shape.ToProto();
-    *instr.mutable_padding_config() = padding_config;
-
-    return AddInstruction(std::move(instr), HloOpcode::kPad,
-                          {operand, padding_value});
+    return PadInternal(shape, operand, padding_value, padding_config);
   });
 }
 
+StatusOr<XlaOp> XlaBuilder::PadInternal(const Shape& shape, XlaOp operand,
+                                        XlaOp padding_value,
+                                        const PaddingConfig& padding_config) {
+  HloInstructionProto instr;
+  *instr.mutable_shape() = shape.ToProto();
+  *instr.mutable_padding_config() = padding_config;
+  return AddInstruction(std::move(instr), HloOpcode::kPad,
+                        {operand, padding_value});
+}
+
 XlaOp XlaBuilder::Reshape(XlaOp operand, absl::Span<const int64> dimensions,
                           absl::Span<const int64> new_sizes,
                           int64 inferred_dimension) {
@@ -1080,7 +1094,6 @@
 
 XlaOp XlaBuilder::Tuple(absl::Span<const XlaOp> elements) {
   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
-    HloInstructionProto instr;
     std::vector<const Shape*> operand_shape_ptrs;
     TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(elements));
     absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
@@ -1088,14 +1101,19 @@
     TF_ASSIGN_OR_RETURN(const Shape shape,
                         ShapeInference::InferVariadicOpShape(
                             HloOpcode::kTuple, operand_shape_ptrs));
-    *instr.mutable_shape() = shape.ToProto();
-    return AddInstruction(std::move(instr), HloOpcode::kTuple, elements);
+    return TupleInternal(shape, elements);
   });
 }
 
+StatusOr<XlaOp> XlaBuilder::TupleInternal(const Shape& shape,
+                                          absl::Span<const XlaOp> elements) {
+  HloInstructionProto instr;
+  *instr.mutable_shape() = shape.ToProto();
+  return AddInstruction(std::move(instr), HloOpcode::kTuple, elements);
+}
+
 XlaOp XlaBuilder::GetTupleElement(XlaOp tuple_data, int64 index) {
   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
-    HloInstructionProto instr;
     TF_ASSIGN_OR_RETURN(const Shape* tuple_shape, GetShapePtr(tuple_data));
     if (!tuple_shape->IsTuple()) {
       return InvalidArgument(
@@ -1107,16 +1125,22 @@
           "GetTupleElement() index (%d) out of range for tuple shape %s", index,
           ShapeUtil::HumanString(*tuple_shape));
     }
-    *instr.mutable_shape() =
-        ShapeUtil::GetTupleElementShape(*tuple_shape, index).ToProto();
-
-    instr.set_tuple_index(index);
-
-    return AddInstruction(std::move(instr), HloOpcode::kGetTupleElement,
-                          {tuple_data});
+    return GetTupleElementInternal(
+        ShapeUtil::GetTupleElementShape(*tuple_shape, index), tuple_data,
+        index);
   });
 }
 
+StatusOr<XlaOp> XlaBuilder::GetTupleElementInternal(const Shape& shape,
+                                                    XlaOp tuple_data,
+                                                    int64 index) {
+  HloInstructionProto instr;
+  *instr.mutable_shape() = shape.ToProto();
+  instr.set_tuple_index(index);
+  return AddInstruction(std::move(instr), HloOpcode::kGetTupleElement,
+                        {tuple_data});
+}
+
 XlaOp XlaBuilder::Dot(XlaOp lhs, XlaOp rhs,
                       const PrecisionConfig* precision_config) {
   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
@@ -1407,14 +1431,11 @@
 XlaOp XlaBuilder::InfeedWithToken(XlaOp token, const Shape& shape,
                                   const string& config) {
   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
-    HloInstructionProto instr;
     if (!LayoutUtil::HasLayout(shape)) {
       return InvalidArgument("Given shape to Infeed must have a layout");
     }
     const Shape infeed_instruction_shape =
         ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()});
-    *instr.mutable_shape() = infeed_instruction_shape.ToProto();
-    instr.set_infeed_config(config);
 
     if (shape.IsArray() && sharding() &&
         sharding()->type() == OpSharding::OTHER) {
@@ -1427,11 +1448,18 @@
       return InvalidArgument(
           "Replicated sharding is not yet supported for infeeds");
     }
-
-    return AddInstruction(std::move(instr), HloOpcode::kInfeed, {token});
+    return InfeedWithTokenInternal(infeed_instruction_shape, token, config);
   });
 }
 
+StatusOr<XlaOp> XlaBuilder::InfeedWithTokenInternal(
+    const Shape& infeed_instruction_shape, XlaOp token, const string& config) {
+  HloInstructionProto instr;
+  *instr.mutable_shape() = infeed_instruction_shape.ToProto();
+  instr.set_infeed_config(config);
+  return AddInstruction(std::move(instr), HloOpcode::kInfeed, {token});
+}
+
 void XlaBuilder::Outfeed(XlaOp operand, const Shape& shape_with_layout,
                          const string& outfeed_config) {
   ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
@@ -1488,10 +1516,6 @@
                                    const Shape& shape_with_layout,
                                    const string& outfeed_config) {
   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
-    HloInstructionProto instr;
-
-    *instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
-
     // Check and set outfeed shape.
     if (!LayoutUtil::HasLayout(shape_with_layout)) {
       return InvalidArgument("Given shape to Outfeed must have a layout");
@@ -1503,15 +1527,22 @@
           ShapeUtil::HumanStringWithLayout(shape_with_layout),
           ShapeUtil::HumanStringWithLayout(*operand_shape));
     }
-    *instr.mutable_outfeed_shape() = shape_with_layout.ToProto();
-
-    instr.set_outfeed_config(outfeed_config);
-
-    return AddInstruction(std::move(instr), HloOpcode::kOutfeed,
-                          {operand, token});
+    return OutfeedWithTokenInternal(operand, token, shape_with_layout,
+                                    outfeed_config);
   });
 }
 
+StatusOr<XlaOp> XlaBuilder::OutfeedWithTokenInternal(
+    XlaOp operand, XlaOp token, const Shape& shape_with_layout,
+    const string& outfeed_config) {
+  HloInstructionProto instr;
+  *instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
+  *instr.mutable_outfeed_shape() = shape_with_layout.ToProto();
+  instr.set_outfeed_config(outfeed_config);
+  return AddInstruction(std::move(instr), HloOpcode::kOutfeed,
+                        {operand, token});
+}
+
 XlaOp XlaBuilder::CreateToken() {
   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
     HloInstructionProto instr;
diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h
index f320fee..4eba598 100644
--- a/tensorflow/compiler/xla/client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_builder.h
@@ -364,6 +364,10 @@
   Status SetInstructionFrontendAttribute(XlaOp op, string attribute,
                                          string value);
 
+  // Returns shapes for the operands.
+  StatusOr<std::vector<Shape>> GetOperandShapes(
+      absl::Span<const XlaOp> operands) const;
+
  private:
   // Build helper which takes the id of the root operation..
   StatusOr<XlaComputation> Build(int64 root_id, bool remove_dynamic_dimensions);
@@ -391,6 +395,10 @@
   XlaOp Pad(XlaOp operand, XlaOp padding_value,
             const PaddingConfig& padding_config);
 
+  virtual StatusOr<XlaOp> PadInternal(const Shape& shape, XlaOp operand,
+                                      XlaOp padding_value,
+                                      const PaddingConfig& padding_config);
+
   XlaOp Reshape(XlaOp operand, absl::Span<const int64> dimensions,
                 absl::Span<const int64> new_sizes,
                 int64 inferred_dimension = -1);
@@ -406,9 +414,12 @@
   XlaOp Slice(XlaOp operand, absl::Span<const int64> start_indices,
               absl::Span<const int64> limit_indices,
               absl::Span<const int64> strides);
-
-  XlaOp SliceInDim(XlaOp operand, int64 start_index, int64 limit_index,
-                   int64 stride, int64 dimno);
+  virtual StatusOr<XlaOp> SliceInternal(const Shape& shape, XlaOp operand,
+                                        absl::Span<const int64> start_indices,
+                                        absl::Span<const int64> limit_indices,
+                                        absl::Span<const int64> strides);
+  virtual XlaOp SliceInDim(XlaOp operand, int64 start_index, int64 limit_index,
+                           int64 stride, int64 dimno);
 
   ABSL_DEPRECATED("Use span-of-indices form instead")
   XlaOp DynamicSlice(XlaOp operand, XlaOp start_indices,
@@ -422,14 +433,22 @@
                            absl::Span<const XlaOp> start_indices);
 
   XlaOp ConcatInDim(absl::Span<const XlaOp> operands, int64 dimension);
+  virtual StatusOr<XlaOp> ConcatInDimInternal(const Shape& shape,
+                                              absl::Span<const XlaOp> operands,
+                                              int64 dimension);
 
   void Trace(const string& tag, XlaOp operand);
 
   XlaOp Select(XlaOp pred, XlaOp on_true, XlaOp on_false);
 
   XlaOp Tuple(absl::Span<const XlaOp> elements);
+  virtual StatusOr<XlaOp> TupleInternal(const Shape& shape,
+                                        absl::Span<const XlaOp> elements);
 
   XlaOp GetTupleElement(XlaOp tuple_data, int64 index);
+  virtual StatusOr<XlaOp> GetTupleElementInternal(const Shape& shape,
+                                                  XlaOp tuple_data,
+                                                  int64 index);
 
   XlaOp Dot(XlaOp lhs, XlaOp rhs,
             const PrecisionConfig* precision_config = nullptr);
@@ -476,15 +495,18 @@
             absl::Span<const int64> fft_length);
 
   XlaOp Infeed(const Shape& shape, const string& config = "");
-  XlaOp InfeedWithToken(XlaOp token, const Shape& shape,
-                        const string& config = "");
+  XlaOp InfeedWithToken(XlaOp token, const Shape& shape, const string& config);
+  virtual StatusOr<XlaOp> InfeedWithTokenInternal(
+      const Shape& infeed_instruction_shape, XlaOp token, const string& config);
 
   void Outfeed(XlaOp operand, const Shape& shape_with_layout,
                const string& outfeed_config);
   XlaOp OutfeedWithToken(XlaOp operand, XlaOp token,
                          const Shape& shape_with_layout,
                          const string& outfeed_config);
-
+  virtual StatusOr<XlaOp> OutfeedWithTokenInternal(
+      XlaOp operand, XlaOp token, const Shape& shape_with_layout,
+      const string& outfeed_config);
   XlaOp Call(const XlaComputation& computation,
              absl::Span<const XlaOp> operands);
 
@@ -624,7 +646,7 @@
   XlaOp RecvFromHost(XlaOp token, const Shape& shape,
                      const ChannelHandle& handle);
 
-  XlaOp CreateToken();
+  virtual XlaOp CreateToken();
 
   XlaOp AfterAll(absl::Span<const XlaOp> tokens);
 
@@ -701,10 +723,6 @@
   // Returns the (inferred) result for the program shape using the given root.
   StatusOr<ProgramShape> GetProgramShape(int64 root_id) const;
 
-  // Returns shapes for the operands.
-  StatusOr<std::vector<Shape>> GetOperandShapes(
-      absl::Span<const XlaOp> operands) const;
-
   // A visitor which checks whether an operation is a compile-time constant,
   // meaning that it doesn't depend on any parameters, or on any stateful
   // operation such as `RngNormal` or `Infeed`. The visitor walks the