All HloCallableInstruction utilities to set called computation thread name recursively. Also add verifier that fusion must have same thread name for parent computation and all nested called computations, this doesn't necessarily hold true for other callable Hlo instruction though.

PiperOrigin-RevId: 459684297
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 78b573e..f9a0e11 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -4870,6 +4870,13 @@
   Cast<HloAsyncInstruction>(this)->set_async_thread_name(async_thread_name);
 }
 
+void HloInstruction::set_called_computations_thread_name(
+    const std::optional<std::string>& async_thread_name,
+    bool skip_async_thread_name_overwrite) {
+  Cast<HloCallableInstruction>(this)->RecursivelySetComputationsThreadName(
+      async_thread_name, skip_async_thread_name_overwrite);
+}
+
 bool HloInstruction::is_cross_program_prefetch() const {
   return Cast<HloCopyStartInstruction>(this)->is_cross_program_prefetch();
 }
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 1d1110f..fd9b4da 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -2139,6 +2139,12 @@
   void set_async_thread_name(
       const std::optional<std::string>& async_thread_name);
 
+  // Delegates to
+  // HloCallableInstruction::RecursivelySetComputationsThreadName().
+  void set_called_computations_thread_name(
+      const std::optional<std::string>& async_thread_name,
+      bool skip_async_thread_name_overwrite);
+
   // Delegates to HloCopyStartInstruction::is_cross_program_prefetch().
   bool is_cross_program_prefetch() const;
 
diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
index 7ca3a7d..4ea40d2 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
@@ -1798,8 +1798,13 @@
 
   auto module = CreateNewVerifiedModule();
   auto* computation = module->AddEntryComputation(builder.Build());
+  constexpr char kParallelThreadName[] = "parallel_thread";
+  computation->SetThreadName(kParallelThreadName);
   HloInstruction* fusion = computation->CreateFusionInstruction(
       {dot, reshape}, HloInstruction::FusionKind::kLoop);
+  fusion->set_called_computations_thread_name(
+      kParallelThreadName,
+      /*skip_async_thread_name_overwrite*/ false);
 
   const std::string expected_fusion =
       R"(f32[5,20]{1,0} fusion(f32[5,10]{1,0}, f32[20,10]{1,0}), kind=kLoop, calls=
@@ -1808,7 +1813,7 @@
   tmp_1 = f32[20,10]{1,0} parameter(1)
   tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
   ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-})";
+}, thread_name="parallel_thread")";
   EXPECT_EQ(fusion->ToString(options), expected_fusion);
 }
 
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index 19825a1..051baf1 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -80,6 +80,28 @@
           }),
       "}");
 }
