[XLA] Handle Literal property and custom-call API version in HloInstruction::Identical.
Previously we ignored these properties, which would lead us to incorrectly CSE
custom-calls.
Also add tests for these and the other properties that might appear on a
custom-call instruction.
PiperOrigin-RevId: 438620723
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index bfd1014..feb973b 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -4705,6 +4705,7 @@
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
)
diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc
index 716d4d2..4fb080e 100644
--- a/tensorflow/compiler/xla/service/hlo_cse_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc
@@ -21,6 +21,7 @@
#include <vector>
#include "absl/memory/memory.h"
+#include "absl/strings/substitute.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -760,5 +761,142 @@
EXPECT_FALSE(changed);
}
+class HloCseCustomCallTest
+ : public HloCseTest,
+ public ::testing::WithParamInterface<std::tuple<
+ std::string /*op1*/, std::string /*op2*/, bool /*should_cse*/>> {};
+
+TEST_P(HloCseCustomCallTest, DoIt) {
+ std::string op1 = std::get<0>(GetParam());
+ std::string op2 = std::get<1>(GetParam());
+ bool should_cse = std::get<2>(GetParam());
+
+ const char* const hlo_string_tmpl = R"(
+ HloModule m
+ ENTRY entry {
+ p0 = f32[1,1,1] parameter(0)
+
+ op0 = $0
+ op1 = $0
+ op2 = $1
+ ROOT root = tuple(op0, op1, op2)
+ }
+ )";
+ std::string hlo_string = absl::Substitute(hlo_string_tmpl, op1, op2);
+ SCOPED_TRACE(absl::StrCat("Module before CSE:\n", hlo_string));
+
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
+ HloCSE cse(/*is_layout_sensitive=*/false);
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&cse, m.get()));
+
+ SCOPED_TRACE(absl::StrCat("Module after CSE:\n", m->ToString()));
+ EXPECT_EQ(changed, true); // we always CSE op0 and op1, which are identical.
+ HloInstruction* root = m->entry_computation()->root_instruction();
+ EXPECT_EQ(root->operand(0), root->operand(1))
+ << "Identical ops should be CSE'ed";
+ if (should_cse) {
+ EXPECT_EQ(root->operand(0), root->operand(2)) << "Ops should be CSE'ed";
+ } else {
+ EXPECT_NE(root->operand(0), root->operand(2)) << "Ops should not be CSE'ed";
+ }
+}
+
+static std::vector<
+ std::tuple<std::string /*op1*/, std::string /*op2*/, bool /*should_cse*/>>
+CustomCallTests() {
+ auto build = [](absl::string_view args1, absl::string_view args2) {
+ absl::string_view prefix =
+ "f32[] custom-call(p0), custom_call_target=\"foo\", ";
+ return std::make_tuple(absl::StrCat(prefix, args1),
+ absl::StrCat(prefix, args2), false);
+ };
+ return {
+ {
+ // metadata shouldn't prevent CSE
+ "f32[] custom-call(p0), custom_call_target=\"foo\"",
+ "f32[] custom-call(p0), custom_call_target=\"foo\", "
+ "metadata={op_name=\"bar\"}",
+ true,
+ },
+ {
+ "f32[] custom-call(p0), custom_call_target=\"foo\"",
+ "f32[] custom-call(p0, p0), custom_call_target=\"foo\"",
+ false,
+ },
+ {
+ "f32[1] custom-call(p0), custom_call_target=\"foo\"",
+ "f32[2] custom-call(p0), custom_call_target=\"foo\"",
+ false,
+ },
+ {
+ "f32[] custom-call(p0), custom_call_target=\"foo\"",
+ "f32[] custom-call(p0), custom_call_target=\"bar\"",
+ false,
+ },
+
+ build("window={size=1}", "window={size=2}"),
+ build("dim_labels=b0f_0oi->b0f", "dim_labels=b0f_0oi->bf0"),
+ build("backend_config=\"foo\"", "backend_config=\"bar\""),
+ build("literal=s32[] 0", "literal=s32[] 1"),
+ build("literal=s32[] 0", "literal=f32[] 0"),
+ build("operand_precision={high,default}",
+ "operand_precision={high, high}"),
+ build("api_version=API_VERSION_STATUS_RETURNING",
+ "api_version=API_VERSION_ORIGINAL"),
+ build("feature_group_count=0", "feature_group_count=1"),
+ };
+}
+
+INSTANTIATE_TEST_SUITE_P(HloCseCustomCallTestSuite, HloCseCustomCallTest,
+ ::testing::ValuesIn(CustomCallTests()));
+
+TEST_F(HloCseTest, CustomCallCalledComputations) {
+ const char* const hlo_string = R"(
+ HloModule m
+
+ comp {
+ lhs = f32[] parameter(0)
+ rhs = f32[] parameter(1)
+ ROOT maximum = f32[] maximum(lhs, rhs)
+ }
+
+ ENTRY entry {
+ p0 = f32[] parameter(0)
+
+ op0 = f32[] custom-call(p0), custom_call_target="foo", called_computations={comp}
+ op1 = f32[] custom-call(p0), custom_call_target="foo", called_computations={comp, comp}
+ ROOT root = tuple(op0, op1)
+ }
+ )";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
+ HloCSE cse(/*is_layout_sensitive=*/false);
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&cse, m.get()));
+
+ SCOPED_TRACE(absl::StrCat("Module after CSE:\n", m->ToString()));
+ EXPECT_EQ(changed, false);
+}
+
+TEST_F(HloCseTest, CustomCallSideEffects) {
+ const char* const hlo_string = R"(
+ HloModule m
+
+ ENTRY entry {
+ p0 = f32[] parameter(0)
+
+ op0 = f32[] custom-call(p0), custom_call_target="foo", custom_call_has_side_effect=true
+ op1 = f32[] custom-call(p0), custom_call_target="foo", custom_call_has_side_effect=true
+ ROOT root = tuple(op0, op1)
+ }
+ )";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
+ HloCSE cse(/*is_layout_sensitive=*/false);
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&cse, m.get()));
+
+ SCOPED_TRACE(absl::StrCat("Module after CSE:\n", m->ToString()));
+ EXPECT_EQ(changed, false);
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index e60dc69..f25c4b3 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -2760,12 +2760,14 @@
if (custom_call_schedule_ != casted_other.custom_call_schedule()) {
return false;
}
- if (HasLiteral() == casted_other.HasLiteral()) {
- if (HasLiteral() && literal() == casted_other.literal()) {
- return false;
- }
- } else {
- return true;
+ if (HasLiteral() != casted_other.HasLiteral()) {
+ return false;
+ }
+ if (HasLiteral() && literal() != casted_other.literal()) {
+ return false;
+ }
+ if (api_version_ != casted_other.api_version_) {
+ return false;
}
// Note: backend_config comparison is done in Identical, which is the
// intended/exposed way to compare computations, and so not repeated here.