[XLA] Avoid materializing large literals in creation utils.
PiperOrigin-RevId: 292976673
Change-Id: I2c1fa4bb96434f085f293dd5d4ff26d67a4236b5
diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc
index c151fcb..f64434a 100644
--- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc
+++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc
@@ -496,6 +496,22 @@
/*result_shape_bounds=*/broadcast_dimensions);
}
+// Recursively creates a dummy op given a shape. Leaf nodes are broadcasted zero
+// while internal nodes are tuples.
+HloInstruction* CreateDummyOp(HloComputation::Builder* b, const Shape& shape) {
+ if (shape.IsArray()) {
+ auto zero = b->AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::Zero(shape.element_type())));
+ return b->AddInstruction(HloInstruction::CreateBroadcast(shape, zero, {}));
+ }
+ CHECK(shape.IsTuple());
+ std::vector<HloInstruction*> sub_instructions;
+ for (const Shape& subshape : shape.tuple_shapes()) {
+ sub_instructions.push_back(CreateDummyOp(b, subshape));
+ }
+ return b->AddInstruction(HloInstruction::CreateTuple(sub_instructions));
+}
+
StatusOr<std::unique_ptr<HloComputation>> CreateComputationWithSignature(
absl::Span<const Shape* const> domain, const Shape& range,
absl::string_view name) {
@@ -508,12 +524,9 @@
}
// We can't change the root type of a computation once it is created so create
- // a dummy root instruction to give the computation the right root shape. In
- // the future we may want to use a (recursive) broadcast here to avoid
- // creating large constants.
- b.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateFromShape(range)));
-
+ // a dummy root instruction to give the computation the right root shape. Use
+ // a (recursive) broadcast here to avoid creating large constants.
+ CreateDummyOp(&b, range);
return b.Build();
}