+
+void SetThreadName(HloComputation* called_computation,
+                   const std::optional<std::string>& thread_name,
+                   bool skip_async_thread_name_overwrite) {
+  called_computation->SetThreadName(thread_name);
+  for (HloInstruction* instr : called_computation->instructions()) {
+    if (instr->IsAsynchronous()) {
+      if (!skip_async_thread_name_overwrite) {
+        // Set async instruction thread name and also recursively set async
+        // computations.
+        instr->set_async_thread_name(thread_name);
+      }
+      continue;
+    }
+    for (HloComputation* nested_called_computation :
+         instr->called_computations()) {
+      SetThreadName(nested_called_computation, thread_name,
+                    skip_async_thread_name_overwrite);
+    }
+  }
+}
+
 }  // namespace
 
 HloBatchNormInstruction::HloBatchNormInstruction(
@@ -328,21 +350,8 @@
 void HloAsyncInstruction::set_async_thread_name(
     const std::optional<std::string>& async_thread_name) {
   async_thread_name_ = async_thread_name;
-  // Recursively sets all called computation to have same thread name.
-  std::function<void(HloComputation*, std::optional<std::string>)>
-      set_computation_thread_name =
-          [&](HloComputation* called_computation,
-              std::optional<std::string> async_thread_name) {
-            called_computation->SetThreadName(async_thread_name);
-            for (HloInstruction* instr : called_computation->instructions()) {
-              for (HloComputation* nested_called_computation :
-                   instr->called_computations()) {
-                set_computation_thread_name(nested_called_computation,
-                                            async_thread_name);
-              }
-            }
-          };
-  set_computation_thread_name(async_wrapped_computation(), async_thread_name);
+  SetThreadName(async_wrapped_computation(), async_thread_name,
+                /*skip_async_thread_name_overwrite=*/false);
 }
 
 HloInstructionProto HloAsyncInstruction::ToProto() const {
@@ -1740,6 +1749,14 @@
   return clone;
 }
 
+void HloCallableInstruction::RecursivelySetComputationsThreadName(
+    std::optional<std::string> thread_name,
+    bool skip_async_thread_name_overwrite) {
+  for (HloComputation* comp : called_computations()) {
+    SetThreadName(comp, thread_name, skip_async_thread_name_overwrite);
+  }
+}
+
 HloFusionInstruction::HloFusionInstruction(const Shape& shape,
                                            FusionKind fusion_kind,
                                            HloInstruction* fused_root)
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h
index 309cfd6..fd2224c 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.h
+++ b/tensorflow/compiler/xla/service/hlo_instructions.h
@@ -983,6 +983,13 @@
 
   HloInstruction* called_computation_root() const;
 
+  // Recursively sets all nested called computation to have thread name as
+  // `thread_name`. if `skip_async_thread_name_overwrite` is true, skip
+  // overwrite async instruction and its comptuations thread name overwriting.
+  void RecursivelySetComputationsThreadName(
+      std::optional<std::string> thread_name,
+      bool skip_async_thread_name_overwrite);
+
  protected:
   // Returns the default called computation name.
   virtual std::string default_called_computation_name() const = 0;
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index 0915450..13b7746 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -24,6 +24,7 @@
 #include "absl/container/flat_hash_set.h"
 #include "absl/strings/str_cat.h"
 #include "absl/strings/str_join.h"
+#include "absl/strings/string_view.h"
 #include "tensorflow/compiler/xla/comparison_util.h"
 #include "tensorflow/compiler/xla/permutation_util.h"
 #include "tensorflow/compiler/xla/primitive_util.h"
@@ -120,6 +121,27 @@
   }
 }
 
+Status CheckNestedComputationThreadNameEqual(const HloComputation* comp,
+                                             bool skip_nested_async_op_check) {
+  std::optional<absl::string_view> thread_name = comp->thread_name();
+  for (const HloInstruction* instr : comp->instructions()) {
+    if (skip_nested_async_op_check && instr->IsAsynchronous()) {
+      continue;
+    }
+    for (const HloComputation* cmp : instr->called_computations()) {
+      if (cmp->thread_name() != thread_name) {
+        return InternalError(
+            "Nested computations expects same computation's thread name (%s vs "
+            "%s).",
+            thread_name ? absl::StrCat(*thread_name) : "none",
+            cmp->thread_name() ? absl::StrCat(*cmp->thread_name()) : "none");
+      }
+      TF_RETURN_IF_ERROR(CheckNestedComputationThreadNameEqual(
+          cmp, skip_nested_async_op_check));
+    }
+  }
+  return Status::OK();
+}
 }  // namespace
 
 Status ShapeVerifier::Preprocess(HloInstruction* hlo) {
@@ -1382,11 +1404,56 @@
   }
   return Status::OK();
 }
