[XLA] Handle Literal property and custom-call API version in HloInstruction::Identical.

Previously we ignored these properties, which would lead us to incorrectly CSE
custom-calls.

Also add tests for these and the other properties that might appear on a
custom-call instruction.

PiperOrigin-RevId: 438620723
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index bfd1014..feb973b 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -4705,6 +4705,7 @@
         "//tensorflow/compiler/xla/tests:xla_internal_test_main",
         "//tensorflow/core:lib",
         "@com_google_absl//absl/memory",
+        "@com_google_absl//absl/strings",
     ],
 )
 
diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc
index 716d4d2..4fb080e 100644
--- a/tensorflow/compiler/xla/service/hlo_cse_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc
@@ -21,6 +21,7 @@
 #include <vector>
 
 #include "absl/memory/memory.h"
+#include "absl/strings/substitute.h"
 #include "tensorflow/compiler/xla/layout_util.h"
 #include "tensorflow/compiler/xla/literal.h"
 #include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -760,5 +761,142 @@
   EXPECT_FALSE(changed);
 }
 
+class HloCseCustomCallTest
+    : public HloCseTest,
+      public ::testing::WithParamInterface<std::tuple<
+          std::string /*op1*/, std::string /*op2*/, bool /*should_cse*/>> {};
+
+TEST_P(HloCseCustomCallTest, DoIt) {
+  std::string op1 = std::get<0>(GetParam());
+  std::string op2 = std::get<1>(GetParam());
+  bool should_cse = std::get<2>(GetParam());
+
+  const char* const hlo_string_tmpl = R"(
+    HloModule m
+    ENTRY entry {
+      p0 = f32[1,1,1] parameter(0)
+
+      op0 = $0
+      op1 = $0
+      op2 = $1
+      ROOT root = tuple(op0, op1, op2)
+    }
+  )";
+  std::string hlo_string = absl::Substitute(hlo_string_tmpl, op1, op2);
+  SCOPED_TRACE(absl::StrCat("Module before CSE:\n", hlo_string));
+
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
+  HloCSE cse(/*is_layout_sensitive=*/false);
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&cse, m.get()));
+
+  SCOPED_TRACE(absl::StrCat("Module after CSE:\n", m->ToString()));
+  EXPECT_EQ(changed, true);  // we always CSE op0 and op1, which are identical.
+  HloInstruction* root = m->entry_computation()->root_instruction();
+  EXPECT_EQ(root->operand(0), root->operand(1))
+      << "Identical ops should be CSE'ed";
+  if (should_cse) {
+    EXPECT_EQ(root->operand(0), root->operand(2)) << "Ops should be CSE'ed";
+  } else {
+    EXPECT_NE(root->operand(0), root->operand(2)) << "Ops should not be CSE'ed";
+  }
+}
+
+static std::vector<
+    std::tuple<std::string /*op1*/, std::string /*op2*/, bool /*should_cse*/>>
+CustomCallTests() {
+  auto build = [](absl::string_view args1, absl::string_view args2) {
+    absl::string_view prefix =
+        "f32[] custom-call(p0), custom_call_target=\"foo\", ";
+    return std::make_tuple(absl::StrCat(prefix, args1),
+                           absl::StrCat(prefix, args2), false);
+  };
+  return {
+      {
+          // metadata shouldn't prevent CSE
+          "f32[] custom-call(p0), custom_call_target=\"foo\"",
+          "f32[] custom-call(p0), custom_call_target=\"foo\", "
+          "metadata={op_name=\"bar\"}",
+          true,
+      },
+      {
+          "f32[] custom-call(p0), custom_call_target=\"foo\"",
+          "f32[] custom-call(p0, p0), custom_call_target=\"foo\"",
+          false,
+      },
+      {
+          "f32[1] custom-call(p0), custom_call_target=\"foo\"",
+          "f32[2] custom-call(p0), custom_call_target=\"foo\"",
+          false,
+      },
+      {
+          "f32[] custom-call(p0), custom_call_target=\"foo\"",
+          "f32[] custom-call(p0), custom_call_target=\"bar\"",
+          false,
+      },
+
+      build("window={size=1}", "window={size=2}"),
+      build("dim_labels=b0f_0oi->b0f", "dim_labels=b0f_0oi->bf0"),
+      build("backend_config=\"foo\"", "backend_config=\"bar\""),
+      build("literal=s32[] 0", "literal=s32[] 1"),
+      build("literal=s32[] 0", "literal=f32[] 0"),
+      build("operand_precision={high,default}",
+            "operand_precision={high, high}"),
+      build("api_version=API_VERSION_STATUS_RETURNING",
+            "api_version=API_VERSION_ORIGINAL"),
+      build("feature_group_count=0", "feature_group_count=1"),
+  };
+}
+
+INSTANTIATE_TEST_SUITE_P(HloCseCustomCallTestSuite, HloCseCustomCallTest,
+                         ::testing::ValuesIn(CustomCallTests()));
+
+TEST_F(HloCseTest, CustomCallCalledComputations) {
+  const char* const hlo_string = R"(
+    HloModule m
+
+    comp {
+      lhs = f32[] parameter(0)
+      rhs = f32[] parameter(1)
+      ROOT maximum = f32[] maximum(lhs, rhs)
+    }
+
+    ENTRY entry {
+      p0 = f32[] parameter(0)
+
+      op0 = f32[] custom-call(p0), custom_call_target="foo", called_computations={comp}
+      op1 = f32[] custom-call(p0), custom_call_target="foo", called_computations={comp, comp}
+      ROOT root = tuple(op0, op1)
+    }
+  )";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
+  HloCSE cse(/*is_layout_sensitive=*/false);
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&cse, m.get()));
+
+  SCOPED_TRACE(absl::StrCat("Module after CSE:\n", m->ToString()));
+  EXPECT_EQ(changed, false);
+}
+
+TEST_F(HloCseTest, CustomCallSideEffects) {
+  const char* const hlo_string = R"(
+    HloModule m
+
+    ENTRY entry {
+      p0 = f32[] parameter(0)
+
+      op0 = f32[] custom-call(p0), custom_call_target="foo", custom_call_has_side_effect=true
+      op1 = f32[] custom-call(p0), custom_call_target="foo", custom_call_has_side_effect=true
+      ROOT root = tuple(op0, op1)
+    }
+  )";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
+  HloCSE cse(/*is_layout_sensitive=*/false);
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&cse, m.get()));
+
+  SCOPED_TRACE(absl::StrCat("Module after CSE:\n", m->ToString()));
+  EXPECT_EQ(changed, false);
+}
+
 }  // namespace
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index e60dc69..f25c4b3 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -2760,12 +2760,14 @@
   if (custom_call_schedule_ != casted_other.custom_call_schedule()) {
     return false;
   }
-  if (HasLiteral() == casted_other.HasLiteral()) {
-    if (HasLiteral() && literal() == casted_other.literal()) {
-      return false;
-    }
-  } else {
-    return true;
+  if (HasLiteral() != casted_other.HasLiteral()) {
+    return false;
+  }
+  if (HasLiteral() && literal() != casted_other.literal()) {
+    return false;
+  }
+  if (api_version_ != casted_other.api_version_) {
+    return false;
   }
   // Note: backend_config comparison is done in Identical, which is the
   // intended/exposed way to compare computations, and so not repeated here.