[XLA] Make HloCostAnalysis account for nested shapes to calculate bytes accessed in fusion, infeed and outfeed

PiperOrigin-RevId: 356785877
Change-Id: Ic482891d04b7dfad63be11d01c3ecc7dae69916a
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
index 4ed89c4..ecf3f8d 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
@@ -146,7 +146,8 @@
     const HloInstruction* hlo) const {
   int64 size = 0;
   bool seen_trivial_user = false;
-  CHECK(hlo->IsFused() && hlo->opcode() == HloOpcode::kParameter);
+  CHECK(hlo->IsFused() && (hlo->opcode() == HloOpcode::kParameter ||
+                           hlo->opcode() == HloOpcode::kGetTupleElement));
   for (const HloInstruction* user : hlo->users()) {
     switch (user->opcode()) {
       case HloOpcode::kFusion: {
@@ -335,11 +336,34 @@
   return Status::OK();
 }
 
-Status HloCostAnalysis::HandleInfeed(const HloInstruction*) {
+Status HloCostAnalysis::HandleInfeed(const HloInstruction* infeed) {
+  // Count nested infeed output tuples.
+  int64 size = 0;
+  for (const auto& indexed_shape : ShapeUtil::GetLeafShapes(infeed->shape())) {
+    size += GetShapeSize(indexed_shape.shape);
+    SetOutputBytesAccessed(indexed_shape.index,
+                           GetShapeSize(indexed_shape.shape));
+  }
+  SetOutputBytesAccessed(size);
+  current_properties_[kBytesAccessedKey] = size;
   return Status::OK();
 }
 
-Status HloCostAnalysis::HandleOutfeed(const HloInstruction*) {
+Status HloCostAnalysis::HandleOutfeed(const HloInstruction* outfeed) {
+  // Count nested outfeed operand tuples.
+  current_properties_[kBytesAccessedKey] = 0;
+  for (int64 i = 0; i < outfeed->operand_count(); ++i) {
+    const HloInstruction* operand = outfeed->operand(i);
+    int64 size = 0;
+    for (const auto& indexed_shape :
+         ShapeUtil::GetLeafShapes(operand->shape())) {
+      size += GetShapeSize(indexed_shape.shape);
+      SetOperandBytesAccessed(i, indexed_shape.index,
+                              GetShapeSize(indexed_shape.shape));
+    }
+    SetOperandBytesAccessed(i, size);
+    current_properties_[kBytesAccessedKey] += size;
+  }
   return Status::OK();
 }
 
@@ -872,9 +896,31 @@
 
   for (int64 i = 0; i < fusion->fused_parameters().size(); ++i) {
     const HloInstruction* operand = fusion->fused_parameter(i);
-    int64 size = FusionParameterReadBytes(operand);
-    current_properties_[kBytesAccessedKey] += size;
-    SetOperandBytesAccessed(i, size);
+    int64 operand_size = 0;
+    if (!fusion->shape().IsTuple()) {
+      operand_size = FusionParameterReadBytes(operand);
+    } else {
+      // If the fusion parameter is a tuple type, find the gte for the leaf
+      // shape and calculate the bytes accessed for those array types.
+      for (const auto& indexed_shape :
+           ShapeUtil::GetLeafShapes(operand->shape())) {
+        const HloInstruction* gte = operand;
+        for (int64 index : indexed_shape.index) {
+          for (const HloInstruction* user : gte->users()) {
+            if (user->opcode() == HloOpcode::kGetTupleElement &&
+                user->tuple_index() == index) {
+              gte = user;
+              break;
+            }
+          }
+        }
+        int64 size = FusionParameterReadBytes(gte);
+        operand_size += size;
+        SetOperandBytesAccessed(i, indexed_shape.index, size);
+      }
+    }
+    current_properties_[kBytesAccessedKey] += operand_size;
+    SetOperandBytesAccessed(i, operand_size);
   }
 
   return Status::OK();
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
index 748eb40..dd9cc41 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
@@ -693,10 +693,10 @@
   EXPECT_EQ(fusion_analysis.flop_count(), 16);
   EXPECT_EQ(fusion_analysis.transcendental_count(), 4);
   EXPECT_EQ(fusion_analysis.bytes_accessed(*fusion),
-            sizeof(float) * (3 + 5) * 2 * 2 + kPointerSize * 2);
+            sizeof(float) * (5 + 5) * 2 * 2);
 
   EXPECT_EQ(fusion_analysis.operand_bytes_accessed(*fusion, 0),
-            kPointerSize * 2);
+            sizeof(float) * 2 * 2 * 2);
   EXPECT_EQ(fusion_analysis.operand_bytes_accessed(*fusion, 1),
             sizeof(float) * 2 * 2);
   EXPECT_EQ(fusion_analysis.operand_bytes_accessed(*fusion, 2),
@@ -758,6 +758,78 @@
             sizeof(float) * 2 * 3 * 4 * 5);
 }
 
+TEST_F(FusionCostAnalysis, TupleBytesAccessed) {
+  absl::string_view hlo_string = R"(
+HloModule module, is_scheduled=true
+
+fused_computation {
+  param = (f32[2,2]{1,0}, f32[2,2]{1,0}) parameter(0)
+  gte0 = f32[2,2]{1,0} get-tuple-element(param), index=0
+  gte1 = f32[2,2]{1,0} get-tuple-element(param), index=1
+  add = f32[2,2]{1,0} add(gte0, gte1)
+  mul = f32[2,2]{1,0} multiply(gte0, gte1)
+  ROOT root = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(add, mul)
+}
+
+ENTRY entry {
+  param0 = f32[2,2]{1,0} parameter(0)
+  param1 = f32[2,2]{1,0} parameter(1)
+  tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(param0, param1)
+  ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(tuple), kind=kLoop, calls=fused_computation
+}
+)";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+
+  HloInstruction* fusion = module->entry_computation()->root_instruction();
+
+  HloCostAnalysis fusion_analysis(ShapeSize);
+  ASSERT_IS_OK(fusion->Accept(&fusion_analysis));
+
+  EXPECT_EQ(fusion_analysis.bytes_accessed(*fusion), sizeof(float) * 2 * 2 * 4);
+  EXPECT_EQ(fusion_analysis.operand_bytes_accessed(*fusion, 0),
+            sizeof(float) * 2 * 2 * 2);
+  EXPECT_EQ(fusion_analysis.output_bytes_accessed(*fusion),
+            sizeof(float) * 2 * 2 * 2);
+}
+
+TEST_F(FusionCostAnalysis, InfeedOutfeed) {
+  absl::string_view hlo_string = R"(
+HloModule module, is_scheduled=true
+
+ENTRY entry {
+  after-all = token[] after-all()
+  infeed = ((f32[2,3]{1,0}), token[]) infeed(after-all)
+  gte0 = (f32[2,3]{1,0}) get-tuple-element(infeed), index=0
+  gte1 = f32[2,3]{1,0} get-tuple-element(gte0), index=0
+  add = f32[2,3]{1,0} add(gte1, gte1)
+  tuple = (f32[2,3]{1,0}) tuple(add)
+  tok = token[] get-tuple-element(infeed), index=1
+  ROOT outfeed = token[] outfeed(tuple, tok)
+}
+)";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+
+  HloInstruction* infeed =
+      module->entry_computation()->GetInstructionWithName("infeed");
+  HloInstruction* outfeed =
+      module->entry_computation()->GetInstructionWithName("outfeed");
+
+  HloCostAnalysis analysis(ShapeSize);
+  ASSERT_IS_OK(infeed->Accept(&analysis));
+  ASSERT_IS_OK(outfeed->Accept(&analysis));
+
+  EXPECT_EQ(analysis.bytes_accessed(*infeed), sizeof(float) * 2 * 3);
+  EXPECT_EQ(analysis.operand_bytes_accessed(*infeed, 0), 0);
+  EXPECT_EQ(analysis.output_bytes_accessed(*infeed), sizeof(float) * 2 * 3);
+
+  EXPECT_EQ(analysis.bytes_accessed(*outfeed), sizeof(float) * 2 * 3);
+  EXPECT_EQ(analysis.operand_bytes_accessed(*outfeed, 0),
+            sizeof(float) * 2 * 3);
+  EXPECT_EQ(analysis.output_bytes_accessed(*outfeed), 0);
+}
+
 TEST_F(HloCostAnalysisTest, TupleCost) {
   HloCostAnalysis analysis(ShapeSize);