+
+Status CheckAsyncOpComputationThreadName(const HloInstruction* async_op) {
+  std::optional<absl::string_view> async_thread_name =
+      async_op->async_thread_name();
+  if (async_thread_name !=
+      async_op->async_wrapped_computation()->thread_name()) {
+    return InternalError(
+        "async-start expects same async thread name as wrapped computation's "
+        "thread name (%s vs %s).",
+        async_thread_name ? absl::StrCat(*async_thread_name) : "none",
+        async_op->async_wrapped_computation()->thread_name()
+            ? absl::StrCat(
+                  *async_op->async_wrapped_computation()->thread_name())
+            : "none");
+  }
+  return CheckNestedComputationThreadNameEqual(
+      async_op->async_wrapped_computation(),
+      /*skip_nested_async_op_check=*/false);
+}
+
+// TODO(b/229887502): apply CheckCallableInstructionThreadName to all
+// CallableInstructions verifier.
+Status CheckCallableInstructionThreadName(const HloInstruction* instruction,
+                                          bool skip_nested_async_op_check) {
+  for (const HloComputation* computation : instruction->called_computations()) {
+    if (instruction->parent() != nullptr) {
+      if (instruction->parent()->thread_name() != computation->thread_name()) {
+        return InternalError(
+            "callable instruction %s expects parent computation thread name "
+            "same as called computation's thread name (%s vs %s).",
+            instruction->ToString(),
+            instruction->parent()->thread_name()
+                ? absl::StrCat(*instruction->parent()->thread_name())
+                : "none",
+            computation->thread_name()
+                ? absl::StrCat(*computation->thread_name())
+                : "none");
+      }
+    }
+    TF_RETURN_IF_ERROR(CheckNestedComputationThreadNameEqual(
+        computation, skip_nested_async_op_check));
+  }
+  return Status::OK();
+}
 }  // namespace
 
 Status ShapeVerifier::HandleAsyncStart(HloInstruction* async_start) {
   TF_RETURN_IF_ERROR(
       CheckAsyncOpComputationShapes(async_start, async_start->shape()));
+  TF_RETURN_IF_ERROR(CheckAsyncOpComputationThreadName(async_start));
   const Shape& param_shape = async_start->shape().tuple_shapes(0);
   for (int i = 0; i < async_start->operand_count(); ++i) {
     if (param_shape.tuple_shapes(i) != async_start->operand(i)->shape()) {
@@ -1402,6 +1469,7 @@
 }
 
 Status ShapeVerifier::HandleAsyncUpdate(HloInstruction* async_update) {
+  TF_RETURN_IF_ERROR(CheckAsyncOpComputationThreadName(async_update));
   if (async_update->operand(0)->shape() != async_update->shape()) {
     return InternalError(
         "The %s expects the shape of operand and output to match (%s vs %s).",
@@ -1415,6 +1483,7 @@
 }
 
 Status ShapeVerifier::HandleAsyncDone(HloInstruction* async_done) {
+  TF_RETURN_IF_ERROR(CheckAsyncOpComputationThreadName(async_done));
   TF_RETURN_IF_ERROR(CheckAsyncOpComputationShapes(
       async_done, async_done->operand(0)->shape()));
   const Shape& root_shape = async_done->operand(0)->shape().tuple_shapes(1);
@@ -2294,6 +2363,8 @@
   Status DefaultAction(HloInstruction*) override { return OkStatus(); }
 
   Status HandleFusion(HloInstruction* fusion) override {
+    TF_RETURN_IF_ERROR(CheckCallableInstructionThreadName(
+        fusion, /*skip_nested_async_op_check*/ false));
     return CheckFusionInstruction(fusion);
   }
 
diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc
index 21e4b83..c00b39b 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc
@@ -1863,6 +1863,56 @@
               HasSubstr("Fused computation shape"));
 }
 
+TEST_F(HloVerifierTest, FusionThreadVerifier) {
+  const char* const kModuleStr = R"(
+  HloModule test
+
+  fused_computation {
+    ROOT p0 = f32[8,12] parameter(0)
+  }, thread_name="parallel_thread"
+
+  ENTRY entry {
+    p0 = f32[8,12] parameter(0)
+    ROOT out = f32[8,12] fusion(p0), kind=kInput, calls=fused_computation
+  }
+  )";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnUnverifiedModule(kModuleStr));
+  EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
+              HasSubstr("expects parent computation thread name same as called "
+                        "computation's thread name"));
+}
+
+TEST_F(HloVerifierTest, FusionNestedComputationThreadVerifier) {
+  const char* const kModuleStr = R"(
+  HloModule test
+
+  add {
+    lhs = f32[] parameter(0)
+    rhs = f32[] parameter(1)
+    ROOT add = f32[] add(lhs, rhs)
+  }, thread_name="parallel_thread"
+
+  fused_computation {
+    p0 = f32[8,12] parameter(0)
+    p1 = f32[8,12] parameter(1)
+    crs0 = f32[8,12] all-reduce(p1), replica_groups={}, to_apply=add
+    ROOT result = add(p0, crs0)
+  }
+
+  ENTRY entry {
+    p0 = f32[8,12] parameter(0)
+    p1 = f32[8,12] parameter(1)
+    ROOT out = f32[8,12] fusion(p0, p1), kind=kInput, calls=fused_computation
+  }
+  )";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnUnverifiedModule(kModuleStr));
+  EXPECT_THAT(
+      verifier().Run(module.get()).status().error_message(),
+      HasSubstr("Nested computations expects same computation's thread name"));
+}
+
 TEST_F(HloVerifierTest, AllReduceVerifier) {
   const char* const kModuleStr = R"(
   HloModule test