blob: 8290b44fa8470ecebd6f9be79e9db82e08f56a12 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/spmd/canonicalize_all_gather_for_cse.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_query.h"
namespace xla {
namespace {
// Returns if an instructions adds only degenerate dimensions to the shape of
// the input, like going from [X,Y] to [1,X,Y,1].
bool IsAddingOnlyDegenerateDimensions(const HloInstruction* inst) {
if (inst->opcode() != HloOpcode::kBitcast &&
inst->opcode() != HloOpcode::kReshape) {
return false;
}
const Shape& in_shape = inst->operand(0)->shape();
const Shape& out_shape = inst->shape();
return ShapeUtil::ElementsIn(in_shape) == ShapeUtil::ElementsIn(out_shape) &&
ShapeUtil::DimensionsUnmodifiedByReshape(in_shape, out_shape).size() ==
in_shape.rank();
}
} // namespace
StatusOr<bool> CanonicalizeAllGatherForCSE::RunOnComputation(
HloComputation* comp) {
bool changed = false;
// Helper to find the respective shape input dimension of an shape output
// dimension of a reshape.
std::vector<HloInstruction*> ordered_hlos = comp->MakeInstructionPostOrder();
for (HloInstruction* hlo : ordered_hlos) {
HloAllGatherInstruction* ag = DynCast<HloAllGatherInstruction>(hlo);
// Only supporting AllGather on dimension 0 as it's the only case currently
// happening and additional cases needs more complexity.
if (!ag || ag->all_gather_dimension() != 0) {
continue;
}
HloInstruction* real_data = ag->mutable_operand(0);
const int64 ag_dim = ag->all_gather_dimension();
const Shape& out_shape = ag->shape();
const Shape& in_shape = ag->operand(0)->shape();
CHECK_EQ(out_shape.dimensions(ag_dim) % in_shape.dimensions(ag_dim), 0);
const int64 all_gather_participants =
out_shape.dimensions(ag_dim) / in_shape.dimensions(ag_dim);
// Look through bitcast/bitcast-like reshapes, keeping track of the position
// of the all-gather dimension through the reshapes (should stay 0 or become
// -1 if the dimension has been added from a reshape we have passed through)
while (IsAddingOnlyDegenerateDimensions(real_data)) {
real_data = real_data->mutable_operand(0);
}
// If we looked through some reshapes and there's more than just one reshape
// adding the dimension the all-gather is operating on then perform the
// canonicalization.
if (real_data != ag->operand(0)) {
std::vector<int64> new_dimensions;
new_dimensions.reserve(real_data->shape().dimensions_size() + 1);
new_dimensions.push_back(1);
new_dimensions.insert(new_dimensions.end(),
real_data->shape().dimensions().begin(),
real_data->shape().dimensions().end());
// Adding specialized all-gather dimension.
HloInstruction* ag_input =
comp->AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(real_data->shape().element_type(),
new_dimensions),
real_data));
new_dimensions[0] = all_gather_participants;
absl::optional<int64> new_channel_id =
ag->channel_id() ? absl::make_optional(this->NextChannelId())
: absl::nullopt;
HloInstruction* new_ag =
comp->AddInstruction(HloInstruction::CreateAllGather(
ShapeUtil::MakeShape(real_data->shape().element_type(),
new_dimensions),
ag_input, /*all_gather_dimension=*/0, ag->replica_groups(),
ag->constrain_layout(), new_channel_id,
ag->use_global_device_ids()));
HloInstruction* new_formatting = comp->AddInstruction(
HloInstruction::CreateReshape(ag->shape(), new_ag));
TF_RETURN_IF_ERROR(ag->ReplaceAllUsesWith(new_formatting));
TF_RETURN_IF_ERROR(comp->RemoveInstructionAndUnusedOperands(ag));
changed = true;
}
}
return changed;
}
StatusOr<bool> CanonicalizeAllGatherForCSE::Run(HloModule* module) {
bool changed = false;
next_channel_id_ = hlo_query::NextChannelId(*module);
for (HloComputation* comp : module->computations()) {
TF_ASSIGN_OR_RETURN(bool comp_changed, RunOnComputation(comp));
changed |= comp_changed;
}
return changed;
}
} // namespace xla