[XLA/GPU] Polish coding styles.
diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc
index 3c920a0..39d8a9e 100644
--- a/tensorflow/compiler/xla/service/copy_insertion_test.cc
+++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc
@@ -2912,8 +2912,6 @@
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_string));
- // Set up the aliasing manually which normally would be set by
- // alias_passthrough_params pass.
ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
/*output_index=*/{0},
/*param_number=*/0,
diff --git a/tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.cc b/tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.cc
index 36bdf1d..da04f4f 100644
--- a/tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.cc
+++ b/tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.cc
@@ -200,17 +200,12 @@
bool AnyOpndIsParamSharedAmongFusions(
const HloInstruction* instr,
const absl::flat_hash_set<HloInstruction*>& fusion_instrs) {
- for (const HloInstruction* opnd : instr->operands()) {
- if (opnd->opcode() != HloOpcode::kParameter) {
- continue;
- }
- for (const HloInstruction* user : opnd->users()) {
- if (user != instr && fusion_instrs.contains(user)) {
- return true;
- }
- }
- }
- return false;
+ return absl::c_any_of(instr->operands(), [&](const HloInstruction* opnd) {
+ return opnd->opcode() == HloOpcode::kParameter &&
+ absl::c_any_of(opnd->users(), [&](const HloInstruction* user) {
+ return user != instr && fusion_instrs.contains(user);
+ });
+ });
}
void HorizontalLoopFusionImpl::FusionCandidates::Initialize(
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
index 549579d..88776125 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
@@ -142,15 +142,21 @@
});
}
-// Returns whether the concat is used in an elementwise manner.
-// A concat followed by slices is considered (effectively) elementwise if the
-// slices combinedly is a reverse function of the concat.
-// prev_concat_opnd_idx: previously seen concat and its operand index.
-// prev_concat_opnd_idx.first is nullptr if no previously seen concat.
-bool ConcatIsEffectivelyElementwise(
+struct ConcatUsageInfo {
+ // Pointer to a previously seen concat. nullptr if no previously seen concat.
+ const HloInstruction* prev_concat;
+ // The opnd id of the seen concat.
+ int64 concat_opnd_idx;
+ // The slice that recovers the opnd in the concat outputs.
+ const HloInstruction* slice_to_recover_opnd;
+};
+
+// Returns an optional concat usage info to denote whether the concat is used in
+// an elementwise manner. A concat followed by slices is considered effectively
+// elementwise if the slices combinedly is a reverse function of the concat.
+absl::optional<ConcatUsageInfo> ConcatIsEffectivelyElementwise(
const HloInstruction& concat, const HloInstruction& operand,
- std::pair<const HloInstruction*, int64>* prev_concat_opnd_idx,
- const HloInstruction** slice_to_recover_opnd) {
+ const ConcatUsageInfo& info) {
// First, check if this concat is in the below pattern. Also, we check
// that the slices combinedly are in effect a reverse function of the concat.
//
@@ -160,16 +166,14 @@
// Slice Slice
//
std::vector<HloInstruction*> users = concat.users();
- if (!absl::c_all_of(users, [](const HloInstruction* i) {
- return Is1dSliceWithoutStrides(i);
- })) {
+ if (!absl::c_all_of(users, Is1dSliceWithoutStrides)) {
// Limit our supported cases to 1 dimensional slices.
- return false;
+ return absl::optional<ConcatUsageInfo>();
}
// Verify that each operand to the concat is reversed by a slice.
if (users.size() != concat.operand_count() ||
concat.operand_count() != concat.unique_operands().size()) {
- return false;
+ return absl::optional<ConcatUsageInfo>();
}
absl::c_sort(users, [](const HloInstruction* a, const HloInstruction* b) {
return a->slice_starts().at(0) < b->slice_starts().at(0);
@@ -180,33 +184,32 @@
int64 slice_size = u->slice_limits().at(0) - u->slice_starts().at(0);
if (u->slice_starts().at(0) != prev_limit ||
slice_size != ShapeUtil::ElementsIn(concat.operand(i)->shape())) {
- return false;
+ return absl::optional<ConcatUsageInfo>();
}
prev_limit = u->slice_limits().at(0);
}
// If we have seen other concats, make sure they are identical. Multiple
// concats exist because horizontal fusion inserts one concat for each output
- // of the fusion candidates. Check all concats and operand ids are the same
- // to make sure that the compute iteration space is the same.
+ // of the fusion candidates. Check that all concats and operand ids are the
+ // same to know that the "transitive use closure" will be computed in the same
+ // iteration space.
int64 operand_idx = concat.operand_index(&operand);
- *slice_to_recover_opnd = users.at(operand_idx);
- if (prev_concat_opnd_idx->first == nullptr) {
- prev_concat_opnd_idx->first = &concat;
- prev_concat_opnd_idx->second = operand_idx;
- } else {
- bool is_concat_identical = prev_concat_opnd_idx->first->Identical(
+ if (info.prev_concat != nullptr) {
+ bool is_concat_identical = info.prev_concat->Identical(
concat,
/*eq_operands=*/[](const HloInstruction*, const HloInstruction*) {
// Operands don't need to be the same.
return true;
});
- if (!is_concat_identical || prev_concat_opnd_idx->second != operand_idx) {
- return false;
+ if (!is_concat_identical || info.concat_opnd_idx != operand_idx) {
+ return absl::optional<ConcatUsageInfo>();
}
}
- return true;
+ const HloInstruction* slice_to_recover_opnd = users.at(operand_idx);
+ return absl::optional<ConcatUsageInfo>(
+ ConcatUsageInfo{&concat, operand_idx, slice_to_recover_opnd});
}
// Returns whether we can prove the transitive uses of `param` are in effect
@@ -224,7 +227,7 @@
absl::flat_hash_set<const HloInstruction*> visited;
absl::InlinedVector<const HloInstruction*, 4> stack;
stack.push_back(param);
- std::pair<const HloInstruction*, int64> prev_concat_opnd_idx(nullptr, 0);
+ ConcatUsageInfo concat_usage_info{nullptr, 0, nullptr};
bool is_output_reachable = false;
while (!stack.empty()) {
const HloInstruction* current = stack.back();
@@ -247,18 +250,19 @@
}
break;
case HloOpcode::kConcatenate: {
- const HloInstruction* slice_to_recover_opnd = nullptr;
- if (!ConcatIsEffectivelyElementwise(*user, *current,
- &prev_concat_opnd_idx,
- &slice_to_recover_opnd)) {
+ absl::optional<ConcatUsageInfo> optional_concat_info =
+ ConcatIsEffectivelyElementwise(*user, *current,
+ concat_usage_info);
+ if (!optional_concat_info) {
return false;
}
+ concat_usage_info = *optional_concat_info;
// Early continue as we only want to traverse through the slice that
// recovers the operand. It is guaranteed that the operand to the
// concat and the slice have the same iteration space. Insert the
// slice instead of the concat.
- CHECK(!visited.contains(slice_to_recover_opnd));
- stack.push_back(slice_to_recover_opnd);
+ CHECK(!visited.contains(concat_usage_info.slice_to_recover_opnd));
+ stack.push_back(concat_usage_info.slice_to_recover_opnd);
continue;
}
default: