[XLA] Make HloCostAnalysis account for nested shapes to calculate bytes accessed in fusion, infeed and outfeed
PiperOrigin-RevId: 356785877
Change-Id: Ic482891d04b7dfad63be11d01c3ecc7dae69916a
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
index 4ed89c4..ecf3f8d 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
@@ -146,7 +146,8 @@
const HloInstruction* hlo) const {
int64 size = 0;
bool seen_trivial_user = false;
- CHECK(hlo->IsFused() && hlo->opcode() == HloOpcode::kParameter);
+ CHECK(hlo->IsFused() && (hlo->opcode() == HloOpcode::kParameter ||
+ hlo->opcode() == HloOpcode::kGetTupleElement));
for (const HloInstruction* user : hlo->users()) {
switch (user->opcode()) {
case HloOpcode::kFusion: {
@@ -335,11 +336,34 @@
return Status::OK();
}
-Status HloCostAnalysis::HandleInfeed(const HloInstruction*) {
+Status HloCostAnalysis::HandleInfeed(const HloInstruction* infeed) {
+ // Count nested infeed output tuples.
+ int64 size = 0;
+ for (const auto& indexed_shape : ShapeUtil::GetLeafShapes(infeed->shape())) {
+ size += GetShapeSize(indexed_shape.shape);
+ SetOutputBytesAccessed(indexed_shape.index,
+ GetShapeSize(indexed_shape.shape));
+ }
+ SetOutputBytesAccessed(size);
+ current_properties_[kBytesAccessedKey] = size;
return Status::OK();
}
-Status HloCostAnalysis::HandleOutfeed(const HloInstruction*) {
+Status HloCostAnalysis::HandleOutfeed(const HloInstruction* outfeed) {
+ // Count nested outfeed operand tuples.
+ current_properties_[kBytesAccessedKey] = 0;
+ for (int64 i = 0; i < outfeed->operand_count(); ++i) {
+ const HloInstruction* operand = outfeed->operand(i);
+ int64 size = 0;
+ for (const auto& indexed_shape :
+ ShapeUtil::GetLeafShapes(operand->shape())) {
+ size += GetShapeSize(indexed_shape.shape);
+ SetOperandBytesAccessed(i, indexed_shape.index,
+ GetShapeSize(indexed_shape.shape));
+ }
+ SetOperandBytesAccessed(i, size);
+ current_properties_[kBytesAccessedKey] += size;
+ }
return Status::OK();
}
@@ -872,9 +896,31 @@
for (int64 i = 0; i < fusion->fused_parameters().size(); ++i) {
const HloInstruction* operand = fusion->fused_parameter(i);
- int64 size = FusionParameterReadBytes(operand);
- current_properties_[kBytesAccessedKey] += size;
- SetOperandBytesAccessed(i, size);
+ int64 operand_size = 0;
+ if (!fusion->shape().IsTuple()) {
+ operand_size = FusionParameterReadBytes(operand);
+ } else {
+ // If the fusion parameter is a tuple type, find the gte for the leaf
+ // shape and calculate the bytes accessed for those array types.
+ for (const auto& indexed_shape :
+ ShapeUtil::GetLeafShapes(operand->shape())) {
+ const HloInstruction* gte = operand;
+ for (int64 index : indexed_shape.index) {
+ for (const HloInstruction* user : gte->users()) {
+ if (user->opcode() == HloOpcode::kGetTupleElement &&
+ user->tuple_index() == index) {
+ gte = user;
+ break;
+ }
+ }
+ }
+ int64 size = FusionParameterReadBytes(gte);
+ operand_size += size;
+ SetOperandBytesAccessed(i, indexed_shape.index, size);
+ }
+ }
+ current_properties_[kBytesAccessedKey] += operand_size;
+ SetOperandBytesAccessed(i, operand_size);
}
return Status::OK();
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
index 748eb40..dd9cc41 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
@@ -693,10 +693,10 @@
EXPECT_EQ(fusion_analysis.flop_count(), 16);
EXPECT_EQ(fusion_analysis.transcendental_count(), 4);
EXPECT_EQ(fusion_analysis.bytes_accessed(*fusion),
- sizeof(float) * (3 + 5) * 2 * 2 + kPointerSize * 2);
+ sizeof(float) * (5 + 5) * 2 * 2);
EXPECT_EQ(fusion_analysis.operand_bytes_accessed(*fusion, 0),
- kPointerSize * 2);
+ sizeof(float) * 2 * 2 * 2);
EXPECT_EQ(fusion_analysis.operand_bytes_accessed(*fusion, 1),
sizeof(float) * 2 * 2);
EXPECT_EQ(fusion_analysis.operand_bytes_accessed(*fusion, 2),
@@ -758,6 +758,78 @@
sizeof(float) * 2 * 3 * 4 * 5);
}
+TEST_F(FusionCostAnalysis, TupleBytesAccessed) {
+ absl::string_view hlo_string = R"(
+HloModule module, is_scheduled=true
+
+fused_computation {
+ param = (f32[2,2]{1,0}, f32[2,2]{1,0}) parameter(0)
+ gte0 = f32[2,2]{1,0} get-tuple-element(param), index=0
+ gte1 = f32[2,2]{1,0} get-tuple-element(param), index=1
+ add = f32[2,2]{1,0} add(gte0, gte1)
+ mul = f32[2,2]{1,0} multiply(gte0, gte1)
+ ROOT root = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(add, mul)
+}
+
+ENTRY entry {
+ param0 = f32[2,2]{1,0} parameter(0)
+ param1 = f32[2,2]{1,0} parameter(1)
+ tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(param0, param1)
+ ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(tuple), kind=kLoop, calls=fused_computation
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseAndReturnVerifiedModule(hlo_string));
+
+ HloInstruction* fusion = module->entry_computation()->root_instruction();
+
+ HloCostAnalysis fusion_analysis(ShapeSize);
+ ASSERT_IS_OK(fusion->Accept(&fusion_analysis));
+
+ EXPECT_EQ(fusion_analysis.bytes_accessed(*fusion), sizeof(float) * 2 * 2 * 4);
+ EXPECT_EQ(fusion_analysis.operand_bytes_accessed(*fusion, 0),
+ sizeof(float) * 2 * 2 * 2);
+ EXPECT_EQ(fusion_analysis.output_bytes_accessed(*fusion),
+ sizeof(float) * 2 * 2 * 2);
+}
+
+TEST_F(FusionCostAnalysis, InfeedOutfeed) {
+ absl::string_view hlo_string = R"(
+HloModule module, is_scheduled=true
+
+ENTRY entry {
+ after-all = token[] after-all()
+ infeed = ((f32[2,3]{1,0}), token[]) infeed(after-all)
+ gte0 = (f32[2,3]{1,0}) get-tuple-element(infeed), index=0
+ gte1 = f32[2,3]{1,0} get-tuple-element(gte0), index=0
+ add = f32[2,3]{1,0} add(gte1, gte1)
+ tuple = (f32[2,3]{1,0}) tuple(add)
+ tok = token[] get-tuple-element(infeed), index=1
+ ROOT outfeed = token[] outfeed(tuple, tok)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseAndReturnVerifiedModule(hlo_string));
+
+ HloInstruction* infeed =
+ module->entry_computation()->GetInstructionWithName("infeed");
+ HloInstruction* outfeed =
+ module->entry_computation()->GetInstructionWithName("outfeed");
+
+ HloCostAnalysis analysis(ShapeSize);
+ ASSERT_IS_OK(infeed->Accept(&analysis));
+ ASSERT_IS_OK(outfeed->Accept(&analysis));
+
+ EXPECT_EQ(analysis.bytes_accessed(*infeed), sizeof(float) * 2 * 3);
+ EXPECT_EQ(analysis.operand_bytes_accessed(*infeed, 0), 0);
+ EXPECT_EQ(analysis.output_bytes_accessed(*infeed), sizeof(float) * 2 * 3);
+
+ EXPECT_EQ(analysis.bytes_accessed(*outfeed), sizeof(float) * 2 * 3);
+ EXPECT_EQ(analysis.operand_bytes_accessed(*outfeed, 0),
+ sizeof(float) * 2 * 3);
+ EXPECT_EQ(analysis.output_bytes_accessed(*outfeed), 0);
+}
+
TEST_F(HloCostAnalysisTest, TupleCost) {
HloCostAnalysis analysis(ShapeSize);