[XLA] Avoid materializing large literals in creation utils.

PiperOrigin-RevId: 292976673
Change-Id: I2c1fa4bb96434f085f293dd5d4ff26d67a4236b5
diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc
index c151fcb..f64434a 100644
--- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc
+++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc
@@ -496,6 +496,22 @@
                           /*result_shape_bounds=*/broadcast_dimensions);
 }
 
+// Recursively creates a dummy op given a shape. Leaf nodes are broadcasted zero
+// while internal nodes are tuples.
+HloInstruction* CreateDummyOp(HloComputation::Builder* b, const Shape& shape) {
+  if (shape.IsArray()) {
+    auto zero = b->AddInstruction(HloInstruction::CreateConstant(
+        LiteralUtil::Zero(shape.element_type())));
+    return b->AddInstruction(HloInstruction::CreateBroadcast(shape, zero, {}));
+  }
+  CHECK(shape.IsTuple());
+  std::vector<HloInstruction*> sub_instructions;
+  for (const Shape& subshape : shape.tuple_shapes()) {
+    sub_instructions.push_back(CreateDummyOp(b, subshape));
+  }
+  return b->AddInstruction(HloInstruction::CreateTuple(sub_instructions));
+}
+
 StatusOr<std::unique_ptr<HloComputation>> CreateComputationWithSignature(
     absl::Span<const Shape* const> domain, const Shape& range,
     absl::string_view name) {
@@ -508,12 +524,9 @@
   }
 
   // We can't change the root type of a computation once it is created so create
-  // a dummy root instruction to give the computation the right root shape.  In
-  // the future we may want to use a (recursive) broadcast here to avoid
-  // creating large constants.
-  b.AddInstruction(
-      HloInstruction::CreateConstant(Literal::CreateFromShape(range)));
-
+  // a dummy root instruction to give the computation the right root shape.  Use
+  // a (recursive) broadcast here to avoid creating large constants.
+  CreateDummyOp(&b, range);
   return b.Build();
 }