blob: ad58833e4d237aad366ac3e1ce412853598607e1 [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include <algorithm>
#include <ostream>
#include <set>
#include <unordered_set>
#include <utility>
#include "absl/algorithm/container.h"
#include "absl/container/inlined_vector.h"
#include "absl/memory/memory.h"
#include "absl/strings/ascii.h"
#include "absl/strings/escaping.h"
#include "absl/strings/numbers.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/protobuf_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/name_uniquer.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/human_readable_json.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
using absl::CEscape;
using absl::StrAppend;
using absl::StrCat;
using absl::StrJoin;
/* static */
StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
const HloInstructionProto& proto,
const tensorflow::gtl::FlatMap<int64, HloInstruction*>& instruction_map,
const tensorflow::gtl::FlatMap<int64, HloComputation*>& computation_map) {
TF_RET_CHECK(!proto.opcode().empty());
TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(proto.opcode()));
TF_RET_CHECK(proto.has_shape());
std::unique_ptr<HloInstruction> instruction;
const auto operands = [&instruction_map, &proto](int index) {
return instruction_map.at(proto.operand_ids(index));
};
const auto all_operands = [&instruction_map, &proto]() {
std::vector<HloInstruction*> result(proto.operand_ids_size());
std::transform(proto.operand_ids().begin(), proto.operand_ids().end(),
result.begin(), [&instruction_map](int64 operand_id) {
return instruction_map.at(operand_id);
});
return result;
};
const auto computations = [&computation_map, &proto](int index) {
return computation_map.at(proto.called_computation_ids(index));
};
switch (opcode) {
// Ops migrated to subclasses.
case HloOpcode::kBatchNormTraining:
TF_RET_CHECK(proto.operand_ids_size() == 3)
<< "BatchNormTraining instruction should have 3 operands but sees "
<< proto.operand_ids_size();
instruction = CreateBatchNormTraining(
proto.shape(), operands(0), operands(1), operands(2), proto.epsilon(),
proto.feature_index());
break;
case HloOpcode::kBatchNormInference:
TF_RET_CHECK(proto.operand_ids_size() == 5)
<< "BatchNormInference instruction should have 5 operands but sees "
<< proto.operand_ids_size();
instruction = CreateBatchNormInference(
proto.shape(), operands(0), operands(1), operands(2), operands(3),
operands(4), proto.epsilon(), proto.feature_index());
break;
case HloOpcode::kBatchNormGrad:
TF_RET_CHECK(proto.operand_ids_size() == 5)
<< "BatchNormGrad instruction should have 5 operands but sees "
<< proto.operand_ids_size();
instruction = CreateBatchNormGrad(proto.shape(), operands(0), operands(1),
operands(2), operands(3), operands(4),
proto.epsilon(), proto.feature_index());
break;
case HloOpcode::kFft: {
TF_RET_CHECK(proto.operand_ids_size() == 1)
<< "Fft instruction should have 1 operand but sees "
<< proto.operand_ids_size();
std::vector<int64> fft_length(proto.fft_length().begin(),
proto.fft_length().end());
instruction = CreateFft(proto.shape(), operands(0), proto.fft_type(),
absl::Span<const int64>(fft_length));
break;
}
case HloOpcode::kSend:
TF_RET_CHECK(proto.operand_ids_size() == 2)
<< "Send instruction should have 2 operand but sees "
<< proto.operand_ids_size();
instruction = CreateSend(operands(0), operands(1), proto.channel_id(),
proto.is_host_transfer());
break;
case HloOpcode::kSendDone:
TF_RET_CHECK(proto.operand_ids_size() == 1)
<< "SendDone instruction should have 1 operand but sees "
<< proto.operand_ids_size();
instruction = CreateSendDone(operands(0), proto.is_host_transfer());
break;
case HloOpcode::kRecv:
TF_RET_CHECK(proto.operand_ids_size() == 1)
<< "Recv instruction should have 1 operand but sees "
<< proto.operand_ids_size();
instruction = CreateRecv(proto.shape().tuple_shapes(0), operands(0),
proto.channel_id(), proto.is_host_transfer());
break;
case HloOpcode::kRecvDone:
TF_RET_CHECK(proto.operand_ids_size() == 1)
<< "RecvDone instruction should have 1 operand but sees "
<< proto.operand_ids_size();
instruction = CreateRecvDone(operands(0), proto.is_host_transfer());
break;
case HloOpcode::kReverse:
TF_RET_CHECK(proto.operand_ids_size() == 1)
<< "Reverse instruction should have 1 operand but sees "
<< proto.operand_ids_size();
instruction = CreateReverse(proto.shape(), operands(0),
std::vector<int64>(proto.dimensions().begin(),
proto.dimensions().end()));
break;
case HloOpcode::kConcatenate:
TF_RET_CHECK(proto.dimensions_size() == 1)
<< "Concatenate instruction should have 1 dimension but sees "
<< proto.dimensions_size();
instruction =
CreateConcatenate(proto.shape(), all_operands(), proto.dimensions(0));
break;
case HloOpcode::kReduce:
TF_RET_CHECK(proto.operand_ids_size() % 2 == 0)
<< "Reduce instruction should have an even number of operands but "
"sees "
<< proto.operand_ids_size();
TF_RET_CHECK(proto.called_computation_ids_size() == 1)
<< "Reduce instruction should have 1 called computation but sees "
<< proto.called_computation_ids_size();
{
const auto reduce_operands = all_operands();
auto inputs = absl::MakeSpan(reduce_operands)
.subspan(0, reduce_operands.size() / 2);
auto init_values =
absl::MakeSpan(reduce_operands)
.subspan(reduce_operands.size() / 2, reduce_operands.size());
instruction =
CreateReduce(proto.shape(), inputs, init_values,
std::vector<int64>(proto.dimensions().begin(),
proto.dimensions().end()),
computations(0));
}
break;
case HloOpcode::kSort: {
TF_RET_CHECK(proto.operand_ids_size() == 1 ||
proto.operand_ids_size() == 2)
<< "Sort instruction should have 1 or 2 operands but has "
<< proto.operand_ids_size();
TF_RET_CHECK(proto.dimensions().size() == 1)
<< "Sort instruction should have 1 dimension";
HloInstruction* keys = operands(0);
HloInstruction* values =
proto.operand_ids_size() == 2 ? operands(1) : nullptr;
instruction =
CreateSort(proto.shape(), proto.dimensions(0), keys, values);
break;
}
case HloOpcode::kTranspose:
TF_RET_CHECK(proto.operand_ids_size() == 1)
<< "Transpose instruction should have 1 operand but sees "
<< proto.operand_ids_size();
instruction =
CreateTranspose(proto.shape(), operands(0),
std::vector<int64>(proto.dimensions().begin(),
proto.dimensions().end()));
break;
case HloOpcode::kBroadcast:
TF_RET_CHECK(proto.operand_ids_size() == 1)
<< "Broadcast instruction should have 1 operand but sees "
<< proto.operand_ids_size();
instruction =
CreateBroadcast(proto.shape(), operands(0),
std::vector<int64>(proto.dimensions().begin(),
proto.dimensions().end()));
break;
case HloOpcode::kMap:
TF_RET_CHECK(proto.called_computation_ids_size() == 1)
<< "Map instruction should have 1 called computation but sees "
<< proto.called_computation_ids_size();
instruction = CreateMap(proto.shape(), all_operands(), computations(0));
break;
case HloOpcode::kSlice: {
TF_RET_CHECK(proto.operand_ids_size() == 1)
<< "Slice instruction should have 1 operand but sees "
<< proto.operand_ids_size();
std::vector<int64> slice_starts, slice_limits, slice_strides;
for (const HloInstructionProto::SliceDimensions& slice_dimensions :
proto.slice_dimensions()) {
slice_starts.push_back(slice_dimensions.start());
slice_limits.push_back(slice_dimensions.limit());
slice_strides.push_back(slice_dimensions.stride());
}
instruction = CreateSlice(proto.shape(), operands(0), slice_starts,
slice_limits, slice_strides);
break;
}
case HloOpcode::kConstant: {
// TODO(b/110214922): Revert this to CHECK(proto.has_literal()).
if (proto.has_literal()) {
TF_ASSIGN_OR_RETURN(auto literal,
Literal::CreateFromProto(proto.literal()));
instruction = CreateConstant(std::move(literal));
} else {
instruction = absl::make_unique<HloConstantInstruction>(proto.shape());
}
break;
}
case HloOpcode::kTrace: {
TF_RET_CHECK(proto.operand_ids_size() == 1)
<< "Trace instruction should have 1 operand but sees "
<< proto.operand_ids_size();
TF_RET_CHECK(proto.has_literal());
TF_ASSIGN_OR_RETURN(auto literal,
Literal::CreateFromProto(proto.literal()));
instruction = CreateTrace(literal.GetR1U8AsString(), operands(0));
break;
}
case HloOpcode::kFusion: {
// In the proto, fused computations are held exclusively within the
// HloInstructionProto and do not appear as an HloComputationProto within
// the HloModuleProto.
TF_RET_CHECK(!proto.fusion_kind().empty());
TF_ASSIGN_OR_RETURN(FusionKind fusion_kind,
StringToFusionKind(proto.fusion_kind()));
// Find the fused computation and set its fusion instruction.
TF_RET_CHECK(proto.called_computation_ids_size() == 1)
<< "Expect 1 called computation for fusion instruction but sees "
<< proto.called_computation_ids_size();
const int64 fusion_id = proto.called_computation_ids(0);
auto* fused_computation = FindPtrOrNull(computation_map, fusion_id);
TF_RET_CHECK(fused_computation != nullptr)
<< "No fusion computation with id " << fusion_id;
instruction = CreateFusion(proto.shape(), fusion_kind, all_operands(),
fused_computation);
break;
}
case HloOpcode::kRng:
instruction =
CreateRng(proto.shape(), proto.distribution(), all_operands());
break;
case HloOpcode::kParameter:
instruction = CreateParameter(proto.parameter_number(), proto.shape(),
proto.name());
break;
case HloOpcode::kGetTupleElement:
TF_RET_CHECK(proto.operand_ids_size() == 1)
<< "GetTupleElement instruction should have 1 operand but sees "
<< proto.operand_ids_size();
instruction = CreateGetTupleElement(proto.shape(), operands(0),
proto.tuple_index());
break;
case HloOpcode::kReducePrecision:
instruction =
CreateReducePrecision(proto.shape(), operands(0),
proto.exponent_bits(), proto.mantissa_bits());
break;
case HloOpcode::kInfeed: {
const Shape& data_shape =
ShapeUtil::GetTupleElementShape(proto.shape(), 0);
TF_RET_CHECK(proto.operand_ids_size() == 1);
instruction =
CreateInfeed(data_shape, operands(0), proto.infeed_config());
} break;
case HloOpcode::kOutfeed:
TF_RET_CHECK(proto.operand_ids_size() == 2);
instruction = CreateOutfeed(proto.outfeed_shape(), operands(0),
operands(1), proto.outfeed_config());
break;
case HloOpcode::kCrossReplicaSum: {
TF_RET_CHECK(proto.called_computation_ids_size() == 1)
<< "CrossReplicaSum should have 1 called computation but sees "
<< proto.called_computation_ids_size();
absl::optional<int64> all_reduce_id;
if (proto.all_reduce_id() > 0) {
all_reduce_id = proto.all_reduce_id();
}
instruction = CreateCrossReplicaSum(
proto.shape(), all_operands(), computations(0),
/*replica_groups=*/
std::vector<ReplicaGroup>(proto.replica_groups().begin(),
proto.replica_groups().end()),
/*barrier=*/proto.cross_replica_sum_barrier(),
/*all_reduce_id=*/all_reduce_id);
break;
}
case HloOpcode::kAllToAll: {
instruction = CreateAllToAll(
proto.shape(), all_operands(),
/*replica_groups=*/
std::vector<ReplicaGroup>(proto.replica_groups().begin(),
proto.replica_groups().end()));
break;
}
case HloOpcode::kCollectivePermute: {
std::vector<std::pair<int64, int64>> source_target_pairs(
proto.source_target_pairs_size());
for (int i = 0; i < source_target_pairs.size(); i++) {
source_target_pairs[i].first = proto.source_target_pairs(i).source();
source_target_pairs[i].second = proto.source_target_pairs(i).target();
}
instruction = CreateCollectivePermute(proto.shape(), operands(0),
source_target_pairs);
break;
}
case HloOpcode::kConvolution: {
TF_RET_CHECK(proto.operand_ids_size() == 2)
<< "Convolution instruction should have 2 operands but sees "
<< proto.operand_ids_size();
TF_RET_CHECK(proto.has_window());
TF_RET_CHECK(proto.has_convolution_dimension_numbers());
PrecisionConfig precision_config = proto.precision_config();
precision_config.mutable_operand_precision()->Resize(
proto.operand_ids_size(), PrecisionConfig::DEFAULT);
instruction = CreateConvolve(
proto.shape(), operands(0), operands(1),
std::max<int64>(proto.feature_group_count(), 1), proto.window(),
proto.convolution_dimension_numbers(), precision_config);
break;
}
case HloOpcode::kReduceWindow:
TF_RET_CHECK(proto.operand_ids_size() == 2)
<< "ReduceWindow instruction should have 2 operands but sees "
<< proto.operand_ids_size();
TF_RET_CHECK(proto.called_computation_ids_size() == 1)
<< "ReduceWindow should have 1 called computation but sees "
<< proto.called_computation_ids_size();
instruction = CreateReduceWindow(proto.shape(), operands(0), operands(1),
proto.window(), computations(0));
break;
case HloOpcode::kSelectAndScatter:
TF_RET_CHECK(proto.operand_ids_size() == 3)
<< "SelectAndScatter instruction should have 3 operands but sees "
<< proto.operand_ids_size();
TF_RET_CHECK(proto.called_computation_ids_size() == 2)
<< "SelectAndScatter should have 2 called computations but sees "
<< proto.called_computation_ids_size();
instruction = CreateSelectAndScatter(
proto.shape(), operands(0), computations(0), proto.window(),
operands(1), operands(2), computations(1));
break;
case HloOpcode::kCustomCall:
instruction = CreateCustomCall(proto.shape(), all_operands(),
proto.custom_call_target());
if (proto.has_window()) {
static_cast<HloCustomCallInstruction*>(instruction.get())
->set_window(proto.window());
}
if (proto.has_convolution_dimension_numbers()) {
static_cast<HloCustomCallInstruction*>(instruction.get())
->set_convolution_dimension_numbers(
proto.convolution_dimension_numbers());
}
static_cast<HloCustomCallInstruction*>(instruction.get())
->set_feature_group_count(
std::max(static_cast<int64>(proto.feature_group_count()), 1LL));
break;
case HloOpcode::kPad:
TF_RET_CHECK(proto.operand_ids_size() == 2)
<< "Pad instruction should have 2 operands but sees "
<< proto.operand_ids_size();
TF_RET_CHECK(proto.has_padding_config());
instruction = CreatePad(proto.shape(), operands(0), operands(1),
proto.padding_config());
break;
case HloOpcode::kDynamicSlice: {
TF_RET_CHECK(proto.operand_ids_size() == 2)
<< "DynamicSlice instruction should have 2 operands but sees "
<< proto.operand_ids_size();
std::vector<int64> slice_sizes(proto.dynamic_slice_sizes_size());
absl::c_copy(proto.dynamic_slice_sizes(), slice_sizes.begin());
instruction = CreateDynamicSlice(proto.shape(), operands(0), operands(1),
slice_sizes);
break;
}
case HloOpcode::kGather: {
TF_RET_CHECK(proto.operand_ids_size() == 2)
<< "Gather instruction should have 2 operands but sees "
<< proto.operand_ids_size();
TF_RET_CHECK(proto.has_gather_dimension_numbers())
<< "Gather instruction should have GatherDimensionNumbers set.";
std::unique_ptr<GatherDimensionNumbers> gather_dimension_numbers =
absl::make_unique<GatherDimensionNumbers>(
proto.gather_dimension_numbers());
std::vector<int64> gather_slice_sizes;
for (int64 bound : proto.gather_slice_sizes()) {
gather_slice_sizes.push_back(bound);
}
instruction = CreateGather(proto.shape(), operands(0), operands(1),
*gather_dimension_numbers, gather_slice_sizes);
break;
}
case HloOpcode::kScatter: {
TF_RET_CHECK(proto.operand_ids_size() == 3)
<< "Scatter instruction should have 3 operands but sees "
<< proto.operand_ids_size();
TF_RET_CHECK(proto.has_scatter_dimension_numbers())
<< "Scatter instruction should have ScatterDimensionNumbers set.";
TF_RET_CHECK(proto.called_computation_ids_size() == 1)
<< "Scatter instruction should have 1 called computation but sees "
<< proto.called_computation_ids_size();
auto scatter_dimension_numbers =
absl::make_unique<ScatterDimensionNumbers>(
proto.scatter_dimension_numbers());
instruction =
CreateScatter(proto.shape(), operands(0), operands(1), operands(2),
computations(0), *scatter_dimension_numbers);
break;
}
case HloOpcode::kIota:
TF_RET_CHECK(proto.dimensions_size() <= 1)
<< "Iota instruction should have at most 1 dimension but sees "
<< proto.dimensions_size();
instruction = CreateIota(proto.shape(), proto.dimensions(0));
break;
case HloOpcode::kDot: {
TF_RET_CHECK(proto.has_dot_dimension_numbers())
<< "Dot instruction should have dot_dimension_numbers.";
TF_RET_CHECK(proto.operand_ids_size() == 2)
<< "Dot instruction should have 2 operands but sees "
<< proto.operand_ids_size();
PrecisionConfig precision_config = proto.precision_config();
precision_config.mutable_operand_precision()->Resize(
proto.operand_ids_size(), PrecisionConfig::DEFAULT);
instruction = absl::make_unique<HloDotInstruction>(
proto.shape(), operands(0), operands(1),
proto.dot_dimension_numbers(), precision_config);
break;
}
case HloOpcode::kDomain:
TF_RET_CHECK(proto.operand_ids_size() == 1)
<< "Domain instruction should have 1 operands but sees "
<< proto.operand_ids_size();
instruction = absl::make_unique<HloDomainInstruction>(
proto.shape(), operands(0), /*operand_side_metadata=*/nullptr,
/*user_side_metadata=*/nullptr);
break;
default: {
instruction = absl::WrapUnique(new HloInstruction(opcode, proto.shape()));
for (const int64 operand_id : proto.operand_ids()) {
TF_RET_CHECK(ContainsKey(instruction_map, operand_id))
<< "No instruction with id " << operand_id;
instruction->AppendOperand(instruction_map.at(operand_id));
}
for (const int64 predecessor_id : proto.control_predecessor_ids()) {
TF_RET_CHECK(ContainsKey(instruction_map, predecessor_id))
<< "No instruction with id " << predecessor_id;
TF_RETURN_IF_ERROR(instruction_map.at(predecessor_id)
->AddControlDependencyTo(instruction.get()));
}
if (instruction->opcode() != HloOpcode::kFusion) {
for (const int64 computation_id : proto.called_computation_ids()) {
TF_RET_CHECK(ContainsKey(computation_map, computation_id))
<< "No computation with id " << computation_id;
instruction->called_computations_.push_back(
computation_map.at(computation_id));
}
}
TF_RET_CHECK(!proto.has_precision_config())
<< instruction->opcode() << proto.DebugString();
TF_RET_CHECK(!proto.has_dot_dimension_numbers()) << instruction->opcode();
break;
}
}
TF_RET_CHECK(!proto.name().empty());
instruction->SetAndSanitizeName(proto.name());
instruction->metadata_ = proto.metadata();
instruction->backend_config_ = proto.backend_config();
instruction->unique_id_ = proto.id();
if (proto.has_sharding()) {
TF_ASSIGN_OR_RETURN(const auto& sharding,
HloSharding::FromProto(proto.sharding()));
instruction->set_sharding(sharding);
}
return std::move(instruction);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateParameter(
int64 parameter_number, const Shape& shape, const string& name) {
return absl::make_unique<HloParameterInstruction>(parameter_number, shape,
name);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTrace(
const string& tag, HloInstruction* operand) {
return absl::make_unique<HloTraceInstruction>(tag, operand);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConstant(
Literal literal) {
return absl::make_unique<HloConstantInstruction>(std::move(literal));
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateIota(
const Shape& shape, int64 iota_dimension) {
return absl::make_unique<HloIotaInstruction>(shape, iota_dimension);
}
/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateGetTupleElement(const Shape& shape,
HloInstruction* operand, int64 index) {
return absl::make_unique<HloGetTupleElementInstruction>(shape, operand,
index);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRng(
const Shape& shape, RandomDistribution distribution,
absl::Span<HloInstruction* const> parameters) {
return absl::make_unique<HloRngInstruction>(shape, distribution, parameters);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateNary(
const Shape& shape, HloOpcode opcode,
absl::Span<HloInstruction* const> operands) {
if (opcode == HloOpcode::kCopy) {
// It is impossible to copy an opaque shape, we don't know how big it is.
CHECK(!ShapeUtil::IsOpaque(shape));
}
auto instruction = absl::WrapUnique(new HloInstruction(opcode, shape));
for (auto operand : operands) {
instruction->AppendOperand(operand);
}
return instruction;
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateUnary(
const Shape& shape, HloOpcode opcode, HloInstruction* operand) {
// Only certain opcodes are supported with CreateUnary: opcodes of unary
// instructions with no auxiliary fields.
switch (opcode) {
case HloOpcode::kAbs:
case HloOpcode::kRoundNearestAfz:
case HloOpcode::kBitcast:
case HloOpcode::kCeil:
case HloOpcode::kCopy:
case HloOpcode::kCos:
case HloOpcode::kClz:
case HloOpcode::kExp:
case HloOpcode::kExpm1:
case HloOpcode::kFloor:
case HloOpcode::kImag:
case HloOpcode::kIsFinite:
case HloOpcode::kLog:
case HloOpcode::kLog1p:
case HloOpcode::kNot:
case HloOpcode::kNegate:
case HloOpcode::kReal:
case HloOpcode::kSign:
case HloOpcode::kSin:
case HloOpcode::kTanh:
break;
default:
LOG(FATAL) << "Invalid unary instruction opcode "
<< HloOpcodeString(opcode);
}
return CreateNary(shape, opcode, {operand});
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateBinary(
const Shape& shape, HloOpcode opcode, HloInstruction* lhs,
HloInstruction* rhs) {
// Only certain opcodes are supported with CreateBinary: opcodes of binary
// instructions with no auxiliary fields.
switch (opcode) {
case HloOpcode::kAdd:
case HloOpcode::kAtan2:
case HloOpcode::kDivide:
case HloOpcode::kComplex:
case HloOpcode::kEq:
case HloOpcode::kGe:
case HloOpcode::kGt:
case HloOpcode::kLe:
case HloOpcode::kLt:
case HloOpcode::kMaximum:
case HloOpcode::kMinimum:
case HloOpcode::kMultiply:
case HloOpcode::kNe:
case HloOpcode::kPower:
case HloOpcode::kRemainder:
case HloOpcode::kSubtract:
case HloOpcode::kAnd:
case HloOpcode::kOr:
case HloOpcode::kXor:
case HloOpcode::kShiftLeft:
case HloOpcode::kShiftRightArithmetic:
case HloOpcode::kShiftRightLogical:
break;
default:
LOG(FATAL) << "Invalid binary instruction opcode "
<< HloOpcodeString(opcode);
}
return CreateNary(shape, opcode, {lhs, rhs});
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTernary(
const Shape& shape, HloOpcode opcode, HloInstruction* lhs,
HloInstruction* rhs, HloInstruction* ehs) {
// Only certain opcodes are supported with CreateTernary: opcodes of ternary
// instructions with no auxiliary fields.
switch (opcode) {
case HloOpcode::kClamp:
case HloOpcode::kSelect:
case HloOpcode::kTupleSelect:
break;
default:
LOG(FATAL) << "Invalid ternary instruction opcode "
<< HloOpcodeString(opcode);
}
return CreateNary(shape, opcode, {lhs, rhs, ehs});
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateVariadic(
const Shape& shape, HloOpcode opcode,
absl::Span<HloInstruction* const> operands) {
CHECK_EQ(HloOpcode::kTuple, opcode);
return CreateNary(shape, opcode, operands);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateMap(
const Shape& shape, absl::Span<HloInstruction* const> operands,
HloComputation* map_computation) {
return absl::make_unique<HloMapInstruction>(shape, operands, map_computation);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConvolve(
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
int64 feature_group_count, const Window& window,
const ConvolutionDimensionNumbers& dimension_numbers,
const PrecisionConfig& precision_config) {
return absl::make_unique<HloConvolutionInstruction>(
shape, lhs, rhs, feature_group_count, window, dimension_numbers,
precision_config);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFft(
const Shape& shape, HloInstruction* operand, FftType fft_type,
absl::Span<const int64> fft_length) {
return absl::make_unique<HloFftInstruction>(shape, operand, fft_type,
fft_length);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDot(
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
const DotDimensionNumbers& dimension_numbers,
const PrecisionConfig& precision_config) {
return absl::make_unique<HloDotInstruction>(
shape, lhs, rhs, dimension_numbers, precision_config);
}
/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateReducePrecision(const Shape& shape,
HloInstruction* operand,
const int exponent_bits,
const int mantissa_bits) {
return absl::make_unique<HloReducePrecisionInstruction>(
shape, operand, exponent_bits, mantissa_bits);
}
/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateCrossReplicaSum(
const Shape& shape, absl::Span<HloInstruction* const> operands,
HloComputation* reduce_computation,
const std::vector<ReplicaGroup>& replica_groups, absl::string_view barrier,
const absl::optional<int64>& all_reduce_id) {
return absl::make_unique<HloAllReduceInstruction>(
shape, operands, reduce_computation, replica_groups, barrier,
all_reduce_id);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAllToAll(
const Shape& shape, absl::Span<HloInstruction* const> operands,
const std::vector<ReplicaGroup>& replica_groups) {
return absl::make_unique<HloAllToAllInstruction>(shape, operands,
replica_groups);
}
/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateCollectivePermute(
const Shape& shape, HloInstruction* operand,
const std::vector<std::pair<int64, int64>>& source_target_pairs) {
return absl::make_unique<HloCollectivePermuteInstruction>(
shape, operand, source_target_pairs);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateInfeed(
const Shape& infeed_shape, HloInstruction* token_operand,
const string& config) {
return absl::make_unique<HloInfeedInstruction>(infeed_shape, token_operand,
config);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateOutfeed(
const Shape& outfeed_shape, HloInstruction* operand,
HloInstruction* token_operand, absl::string_view outfeed_config) {
return absl::make_unique<HloOutfeedInstruction>(
outfeed_shape, operand, token_operand, outfeed_config);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSend(
HloInstruction* operand, HloInstruction* token, int64 channel_id,
bool is_host_transfer) {
return absl::make_unique<HloSendInstruction>(operand, token, channel_id,
is_host_transfer);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSendDone(
HloInstruction* operand, bool is_host_transfer) {
auto send_operand = DynCast<HloSendInstruction>(operand);
CHECK(send_operand != nullptr)
<< "SendDone must take the context operand from Send";
return absl::make_unique<HloSendDoneInstruction>(send_operand,
is_host_transfer);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecv(
const Shape& shape, HloInstruction* token, int64 channel_id,
bool is_host_transfer) {
return absl::make_unique<HloRecvInstruction>(shape, token, channel_id,
is_host_transfer);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecvDone(
HloInstruction* operand, bool is_host_transfer) {
auto recv_operand = DynCast<HloRecvInstruction>(operand);
CHECK(recv_operand != nullptr)
<< "RecvDone must take the context operand from Recv";
return absl::make_unique<HloRecvDoneInstruction>(recv_operand,
is_host_transfer);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReverse(
const Shape& shape, HloInstruction* operand,
absl::Span<const int64> dimensions) {
return absl::make_unique<HloReverseInstruction>(shape, operand, dimensions);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAfterAll(
absl::Span<HloInstruction* const> operands) {
CHECK(!operands.empty());
auto instruction = absl::WrapUnique(
new HloInstruction(HloOpcode::kAfterAll, ShapeUtil::MakeTokenShape()));
for (auto operand : operands) {
instruction->AppendOperand(operand);
}
return instruction;
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateToken() {
return absl::WrapUnique(
new HloInstruction(HloOpcode::kAfterAll, ShapeUtil::MakeTokenShape()));
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateWhile(
const Shape& shape, HloComputation* condition, HloComputation* body,
HloInstruction* init) {
auto instruction =
absl::WrapUnique(new HloInstruction(HloOpcode::kWhile, shape));
instruction->AppendOperand(init);
// Body comes before condition computation in the vector.
instruction->called_computations_.push_back(body);
instruction->called_computations_.push_back(condition);
return instruction;
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConditional(
const Shape& shape, HloInstruction* pred,
HloInstruction* true_computation_arg, HloComputation* true_computation,
HloInstruction* false_computation_arg, HloComputation* false_computation) {
auto instruction =
absl::WrapUnique(new HloInstruction(HloOpcode::kConditional, shape));
instruction->AppendOperand(pred);
instruction->AppendOperand(true_computation_arg);
instruction->AppendOperand(false_computation_arg);
// In called_computations_, the index of true_computation must be 0 and that
// of false computation must be 1, as defined by kTrueComputationIndex and
// kFalseComputationIndex.
instruction->called_computations_.push_back(true_computation);
instruction->called_computations_.push_back(false_computation);
return instruction;
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSlice(
const Shape& shape, HloInstruction* operand,
absl::Span<const int64> start_indices,
absl::Span<const int64> limit_indices, absl::Span<const int64> strides) {
return absl::make_unique<HloSliceInstruction>(shape, operand, start_indices,
limit_indices, strides);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDynamicSlice(
const Shape& shape, HloInstruction* operand, HloInstruction* start_indices,
absl::Span<const int64> slice_sizes) {
return absl::make_unique<HloDynamicSliceInstruction>(
shape, operand, start_indices, slice_sizes);
}
/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateDynamicUpdateSlice(const Shape& shape,
HloInstruction* operand,
HloInstruction* update,
HloInstruction* start_indices) {
auto instruction = absl::WrapUnique(
new HloInstruction(HloOpcode::kDynamicUpdateSlice, shape));
instruction->AppendOperand(operand);
instruction->AppendOperand(update);
instruction->AppendOperand(start_indices);
return instruction;
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConcatenate(
const Shape& shape, absl::Span<HloInstruction* const> operands,
int64 dimension) {
return absl::make_unique<HloConcatenateInstruction>(shape, operands,
dimension);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConvert(
const Shape& shape, HloInstruction* operand) {
auto instruction =
absl::WrapUnique(new HloInstruction(HloOpcode::kConvert, shape));
instruction->AppendOperand(operand);
return instruction;
}
/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateBitcastConvert(const Shape& shape,
HloInstruction* operand) {
auto instruction =
absl::WrapUnique(new HloInstruction(HloOpcode::kBitcastConvert, shape));
instruction->AppendOperand(operand);
return instruction;
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduce(
const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
absl::Span<const int64> dimensions_to_reduce,
HloComputation* reduce_computation) {
auto instruction = absl::WrapUnique(new HloReduceInstruction(
shape, {operand, init_value}, dimensions_to_reduce, reduce_computation));
return std::move(instruction);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduce(
const Shape& shape, absl::Span<HloInstruction* const> operands,
absl::Span<HloInstruction* const> init_values,
absl::Span<const int64> dimensions_to_reduce,
HloComputation* reduce_computation) {
std::vector<HloInstruction*> all_args;
all_args.reserve(operands.size() * 2);
all_args.insert(all_args.end(), operands.begin(), operands.end());
all_args.insert(all_args.end(), init_values.begin(), init_values.end());
return absl::make_unique<HloReduceInstruction>(
shape, all_args, dimensions_to_reduce, reduce_computation);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduceWindow(
const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
const Window& window, HloComputation* reduce_computation) {
return absl::make_unique<HloReduceWindowInstruction>(
shape, operand, init_value, window, reduce_computation);
}
/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateBatchNormTraining(const Shape& shape,
HloInstruction* operand,
HloInstruction* scale,
HloInstruction* offset, float epsilon,
int64 feature_index) {
return absl::make_unique<HloBatchNormTrainingInstruction>(
shape, operand, scale, offset, epsilon, feature_index);
}
/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateBatchNormInference(
const Shape& shape, HloInstruction* operand, HloInstruction* scale,
HloInstruction* offset, HloInstruction* mean, HloInstruction* variance,
float epsilon, int64 feature_index) {
return absl::make_unique<HloBatchNormInferenceInstruction>(
shape, operand, scale, offset, mean, variance, epsilon, feature_index);
}
/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateBatchNormGrad(const Shape& shape, HloInstruction* operand,
HloInstruction* scale, HloInstruction* mean,
HloInstruction* variance,
HloInstruction* grad_output, float epsilon,
int64 feature_index) {
return absl::make_unique<HloBatchNormGradInstruction>(
shape, operand, scale, mean, variance, grad_output, epsilon,
feature_index);
}
/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateSelectAndScatter(
const Shape& shape, HloInstruction* operand, HloComputation* select,
const Window& window, HloInstruction* source, HloInstruction* init_value,
HloComputation* scatter) {
return absl::make_unique<HloSelectAndScatterInstruction>(
shape, operand, select, window, source, init_value, scatter);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateBroadcast(
const Shape& shape, HloInstruction* operand,
absl::Span<const int64> broadcast_dimensions) {
return absl::make_unique<HloBroadcastInstruction>(shape, operand,
broadcast_dimensions);
}
/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateBroadcastSequence(
const Shape& output_shape, HloInstruction* operand,
const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>&
adder) {
CHECK(ShapeUtil::IsScalar(operand->shape()) ||
ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(output_shape));
Shape broadcast_shape = ShapeUtil::ChangeElementType(
output_shape, operand->shape().element_type());
// Do explicit broadcast for scalar.
if (ShapeUtil::IsScalar(operand->shape())) {
auto broadcast =
HloInstruction::CreateBroadcast(broadcast_shape, operand, {});
broadcast->set_metadata(operand->metadata());
if (operand->has_sharding()) {
broadcast->set_sharding(operand->sharding());
}
return broadcast;
}
// Do explicit broadcast for degenerate broadcast.
std::vector<int64> broadcast_dimensions;
std::vector<int64> reshaped_dimensions;
for (int i = 0; i < ShapeUtil::Rank(operand->shape()); i++) {
if (operand->shape().dimensions(i) == output_shape.dimensions(i)) {
broadcast_dimensions.push_back(i);
reshaped_dimensions.push_back(operand->shape().dimensions(i));
} else {
CHECK_EQ(operand->shape().dimensions(i), 1)
<< "An explicit broadcast sequence requires the broadcasted "
"dimensions to be trivial; operand: "
<< operand->ToString() << "; output_shape: " << output_shape;
}
}
// Eliminate the size one dimensions.
HloInstruction* reshaped_operand = adder(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(operand->shape().element_type(),
reshaped_dimensions),
operand));
reshaped_operand->set_metadata(operand->metadata());
if (operand->has_sharding()) {
reshaped_operand->set_sharding(operand->sharding());
}
// Broadcast 'reshape' up to the larger size.
auto broadcast = HloInstruction::CreateBroadcast(
broadcast_shape, reshaped_operand, broadcast_dimensions);
broadcast->set_metadata(operand->metadata());
if (operand->has_sharding()) {
broadcast->set_sharding(operand->sharding());
}
return broadcast;
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreatePad(
const Shape& shape, HloInstruction* operand, HloInstruction* padding_value,
const PaddingConfig& padding_config) {
return absl::make_unique<HloPadInstruction>(shape, operand, padding_value,
padding_config);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReshape(
const Shape& shape, HloInstruction* operand) {
CHECK_EQ(ShapeUtil::ElementsIn(shape),
ShapeUtil::ElementsIn(operand->shape()))
<< "shape: " << ShapeUtil::HumanString(shape)
<< " operand: " << ShapeUtil::HumanString(operand->shape());
auto instruction =
absl::WrapUnique(new HloInstruction(HloOpcode::kReshape, shape));
instruction->AppendOperand(operand);
return instruction;
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTranspose(
const Shape& shape, HloInstruction* operand,
absl::Span<const int64> dimensions) {
return absl::make_unique<HloTransposeInstruction>(shape, operand, dimensions);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSort(
const Shape& shape, int64 dimension, HloInstruction* keys,
HloInstruction* values) {
return absl::make_unique<HloSortInstruction>(shape, dimension, keys, values);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFusion(
const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root) {
return absl::make_unique<HloFusionInstruction>(shape, fusion_kind,
fused_root);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFusion(
const Shape& shape, FusionKind fusion_kind,
absl::Span<HloInstruction* const> operands,
HloComputation* fusion_computation) {
return absl::make_unique<HloFusionInstruction>(shape, fusion_kind, operands,
fusion_computation);
}
void HloInstruction::set_single_sharding(const HloSharding& sharding) {
CHECK(!sharding.IsTuple()) << sharding;
if (ShapeUtil::IsTuple(shape())) {
set_sharding(HloSharding::Tuple(sharding.GetAsShapeTree(shape())));
} else {
set_sharding(sharding);
}
}
void HloInstruction::SetupDerivedInstruction(
HloInstruction* derived_instruction) const {
if (sharding_ != nullptr) {
derived_instruction->set_sharding(*sharding_);
} else {
derived_instruction->clear_sharding();
}
derived_instruction->set_metadata(metadata_);
}
bool HloInstruction::HasSideEffectNoRecurse() const {
switch (opcode_) {
case HloOpcode::kSend:
case HloOpcode::kSendDone:
case HloOpcode::kRecv:
case HloOpcode::kRecvDone:
case HloOpcode::kRng:
case HloOpcode::kInfeed:
case HloOpcode::kOutfeed:
case HloOpcode::kTrace:
return true;
case HloOpcode::kCrossReplicaSum:
return all_reduce_id().has_value();
default:
return false;
}
}
bool HloInstruction::HasSideEffect() const {
if (HasSideEffectNoRecurse()) {
return true;
}
// Check if any of the called computations has a side effect.
for (const auto& computation : called_computations()) {
if (computation->HasSideEffect()) {
return true;
}
}
return false;
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCall(
const Shape& shape, absl::Span<HloInstruction* const> operands,
HloComputation* computation) {
std::unique_ptr<HloInstruction> instruction =
absl::WrapUnique(new HloInstruction(HloOpcode::kCall, shape));
for (auto operand : operands) {
instruction->AppendOperand(operand);
}
instruction->called_computations_.push_back(computation);
return instruction;
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall(
const Shape& shape, absl::Span<HloInstruction* const> operands,
absl::string_view custom_call_target) {
return absl::make_unique<HloCustomCallInstruction>(shape, operands,
custom_call_target);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTuple(
absl::Span<HloInstruction* const> elements) {
std::vector<Shape> element_shapes;
for (auto element : elements) {
element_shapes.push_back(element->shape());
}
Shape tuple_shape = ShapeUtil::MakeTupleShape(element_shapes);
return CreateVariadic(tuple_shape, HloOpcode::kTuple, elements);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateGather(
const Shape& shape, HloInstruction* operand, HloInstruction* start_indices,
const GatherDimensionNumbers& gather_dim_numbers,
absl::Span<const int64> slice_sizes) {
return absl::make_unique<HloGatherInstruction>(
shape, operand, start_indices, gather_dim_numbers, slice_sizes);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateScatter(
const Shape& shape, HloInstruction* operand,
HloInstruction* scatter_indices, HloInstruction* updates,
HloComputation* update_computation,
const ScatterDimensionNumbers& scatter_dim_numbers) {
return absl::make_unique<HloScatterInstruction>(
shape, operand, scatter_indices, updates, update_computation,
scatter_dim_numbers);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDomain(
const Shape& shape, HloInstruction* operand,
std::unique_ptr<DomainMetadata> operand_side_metadata,
std::unique_ptr<DomainMetadata> user_side_metadata) {
return absl::make_unique<HloDomainInstruction>(
shape, operand, std::move(operand_side_metadata),
std::move(user_side_metadata));
}
std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
VLOG(3) << "CloneWithNewOperands:\n " << ToString();
VLOG(3) << " new operands:";
for (const HloInstruction* new_operand : new_operands) {
VLOG(3) << " %" << new_operand->name();
}
std::unique_ptr<HloInstruction> clone;
// Explicitly call the factory for the instruction type. This is more robust
// in the face of code changes than copying fields explicitly. This also
// properly sets the user fields of the operands.
switch (opcode_) {
// Ops migrated to subclasses.
// TODO(b/80131774): Remove this switch when migration is complete.
case HloOpcode::kBatchNormTraining:
case HloOpcode::kBatchNormInference:
case HloOpcode::kBatchNormGrad:
case HloOpcode::kFft:
case HloOpcode::kSend:
case HloOpcode::kSendDone:
case HloOpcode::kRecv:
case HloOpcode::kRecvDone:
case HloOpcode::kReverse:
case HloOpcode::kConcatenate:
case HloOpcode::kReduce:
case HloOpcode::kTranspose:
case HloOpcode::kBroadcast:
case HloOpcode::kMap:
case HloOpcode::kSlice:
case HloOpcode::kConstant:
case HloOpcode::kTrace:
case HloOpcode::kFusion:
case HloOpcode::kRng:
case HloOpcode::kParameter:
case HloOpcode::kGetTupleElement:
case HloOpcode::kReducePrecision:
case HloOpcode::kCrossReplicaSum:
case HloOpcode::kAllToAll:
case HloOpcode::kCollectivePermute:
case HloOpcode::kInfeed:
case HloOpcode::kOutfeed:
case HloOpcode::kConvolution:
case HloOpcode::kCustomCall:
case HloOpcode::kReduceWindow:
case HloOpcode::kSelectAndScatter:
case HloOpcode::kPad:
case HloOpcode::kDynamicSlice:
case HloOpcode::kSort:
case HloOpcode::kGather:
case HloOpcode::kScatter:
case HloOpcode::kIota:
case HloOpcode::kDot:
case HloOpcode::kDomain:
clone = CloneWithNewOperandsImpl(shape, new_operands, context);
break;
// Unary ops.
case HloOpcode::kAbs:
case HloOpcode::kRoundNearestAfz:
case HloOpcode::kBitcast:
case HloOpcode::kCeil:
case HloOpcode::kClz:
case HloOpcode::kCopy:
case HloOpcode::kCos:
case HloOpcode::kExp:
case HloOpcode::kExpm1:
case HloOpcode::kImag:
case HloOpcode::kIsFinite:
case HloOpcode::kFloor:
case HloOpcode::kLog:
case HloOpcode::kLog1p:
case HloOpcode::kNot:
case HloOpcode::kNegate:
case HloOpcode::kReal:
case HloOpcode::kSign:
case HloOpcode::kSin:
case HloOpcode::kTanh:
CHECK_EQ(new_operands.size(), 1);
clone = CreateUnary(shape, opcode_, new_operands[0]);
break;
// Binary ops.
case HloOpcode::kAdd:
case HloOpcode::kAtan2:
case HloOpcode::kComplex:
case HloOpcode::kDivide:
case HloOpcode::kMultiply:
case HloOpcode::kSubtract:
case HloOpcode::kEq:
case HloOpcode::kGe:
case HloOpcode::kGt:
case HloOpcode::kLe:
case HloOpcode::kLt:
case HloOpcode::kNe:
case HloOpcode::kMaximum:
case HloOpcode::kMinimum:
case HloOpcode::kPower:
case HloOpcode::kRemainder:
case HloOpcode::kAnd:
case HloOpcode::kOr:
case HloOpcode::kXor:
case HloOpcode::kShiftLeft:
case HloOpcode::kShiftRightArithmetic:
case HloOpcode::kShiftRightLogical:
CHECK_EQ(new_operands.size(), 2);
clone = CreateBinary(shape, opcode_, new_operands[0], new_operands[1]);
break;
// Ternary ops.
case HloOpcode::kClamp:
case HloOpcode::kSelect:
case HloOpcode::kTupleSelect:
CHECK_EQ(new_operands.size(), 3);
clone = CreateTernary(shape, opcode_, new_operands[0], new_operands[1],
new_operands[2]);
break;
// Other supported ops.
case HloOpcode::kCall:
clone = CreateCall(shape, new_operands, to_apply());
break;
case HloOpcode::kConvert:
CHECK_EQ(new_operands.size(), 1);
clone = CreateConvert(shape, new_operands[0]);
break;
case HloOpcode::kBitcastConvert:
CHECK_EQ(new_operands.size(), 1);
clone = CreateBitcastConvert(shape, new_operands[0]);
break;
case HloOpcode::kReshape:
CHECK_EQ(new_operands.size(), 1);
clone = CreateReshape(shape, new_operands[0]);
break;
case HloOpcode::kDynamicUpdateSlice:
CHECK_EQ(new_operands.size(), 3);
clone = CreateDynamicUpdateSlice(shape, new_operands[0], new_operands[1],
new_operands[2]);
break;
case HloOpcode::kTuple:
clone = CreateTuple(new_operands);
*clone->mutable_shape() = shape;
break;
case HloOpcode::kWhile:
CHECK_EQ(new_operands.size(), 1);
clone =
CreateWhile(shape, while_condition(), while_body(), new_operands[0]);
break;
case HloOpcode::kConditional:
CHECK_EQ(new_operands.size(), 3);
clone = CreateConditional(shape, new_operands[0], new_operands[1],
true_computation(), new_operands[2],
false_computation());
break;
case HloOpcode::kAfterAll:
if (new_operands.empty()) {
clone = CreateToken();
} else {
clone = CreateAfterAll(new_operands);
}
break;
}
// SetupDerivedInstruction will setup the precision_config_ field.
SetupDerivedInstruction(clone.get());
clone->set_parent(parent_);
clone->set_raw_backend_config_string(backend_config_);
if (context != nullptr) {
context->MapInstruction(this, clone.get());
clone->ReplaceCalledComputations([&](HloComputation* callee) {
return callee->parent() != context->module()
? context->module()->DeepCloneComputation(callee, context)
: callee;
});
}
return clone;
}
HloInstruction::~HloInstruction() {
// Detach from operands. An instruction may be repeated as an operand. To
// avoid calling RemoveUser twice on the same operand, check before remove.
for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) {
HloInstruction* operand = operands_[operand_num];
if (operand == nullptr) {
continue;
}
if (operand->user_set_.find(this) != operand->user_set_.end()) {
operand->RemoveUser(this);
}
operands_[operand_num] = nullptr;
}
// Update users. Set `nullptr` to the correpsonding operand slot for users.
for (auto& user : this->users()) {
for (int i = 0; i < user->operand_count(); ++i) {
if (user->operands_[i] == this) {
user->operands_[i] = nullptr;
}
}
}
}
std::unique_ptr<HloInstruction> HloInstruction::Clone(
const string& suffix, HloCloneContext* context) const {
std::unique_ptr<HloInstruction> clone =
CloneWithNewOperands(shape_, operands_, context);
if (suffix.empty()) {
clone->name_ = name();
} else {
// If an instruction is cloned multiple times avoid names like
// foo.suffix.suffix.suffix. Instead of repeating the suffix add a numeric
// suffix. Specifically, the clone of foo.suffix is named foo.suffix2, the
// clone of foo.suffix2 is named foo.suffix3 and so on.
const string dot_suffix = "." + suffix;
size_t index = name().rfind(dot_suffix);
if (index == string::npos) {
// Existing name does not include ".suffix".
clone->name_ = name() + dot_suffix;
} else {
// Existing name includes ".suffix". Determine if substring after
// ".suffix" is numeric and should be replaced with an incremented number.
string after_suffix = name().substr(index + dot_suffix.size());
if (after_suffix.empty()) {
// Existing name ends in ".suffix". New name should end in ".suffix2".
clone->name_ = name() + "2";
} else {
// If names ends with .suffix[0-9]+ then replace with a suffix with the
// numeric value incremented.
int64 numeric_suffix;
if (absl::SimpleAtoi(after_suffix, &numeric_suffix)) {
clone->name_ =
StrCat(name().substr(0, index), dot_suffix, numeric_suffix + 1);
} else {
// Substring after ".suffix" is non-numeric.
clone->name_ = name() + dot_suffix;
}
}
}
}
return clone;
}
std::pair<const HloInstruction*, ShapeIndex>
HloInstruction::LatestNonGteAncestorAndIndex() const {
const HloInstruction* hlo = this;
ShapeIndex index;
while (hlo->opcode() == HloOpcode::kGetTupleElement) {
index.push_back(hlo->tuple_index());
hlo = hlo->operand(0);
}
// We built up index in the reverse order from what we want.
std::reverse(index.begin(), index.end());
return {hlo, index};
}
const HloInstruction* HloInstruction::LatestNonGteAncestor() const {
const HloInstruction* hlo = this;
while (hlo->opcode() == HloOpcode::kGetTupleElement) {
hlo = hlo->operand(0);
}
return hlo;
}
const HloInstruction* HloInstruction::operand(int64 i) const {
return operands_[i];
}
HloInstruction* HloInstruction::mutable_operand(int64 i) {
CHECK(operands_[i] != nullptr);
return operands_[i];
}
int64 HloInstruction::operand_index(const HloInstruction* target) const {
for (int64 i = 0; i < operand_count(); ++i) {
if (target == operand(i)) {
return i;
}
}
LOG(FATAL) << "target was not an operand: " << target->ToString();
}
HloInstruction::InstructionVector HloInstruction::unique_operands() const {
InstructionVector unique;
tensorflow::gtl::FlatSet<const HloInstruction*> seen;
for (HloInstruction* operand : operands()) {
if (seen.insert(operand).second) {
unique.push_back(operand);
}
}
return unique;
}
Status HloInstruction::AddControlDependencyTo(HloInstruction* instruction) {
TF_RET_CHECK(instruction->parent() == parent());
if (std::find(control_successors_.begin(), control_successors_.end(),
instruction) == control_successors_.end()) {
control_successors_.push_back(instruction);
TF_RET_CHECK(std::find(instruction->control_predecessors_.begin(),
instruction->control_predecessors_.end(),
this) == instruction->control_predecessors_.end());
instruction->control_predecessors_.push_back(this);
}
return Status::OK();
}
Status HloInstruction::RemoveControlDependencyTo(HloInstruction* instruction) {
TF_RET_CHECK(instruction->parent() == parent());
TF_RETURN_IF_ERROR(EraseElementFromVector(&control_successors_, instruction));
TF_RETURN_IF_ERROR(
EraseElementFromVector(&instruction->control_predecessors_, this));
return Status::OK();
}
Status HloInstruction::DropAllControlDeps() {
for (auto* ctrl_succ : control_successors_) {
TF_RETURN_IF_ERROR(
EraseElementFromVector(&ctrl_succ->control_predecessors_, this));
}
for (auto* ctrl_pred : control_predecessors_) {
TF_RETURN_IF_ERROR(
EraseElementFromVector(&ctrl_pred->control_successors_, this));
}
control_successors_.clear();
control_predecessors_.clear();
return Status::OK();
}
Status HloInstruction::CopyAllControlDepsFrom(const HloInstruction* inst) {
for (auto* ctrl_pred : inst->control_predecessors()) {
TF_RETURN_IF_ERROR(ctrl_pred->AddControlDependencyTo(this));
}
for (auto* ctrl_succ : inst->control_successors()) {
TF_RETURN_IF_ERROR(this->AddControlDependencyTo(ctrl_succ));
}
return Status::OK();
}
void HloInstruction::AppendOperand(HloInstruction* operand) {
operands_.push_back(operand);
operand->AddUser(this);
}
void HloInstruction::RemoveOperandsAtAscendingIndices(
absl::Span<const int> ascending_indices) {
if (ascending_indices.empty()) {
return;
}
int next_index = 0;
int removed_count = 0;
for (int to_remove : ascending_indices) {
while (next_index < to_remove) {
operands_[next_index - removed_count] = operands_[next_index];
++next_index;
}
CHECK_LT(to_remove, operands_.size());
++removed_count;
++next_index;
}
while (next_index < operands_.size()) {
operands_[next_index - removed_count] = operands_[next_index];
++next_index;
}
CHECK_EQ(removed_count, ascending_indices.size());
operands_.resize(operands_.size() - removed_count);
}
void HloInstruction::AddUser(HloInstruction* user) {
if (!ContainsKey(user_set_, user)) {
user_set_.insert(user);
users_.push_back(user);
}
}
bool HloInstruction::HasConstantOperand() const {
for (const HloInstruction* operand : operands_) {
if (operand->IsConstant()) {
return true;
}
}
return false;
}
bool HloInstruction::IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const {
// Perform opcode specific checks.
switch (opcode()) {
// The result of these instructions only depend upon their opcode and
// operands.
case HloOpcode::kAbs:
case HloOpcode::kAtan2:
case HloOpcode::kAdd:
case HloOpcode::kBitcast:
case HloOpcode::kBitcastConvert:
case HloOpcode::kCeil:
case HloOpcode::kClamp:
case HloOpcode::kClz:
case HloOpcode::kComplex:
case HloOpcode::kConvert:
case HloOpcode::kCopy:
case HloOpcode::kCos:
case HloOpcode::kDivide:
case HloOpcode::kDynamicUpdateSlice:
case HloOpcode::kEq:
case HloOpcode::kExp:
case HloOpcode::kExpm1:
case HloOpcode::kFloor:
case HloOpcode::kGe:
case HloOpcode::kGt:
case HloOpcode::kImag:
case HloOpcode::kIsFinite:
case HloOpcode::kLe:
case HloOpcode::kLog:
case HloOpcode::kLog1p:
case HloOpcode::kAnd:
case HloOpcode::kNot:
case HloOpcode::kOr:
case HloOpcode::kXor:
case HloOpcode::kLt:
case HloOpcode::kMaximum:
case HloOpcode::kMinimum:
case HloOpcode::kMultiply:
case HloOpcode::kNe:
case HloOpcode::kNegate:
case HloOpcode::kPower:
case HloOpcode::kReal:
case HloOpcode::kRemainder:
case HloOpcode::kReshape:
case HloOpcode::kRoundNearestAfz:
case HloOpcode::kSelect:
case HloOpcode::kShiftLeft:
case HloOpcode::kShiftRightArithmetic:
case HloOpcode::kShiftRightLogical:
case HloOpcode::kSign:
case HloOpcode::kSin:
case HloOpcode::kSubtract:
case HloOpcode::kTanh:
case HloOpcode::kTuple:
case HloOpcode::kTupleSelect:
return true;
// This opcode has complex or special behavior so just return false.
case HloOpcode::kAfterAll:
return false;
// Remaining instructions with special values.
case HloOpcode::kCall:
return eq_computations(to_apply(), other.to_apply());
case HloOpcode::kConditional:
return eq_computations(true_computation(), other.true_computation()) &&
eq_computations(false_computation(), other.false_computation());
case HloOpcode::kWhile: {
if (eq_computations(while_body(), other.while_body()) &&
eq_computations(while_condition(), other.while_condition())) {
return true;
}
return false;
}
// Ops migrated to subclasses should never come to this line.
// TODO(b/80131774): Remove this switch when migration is complete.
case HloOpcode::kBatchNormTraining:
case HloOpcode::kBatchNormInference:
case HloOpcode::kBatchNormGrad:
case HloOpcode::kFft:
case HloOpcode::kSend:
case HloOpcode::kSendDone:
case HloOpcode::kRecv:
case HloOpcode::kRecvDone:
case HloOpcode::kReverse:
case HloOpcode::kConcatenate:
case HloOpcode::kReduce:
case HloOpcode::kSort:
case HloOpcode::kTranspose:
case HloOpcode::kBroadcast:
case HloOpcode::kMap:
case HloOpcode::kSlice:
case HloOpcode::kConstant:
case HloOpcode::kIota:
case HloOpcode::kTrace:
case HloOpcode::kFusion:
case HloOpcode::kRng:
case HloOpcode::kParameter:
case HloOpcode::kGetTupleElement:
case HloOpcode::kReducePrecision:
case HloOpcode::kInfeed:
case HloOpcode::kOutfeed:
case HloOpcode::kCrossReplicaSum:
case HloOpcode::kAllToAll:
case HloOpcode::kCollectivePermute:
case HloOpcode::kConvolution:
case HloOpcode::kCustomCall:
case HloOpcode::kReduceWindow:
case HloOpcode::kSelectAndScatter:
case HloOpcode::kPad:
case HloOpcode::kDynamicSlice:
case HloOpcode::kGather:
case HloOpcode::kScatter:
case HloOpcode::kDot:
case HloOpcode::kDomain:
LOG(FATAL) << "Base class impl called for opcode with subclass: "
<< opcode();
}
return false;
}
void HloInstruction::RemoveUser(HloInstruction* user) {
auto set_it = user_set_.find(user);
CHECK(set_it != user_set_.end());
user_set_.erase(set_it);
// This is linear in the number of the users, but a vector provides a stable
// iteration order and much faster traversal.
auto vec_it = std::find(users_.begin(), users_.end(), user);
CHECK(vec_it != users_.end());
users_.erase(vec_it);
}
Status HloInstruction::ReplaceUseWith(HloInstruction* user,
HloInstruction* new_producer) {
TF_RET_CHECK(
ShapeUtil::CompatibleIgnoringFpPrecision(shape(), new_producer->shape()))
<< "this shape: " << ShapeUtil::HumanString(shape())
<< ", replacement shape: "
<< ShapeUtil::HumanString(new_producer->shape());
VLOG(3) << "Replacing uses of " << name() << " in " << user->name()
<< " with " << new_producer->name();
RemoveUser(user);
TF_RET_CHECK(
std::count(user->operands_.begin(), user->operands_.end(), this) >= 0);
std::replace(user->operands_.begin(), user->operands_.end(), this,
new_producer);
new_producer->AddUser(user);
if (user->opcode() == HloOpcode::kFusion) {
TF_RETURN_IF_ERROR(
Cast<HloFusionInstruction>(user)->DeduplicateFusionOperands());
}
return Status::OK();
}
Status HloInstruction::ReplaceOperandWith(int64 operand_num,
HloInstruction* new_operand) {
TF_RET_CHECK(operand_num >= 0);
TF_RET_CHECK(operand_num < operand_count());
HloInstruction* old_operand = mutable_operand(operand_num);
if (old_operand == new_operand) {
return Status::OK();
}
TF_RET_CHECK(ShapeUtil::CompatibleIgnoringFpPrecision(old_operand->shape(),
new_operand->shape()))
<< old_operand->shape() << " is not compatible with "
<< new_operand->shape();
operands_[operand_num] = new_operand;
VLOG(3) << "Replacing operand " << operand_num << " of " << name() << " with "
<< new_operand->name() << ", was " << old_operand->name();
if (std::find(operands_.begin(), operands_.end(), old_operand) ==
operands_.end()) {
old_operand->RemoveUser(this);
}
new_operand->AddUser(this);
return Status::OK();
}
Status HloInstruction::ReplaceAllUsesWith(HloInstruction* new_producer) {
bool new_producer_is_user = false;
for (HloInstruction* user : users()) {
if (user == new_producer) {
// It's possible that new_producer is a user of this instruction as might
// be the case when replacing an instruction with a kCopy of itself. In
// this case, don't do the replacement to avoid creating a cycle in the
// graph. new_producer remains the only user of this instruction.
new_producer_is_user = true;
} else {
std::replace(user->operands_.begin(), user->operands_.end(), this,
new_producer);
new_producer->AddUser(user);
if (user->opcode() == HloOpcode::kFusion) {
TF_RETURN_IF_ERROR(
Cast<HloFusionInstruction>(user)->DeduplicateFusionOperands());
}
}
}
users_.clear();
user_set_.clear();
if (new_producer_is_user) {
AddUser(new_producer);
}
if (parent_ && parent_->root_instruction() == this) {
parent_->set_root_instruction(new_producer);
}
return Status::OK();
}
HloComputation* HloInstruction::to_apply() const {
switch (opcode_) {
case HloOpcode::kCall:
case HloOpcode::kMap:
case HloOpcode::kReduceWindow:
case HloOpcode::kReduce:
case HloOpcode::kCrossReplicaSum:
case HloOpcode::kScatter:
CHECK_EQ(called_computations_.size(), 1);
return called_computations_[0];
default:
LOG(FATAL) << "Invalid opcode for to_apply(): "
<< HloOpcodeString(opcode());
}
}
void HloInstruction::set_to_apply(HloComputation* computation) {
// Don't allow changing the computation for fused instructions so we don't
// have to recompute called_instructions for the entire fusion instruction.
CHECK(!IsFused());
switch (opcode_) {
case HloOpcode::kCall:
case HloOpcode::kMap:
case HloOpcode::kReduceWindow:
case HloOpcode::kReduce:
case HloOpcode::kCrossReplicaSum:
case HloOpcode::kScatter:
CHECK_EQ(called_computations_.size(), 1);
called_computations_[0] = computation;
break;
default:
LOG(FATAL) << "Invalid opcode for to_apply(): "
<< HloOpcodeString(opcode());
}
}
HloComputation* HloInstruction::while_condition() const {
CHECK_EQ(HloOpcode::kWhile, opcode_);
return called_computations_[kConditionComputationIndex];
}
HloComputation* HloInstruction::while_body() const {
CHECK_EQ(HloOpcode::kWhile, opcode_);
return called_computations_[kBodyComputationIndex];
}
void HloInstruction::set_while_condition(HloComputation* computation) {
// Don't allow changing the computation for fused instructions so we don't
// have to recompute called_instructions for the entire fusion instruction.
CHECK(!IsFused());
CHECK_EQ(HloOpcode::kWhile, opcode_);
called_computations_[kConditionComputationIndex] = computation;
}
void HloInstruction::set_while_body(HloComputation* computation) {
// Don't allow changing the computation for fused instructions so we don't
// have to recompute called_instructions for the entire fusion instruction.
CHECK(!IsFused());
CHECK_EQ(HloOpcode::kWhile, opcode_);
called_computations_[kBodyComputationIndex] = computation;
}
HloComputation* HloInstruction::true_computation() const {
CHECK_EQ(HloOpcode::kConditional, opcode_);
return called_computations_[kTrueComputationIndex];
}
HloComputation* HloInstruction::false_computation() const {
CHECK_EQ(HloOpcode::kConditional, opcode_);
return called_computations_[kFalseComputationIndex];
}
void HloInstruction::set_true_computation(HloComputation* true_computation) {
// Don't allow changing the computation for fused instructions so we don't
// have to recompute called_instructions for the entire fusion instruction.
CHECK(!IsFused());
CHECK_EQ(HloOpcode::kConditional, opcode_);
called_computations_[kTrueComputationIndex] = true_computation;
}
void HloInstruction::set_false_computation(HloComputation* false_computation) {
// Don't allow changing the computation for fused instructions so we don't
// have to recompute called_instructions for the entire fusion instruction.
CHECK(!IsFused());
CHECK_EQ(HloOpcode::kConditional, opcode_);
called_computations_[kFalseComputationIndex] = false_computation;
}
string HloInstruction::SignatureString() const {
string operands =
StrJoin(operands_, ", ", [](string* out, HloInstruction* operand) {
StrAppend(out, ShapeUtil::HumanString(operand->shape()));
});
return StrCat("(", operands, ") -> ", ShapeUtil::HumanString(shape()));
}
namespace {
string PrintName(const string& name, const HloPrintOptions& options) {
return StrCat(options.print_percent() ? "%" : "", name);
}
} // namespace
string HloInstruction::ToString(const HloPrintOptions& options) const {
CanonicalNameMap new_map;
return ToStringWithCanonicalNameMap(options, &new_map);
}
bool HloInstruction::IsElementwiseImpl(
const absl::optional<int64>& operand_idx) const {
switch (opcode_) {
// Unary elementwise operations.
case HloOpcode::kAbs:
case HloOpcode::kRoundNearestAfz:
case HloOpcode::kCeil:
case HloOpcode::kClz:
case HloOpcode::kConvert:
case HloOpcode::kBitcastConvert:
case HloOpcode::kCopy:
case HloOpcode::kCos:
case HloOpcode::kExp:
case HloOpcode::kExpm1:
case HloOpcode::kFloor:
case HloOpcode::kImag:
case HloOpcode::kIsFinite:
case HloOpcode::kLog:
case HloOpcode::kLog1p:
case HloOpcode::kNot:
case HloOpcode::kNegate:
case HloOpcode::kReal:
case HloOpcode::kReducePrecision:
case HloOpcode::kSign:
case HloOpcode::kSin:
case HloOpcode::kTanh:
CHECK_EQ(1, operand_count());
return true;
// Binary elementwise operations, the same as in IsElementwiseBinary().
case HloOpcode::kAdd:
case HloOpcode::kAtan2:
case HloOpcode::kComplex:
case HloOpcode::kDivide:
case HloOpcode::kEq:
case HloOpcode::kGe:
case HloOpcode::kGt:
case HloOpcode::kLe:
case HloOpcode::kLt:
case HloOpcode::kMaximum:
case HloOpcode::kMinimum:
case HloOpcode::kMultiply:
case HloOpcode::kNe:
case HloOpcode::kPower:
case HloOpcode::kRemainder:
case HloOpcode::kSubtract:
case HloOpcode::kAnd:
case HloOpcode::kOr:
case HloOpcode::kXor:
case HloOpcode::kShiftLeft:
case HloOpcode::kShiftRightArithmetic:
case HloOpcode::kShiftRightLogical:
CHECK_EQ(2, operand_count());
return true;
// Ternary elementwise operations.
case HloOpcode::kSelect:
case HloOpcode::kClamp:
return true;
case HloOpcode::kDynamicUpdateSlice:
return operand_idx.has_value() && operand_idx.value() == 0;
default:
return false;
}
}
bool HloInstruction::IsCrossModuleAllReduce() const {
return opcode() == HloOpcode::kCrossReplicaSum && all_reduce_id();
}
string HloInstruction::ToStringWithCanonicalNameMap(
const HloPrintOptions& options,
CanonicalNameMap* canonical_name_map) const {
string result = "";
// Logic to print the instruction name (e.g. "%foo = ").
if (options.canonicalize_instruction_names()) {
if (options.is_in_nested_computation()) {
// If we are canonicalizing instruction names and this is a top-level
// HloInstruction::ToString() call, don't print an instruction name.
StrAppend(&result,
PrintName(canonical_name_map->LookupOrInsert(name()), options),
" = ");
}
} else {
StrAppend(&result, PrintName(name(), options), " = ");
}
// Print opcode, operand(s) and shape.
StrAppend(&result, ShapeUtil::HumanStringWithLayout(shape()), " ",
HloOpcodeString(opcode()), "(",
OperandsToStringWithCanonicalNameMap(options, canonical_name_map),
")");
// Print additional attributes. If an instruction contains a subcomputation,
// the subcomputation is also printed here.
for (const string& extra : ExtraAttributesToString(options)) {
StrAppend(&result, ", ", extra);
}
if (options.print_metadata() &&
(!metadata_.op_type().empty() || !metadata_.op_name().empty() ||
!metadata_.source_file().empty())) {
StrAppend(&result, ", metadata={", xla::OpMetadataToString(metadata_), "}");
}
if (options.print_backend_config() && !backend_config_.empty()) {
StrAppend(&result, ", backend_config=\"", CEscape(backend_config_), "\"");
}
return result;
}
string HloInstruction::OperandsToString(const HloPrintOptions& options) const {
CanonicalNameMap new_map;
return OperandsToStringWithCanonicalNameMap(options, &new_map);
}
string HloInstruction::OperandsToStringWithCanonicalNameMap(
const HloPrintOptions& options,
CanonicalNameMap* canonical_name_map) const {
string operands;
absl::Span<HloInstruction* const> slice(operands_);
const int64 kMaxOperandsToShowIfCompact = 4;
if (options.compact_operands() &&
slice.size() > kMaxOperandsToShowIfCompact) {
slice.remove_suffix(slice.size() - kMaxOperandsToShowIfCompact);
}
operands = StrJoin(slice, ", ", [&](string* out, HloInstruction* operand) {
// If operand is already been deleted, put `null` to the string output.
if (operand == nullptr) {
StrAppend(out, "null ");
return;
}
std::vector<string> str;
if (options.print_operand_shape()) {
str.push_back(ShapeUtil::HumanStringWithLayout(operand->shape()));
}
// In a top-level HloInstruction::ToString() call, the operand name is not
// part of the canonical string.
if (options.canonicalize_instruction_names() &&
options.is_in_nested_computation()) {
str.push_back(PrintName(
canonical_name_map->LookupOrInsert(operand->name()), options));
} else if (!options.compact_operands()) {
str.push_back(PrintName(operand->name(), options));
}
StrAppend(out, StrJoin(str, " "));
});
const int64 remaining = operands_.size() - slice.size();
if (slice.size() != operands_.size()) {
StrAppend(&operands, ", ...(+", remaining, ")");
}
return operands;
}
std::vector<string> HloInstruction::ExtraAttributesToString(
const HloPrintOptions& options) const {
std::vector<string> extra = ExtraAttributesToStringImpl(options);
if (options.print_subcomputation_mode() ==
HloPrintOptions::PrintSubcomputationMode::kNameOnly) {
if (opcode() == HloOpcode::kWhile) {
extra.push_back(
StrCat("condition=", PrintName(while_condition()->name(), options)));
extra.push_back(
StrCat("body=", PrintName(while_body()->name(), options)));
} else if (opcode() == HloOpcode::kSelectAndScatter) {
extra.push_back(StrCat("select=", PrintName(select()->name(), options)));
extra.push_back(
StrCat("scatter=", PrintName(scatter()->name(), options)));
} else if (opcode() == HloOpcode::kConditional) {
extra.push_back(StrCat("true_computation=",
PrintName(true_computation()->name(), options)));
extra.push_back(StrCat("false_computation=",
PrintName(false_computation()->name(), options)));
} else if (opcode() == HloOpcode::kCall || opcode() == HloOpcode::kMap ||
opcode() == HloOpcode::kReduceWindow ||
opcode() == HloOpcode::kReduce ||
opcode() == HloOpcode::kCrossReplicaSum ||
opcode() == HloOpcode::kScatter) {
extra.push_back(
StrCat("to_apply=", PrintName(to_apply()->name(), options)));
} else if (!called_computations().empty()) {
extra.push_back(StrCat(
"calls=",
StrJoin(called_computations(), ", ",
[&](string* out, const HloComputation* computation) {
StrAppend(out, PrintName(computation->name(), options));
})));
}
} else if (options.print_subcomputation_mode() ==
HloPrintOptions::PrintSubcomputationMode::kFullBodies) {
HloPrintOptions new_options = options;
new_options.set_is_in_nested_computation(true);
switch (opcode()) {
case HloOpcode::kWhile:
extra.push_back(
StrCat("condition=\n", while_condition()->ToString(new_options)));
extra.push_back(StrCat("body=\n", while_body()->ToString(new_options)));
break;
case HloOpcode::kSelectAndScatter:
extra.push_back(StrCat("select=\n", select()->ToString(new_options)));
extra.push_back(StrCat("scatter=\n", scatter()->ToString(new_options)));
break;
case HloOpcode::kConditional:
extra.push_back(StrCat("true_computation=\n",
true_computation()->ToString(new_options)));
extra.push_back(StrCat("false_computation=\n",
false_computation()->ToString(new_options)));
break;
case HloOpcode::kCall:
case HloOpcode::kMap:
case HloOpcode::kReduceWindow:
case HloOpcode::kReduce:
case HloOpcode::kCrossReplicaSum:
case HloOpcode::kScatter:
extra.push_back(
StrCat("to_apply=\n", to_apply()->ToString(new_options)));
break;
default:
if (!called_computations().empty()) {
extra.push_back(StrCat(
"calls=\n",
StrJoin(called_computations(), ", ",
[&](string* out, const HloComputation* computation) {
StrAppend(out, computation->ToString(new_options));
})));
}
break;
}
}
if (has_sharding()) {
extra.push_back(StrCat("sharding=", sharding().ToString()));
}
if (options.print_control_dependencies() && !control_predecessors_.empty()) {
extra.push_back(StrCat("control-predecessors={",
StrJoin(control_predecessors_, ", ",
[&](string* out, HloInstruction* pre) {
StrAppend(out,
PrintName(pre->name(), options));
}),
"}"));
}
return extra;
}
string HloInstruction::ToShortString() const {
return StrCat("%", name(), " = ", HloOpcodeString(opcode()), "(",
StrJoin(operands_, ", ",
[](string* out, HloInstruction* operand) {
StrAppend(out, "%", operand->name());
}),
")");
}
HloInstructionProto HloInstruction::ToProto() const {
HloInstructionProto proto;
CHECK(unique_id_ != -1)
<< "This instruction does not have a valid id. Please make sure the "
"instruction is inside a module before dumping it.";
proto.set_id(unique_id_);
proto.set_name(name_);
proto.set_opcode(HloOpcodeString(opcode_));
*proto.mutable_shape() = shape_;
for (const HloInstruction* operand : operands_) {
proto.add_operand_ids(operand->unique_id());
}
for (const HloInstruction* control : control_predecessors_) {
proto.add_control_predecessor_ids(control->unique_id());
}
*proto.mutable_metadata() = metadata_;
proto.set_backend_config(backend_config_);
if (opcode() != HloOpcode::kFusion) {
for (const HloComputation* computation : called_computations_) {
proto.add_called_computation_ids(computation->unique_id());
}
}
if (has_sharding()) {
*proto.mutable_sharding() = sharding().ToProto();
}
return proto;
}
string HloInstruction::ToCategory() const {
if (opcode() == HloOpcode::kTranspose || opcode() == HloOpcode::kCopy ||
opcode() == HloOpcode::kReshape) {
return "data formatting";
}
if (IsElementwise()) {
return "non-fusion elementwise";
}
return HloOpcodeString(opcode());
}
HloInstruction* HloInstruction::tracing() const { return trace_instruction_; }
void HloInstruction::set_tracing(HloInstruction* trace_instruction) {
trace_instruction_ = trace_instruction;
}
bool HloInstruction::IsFused() const { return parent_->IsFusionComputation(); }
bool HloInstruction::IsFusible() const {
// Instructions which are traced should not be fused.
if (tracing()) {
return false;
}
// Some kinds of instructions don't make sense to fuse.
switch (opcode_) {
case HloOpcode::kDomain:
case HloOpcode::kParameter:
return false;
// Side effecting instrutions cannot be fused.
default:
return !HasSideEffect();
}
}
HloInstruction::HloInstruction(HloOpcode opcode, const Shape& shape)
: unique_id_(-1),
opcode_(opcode),
shape_(shape),
name_(HloOpcodeString(opcode)) {
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape_));
}
template <typename HloInstructionPtr>
Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
switch (opcode_) {
case HloOpcode::kAbs:
return visitor->HandleAbs(this);
case HloOpcode::kAtan2:
return visitor->HandleAtan2(this);
case HloOpcode::kRoundNearestAfz:
return visitor->HandleRound(this);
case HloOpcode::kBatchNormTraining:
return visitor->HandleBatchNormTraining(this);
case HloOpcode::kBatchNormInference:
return visitor->HandleBatchNormInference(this);
case HloOpcode::kBatchNormGrad:
return visitor->HandleBatchNormGrad(this);
case HloOpcode::kSign:
return visitor->HandleSign(this);
case HloOpcode::kConstant:
return visitor->HandleConstant(this);
case HloOpcode::kGetTupleElement:
return visitor->HandleGetTupleElement(this);
case HloOpcode::kParameter:
return visitor->HandleParameter(this);
case HloOpcode::kEq:
case HloOpcode::kGe:
case HloOpcode::kGt:
case HloOpcode::kLe:
case HloOpcode::kLt:
case HloOpcode::kNe:
return visitor->HandleCompare(this);
case HloOpcode::kComplex:
return visitor->HandleComplex(this);
case HloOpcode::kAdd:
return visitor->HandleAdd(this);
case HloOpcode::kDivide:
return visitor->HandleDivide(this);
case HloOpcode::kSubtract:
return visitor->HandleSubtract(this);
case HloOpcode::kMaximum:
return visitor->HandleMaximum(this);
case HloOpcode::kMinimum:
return visitor->HandleMinimum(this);
case HloOpcode::kAnd:
return visitor->HandleAnd(this);
case HloOpcode::kOr:
return visitor->HandleOr(this);
case HloOpcode::kXor:
return visitor->HandleXor(this);
case HloOpcode::kShiftLeft:
return visitor->HandleShiftLeft(this);
case HloOpcode::kShiftRightArithmetic:
return visitor->HandleShiftRightArithmetic(this);
case HloOpcode::kShiftRightLogical:
return visitor->HandleShiftRightLogical(this);
case HloOpcode::kConcatenate:
return visitor->HandleConcatenate(this);
case HloOpcode::kConvert:
return visitor->HandleConvert(this);
case HloOpcode::kBitcastConvert:
return visitor->HandleBitcastConvert(this);
case HloOpcode::kCopy:
return visitor->HandleCopy(this);
case HloOpcode::kMultiply:
return visitor->HandleMultiply(this);
case HloOpcode::kDot:
return visitor->HandleDot(this);
case HloOpcode::kPower:
return visitor->HandlePower(this);
case HloOpcode::kRemainder:
return visitor->HandleRemainder(this);
case HloOpcode::kSelect:
return visitor->HandleSelect(this);
case HloOpcode::kTupleSelect:
return visitor->HandleTupleSelect(this);
case HloOpcode::kConvolution:
return visitor->HandleConvolution(this);
case HloOpcode::kFft:
return visitor->HandleFft(this);
case HloOpcode::kCrossReplicaSum:
return visitor->HandleCrossReplicaSum(this);
case HloOpcode::kAllToAll:
return visitor->HandleAllToAll(this);
case HloOpcode::kCollectivePermute:
return visitor->HandleCollectivePermute(this);
case HloOpcode::kTuple:
return visitor->HandleTuple(this);
case HloOpcode::kMap:
return visitor->HandleMap(this);
case HloOpcode::kClamp:
return visitor->HandleClamp(this);
case HloOpcode::kReduce:
return visitor->HandleReduce(this);
case HloOpcode::kReduceWindow:
return visitor->HandleReduceWindow(this);
case HloOpcode::kSelectAndScatter:
return visitor->HandleSelectAndScatter(this);
case HloOpcode::kNegate:
return visitor->HandleNegate(this);
case HloOpcode::kExp:
return visitor->HandleExp(this);
case HloOpcode::kExpm1:
return visitor->HandleExpm1(this);
case HloOpcode::kFloor:
return visitor->HandleFloor(this);
case HloOpcode::kCeil:
return visitor->HandleCeil(this);
case HloOpcode::kClz:
return visitor->HandleClz(this);
case HloOpcode::kLog:
return visitor->HandleLog(this);
case HloOpcode::kLog1p:
return visitor->HandleLog1p(this);
case HloOpcode::kTanh:
return visitor->HandleTanh(this);
case HloOpcode::kCos:
return visitor->HandleCos(this);
case HloOpcode::kSin:
return visitor->HandleSin(this);
case HloOpcode::kReal:
return visitor->HandleReal(this);
case HloOpcode::kImag:
return visitor->HandleImag(this);
case HloOpcode::kIsFinite:
return visitor->HandleIsFinite(this);
case HloOpcode::kNot:
return visitor->HandleNot(this);
case HloOpcode::kBitcast:
return visitor->HandleBitcast(this);
case HloOpcode::kBroadcast:
return visitor->HandleBroadcast(this);
case HloOpcode::kPad:
return visitor->HandlePad(this);
case HloOpcode::kReshape:
return visitor->HandleReshape(this);
case HloOpcode::kTranspose:
return visitor->HandleTranspose(this);
case HloOpcode::kReverse:
return visitor->HandleReverse(this);
case HloOpcode::kReducePrecision:
return visitor->HandleReducePrecision(this);
case HloOpcode::kSlice:
return visitor->HandleSlice(this);
case HloOpcode::kDynamicSlice:
return visitor->HandleDynamicSlice(this);
case HloOpcode::kDynamicUpdateSlice:
return visitor->HandleDynamicUpdateSlice(this);
case HloOpcode::kSort:
return visitor->HandleSort(this);
case HloOpcode::kInfeed:
return visitor->HandleInfeed(this);
case HloOpcode::kOutfeed:
return visitor->HandleOutfeed(this);
case HloOpcode::kRng:
return visitor->HandleRng(this);
case HloOpcode::kWhile:
return visitor->HandleWhile(this);
case HloOpcode::kFusion:
return visitor->HandleFusion(this);
case HloOpcode::kCall:
return visitor->HandleCall(this);
case HloOpcode::kConditional:
return visitor->HandleConditional(this);
case HloOpcode::kCustomCall:
return visitor->HandleCustomCall(this);
case HloOpcode::kRecv:
return visitor->HandleRecv(this);
case HloOpcode::kRecvDone:
return visitor->HandleRecvDone(this);
case HloOpcode::kSend:
return visitor->HandleSend(this);
case HloOpcode::kSendDone:
return visitor->HandleSendDone(this);
case HloOpcode::kGather:
return visitor->HandleGather(this);
case HloOpcode::kScatter:
return visitor->HandleScatter(this);
case HloOpcode::kDomain:
return visitor->HandleDomain(this);
case HloOpcode::kAfterAll:
return visitor->HandleAfterAll(this);
case HloOpcode::kIota:
return visitor->HandleIota(this);
// These opcodes are not handled here.
case HloOpcode::kTrace:
break;
}
return InternalError(
"Unhandled HloOpcode for DfsHloVisitor: %s. This should not happen - "
"please file a bug for XLA.",
HloOpcodeString(opcode_));
}
// Explicit instantiations.
template Status HloInstruction::Visit(DfsHloVisitor* visitor);
template Status HloInstruction::Visit(ConstDfsHloVisitor* visitor);
using DFSStack = absl::InlinedVector<std::pair<int, HloInstruction*>, 16>;
// Push "child" onto the dfs_stack if not already visited. Returns false if a
// cycle was detected, and true otherwise.
template <typename Visitor>
inline bool PushDFSChild(Visitor* visitor, DFSStack* dfs_stack,
HloInstruction* child) {
CHECK(child != nullptr);
const int id = child->unique_id();
CHECK_GE(id, 0) << "instruction may not have a parent computation";
switch (visitor->GetVisitState(id)) {
case Visitor::kVisiting:
return false;
case Visitor::kVisited:
// Nothing to do
return true;
case Visitor::kNotVisited:
dfs_stack->push_back(std::make_pair(id, child));
return true;
}
}
using InternalCompareFunction =
std::function<bool(std::pair<int, const HloInstruction*>,
std::pair<int, const HloInstruction*>)>;
template <typename Visitor>
static Status PostOrderDFS(HloInstruction* root, Visitor* visitor,
const InternalCompareFunction* operand_order,
bool ignore_control_predecessors) {
visitor->ReserveVisitStates(root->GetModule()->NumUniqueInstructionIds());
// dfs_stack holds pairs of <HloInstruction*->unique_id(), HloInstruction*>.
//
// We need to keep track of both the id and the instruction because
// instructions can get deleted while they are on the stack, so we
// can't always use the (potentially dead) instruction object to grab
// its id.
DFSStack dfs_stack;
dfs_stack.emplace_back(root->unique_id(), root);
do {
DCHECK(!dfs_stack.empty());
int current_id = dfs_stack.back().first;
HloInstruction* current_node = dfs_stack.back().second;
CHECK_GE(current_id, 0) << current_id << ": " << current_node
<< ": instruction may not have parent computation";
typename Visitor::VisitState visit_state =
visitor->GetVisitState(current_id);
if (visit_state == Visitor::kVisited) {
dfs_stack.pop_back();
VLOG(3) << "Not visiting HLO %" << current_node->name()
<< " as it was already visited.";
continue;
}
if (visit_state == Visitor::kVisiting) {
dfs_stack.pop_back();
TF_RETURN_IF_ERROR(visitor->Preprocess(current_node));
VLOG(2) << "Visiting HLO %" << current_node->name();
TF_RETURN_IF_ERROR(current_node->Visit(visitor));
visitor->SetVisitState(current_id, Visitor::kVisited);
TF_RETURN_IF_ERROR(visitor->Postprocess(current_node));
continue;
}
visitor->SetVisitState(current_id, Visitor::kVisiting);
const size_t old_dfs_stack_size = dfs_stack.size();
for (HloInstruction* child : current_node->operands()) {
if (!TF_PREDICT_TRUE(PushDFSChild(visitor, &dfs_stack, child))) {
return FailedPrecondition(
"A cycle is detected while visiting instruction %s",
current_node->ToString());
}
}
if (!ignore_control_predecessors) {
for (HloInstruction* child : current_node->control_predecessors()) {
if (!TF_PREDICT_TRUE(PushDFSChild(visitor, &dfs_stack, child))) {
return FailedPrecondition(
"A cycle is detected while visiting instruction %s",
current_node->ToString());
}
}
}
if (operand_order != nullptr) {
std::sort(dfs_stack.begin() + old_dfs_stack_size, dfs_stack.end(),
*operand_order);
}
// This makes the traversal order the same as what you'd expect
// out of a recursive algorithm.
std::reverse(dfs_stack.begin() + old_dfs_stack_size, dfs_stack.end());
} while (!dfs_stack.empty());
return Status::OK();
}
template <typename HloInstructionPtr>
Status HloInstruction::Accept(DfsHloVisitorBase<HloInstructionPtr>* visitor,
bool call_finish_visit,
bool ignore_control_predecessors) {
VLOG(3) << "HloInstruction::Accept(%" << name() << ")";
TF_RETURN_IF_ERROR(
PostOrderDFS(this, visitor, nullptr, ignore_control_predecessors));
if (call_finish_visit) {
TF_RETURN_IF_ERROR(visitor->FinishVisit(this));
}
return Status::OK();
}
// Explicit instantiations.
template Status HloInstruction::Accept(DfsHloVisitor*, bool, bool);
template Status HloInstruction::Accept(ConstDfsHloVisitor*, bool, bool);
Status HloInstruction::AcceptWithOperandOrder(
DfsHloVisitor* visitor, const CompareFunction& operand_order,
bool call_finish_visit) {
VLOG(2) << "HloInstruction::AcceptWithOperandOrder(%" << name() << ")";
InternalCompareFunction func = [&operand_order](
std::pair<int, const HloInstruction*> a,
std::pair<int, const HloInstruction*> b) {
// Call the client's comparison function on the actual HloInstruction*
// objects (ignoring the internal ids we also have in our stack entries)
return operand_order(a.second, b.second);
};
TF_RETURN_IF_ERROR(PostOrderDFS(this, visitor, &func,
/*ignore_control_predecessors=*/false));
if (call_finish_visit) {
VLOG(3) << "HloInstruction::AcceptWithOperandOrder BEFORE FINISH VISIT";
TF_RETURN_IF_ERROR(visitor->FinishVisit(this));
VLOG(3) << "HloInstruction::AcceptWithOperandOrder AFTER FINISH VISIT";
}
VLOG(2) << "HloInstruction::AcceptWithOperandOrder EXIT";
return Status::OK();
}
namespace {
// Returns true if the given order is a topological sort of the instructions
// it contains.
bool OrderIsTopologicalSort(const std::vector<const HloInstruction*>& order) {
// Create a map from instruction to its position in 'order'.
std::unordered_map<const HloInstruction*, int> order_position;
for (int i = 0; i < order.size(); i++) {
if (!order_position.insert({order[i], i}).second) {
// Instruction order[i] is duplicated in the order.
return false;
}
}
// Verify that the operand of each instruction in the order is also in the
// order *and* the operand's position is earlier (defs are before uses for
// all ops).
for (auto* instruction : order) {
for (auto* operand : instruction->operands()) {
if (!ContainsKey(order_position, operand) ||
order_position.at(operand) >= order_position.at(instruction)) {
return false;
}
}
}
return true;
}
} // namespace
Status HloInstruction::Accept(
const std::function<Status(HloInstruction*)>& visitor_func) {
FunctionVisitor visitor(visitor_func);
return this->Accept(&visitor);
}
Status HloInstruction::Accept(
const std::function<Status(const HloInstruction*)>& visitor_func) const {
ConstFunctionVisitor visitor(visitor_func);
return this->Accept(&visitor);
}
Status HloInstruction::AcceptOrdered(
DfsHloVisitor* visitor, const std::vector<const HloInstruction*>& order) {
VLOG(2) << "HloInstruction::AcceptOrdered(%" << name() << ")";
TF_RET_CHECK(OrderIsTopologicalSort(order));
// Compute the predecessors of this instruction.
std::unordered_set<const HloInstruction*> predecessors;
TF_RETURN_IF_ERROR(this->Accept([&predecessors](HloInstruction* instruction) {
predecessors.insert(instruction);
return Status::OK();
}));
for (auto* const_instruction : order) {
if (!ContainsKey(predecessors, const_instruction)) {
// Instruction is not a predecessors of 'this'.
continue;
}
// The visitor can mark instructions as visited to skip particular
// instructions.
if (visitor->DidVisit(*const_instruction)) {
VLOG(3) << "Not visiting HLO %" << const_instruction->name()
<< " as it was already visited.";
continue;
}
// TODO(b/78350259): Eliminate const laundering.
HloInstruction* instruction =
const_cast<HloInstruction*>(const_instruction);
TF_RETURN_IF_ERROR(visitor->Preprocess(instruction));
VLOG(2) << "Visiting HLO %" << instruction->name();
TF_RETURN_IF_ERROR(instruction->Visit(visitor));
visitor->SetVisited(*instruction);
TF_RETURN_IF_ERROR(visitor->Postprocess(instruction));
}
return visitor->FinishVisit(this);
}
const Shape& HloInstruction::shape() const {
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape_));
return shape_;
}
std::vector<int64> HloInstruction::OperandIndices(
const HloInstruction* operand) const {
std::vector<int64> result;
for (int64 i = 0; i < operand_count(); ++i) {
if (this->operand(i) == operand) {
result.push_back(i);
}
}
return result;
}
bool HloInstruction::IsElementwiseBinary() const {
return IsElementwise() && operand_count() == 2;
}
bool HloInstruction::IsElementwise() const {
return IsElementwiseImpl(absl::nullopt);
}
bool HloInstruction::ImplicitlyBroadcastsOperand(int64 operand_idx) const {
CHECK(IsElementwise());
return !ShapeUtil::SameDimensions(shape(), operand(operand_idx)->shape());
}
bool HloInstruction::IsElementwiseOnOperand(int64 operand_idx) const {
return IsElementwiseImpl(operand_idx);
}
// A helper class for memoized, recursive computation of HloOpcode::kFusion
// in HloInstruction::OperandElementUse below.
class HloInstruction::FusionReusesParamElements {
public:
using UseKind = HloInstruction::UseKind;
// We could rather iterate backwards through fused_instructions_ here, as it
// is in reverse postorder, and compute whether each fused instruction reuses
// the value of this parameter, which would save stack space but not allow us
// to finish early if we find a reuse.
static UseKind Compute(int64 i, const HloInstruction& hlo) {
tensorflow::gtl::FlatMap<const HloInstruction*, UseKind> memoization_cache;
return ComputeInternal(i, hlo, &memoization_cache);
}
private:
static UseKind ComputeInternal(
int64 i, const HloInstruction& hlo,
tensorflow::gtl::FlatMap<const HloInstruction*, UseKind>* cache) {
if (auto hlo_param = DynCast<HloParameterInstruction>(&hlo)) {
if (hlo_param->parameter_number() == i) {
return UseKind::kUse;
}
}
auto p = cache->emplace(&hlo, UseKind{});
auto value_it = p.first;
const bool key_is_new = p.second;
if (key_is_new) {
for (int64 j = 0; j < hlo.operands_.size(); ++j) {
UseKind old_val = value_it->second;
// The next operation invalidates iterators.
UseKind new_val =
Plus(old_val, std::min(hlo.OperandElementUse(j),
ComputeInternal(i, *hlo.operand(j), cache)));
// Re-acquire the iterator. We could work harder to do this only if
// absolutely necessary, but this code is not hot enough to warrant
// that.
value_it = cache->find(&hlo);
value_it->second = new_val;
}
}
return value_it->second;
}
// Fold operation for UseKinds.
static UseKind Plus(UseKind a, UseKind b) {
if (a == UseKind::kNoUse) {
return b;
} else if (b == UseKind::kNoUse) {
return a;
} else if (a == UseKind::kReuse || b == UseKind::kReuse) {
return UseKind::kReuse;
} else if (a == UseKind::kUsePermutingElements ||
b == UseKind::kUsePermutingElements) {
return UseKind::kReuse;
} else {
CHECK(a == UseKind::kUse && b == UseKind::kUse);
return UseKind::kUse;
}
}
};
HloInstruction::UseKind HloInstruction::OperandElementUse(int64 i) const {
switch (opcode_) {
case HloOpcode::kBitcast:
case HloOpcode::kConcatenate:
case HloOpcode::kReshape:
case HloOpcode::kReverse:
case HloOpcode::kSlice:
case HloOpcode::kTranspose:
return UseKind::kUsePermutingElements;
case HloOpcode::kPad:
// Pad reuses the padding value but not the padded array elements.
return i > 0 ? UseKind::kReuse : UseKind::kUsePermutingElements;
case HloOpcode::kReduce:
// Reduce reuses the init values but not the operand array elements.
return i >= Cast<HloReduceInstruction>(this)->input_count()
? UseKind::kReuse
: UseKind::kUsePermutingElements;
case HloOpcode::kFusion:
// Uses the memoizing, recursive computation defined above.
return FusionReusesParamElements::Compute(i, *fused_expression_root());
case HloOpcode::kDot:
// Dot operations with inputs [A,B] * [B,1] do not re-use
// elements on their left operand.
// Dot operations with inputs [1,A] * [A,B] do not re-use
// elements on their right operand.
if (shape().dimensions_size() == 2) {
if ((i == 0 && shape().dimensions(1) == 1) ||
(i == 1 && shape().dimensions(0) == 1)) {
return UseKind::kUse;
}
}
return UseKind::kReuse;
case HloOpcode::kDynamicUpdateSlice:
// Dynamic-update-slice reuses only operand 2 (start_indices).
if (i == 0 || i == 1) {
return UseKind::kUse;
}
return UseKind::kReuse;
default:
return IsElementwise() && !ImplicitlyBroadcastsOperand(i)
? UseKind::kUse
: UseKind::kReuse;
}
}
std::tuple<bool, std::vector<int64>, std::vector<int64>>
HloInstruction::ReshapeMerelyInsertsOrDeletes1SizedDimensions() const {
if (HloOpcode::kReshape != opcode_) {
return std::make_tuple(false, std::vector<int64>(), std::vector<int64>());
}
return ShapeUtil::InsertedOrDeleted1SizedDimensions(operand(0)->shape_,
shape_);
}
string ToString(HloInstruction::FusionKind kind) {
switch (kind) {
case HloInstruction::FusionKind::kLoop:
return "kLoop";
case HloInstruction::FusionKind::kInput:
return "kInput";
case HloInstruction::FusionKind::kOutput:
return "kOutput";
case HloInstruction::FusionKind::kCustom:
return "kCustom";
}
}
StatusOr<HloInstruction::FusionKind> StringToFusionKind(
const string& kind_name) {
if (kind_name == "kLoop") {
return HloInstruction::FusionKind::kLoop;
}
if (kind_name == "kInput") {
return HloInstruction::FusionKind::kInput;
}
if (kind_name == "kOutput") {
return HloInstruction::FusionKind::kOutput;
}
if (kind_name == "kCustom") {
return HloInstruction::FusionKind::kCustom;
}
return InvalidArgument("Unknown fusion kind: %s", kind_name);
}
string PaddingConfigToString(const PaddingConfig& padding) {
bool has_interior_padding =
std::any_of(padding.dimensions().begin(), padding.dimensions().end(),
[](const PaddingConfig::PaddingConfigDimension& dim) {
return dim.interior_padding() != 0;
});
return StrJoin(
padding.dimensions(), "x",
[&](string* out, const PaddingConfig::PaddingConfigDimension& dim) {
StrAppend(
out, dim.edge_padding_low(), "_", dim.edge_padding_high(),
has_interior_padding ? StrCat("_", dim.interior_padding()) : "");
});
}
string OpMetadataToString(const OpMetadata& metadata) {
std::vector<string> result;
if (!metadata.op_type().empty()) {
result.push_back(StrCat("op_type=\"", CEscape(metadata.op_type()), "\""));
}
if (!metadata.op_name().empty()) {
result.push_back(StrCat("op_name=\"", CEscape(metadata.op_name()), "\""));
}
if (!metadata.source_file().empty()) {
result.push_back(
StrCat("source_file=\"", CEscape(metadata.source_file()), "\""));
}
if (metadata.source_line() != 0) {
result.push_back(StrCat("source_line=", metadata.source_line()));
}
return StrJoin(result, " ");
}
string RandomDistributionToString(const RandomDistribution& distribution) {
return absl::AsciiStrToLower(RandomDistribution_Name(distribution));
}
string PrecisionToString(const PrecisionConfig::Precision& precision) {
return absl::AsciiStrToLower(PrecisionConfig::Precision_Name(precision));
}
string ConvolutionDimensionNumbersToString(
const ConvolutionDimensionNumbers& dnums) {
// lhs_dims[i] is the symbol of the logical dimension i for the lhs
// operand. E.g. if batch has dimension number 2, then lhs_dims[2] == "b".
std::vector<string> lhs_dims(2 + dnums.input_spatial_dimensions().size());
lhs_dims[dnums.input_batch_dimension()] = 'b';
lhs_dims[dnums.input_feature_dimension()] = 'f';
for (int64 i = 0; i < dnums.input_spatial_dimensions().size(); ++i) {
lhs_dims[dnums.input_spatial_dimensions(i)] = StrCat(i);
}
std::vector<string> rhs_dims(2 + dnums.kernel_spatial_dimensions().size());
rhs_dims[dnums.kernel_input_feature_dimension()] = "i";
rhs_dims[dnums.kernel_output_feature_dimension()] = "o";
for (int64 i = 0; i < dnums.kernel_spatial_dimensions().size(); ++i) {
rhs_dims[dnums.kernel_spatial_dimensions(i)] = StrCat(i);
}
std::vector<string> output_dims(2 + dnums.output_spatial_dimensions().size());
output_dims[dnums.output_batch_dimension()] = 'b';
output_dims[dnums.output_feature_dimension()] = 'f';
for (int64 i = 0; i < dnums.output_spatial_dimensions().size(); ++i) {
output_dims[dnums.output_spatial_dimensions(i)] = StrCat(i);
}
return StrCat(StrJoin(lhs_dims, ""), "_", StrJoin(rhs_dims, ""), "->",
StrJoin(output_dims, ""));
}
StatusOr<RandomDistribution> StringToRandomDistribution(const string& name) {
static std::unordered_map<string, RandomDistribution>* map = [] {
static auto* map = new std::unordered_map<string, RandomDistribution>;
for (int i = 0; i < RandomDistribution_ARRAYSIZE; i++) {
if (RandomDistribution_IsValid(i)) {
auto value = static_cast<RandomDistribution>(i);
(*map)[RandomDistributionToString(value)] = value;
}
}
return map;
}();
auto found = map->find(absl::AsciiStrToLower(name));
if (found == map->end()) {
return InvalidArgument("Unknown distribution");
}
return found->second;
}
StatusOr<PrecisionConfig::Precision> StringToPrecision(const string& name) {
static std::unordered_map<string, PrecisionConfig::Precision>* map = [] {
static auto* map =
new std::unordered_map<string, PrecisionConfig::Precision>;
for (int i = 0; i < PrecisionConfig::Precision_ARRAYSIZE; i++) {
if (PrecisionConfig::Precision_IsValid(i)) {
auto value = static_cast<PrecisionConfig::Precision>(i);
(*map)[PrecisionToString(value)] = value;
}
}
return map;
}();
auto found = map->find(absl::AsciiStrToLower(name));
if (found == map->end()) {
return InvalidArgument("Unknown distribution");
}
return found->second;
}
std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) {
return os << ToString(kind);
}
bool HloPtrComparator::operator()(const HloInstruction* const& lhs,
const HloInstruction* const& rhs) const {
if (rhs == nullptr) {
// Nothing compares less than nullptr.
return false;
}
if (lhs == nullptr) {
return true;
}
auto lhs_module = lhs->GetModule();
auto rhs_module = rhs->GetModule();
CHECK((lhs_module == nullptr && rhs_module == nullptr) ||
(lhs_module != nullptr && rhs_module != nullptr));
if (lhs_module != nullptr &&
lhs_module->unique_id() != rhs_module->unique_id()) {
return lhs_module->unique_id() < rhs_module->unique_id();
}
return lhs->unique_id() < rhs->unique_id();
}
bool HloInstruction::CouldBeBitcast() const {
switch (opcode_) {
case HloOpcode::kTranspose:
return true;
case HloOpcode::kReshape:
return std::get<0>(ReshapeMerelyInsertsOrDeletes1SizedDimensions());
default:
return false;
}
}
Status HloInstruction::GetBackendConfigInternal(
tensorflow::protobuf::Message* proto) const {
proto->Clear();
// Empty string does not parse as valid JSON, but it's a valid backend config,
// corresponding to the empty proto.
if (backend_config_.empty()) {
return Status::OK();
}
return tensorflow::HumanReadableJsonToProto(backend_config_, proto);
}
Status HloInstruction::set_backend_config(
const tensorflow::protobuf::Message& proto) {
TF_ASSIGN_OR_RETURN(backend_config_, BackendConfigToRawString(proto));
return Status::OK();
}
/* static */ StatusOr<string> HloInstruction::BackendConfigToRawString(
const tensorflow::protobuf::Message& proto) {
string ret;
TF_RETURN_IF_ERROR(tensorflow::ProtoToHumanReadableJson(proto, &ret));
return ret;
}
const PrecisionConfig& HloInstruction::precision_config() const {
if (auto* convolution = DynCast<HloConvolutionInstruction>(this)) {
return convolution->precision_config();
}
if (auto* dot = DynCast<HloDotInstruction>(this)) {
return dot->precision_config();
}
LOG(FATAL) << "Unimplemented method.";
}
HloModule* HloInstruction::GetModule() const {
if (parent_) {
return parent_->parent();
}
return nullptr;
}
void HloInstruction::UniquifyName(NameUniquer* name_uniquer) {
string parent_str = parent() == nullptr ? "noparent" : parent()->name();
name_ = name_uniquer->GetUniqueName(name_);
}
void HloInstruction::set_outer_dimension_partitions(
const std::vector<int64>& outer_dimension_partitions) {
outer_dimension_partitions_ = outer_dimension_partitions;
}
// TODO(b/80131774): Remove these temporary methods after transition.
int64 HloInstruction::feature_index() const {
return Cast<HloBatchNormInstruction>(this)->feature_index();
}
float HloInstruction::epsilon() const {
return Cast<HloBatchNormInstruction>(this)->epsilon();
}
FftType HloInstruction::fft_type() const {
return Cast<HloFftInstruction>(this)->fft_type();
}
const std::vector<int64>& HloInstruction::fft_length() const {
return Cast<HloFftInstruction>(this)->fft_length();
}
int64 HloInstruction::channel_id() const {
return Cast<HloSendRecvInstruction>(this)->channel_id();
}
int64 HloInstruction::concatenate_dimension() const {
return Cast<HloConcatenateInstruction>(this)->concatenate_dimension();
}
bool HloInstruction::IsRank2Transpose() const {
auto transpose = DynCast<HloTransposeInstruction>(this);
return transpose != nullptr && transpose->IsRank2Transpose();
}
int64 HloInstruction::slice_starts(int64 dimension) const {
return Cast<HloSliceInstruction>(this)->slice_starts(dimension);
}
const std::vector<int64>& HloInstruction::slice_starts() const {
return Cast<HloSliceInstruction>(this)->slice_starts();
}
int64 HloInstruction::slice_limits(int64 dimension) const {
return Cast<HloSliceInstruction>(this)->slice_limits(dimension);
}
const std::vector<int64>& HloInstruction::slice_limits() const {
return Cast<HloSliceInstruction>(this)->slice_limits();
}
int64 HloInstruction::slice_strides(int64 dimension) const {
return Cast<HloSliceInstruction>(this)->slice_strides(dimension);
}
const std::vector<int64>& HloInstruction::slice_strides() const {
return Cast<HloSliceInstruction>(this)->slice_strides();
}
bool HloInstruction::IsInPlaceSlice() const {
return Cast<HloSliceInstruction>(this)->IsInPlaceSlice();
}
const Literal& HloInstruction::literal() const {
return Cast<HloConstantInstruction>(this)->literal();
}
bool HloInstruction::IsConstant() const {
return DynCast<HloConstantInstruction>(this) != nullptr;
}
void HloInstruction::RelayoutConstant(const Layout& new_layout,
const ShapeIndex& shape_index) {
Cast<HloConstantInstruction>(this)->RelayoutConstant(new_layout, shape_index);
}
string HloInstruction::TracingTag() const {
return Cast<HloTraceInstruction>(this)->TracingTag();
}
HloInstruction* HloInstruction::AddFusionOperand(HloInstruction* new_operand) {
return Cast<HloFusionInstruction>(this)->AddFusionOperand(new_operand);
}
// Delegates to HloFusionInstruction::MergeFusionInstruction.
void HloInstruction::MergeFusionInstruction(
HloInstruction* instruction_to_merge) {
return Cast<HloFusionInstruction>(this)->MergeFusionInstruction(
Cast<HloFusionInstruction>(instruction_to_merge));
}
// Delegates to HloFusionInstruction::MergeFusionInstructionIntoMultiOutput.
void HloInstruction::MergeFusionInstructionIntoMultiOutput(
HloInstruction* instruction_to_merge) {
return Cast<HloFusionInstruction>(this)
->MergeFusionInstructionIntoMultiOutput(
Cast<HloFusionInstruction>(instruction_to_merge));
}
HloInstruction* HloInstruction::FuseInstruction(
HloInstruction* instruction_to_fuse) {
return Cast<HloFusionInstruction>(this)->FuseInstruction(instruction_to_fuse);
}
HloInstruction* HloInstruction::FuseInstructionIntoMultiOutput(
HloInstruction* instruction_to_fuse) {
return Cast<HloFusionInstruction>(this)->FuseInstructionIntoMultiOutput(
instruction_to_fuse);
}
HloComputation* HloInstruction::fused_instructions_computation() const {
return Cast<HloFusionInstruction>(this)->fused_instructions_computation();
}
HloInstruction* HloInstruction::fused_expression_root() const {
return Cast<HloFusionInstruction>(this)->fused_expression_root();
}
const tensorflow::gtl::iterator_range<UnwrappingIterator<
std::list<std::unique_ptr<HloInstruction>>::const_iterator>>
HloInstruction::fused_instructions() const {
return Cast<HloFusionInstruction>(this)->fused_instructions();
}
const tensorflow::gtl::iterator_range<
UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>
HloInstruction::fused_instructions() {
return Cast<HloFusionInstruction>(this)->fused_instructions();
}
int64 HloInstruction::fused_instruction_count() const {
return Cast<HloFusionInstruction>(this)->fused_instruction_count();
}
HloInstruction* HloInstruction::fused_parameter(int64 parameter_number) const {
return Cast<HloFusionInstruction>(this)->fused_parameter(parameter_number);
}
const std::vector<HloInstruction*>& HloInstruction::fused_parameters() const {
return Cast<HloFusionInstruction>(this)->fused_parameters();
}
const bool HloInstruction::IsMultiOutputFusion() const {
const HloFusionInstruction* fusion = DynCast<HloFusionInstruction>(this);
return fusion != nullptr && fusion->IsMultiOutputFusion();
}
HloInstruction::FusionKind HloInstruction::fusion_kind() const {
return Cast<HloFusionInstruction>(this)->fusion_kind();
}
void HloInstruction::set_fusion_kind(FusionKind kind) {
return Cast<HloFusionInstruction>(this)->set_fusion_kind(kind);
}
RandomDistribution HloInstruction::random_distribution() const {
return Cast<HloRngInstruction>(this)->random_distribution();
}
int64 HloInstruction::parameter_number() const {
return Cast<HloParameterInstruction>(this)->parameter_number();
}
int64 HloInstruction::tuple_index() const {
return Cast<HloGetTupleElementInstruction>(this)->tuple_index();
}
int32 HloInstruction::exponent_bits() const {
return Cast<HloReducePrecisionInstruction>(this)->exponent_bits();
}
int32 HloInstruction::mantissa_bits() const {
return Cast<HloReducePrecisionInstruction>(this)->mantissa_bits();
}
string HloInstruction::infeed_config() const {
return Cast<HloInfeedInstruction>(this)->infeed_config();
}
void HloInstruction::set_infeed_config(const string& config) {
return Cast<HloInfeedInstruction>(this)->set_infeed_config(config);
}
const Shape& HloInstruction::outfeed_shape() const {
return Cast<HloOutfeedInstruction>(this)->outfeed_shape();
}
const string& HloInstruction::outfeed_config() const {
return Cast<HloOutfeedInstruction>(this)->outfeed_config();
}
const std::vector<ReplicaGroup>& HloInstruction::replica_groups() const {
return Cast<HloCollectiveInstruction>(this)->replica_groups();
}
const std::vector<std::pair<int64, int64>>&
HloInstruction::source_target_pairs() const {
return Cast<HloCollectivePermuteInstruction>(this)->source_target_pairs();
}
string HloInstruction::cross_replica_sum_barrier() const {
return Cast<HloAllReduceInstruction>(this)->cross_replica_sum_barrier();
}
void HloInstruction::set_cross_replica_sum_barrier(const string& barrier) {
return Cast<HloAllReduceInstruction>(this)->set_cross_replica_sum_barrier(
barrier);
}
absl::optional<int64> HloInstruction::all_reduce_id() const {
return Cast<HloAllReduceInstruction>(this)->all_reduce_id();
}
const ConvolutionDimensionNumbers&
HloInstruction::convolution_dimension_numbers() const {
if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
return convolution->convolution_dimension_numbers();
}
if (auto custom_call = DynCast<HloCustomCallInstruction>(this)) {
return custom_call->convolution_dimension_numbers();
}
LOG(FATAL) << "Unimplemented method.";
}
void HloInstruction::set_convolution_dimension_numbers(
const ConvolutionDimensionNumbers& dnums) {
if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
convolution->set_convolution_dimension_numbers(dnums);
} else if (auto custom_call = DynCast<HloCustomCallInstruction>(this)) {
custom_call->set_convolution_dimension_numbers(dnums);
} else {
LOG(FATAL) << "Unimplemented method.";
}
}
int64 HloInstruction::feature_group_count() const {
if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
return convolution->feature_group_count();
}
return Cast<HloCustomCallInstruction>(this)->feature_group_count();
}
void HloInstruction::set_feature_group_count(int64 feature_group_count) {
Cast<HloCustomCallInstruction>(this)->set_feature_group_count(
feature_group_count);
}
HloComputation* HloInstruction::select() const {
return Cast<HloSelectAndScatterInstruction>(this)->select();
}
HloComputation* HloInstruction::scatter() const {
return Cast<HloSelectAndScatterInstruction>(this)->scatter();
}
void HloInstruction::set_select(HloComputation* computation) {
return Cast<HloSelectAndScatterInstruction>(this)->set_select(computation);
}
void HloInstruction::set_scatter(HloComputation* computation) {
return Cast<HloSelectAndScatterInstruction>(this)->set_scatter(computation);
}
const string& HloInstruction::custom_call_target() const {
return Cast<HloCustomCallInstruction>(this)->custom_call_target();
}
const PaddingConfig& HloInstruction::padding_config() const {
return Cast<HloPadInstruction>(this)->padding_config();
}
int64 HloInstruction::slice_sizes(int64 dimension) const {
return Cast<HloDynamicSliceInstruction>(this)->slice_sizes(dimension);
}
const std::vector<int64>& HloInstruction::dynamic_slice_sizes() const {
return Cast<HloDynamicSliceInstruction>(this)->dynamic_slice_sizes();
}
const GatherDimensionNumbers& HloInstruction::gather_dimension_numbers() const {
return Cast<HloGatherInstruction>(this)->gather_dimension_numbers();
}
absl::Span<const int64> HloInstruction::gather_slice_sizes() const {
return Cast<HloGatherInstruction>(this)->gather_slice_sizes();
}
const ScatterDimensionNumbers& HloInstruction::scatter_dimension_numbers()
const {
return Cast<HloScatterInstruction>(this)->scatter_dimension_numbers();
}
const DotDimensionNumbers& HloInstruction::dot_dimension_numbers() const {
return Cast<HloDotInstruction>(this)->dot_dimension_numbers();
}
const DomainMetadata& HloInstruction::operand_side_metadata() const {
return Cast<HloDomainInstruction>(this)->operand_side_metadata();
}
const DomainMetadata& HloInstruction::user_side_metadata() const {
return Cast<HloDomainInstruction>(this)->user_side_metadata();
}
} // namespace xla